import itertools
import numpy as np
import pyproj
import rasterio.transform
import sparse
import xarray as xr
import xesmf as xe
from common import earthaccess_args, target_extent
from icechunk import IcechunkStore, StorageConfig
Pre-generate weights for resampling with XESMF or sparse
def make_grid_ds(*, te, tilesize, dstSRS) -> xr.Dataset:
"""
Make a dataset representing a target grid
Returns
-------
xr.Dataset
Target grid dataset with the following variables:
- "x": X coordinate in Web Mercator projection (grid cell center)
- "y": Y coordinate in Web Mercator projection (grid cell center)
- "lat": latitude coordinate (grid cell center)
- "lon": longitude coordinate (grid cell center)
- "lat_b": latitude bounds for grid cell
- "lon_b": longitude bounds for grid cell
Notes
-----
Modified from ndpyramid - https://github.com/carbonplan/ndpyramid
"""
= rasterio.transform.Affine.translation(
transform 0], te[3]
te[* rasterio.transform.Affine.scale((te[2] * 2) / tilesize, (te[1] * 2) / tilesize)
)
= pyproj.Proj(dstSRS)
p
= (tilesize, tilesize)
grid_shape = (tilesize + 1, tilesize + 1)
bounds_shape
= np.empty(grid_shape)
xs = np.empty(grid_shape)
ys = np.empty(grid_shape)
lat = np.empty(grid_shape)
lon = np.zeros(bounds_shape)
lat_b = np.zeros(bounds_shape)
lon_b
# calc grid cell center coordinates
= np.meshgrid(np.arange(tilesize) + 0.5, np.arange(tilesize) + 0.5)
ii, jj for i, j in itertools.product(range(grid_shape[0]), range(grid_shape[1])):
= [ii[i, j], jj[i, j]]
locs = transform * locs
xs[i, j], ys[i, j] = p(xs[i, j], ys[i, j], inverse=True)
lon[i, j], lat[i, j]
# calc grid cell bounds
= np.meshgrid(np.arange(tilesize + 1), np.arange(tilesize + 1))
iib, jjb for i, j in itertools.product(range(bounds_shape[0]), range(bounds_shape[1])):
= [iib[i, j], jjb[i, j]]
locs = transform * locs
x, y = p(x, y, inverse=True)
lon_b[i, j], lat_b[i, j]
= xr.DataArray(
latitude 0],
lat[:, ="y",
dims=dict(
attrs="latitude",
standard_name="Latitude",
long_name="degrees_north",
units="X",
axis
),
)= xr.DataArray(
longitude 0, :],
lon[="x",
dims=dict(
attrs="longitude",
standard_name="Longitude",
long_name="degrees_east",
units="Y",
axis
),
)
return xr.Dataset(
{"lat_b": xr.DataArray(lat_b, dims=["y_b", "x_b"]),
"lon_b": xr.DataArray(lon_b, dims=["y_b", "x_b"]),
},
{"latitude": latitude,
"longitude": longitude,
},
)
def xesmf_weights_to_xarray(regridder) -> xr.Dataset:
"""
Construct an xarray dataset from XESMF weights
Notes
-----
From ndpyramid - https://github.com/carbonplan/ndpyramid
"""
= regridder.weights.data
w = "n_s"
dim = xr.Dataset(
ds
{"S": (dim, w.data),
"col": (dim, w.coords[1, :] + 1),
"row": (dim, w.coords[0, :] + 1),
}
)= {"n_in": regridder.n_in, "n_out": regridder.n_out}
ds.attrs return ds
def _reconstruct_xesmf_weights(ds_w):
"""
Reconstruct weights into format that xESMF understands
Notes
-----
From ndpyramid - https://github.com/carbonplan/ndpyramid
"""
= ds_w["col"].values - 1
col = ds_w["row"].values - 1
row = ds_w["S"].values
s = ds_w.attrs["n_out"], ds_w.attrs["n_in"]
n_out, n_in = np.stack([row, col])
crds return xr.DataArray(
=("out_dim", "in_dim"), name="weights"
sparse.COO(crds, s, (n_out, n_in)), dims )
def generate_weights(dataset, zoom):
= target_extent[zoom]
te
# Define filepath, driver, and variable information
= earthaccess_args[dataset]
args # Create icechunk repos for caching weights and target grid
= StorageConfig.s3_from_env(
weights_storage ="nasa-veda-scratch",
bucket=f"resampling/test-weight-caching/{dataset}-weights-{zoom}",
prefix="us-west-2",
region
)= StorageConfig.s3_from_env(
target_storage ="nasa-veda-scratch",
bucket=f"resampling/test-weight-caching/{dataset}-target-{zoom}",
prefix="us-west-2",
region
)= IcechunkStore.open_or_create(storage=weights_storage, mode="w")
weights_store = IcechunkStore.open_or_create(storage=target_storage, mode="w")
target_store # Create target grid
= make_grid_ds(te=te, tilesize=256, dstSRS="EPSG:3857")
target_grid # Open dataset
= StorageConfig.s3_from_env(
storage ="nasa-veda-scratch",
bucket=f"resampling/icechunk/{dataset}",
prefix="us-west-2",
region
)= IcechunkStore.open_existing(storage=storage, mode="r")
store = xr.open_zarr(store, zarr_format=3, consolidated=False)[args["variable"]]
da # Chunk target grid for parallel weights generations
= 128
output_chunk_size = target_grid.chunk(
target_grid
{"x": output_chunk_size,
"y": output_chunk_size,
"y_b": output_chunk_size,
"x_b": output_chunk_size,
}
)# Create XESMF regridder
= xe.Regridder(
regridder
da,
target_grid,"nearest_s2d",
=True,
periodic="nearest_s2d",
extrap_method=True,
ignore_degenerate=True,
parallel
)# Convert weigts to a dataset
= xesmf_weights_to_xarray(regridder)
weights # Store weights using icechunk
=3, consolidated=False)
weights.to_zarr(weights_store, zarr_format# Commit data to icechunk stores
"Store weights")
weights_store.commit(# Store target grid using icechunk
=3, consolidated=False)
target_grid.load().to_zarr(target_store, zarr_format"Generate target grid")
target_store.commit(# Store weights using Zarr
= f"s3://nasa-veda-scratch/resampling/test-weight-caching/{dataset}-weights-{zoom}.zarr"
output ="w", storage_options={"use_listings_cache": False})
weights.to_zarr(output, mode# Store target grid using Zarr
= f"s3://nasa-veda-scratch/resampling/test-weight-caching/{dataset}-target-{zoom}.zarr"
output ="w", storage_options={"use_listings_cache": False}) target_grid.to_zarr(output, mode
= "gpm_imerg"
dataset 0) generate_weights(dataset,
1) generate_weights(dataset,
2) generate_weights(dataset,