import argparse
import itertools
import numpy as np
import pyproj
import rasterio.transform
import xarray as xr
import xesmf as xe
from icechunk import IcechunkStore, StorageConfig
Resampling with XESMF (S3 storage, NetCDF file, Zarr reader, icechunk virtualization)
def make_grid_ds(*, te, tilesize, dstSRS) -> xr.Dataset:
"""
Make a dataset representing a target grid
Returns
-------
xr.Dataset
Target grid dataset with the following variables:
- "x": X coordinate in Web Mercator projection (grid cell center)
- "y": Y coordinate in Web Mercator projection (grid cell center)
- "lat": latitude coordinate (grid cell center)
- "lon": longitude coordinate (grid cell center)
- "lat_b": latitude bounds for grid cell
- "lon_b": longitude bounds for grid cell
Notes
-----
Modified from ndpyramid - https://github.com/carbonplan/ndpyramid
"""
= rasterio.transform.Affine.translation(
transform 0], te[3]
te[* rasterio.transform.Affine.scale((te[2] * 2) / tilesize, (te[1] * 2) / tilesize)
)
= pyproj.Proj(dstSRS)
p
= (tilesize, tilesize)
grid_shape = (tilesize + 1, tilesize + 1)
bounds_shape
= np.empty(grid_shape)
xs = np.empty(grid_shape)
ys = np.empty(grid_shape)
lat = np.empty(grid_shape)
lon = np.zeros(bounds_shape)
lat_b = np.zeros(bounds_shape)
lon_b
# calc grid cell center coordinates
= np.meshgrid(np.arange(tilesize) + 0.5, np.arange(tilesize) + 0.5)
ii, jj for i, j in itertools.product(range(grid_shape[0]), range(grid_shape[1])):
= [ii[i, j], jj[i, j]]
locs = transform * locs
xs[i, j], ys[i, j] = p(xs[i, j], ys[i, j], inverse=True)
lon[i, j], lat[i, j]
# calc grid cell bounds
= np.meshgrid(np.arange(tilesize + 1), np.arange(tilesize + 1))
iib, jjb for i, j in itertools.product(range(bounds_shape[0]), range(bounds_shape[1])):
= [iib[i, j], jjb[i, j]]
locs = transform * locs
x, y = p(x, y, inverse=True)
lon_b[i, j], lat_b[i, j]
return xr.Dataset(
{"x": xr.DataArray(xs[0, :], dims=["x"]),
"y": xr.DataArray(ys[:, 0], dims=["y"]),
"lat": xr.DataArray(lat, dims=["y", "x"]),
"lon": xr.DataArray(lon, dims=["y", "x"]),
"lat_b": xr.DataArray(lat_b, dims=["y_b", "x_b"]),
"lon_b": xr.DataArray(lon_b, dims=["y_b", "x_b"]),
}, )
def regrid(dataset, zoom=0):
from common import earthaccess_args, target_extent
= target_extent[zoom]
te
# Define filepath, driver, and variable information
= earthaccess_args[dataset]
args # Create grid to hold result
= make_grid_ds(te=te, tilesize=256, dstSRS="EPSG:3857")
target_grid # Open dataset
= StorageConfig.s3_from_env(
storage ="nasa-veda-scratch",
bucket=f"resampling/icechunk/{dataset}-reference",
prefix="us-west-2",
region
)= IcechunkStore.open_existing(storage=storage, mode="r")
store = xr.open_zarr(store, zarr_format=3, consolidated=False)[args["variable"]]
da # Create XESMF regridder
= xe.Regridder(
regridder
da,
target_grid,"nearest_s2d",
=True,
periodic="nearest_s2d",
extrap_method=True,
ignore_degenerate
)# Regrid dataset
return regridder(da).load()
if __name__ == "__main__":
if "get_ipython" in dir():
# Just call warp_resample if running as a Jupyter Notebook
= regrid("gpm_imerg")
da else:
# Configure dataset via argpase if running via CLI
= argparse.ArgumentParser(description="Set environment for the script.")
parser
parser.add_argument("--dataset",
="mursst",
defaulthelp="Dataset to resample.",
=["gpm_imerg", "mursst"],
choices
)
parser.add_argument("--zoom",
=0,
defaulthelp="Zoom level for tile extent.",
)= parser.parse_args()
user_args = regrid(user_args.dataset, int(user_args.zoom)) da