Resampling with XESMF (S3 storage, NetCDF file, H5NetCDF driver, earthaccess auth)

import argparse
import itertools

import earthaccess
import numpy as np
import pyproj
import rasterio.transform
import xarray as xr
import xesmf as xe
def make_grid_ds(*, te, tilesize, dstSRS) -> xr.Dataset:
    """
    Make a dataset representing a target grid

    Returns
    -------
    xr.Dataset
        Target grid dataset with the following variables:
        - "x": X coordinate in Web Mercator projection (grid cell center)
        - "y": Y coordinate in Web Mercator projection (grid cell center)
        - "lat": latitude coordinate (grid cell center)
        - "lon": longitude coordinate (grid cell center)
        - "lat_b": latitude bounds for grid cell
        - "lon_b": longitude bounds for grid cell

    Notes
    -----
    Modified from ndpyramid - https://github.com/carbonplan/ndpyramid
    """

    transform = rasterio.transform.Affine.translation(
        te[0], te[3]
    ) * rasterio.transform.Affine.scale((te[2] * 2) / tilesize, (te[1] * 2) / tilesize)

    p = pyproj.Proj(dstSRS)

    grid_shape = (tilesize, tilesize)
    bounds_shape = (tilesize + 1, tilesize + 1)

    xs = np.empty(grid_shape)
    ys = np.empty(grid_shape)
    lat = np.empty(grid_shape)
    lon = np.empty(grid_shape)
    lat_b = np.zeros(bounds_shape)
    lon_b = np.zeros(bounds_shape)

    # calc grid cell center coordinates
    ii, jj = np.meshgrid(np.arange(tilesize) + 0.5, np.arange(tilesize) + 0.5)
    for i, j in itertools.product(range(grid_shape[0]), range(grid_shape[1])):
        locs = [ii[i, j], jj[i, j]]
        xs[i, j], ys[i, j] = transform * locs
        lon[i, j], lat[i, j] = p(xs[i, j], ys[i, j], inverse=True)

    # calc grid cell bounds
    iib, jjb = np.meshgrid(np.arange(tilesize + 1), np.arange(tilesize + 1))
    for i, j in itertools.product(range(bounds_shape[0]), range(bounds_shape[1])):
        locs = [iib[i, j], jjb[i, j]]
        x, y = transform * locs
        lon_b[i, j], lat_b[i, j] = p(x, y, inverse=True)

    return xr.Dataset(
        {
            "x": xr.DataArray(xs[0, :], dims=["x"]),
            "y": xr.DataArray(ys[:, 0], dims=["y"]),
            "lat": xr.DataArray(lat, dims=["y", "x"]),
            "lon": xr.DataArray(lon, dims=["y", "x"]),
            "lat_b": xr.DataArray(lat_b, dims=["y_b", "x_b"]),
            "lon_b": xr.DataArray(lon_b, dims=["y_b", "x_b"]),
        },
    )
def regrid(dataset, zoom=0):
    from common import earthaccess_args, target_extent

    te = target_extent[zoom]

    # Define filepath, driver, and variable information
    args = earthaccess_args[dataset]
    input_uri = f'{args["folder"]}/{args["filename"]}'
    src = f's3://{args["bucket"]}/{input_uri}'
    # Create grid to hold result
    target_grid = make_grid_ds(te=te, tilesize=256, dstSRS="EPSG:3857")
    # Authenticate with earthaccess
    fs = earthaccess.get_s3_filesystem(daac=args["daac"])
    # Specify fsspec caching since default options don't work well for raster data
    fsspec_caching = {
        "cache_type": "none",
    }
    with fs.open(src, **fsspec_caching) as f:
        # Open dataset
        da = xr.open_dataset(f, engine="h5netcdf", chunks={}, mask_and_scale=True)[
            args["variable"]
        ]
        # Create XESMF regridder
        regridder = xe.Regridder(
            da,
            target_grid,
            "nearest_s2d",
            periodic=True,
            extrap_method="nearest_s2d",
            ignore_degenerate=True,
        )
        # Regrid dataset
        return regridder(da).load()
if __name__ == "__main__":
    if "get_ipython" in dir():
        # Just call warp_resample if running as a Jupyter Notebook
        da = regrid("gpm_imerg")
    else:
        # Configure dataset via argpase if running via CLI
        parser = argparse.ArgumentParser(description="Set environment for the script.")
        parser.add_argument(
            "--dataset",
            default="mursst",
            help="Dataset to resample.",
            choices=["gpm_imerg", "mursst"],
        )
        parser.add_argument(
            "--zoom",
            default=0,
            help="Zoom level for tile extent.",
        )
        user_args = parser.parse_args()
        da = regrid(user_args.dataset, int(user_args.zoom))