import itertools
import numpy as np
import pyproj
import rasterio.transform
import sparse
import xarray as xr
import xesmf as xe
from common import earthaccess_args, target_extent
from icechunk import IcechunkStore, StorageConfigPre-generate weights for resampling with XESMF or sparse
def make_grid_ds(*, te, tilesize, dstSRS) -> xr.Dataset:
"""
Make a dataset representing a target grid
Returns
-------
xr.Dataset
Target grid dataset with the following variables:
- "x": X coordinate in Web Mercator projection (grid cell center)
- "y": Y coordinate in Web Mercator projection (grid cell center)
- "lat": latitude coordinate (grid cell center)
- "lon": longitude coordinate (grid cell center)
- "lat_b": latitude bounds for grid cell
- "lon_b": longitude bounds for grid cell
Notes
-----
Modified from ndpyramid - https://github.com/carbonplan/ndpyramid
"""
transform = rasterio.transform.Affine.translation(
te[0], te[3]
) * rasterio.transform.Affine.scale((te[2] * 2) / tilesize, (te[1] * 2) / tilesize)
p = pyproj.Proj(dstSRS)
grid_shape = (tilesize, tilesize)
bounds_shape = (tilesize + 1, tilesize + 1)
xs = np.empty(grid_shape)
ys = np.empty(grid_shape)
lat = np.empty(grid_shape)
lon = np.empty(grid_shape)
lat_b = np.zeros(bounds_shape)
lon_b = np.zeros(bounds_shape)
# calc grid cell center coordinates
ii, jj = np.meshgrid(np.arange(tilesize) + 0.5, np.arange(tilesize) + 0.5)
for i, j in itertools.product(range(grid_shape[0]), range(grid_shape[1])):
locs = [ii[i, j], jj[i, j]]
xs[i, j], ys[i, j] = transform * locs
lon[i, j], lat[i, j] = p(xs[i, j], ys[i, j], inverse=True)
# calc grid cell bounds
iib, jjb = np.meshgrid(np.arange(tilesize + 1), np.arange(tilesize + 1))
for i, j in itertools.product(range(bounds_shape[0]), range(bounds_shape[1])):
locs = [iib[i, j], jjb[i, j]]
x, y = transform * locs
lon_b[i, j], lat_b[i, j] = p(x, y, inverse=True)
latitude = xr.DataArray(
lat[:, 0],
dims="y",
attrs=dict(
standard_name="latitude",
long_name="Latitude",
units="degrees_north",
axis="X",
),
)
longitude = xr.DataArray(
lon[0, :],
dims="x",
attrs=dict(
standard_name="longitude",
long_name="Longitude",
units="degrees_east",
axis="Y",
),
)
return xr.Dataset(
{
"lat_b": xr.DataArray(lat_b, dims=["y_b", "x_b"]),
"lon_b": xr.DataArray(lon_b, dims=["y_b", "x_b"]),
},
{
"latitude": latitude,
"longitude": longitude,
},
)
def xesmf_weights_to_xarray(regridder) -> xr.Dataset:
"""
Construct an xarray dataset from XESMF weights
Notes
-----
From ndpyramid - https://github.com/carbonplan/ndpyramid
"""
w = regridder.weights.data
dim = "n_s"
ds = xr.Dataset(
{
"S": (dim, w.data),
"col": (dim, w.coords[1, :] + 1),
"row": (dim, w.coords[0, :] + 1),
}
)
ds.attrs = {"n_in": regridder.n_in, "n_out": regridder.n_out}
return ds
def _reconstruct_xesmf_weights(ds_w):
"""
Reconstruct weights into format that xESMF understands
Notes
-----
From ndpyramid - https://github.com/carbonplan/ndpyramid
"""
col = ds_w["col"].values - 1
row = ds_w["row"].values - 1
s = ds_w["S"].values
n_out, n_in = ds_w.attrs["n_out"], ds_w.attrs["n_in"]
crds = np.stack([row, col])
return xr.DataArray(
sparse.COO(crds, s, (n_out, n_in)), dims=("out_dim", "in_dim"), name="weights"
)def generate_weights(dataset, zoom):
te = target_extent[zoom]
# Define filepath, driver, and variable information
args = earthaccess_args[dataset]
# Create icechunk repos for caching weights and target grid
weights_storage = StorageConfig.s3_from_env(
bucket="nasa-veda-scratch",
prefix=f"resampling/test-weight-caching/{dataset}-weights-{zoom}",
region="us-west-2",
)
target_storage = StorageConfig.s3_from_env(
bucket="nasa-veda-scratch",
prefix=f"resampling/test-weight-caching/{dataset}-target-{zoom}",
region="us-west-2",
)
weights_store = IcechunkStore.open_or_create(storage=weights_storage, mode="w")
target_store = IcechunkStore.open_or_create(storage=target_storage, mode="w")
# Create target grid
target_grid = make_grid_ds(te=te, tilesize=256, dstSRS="EPSG:3857")
# Open dataset
storage = StorageConfig.s3_from_env(
bucket="nasa-veda-scratch",
prefix=f"resampling/icechunk/{dataset}",
region="us-west-2",
)
store = IcechunkStore.open_existing(storage=storage, mode="r")
da = xr.open_zarr(store, zarr_format=3, consolidated=False)[args["variable"]]
# Chunk target grid for parallel weights generations
output_chunk_size = 128
target_grid = target_grid.chunk(
{
"x": output_chunk_size,
"y": output_chunk_size,
"y_b": output_chunk_size,
"x_b": output_chunk_size,
}
)
# Create XESMF regridder
regridder = xe.Regridder(
da,
target_grid,
"nearest_s2d",
periodic=True,
extrap_method="nearest_s2d",
ignore_degenerate=True,
parallel=True,
)
# Convert weigts to a dataset
weights = xesmf_weights_to_xarray(regridder)
# Store weights using icechunk
weights.to_zarr(weights_store, zarr_format=3, consolidated=False)
# Commit data to icechunk stores
weights_store.commit("Store weights")
# Store target grid using icechunk
target_grid.load().to_zarr(target_store, zarr_format=3, consolidated=False)
target_store.commit("Generate target grid")
# Store weights using Zarr
output = f"s3://nasa-veda-scratch/resampling/test-weight-caching/{dataset}-weights-{zoom}.zarr"
weights.to_zarr(output, mode="w", storage_options={"use_listings_cache": False})
# Store target grid using Zarr
output = f"s3://nasa-veda-scratch/resampling/test-weight-caching/{dataset}-target-{zoom}.zarr"
target_grid.to_zarr(output, mode="w", storage_options={"use_listings_cache": False})dataset = "gpm_imerg"
generate_weights(dataset, 0)generate_weights(dataset, 1)generate_weights(dataset, 2)