Resampling with sparse (S3 storage, Zarr V3 store, Zarr reader with icechunk)¶

import argparse
import warnings

import numpy as np
import xarray as xr
from common import earthaccess_args
from icechunk import IcechunkStore, StorageConfig
def _reconstruct_xesmf_weights(ds_w):
    """
    Reconstruct weights into format that xESMF understands

    Notes
    -----
    From ndpyramid - https://github.com/carbonplan/ndpyramid
    """
    import sparse
    import xarray as xr

    col = ds_w["col"].values - 1
    row = ds_w["row"].values - 1
    s = ds_w["S"].values
    n_out, n_in = ds_w.attrs["n_out"], ds_w.attrs["n_in"]
    crds = np.stack([row, col])
    return xr.DataArray(
        sparse.COO(crds, s, (n_out, n_in)), dims=("out_dim", "in_dim"), name="weights"
    )
def xr_regridder(
    ds: xr.Dataset,
    grid: xr.Dataset,
    weights: xr.DataArray,
) -> xr.Dataset:
    """
    Xarray-aware regridding function that uses weights from xESMF but performs the regridding using sparse matrix multiplication.

    Parameters
    ----------
    ds
    weights
    out_grid_shape

    Returns
    -------
    regridded_ds

    Notes
    -----
    Modified from https://github.com/carbonplan/ndpyramid/pull/130
    """

    latlon_dims = ["lat", "lon"]

    shape_in = (ds.sizes["lat"], ds.sizes["lon"])
    shape_out = (grid.sizes["y"], grid.sizes["x"])

    regridded_ds = xr.apply_ufunc(
        esmf_apply_weights,
        weights,
        ds,
        input_core_dims=[["out_dim", "in_dim"], latlon_dims],
        output_core_dims=[latlon_dims],
        exclude_dims=set(latlon_dims),
        kwargs={"shape_in": shape_in, "shape_out": shape_out},
        keep_attrs=True,
    )

    return regridded_ds


def esmf_apply_weights(weights, indata, shape_in, shape_out):
    """
    Apply regridding weights to data.
    Parameters
    ----------
    A : scipy sparse COO matrix
    indata : numpy array of shape ``(..., n_lat, n_lon)`` or ``(..., n_y, n_x)``.
        Should be C-ordered. Will be then tranposed to F-ordered.
    shape_in, shape_out : tuple of two integers
        Input/output data shape for unflatten operation.
        For rectilinear grid, it is just ``(n_lat, n_lon)``.
    Returns
    -------
    outdata : numpy array of shape ``(..., shape_out[0], shape_out[1])``.
        Extra dimensions are the same as `indata`.
        If input data is C-ordered, output will also be C-ordered.
    Notes
    -----
    From https://github.com/carbonplan/ndpyramid/pull/130
    """

    # COO matrix is fast with F-ordered array but slow with C-array, so we
    # take in a C-ordered and then transpose)
    # (CSR or CRS matrix is fast with C-ordered array but slow with F-array)
    if not indata.flags["C_CONTIGUOUS"]:
        warnings.warn("Input array is not C_CONTIGUOUS. " "Will affect performance.")

    # get input shape information
    shape_horiz = indata.shape[-2:]
    extra_shape = indata.shape[0:-2]

    assert shape_horiz == shape_in, (
        "The horizontal shape of input data is {}, different from that of"
        "the regridder {}!".format(shape_horiz, shape_in)
    )

    assert (
        shape_in[0] * shape_in[1] == weights.shape[1]
    ), "ny_in * nx_in should equal to weights.shape[1]"

    assert (
        shape_out[0] * shape_out[1] == weights.shape[0]
    ), "ny_out * nx_out should equal to weights.shape[0]"

    # use flattened array for dot operation
    indata_flat = indata.reshape(-1, shape_in[0] * shape_in[1])
    outdata_flat = weights.dot(indata_flat.T).T

    # unflattened output array
    outdata = outdata_flat.reshape([*extra_shape, shape_out[0], shape_out[1]])
    return outdata
def regrid(dataset, zoom=0):

    args = earthaccess_args[dataset]
    # Load pre-generated weights and target dataset
    weights_storage = StorageConfig.s3_from_env(
        bucket="nasa-veda-scratch",
        prefix=f"resampling/test-weight-caching/{dataset}-weights-{zoom}",
        region="us-west-2",
    )
    target_storage = StorageConfig.s3_from_env(
        bucket="nasa-veda-scratch",
        prefix=f"resampling/test-weight-caching/{dataset}-target-{zoom}",
        region="us-west-2",
    )
    weights_store = IcechunkStore.open_existing(storage=weights_storage, mode="r")
    target_store = IcechunkStore.open_existing(storage=target_storage, mode="r")
    weights = _reconstruct_xesmf_weights(
        xr.open_zarr(weights_store, zarr_format=3, consolidated=False)
    )
    grid = xr.open_zarr(target_store, zarr_format=3, consolidated=False).load()
    # Open dataset
    storage = StorageConfig.s3_from_env(
        bucket="nasa-veda-scratch",
        prefix=f"resampling/icechunk/{dataset}",
        region="us-west-2",
    )
    store = IcechunkStore.open_existing(storage=storage, mode="r")
    da = xr.open_zarr(store, zarr_format=3, consolidated=False)[args["variable"]].load()
    return xr_regridder(da, grid, weights)
if __name__ == "__main__":
    if "get_ipython" in dir():
        # Just call warp_resample if running as a Jupyter Notebook
        da = regrid("gpm_imerg")
    else:
        # Configure dataset via argpase if running via CLI
        parser = argparse.ArgumentParser(description="Set environment for the script.")
        parser.add_argument(
            "--dataset",
            default="gpm_imerg",
            help="Dataset to resample.",
            choices=["gpm_imerg", "mursst"],
        )
        parser.add_argument(
            "--zoom",
            default=0,
            help="Zoom level for tile extent.",
        )
        user_args = parser.parse_args()
        da = regrid(user_args.dataset, int(user_args.zoom))