Resampling with XESMF (S3 storage, NetCDF file, Zarr reader, icechunk virtualization)

import argparse
import itertools

import numpy as np
import pyproj
import rasterio.transform
import xarray as xr
import xesmf as xe
from icechunk import IcechunkStore, StorageConfig
def make_grid_ds(*, te, tilesize, dstSRS) -> xr.Dataset:
    """
    Make a dataset representing a target grid

    Returns
    -------
    xr.Dataset
        Target grid dataset with the following variables:
        - "x": X coordinate in Web Mercator projection (grid cell center)
        - "y": Y coordinate in Web Mercator projection (grid cell center)
        - "lat": latitude coordinate (grid cell center)
        - "lon": longitude coordinate (grid cell center)
        - "lat_b": latitude bounds for grid cell
        - "lon_b": longitude bounds for grid cell

    Notes
    -----
    Modified from ndpyramid - https://github.com/carbonplan/ndpyramid
    """

    transform = rasterio.transform.Affine.translation(
        te[0], te[3]
    ) * rasterio.transform.Affine.scale((te[2] * 2) / tilesize, (te[1] * 2) / tilesize)

    p = pyproj.Proj(dstSRS)

    grid_shape = (tilesize, tilesize)
    bounds_shape = (tilesize + 1, tilesize + 1)

    xs = np.empty(grid_shape)
    ys = np.empty(grid_shape)
    lat = np.empty(grid_shape)
    lon = np.empty(grid_shape)
    lat_b = np.zeros(bounds_shape)
    lon_b = np.zeros(bounds_shape)

    # calc grid cell center coordinates
    ii, jj = np.meshgrid(np.arange(tilesize) + 0.5, np.arange(tilesize) + 0.5)
    for i, j in itertools.product(range(grid_shape[0]), range(grid_shape[1])):
        locs = [ii[i, j], jj[i, j]]
        xs[i, j], ys[i, j] = transform * locs
        lon[i, j], lat[i, j] = p(xs[i, j], ys[i, j], inverse=True)

    # calc grid cell bounds
    iib, jjb = np.meshgrid(np.arange(tilesize + 1), np.arange(tilesize + 1))
    for i, j in itertools.product(range(bounds_shape[0]), range(bounds_shape[1])):
        locs = [iib[i, j], jjb[i, j]]
        x, y = transform * locs
        lon_b[i, j], lat_b[i, j] = p(x, y, inverse=True)

    return xr.Dataset(
        {
            "x": xr.DataArray(xs[0, :], dims=["x"]),
            "y": xr.DataArray(ys[:, 0], dims=["y"]),
            "lat": xr.DataArray(lat, dims=["y", "x"]),
            "lon": xr.DataArray(lon, dims=["y", "x"]),
            "lat_b": xr.DataArray(lat_b, dims=["y_b", "x_b"]),
            "lon_b": xr.DataArray(lon_b, dims=["y_b", "x_b"]),
        },
    )
def regrid(dataset, zoom=0):
    from common import earthaccess_args, target_extent

    te = target_extent[zoom]

    # Define filepath, driver, and variable information
    args = earthaccess_args[dataset]
    # Create grid to hold result
    target_grid = make_grid_ds(te=te, tilesize=256, dstSRS="EPSG:3857")
    # Open dataset
    storage = StorageConfig.s3_from_env(
        bucket="nasa-veda-scratch",
        prefix=f"resampling/icechunk/{dataset}-reference",
        region="us-west-2",
    )
    store = IcechunkStore.open_existing(storage=storage, mode="r")
    da = xr.open_zarr(store, zarr_format=3, consolidated=False)[args["variable"]]
    # Create XESMF regridder
    regridder = xe.Regridder(
        da,
        target_grid,
        "nearest_s2d",
        periodic=True,
        extrap_method="nearest_s2d",
        ignore_degenerate=True,
    )
    # Regrid dataset
    return regridder(da).load()
if __name__ == "__main__":
    if "get_ipython" in dir():
        # Just call warp_resample if running as a Jupyter Notebook
        da = regrid("gpm_imerg")
    else:
        # Configure dataset via argpase if running via CLI
        parser = argparse.ArgumentParser(description="Set environment for the script.")
        parser.add_argument(
            "--dataset",
            default="mursst",
            help="Dataset to resample.",
            choices=["gpm_imerg", "mursst"],
        )
        parser.add_argument(
            "--zoom",
            default=0,
            help="Zoom level for tile extent.",
        )
        user_args = parser.parse_args()
        da = regrid(user_args.dataset, int(user_args.zoom))