Pre-generate weights for resampling with XESMF or sparse

import itertools

import numpy as np
import pyproj
import rasterio.transform
import sparse
import xarray as xr
import xesmf as xe
from common import earthaccess_args, target_extent
from icechunk import IcechunkStore, StorageConfig
def make_grid_ds(*, te, tilesize, dstSRS) -> xr.Dataset:
    """
    Make a dataset representing a target grid

    Returns
    -------
    xr.Dataset
        Target grid dataset with the following variables:
        - "x": X coordinate in Web Mercator projection (grid cell center)
        - "y": Y coordinate in Web Mercator projection (grid cell center)
        - "lat": latitude coordinate (grid cell center)
        - "lon": longitude coordinate (grid cell center)
        - "lat_b": latitude bounds for grid cell
        - "lon_b": longitude bounds for grid cell

    Notes
    -----
    Modified from ndpyramid - https://github.com/carbonplan/ndpyramid
    """

    transform = rasterio.transform.Affine.translation(
        te[0], te[3]
    ) * rasterio.transform.Affine.scale((te[2] * 2) / tilesize, (te[1] * 2) / tilesize)

    p = pyproj.Proj(dstSRS)

    grid_shape = (tilesize, tilesize)
    bounds_shape = (tilesize + 1, tilesize + 1)

    xs = np.empty(grid_shape)
    ys = np.empty(grid_shape)
    lat = np.empty(grid_shape)
    lon = np.empty(grid_shape)
    lat_b = np.zeros(bounds_shape)
    lon_b = np.zeros(bounds_shape)

    # calc grid cell center coordinates
    ii, jj = np.meshgrid(np.arange(tilesize) + 0.5, np.arange(tilesize) + 0.5)
    for i, j in itertools.product(range(grid_shape[0]), range(grid_shape[1])):
        locs = [ii[i, j], jj[i, j]]
        xs[i, j], ys[i, j] = transform * locs
        lon[i, j], lat[i, j] = p(xs[i, j], ys[i, j], inverse=True)

    # calc grid cell bounds
    iib, jjb = np.meshgrid(np.arange(tilesize + 1), np.arange(tilesize + 1))
    for i, j in itertools.product(range(bounds_shape[0]), range(bounds_shape[1])):
        locs = [iib[i, j], jjb[i, j]]
        x, y = transform * locs
        lon_b[i, j], lat_b[i, j] = p(x, y, inverse=True)

    latitude = xr.DataArray(
        lat[:, 0],
        dims="y",
        attrs=dict(
            standard_name="latitude",
            long_name="Latitude",
            units="degrees_north",
            axis="X",
        ),
    )
    longitude = xr.DataArray(
        lon[0, :],
        dims="x",
        attrs=dict(
            standard_name="longitude",
            long_name="Longitude",
            units="degrees_east",
            axis="Y",
        ),
    )

    return xr.Dataset(
        {
            "lat_b": xr.DataArray(lat_b, dims=["y_b", "x_b"]),
            "lon_b": xr.DataArray(lon_b, dims=["y_b", "x_b"]),
        },
        {
            "latitude": latitude,
            "longitude": longitude,
        },
    )


def xesmf_weights_to_xarray(regridder) -> xr.Dataset:
    """
    Construct an xarray dataset from XESMF weights

    Notes
    -----
    From ndpyramid - https://github.com/carbonplan/ndpyramid
    """
    w = regridder.weights.data
    dim = "n_s"
    ds = xr.Dataset(
        {
            "S": (dim, w.data),
            "col": (dim, w.coords[1, :] + 1),
            "row": (dim, w.coords[0, :] + 1),
        }
    )
    ds.attrs = {"n_in": regridder.n_in, "n_out": regridder.n_out}
    return ds


def _reconstruct_xesmf_weights(ds_w):
    """
    Reconstruct weights into format that xESMF understands

    Notes
    -----
    From ndpyramid - https://github.com/carbonplan/ndpyramid
    """

    col = ds_w["col"].values - 1
    row = ds_w["row"].values - 1
    s = ds_w["S"].values
    n_out, n_in = ds_w.attrs["n_out"], ds_w.attrs["n_in"]
    crds = np.stack([row, col])
    return xr.DataArray(
        sparse.COO(crds, s, (n_out, n_in)), dims=("out_dim", "in_dim"), name="weights"
    )
def generate_weights(dataset, zoom):
    te = target_extent[zoom]

    # Define filepath, driver, and variable information
    args = earthaccess_args[dataset]
    # Create icechunk repos for caching weights and target grid
    weights_storage = StorageConfig.s3_from_env(
        bucket="nasa-veda-scratch",
        prefix=f"resampling/test-weight-caching/{dataset}-weights-{zoom}",
        region="us-west-2",
    )
    target_storage = StorageConfig.s3_from_env(
        bucket="nasa-veda-scratch",
        prefix=f"resampling/test-weight-caching/{dataset}-target-{zoom}",
        region="us-west-2",
    )
    weights_store = IcechunkStore.open_or_create(storage=weights_storage, mode="w")
    target_store = IcechunkStore.open_or_create(storage=target_storage, mode="w")
    # Create target grid
    target_grid = make_grid_ds(te=te, tilesize=256, dstSRS="EPSG:3857")
    # Open dataset
    storage = StorageConfig.s3_from_env(
        bucket="nasa-veda-scratch",
        prefix=f"resampling/icechunk/{dataset}",
        region="us-west-2",
    )
    store = IcechunkStore.open_existing(storage=storage, mode="r")
    da = xr.open_zarr(store, zarr_format=3, consolidated=False)[args["variable"]]
    # Chunk target grid for parallel weights generations
    output_chunk_size = 128
    target_grid = target_grid.chunk(
        {
            "x": output_chunk_size,
            "y": output_chunk_size,
            "y_b": output_chunk_size,
            "x_b": output_chunk_size,
        }
    )
    # Create XESMF regridder
    regridder = xe.Regridder(
        da,
        target_grid,
        "nearest_s2d",
        periodic=True,
        extrap_method="nearest_s2d",
        ignore_degenerate=True,
        parallel=True,
    )
    # Convert weigts to a dataset
    weights = xesmf_weights_to_xarray(regridder)
    # Store weights using icechunk
    weights.to_zarr(weights_store, zarr_format=3, consolidated=False)
    # Commit data to icechunk stores
    weights_store.commit("Store weights")
    # Store target grid using icechunk
    target_grid.load().to_zarr(target_store, zarr_format=3, consolidated=False)
    target_store.commit("Generate target grid")
    # Store weights using Zarr
    output = f"s3://nasa-veda-scratch/resampling/test-weight-caching/{dataset}-weights-{zoom}.zarr"
    weights.to_zarr(output, mode="w", storage_options={"use_listings_cache": False})
    # Store target grid using Zarr
    output = f"s3://nasa-veda-scratch/resampling/test-weight-caching/{dataset}-target-{zoom}.zarr"
    target_grid.to_zarr(output, mode="w", storage_options={"use_listings_cache": False})
dataset = "gpm_imerg"
generate_weights(dataset, 0)
generate_weights(dataset, 1)
generate_weights(dataset, 2)