Resampling with XESMF (local storage, NetCDF file, H5NetCDF driver)

import argparse
import itertools

import fsspec
import numpy as np
import pyproj
import rasterio.transform
import xarray as xr
import xesmf as xe
def make_grid_ds(*, te, tilesize, dstSRS) -> xr.Dataset:
    """
    Make a dataset representing a target grid

    Returns
    -------
    xr.Dataset
        Target grid dataset with the following variables:
        - "x": X coordinate in Web Mercator projection (grid cell center)
        - "y": Y coordinate in Web Mercator projection (grid cell center)
        - "lat": latitude coordinate (grid cell center)
        - "lon": longitude coordinate (grid cell center)
        - "lat_b": latitude bounds for grid cell
        - "lon_b": longitude bounds for grid cell

    Notes
    -----
    Modified from ndpyramid - https://github.com/carbonplan/ndpyramid
    """

    transform = rasterio.transform.Affine.translation(
        te[0], te[3]
    ) * rasterio.transform.Affine.scale((te[2] * 2) / tilesize, (te[1] * 2) / tilesize)

    p = pyproj.Proj(dstSRS)

    grid_shape = (tilesize, tilesize)
    bounds_shape = (tilesize + 1, tilesize + 1)

    xs = np.empty(grid_shape)
    ys = np.empty(grid_shape)
    lat = np.empty(grid_shape)
    lon = np.empty(grid_shape)
    lat_b = np.zeros(bounds_shape)
    lon_b = np.zeros(bounds_shape)

    # calc grid cell center coordinates
    ii, jj = np.meshgrid(np.arange(tilesize) + 0.5, np.arange(tilesize) + 0.5)
    for i, j in itertools.product(range(grid_shape[0]), range(grid_shape[1])):
        locs = [ii[i, j], jj[i, j]]
        xs[i, j], ys[i, j] = transform * locs
        lon[i, j], lat[i, j] = p(xs[i, j], ys[i, j], inverse=True)

    # calc grid cell bounds
    iib, jjb = np.meshgrid(np.arange(tilesize + 1), np.arange(tilesize + 1))
    for i, j in itertools.product(range(bounds_shape[0]), range(bounds_shape[1])):
        locs = [iib[i, j], jjb[i, j]]
        x, y = transform * locs
        lon_b[i, j], lat_b[i, j] = p(x, y, inverse=True)

    return xr.Dataset(
        {
            "x": xr.DataArray(xs[0, :], dims=["x"]),
            "y": xr.DataArray(ys[:, 0], dims=["y"]),
            "lat": xr.DataArray(lat, dims=["y", "x"]),
            "lon": xr.DataArray(lon, dims=["y", "x"]),
            "lat_b": xr.DataArray(lat_b, dims=["y_b", "x_b"]),
            "lon_b": xr.DataArray(lon_b, dims=["y_b", "x_b"]),
        },
    )
def regrid(dataset, zoom=0):
    from common import earthaccess_args, target_extent

    te = target_extent[zoom]

    # Define filepath, driver, and variable information    args = earthaccess_args[dataset]
    args = earthaccess_args[dataset]
    src = f'earthaccess_data/{args["filename"]}'
    # Create grid to hold result
    target_grid = make_grid_ds(te=te, tilesize=256, dstSRS="EPSG:3857")
    # Specify fsspec caching since default options don't work well for raster data
    fsspec_caching = {
        "cache_type": "none",
    }
    fs = fsspec.filesystem("file")
    with fs.open(src, **fsspec_caching) as f:
        # Open dataset
        da = xr.open_dataset(f, engine="h5netcdf", chunks={}, mask_and_scale=True)[
            args["variable"]
        ]
        # Create XESMF regridder
        regridder = xe.Regridder(
            da,
            target_grid,
            "nearest_s2d",
            periodic=True,
            extrap_method="nearest_s2d",
            ignore_degenerate=True,
        )
        # Regrid dataset
        return regridder(da).load()
%%time
if __name__ == "__main__":
    if "get_ipython" in dir():
        # Just call warp_resample if running as a Jupyter Notebook
        da = regrid("gpm_imerg")
    else:
        # Configure dataset via argpase if running via CLI
        parser = argparse.ArgumentParser(description="Set environment for the script.")
        parser.add_argument(
            "--dataset",
            default="mursst",
            help="Dataset to resample.",
            choices=["gpm_imerg", "mursst"],
        )
        parser.add_argument(
            "--zoom",
            default=0,
            help="Zoom level for tile extent.",
        )
        user_args = parser.parse_args()
        da = regrid(user_args.dataset, int(user_args.zoom))