import argparse
import itertools
import numpy as np
import pyproj
import rasterio.transform
import xarray as xr
import xesmf as xe
from icechunk import IcechunkStore, StorageConfigResampling with XESMF (S3 storage, Zarr V3 store, Zarr reader with icechunk)¶
def make_grid_ds(*, te, tilesize, dstSRS) -> xr.Dataset:
"""
Make a dataset representing a target grid
Returns
-------
xr.Dataset
Target grid dataset with the following variables:
- "x": X coordinate in Web Mercator projection (grid cell center)
- "y": Y coordinate in Web Mercator projection (grid cell center)
- "lat": latitude coordinate (grid cell center)
- "lon": longitude coordinate (grid cell center)
- "lat_b": latitude bounds for grid cell
- "lon_b": longitude bounds for grid cell
Notes
-----
Modified from ndpyramid - https://github.com/carbonplan/ndpyramid
"""
transform = rasterio.transform.Affine.translation(
te[0], te[3]
) * rasterio.transform.Affine.scale((te[2] * 2) / tilesize, (te[1] * 2) / tilesize)
p = pyproj.Proj(dstSRS)
grid_shape = (tilesize, tilesize)
bounds_shape = (tilesize + 1, tilesize + 1)
xs = np.empty(grid_shape)
ys = np.empty(grid_shape)
lat = np.empty(grid_shape)
lon = np.empty(grid_shape)
lat_b = np.zeros(bounds_shape)
lon_b = np.zeros(bounds_shape)
# calc grid cell center coordinates
ii, jj = np.meshgrid(np.arange(tilesize) + 0.5, np.arange(tilesize) + 0.5)
for i, j in itertools.product(range(grid_shape[0]), range(grid_shape[1])):
locs = [ii[i, j], jj[i, j]]
xs[i, j], ys[i, j] = transform * locs
lon[i, j], lat[i, j] = p(xs[i, j], ys[i, j], inverse=True)
# calc grid cell bounds
iib, jjb = np.meshgrid(np.arange(tilesize + 1), np.arange(tilesize + 1))
for i, j in itertools.product(range(bounds_shape[0]), range(bounds_shape[1])):
locs = [iib[i, j], jjb[i, j]]
x, y = transform * locs
lon_b[i, j], lat_b[i, j] = p(x, y, inverse=True)
return xr.Dataset(
{
"x": xr.DataArray(xs[0, :], dims=["x"]),
"y": xr.DataArray(ys[:, 0], dims=["y"]),
"lat": xr.DataArray(lat, dims=["y", "x"]),
"lon": xr.DataArray(lon, dims=["y", "x"]),
"lat_b": xr.DataArray(lat_b, dims=["y_b", "x_b"]),
"lon_b": xr.DataArray(lon_b, dims=["y_b", "x_b"]),
},
)def regrid(dataset, zoom=0):
from common import earthaccess_args, target_extent
te = target_extent[zoom]
# Define filepath, driver, and variable information
args = earthaccess_args[dataset]
# Create grid to hold result
target_grid = make_grid_ds(te=te, tilesize=256, dstSRS="EPSG:3857")
# Open dataset
storage = StorageConfig.s3_from_env(
bucket="nasa-veda-scratch",
prefix=f"resampling/icechunk/{dataset}",
region="us-west-2",
)
store = IcechunkStore.open_existing(storage=storage, mode="r")
da = xr.open_zarr(store, zarr_format=3, consolidated=False)[args["variable"]]
# Create XESMF regridder
regridder = xe.Regridder(
da,
target_grid,
"nearest_s2d",
periodic=True,
extrap_method="nearest_s2d",
ignore_degenerate=True,
)
# Regrid dataset
return regridder(da).load()if __name__ == "__main__":
if "get_ipython" in dir():
# Just call warp_resample if running as a Jupyter Notebook
da = regrid("gpm_imerg")
else:
# Configure dataset via argpase if running via CLI
parser = argparse.ArgumentParser(description="Set environment for the script.")
parser.add_argument(
"--dataset",
default="mursst",
help="Dataset to resample.",
choices=["gpm_imerg", "mursst"],
)
parser.add_argument(
"--zoom",
default=0,
help="Zoom level for tile extent.",
)
user_args = parser.parse_args()
da = regrid(user_args.dataset, int(user_args.zoom))