Resampling with XESMF (S3 storage, NetCDF file, Zarr reader, kerchunk virtualization, and pre-generated weights)

import argparse

import earthaccess
import fsspec
import numpy as np
import xarray as xr
import xesmf as xe
def _reconstruct_xesmf_weights(ds_w):
    """
    Reconstruct weights into format that xESMF understands

    Notes
    -----
    From ndpyramid - https://github.com/carbonplan/ndpyramid
    """
    import sparse
    import xarray as xr

    col = ds_w["col"].values - 1
    row = ds_w["row"].values - 1
    s = ds_w["S"].values
    n_out, n_in = ds_w.attrs["n_out"], ds_w.attrs["n_in"]
    crds = np.stack([row, col])
    return xr.DataArray(
        sparse.COO(crds, s, (n_out, n_in)), dims=("out_dim", "in_dim"), name="weights"
    )


def reconstruct_weights(weights_fp):
    """
    Reconstruct weights into format that xESMF understands

    Notes
    -----
    From ndpyramid - https://github.com/carbonplan/ndpyramid
    """
    return _reconstruct_xesmf_weights(xr.open_zarr(weights_fp))
def regrid(dataset, zoom=0):
    from common import earthaccess_args  # noqa: 402

    args = earthaccess_args[dataset]
    # Load pre-generated weights and target dataset
    weights_fp = f"s3://nasa-veda-scratch/resampling/test-weight-caching/{dataset}-weights-{zoom}.zarr"
    target_grid_fp = f"s3://nasa-veda-scratch/resampling/test-weight-caching/{dataset}-target-{zoom}.zarr"
    weights = reconstruct_weights(weights_fp)
    grid = xr.open_zarr(target_grid_fp)
    if dataset == "gpm_imerg":
        src = f'earthaccess_data/{args["filename"][:-4]}.json'
    else:
        src = f'earthaccess_data/{args["filename"][:-3]}.json'
    # Authenticate with earthaccess
    s3_fs = earthaccess.get_s3fs_session(daac=args["daac"])
    storage_options = s3_fs.storage_options.copy()
    # Specify fsspec caching since default options don't work well for raster data
    fsspec_caching = {
        "cache_type": "none",
    }
    # Open dataset using kerchunk
    fs = fsspec.filesystem("reference", fo=src, **fsspec_caching)
    m = fs.get_mapper("")
    da = xr.open_dataset(
        m,
        engine="kerchunk",
        chunks={},
        storage_options=storage_options,
    )[args["variable"]]
    # Create XESMF regridder
    regridder = xe.Regridder(
        da,
        grid,
        "nearest_s2d",
        periodic=True,
        extrap_method="nearest_s2d",
        ignore_degenerate=True,
        reuse_weights=True,
        weights=weights,
    )
    # Regrid dataset
    return regridder(da).load()
if __name__ == "__main__":
    if "get_ipython" in dir():
        # Just call warp_resample if running as a Jupyter Notebook
        da = regrid("gpm_imerg")
    else:
        # Configure dataset via argpase if running via CLI
        parser = argparse.ArgumentParser(description="Set environment for the script.")
        parser.add_argument(
            "--dataset",
            default="gpm_imerg",
            help="Dataset to resample.",
            choices=["gpm_imerg", "mursst"],
        )
        parser.add_argument(
            "--zoom",
            default=0,
            help="Zoom level for tile extent.",
        )
        user_args = parser.parse_args()
        da = regrid(user_args.dataset, int(user_args.zoom))