import argparse
import warnings
import numpy as np
import xarray as xr
from common import earthaccess_args
from icechunk import IcechunkStore, StorageConfig
Resampling with sparse (S3 storage, Zarr V3 store, Zarr reader with icechunk)¶
def _reconstruct_xesmf_weights(ds_w):
"""
Reconstruct weights into format that xESMF understands
Notes
-----
From ndpyramid - https://github.com/carbonplan/ndpyramid
"""
import sparse
import xarray as xr
= ds_w["col"].values - 1
col = ds_w["row"].values - 1
row = ds_w["S"].values
s = ds_w.attrs["n_out"], ds_w.attrs["n_in"]
n_out, n_in = np.stack([row, col])
crds return xr.DataArray(
=("out_dim", "in_dim"), name="weights"
sparse.COO(crds, s, (n_out, n_in)), dims )
def xr_regridder(
ds: xr.Dataset,
grid: xr.Dataset,
weights: xr.DataArray,-> xr.Dataset:
) """
Xarray-aware regridding function that uses weights from xESMF but performs the regridding using sparse matrix multiplication.
Parameters
----------
ds
weights
out_grid_shape
Returns
-------
regridded_ds
Notes
-----
Modified from https://github.com/carbonplan/ndpyramid/pull/130
"""
= ["lat", "lon"]
latlon_dims
= (ds.sizes["lat"], ds.sizes["lon"])
shape_in = (grid.sizes["y"], grid.sizes["x"])
shape_out
= xr.apply_ufunc(
regridded_ds
esmf_apply_weights,
weights,
ds,=[["out_dim", "in_dim"], latlon_dims],
input_core_dims=[latlon_dims],
output_core_dims=set(latlon_dims),
exclude_dims={"shape_in": shape_in, "shape_out": shape_out},
kwargs=True,
keep_attrs
)
return regridded_ds
def esmf_apply_weights(weights, indata, shape_in, shape_out):
"""
Apply regridding weights to data.
Parameters
----------
A : scipy sparse COO matrix
indata : numpy array of shape ``(..., n_lat, n_lon)`` or ``(..., n_y, n_x)``.
Should be C-ordered. Will be then tranposed to F-ordered.
shape_in, shape_out : tuple of two integers
Input/output data shape for unflatten operation.
For rectilinear grid, it is just ``(n_lat, n_lon)``.
Returns
-------
outdata : numpy array of shape ``(..., shape_out[0], shape_out[1])``.
Extra dimensions are the same as `indata`.
If input data is C-ordered, output will also be C-ordered.
Notes
-----
From https://github.com/carbonplan/ndpyramid/pull/130
"""
# COO matrix is fast with F-ordered array but slow with C-array, so we
# take in a C-ordered and then transpose)
# (CSR or CRS matrix is fast with C-ordered array but slow with F-array)
if not indata.flags["C_CONTIGUOUS"]:
"Input array is not C_CONTIGUOUS. " "Will affect performance.")
warnings.warn(
# get input shape information
= indata.shape[-2:]
shape_horiz = indata.shape[0:-2]
extra_shape
assert shape_horiz == shape_in, (
"The horizontal shape of input data is {}, different from that of"
"the regridder {}!".format(shape_horiz, shape_in)
)
assert (
0] * shape_in[1] == weights.shape[1]
shape_in["ny_in * nx_in should equal to weights.shape[1]"
),
assert (
0] * shape_out[1] == weights.shape[0]
shape_out["ny_out * nx_out should equal to weights.shape[0]"
),
# use flattened array for dot operation
= indata.reshape(-1, shape_in[0] * shape_in[1])
indata_flat = weights.dot(indata_flat.T).T
outdata_flat
# unflattened output array
= outdata_flat.reshape([*extra_shape, shape_out[0], shape_out[1]])
outdata return outdata
def regrid(dataset, zoom=0):
= earthaccess_args[dataset]
args # Load pre-generated weights and target dataset
= StorageConfig.s3_from_env(
weights_storage ="nasa-veda-scratch",
bucket=f"resampling/test-weight-caching/{dataset}-weights-{zoom}",
prefix="us-west-2",
region
)= StorageConfig.s3_from_env(
target_storage ="nasa-veda-scratch",
bucket=f"resampling/test-weight-caching/{dataset}-target-{zoom}",
prefix="us-west-2",
region
)= IcechunkStore.open_existing(storage=weights_storage, mode="r")
weights_store = IcechunkStore.open_existing(storage=target_storage, mode="r")
target_store = _reconstruct_xesmf_weights(
weights =3, consolidated=False)
xr.open_zarr(weights_store, zarr_format
)= xr.open_zarr(target_store, zarr_format=3, consolidated=False).load()
grid # Open dataset
= StorageConfig.s3_from_env(
storage ="nasa-veda-scratch",
bucket=f"resampling/icechunk/{dataset}",
prefix="us-west-2",
region
)= IcechunkStore.open_existing(storage=storage, mode="r")
store = xr.open_zarr(store, zarr_format=3, consolidated=False)[args["variable"]].load()
da return xr_regridder(da, grid, weights)
if __name__ == "__main__":
if "get_ipython" in dir():
# Just call warp_resample if running as a Jupyter Notebook
= regrid("gpm_imerg")
da else:
# Configure dataset via argpase if running via CLI
= argparse.ArgumentParser(description="Set environment for the script.")
parser
parser.add_argument("--dataset",
="gpm_imerg",
defaulthelp="Dataset to resample.",
=["gpm_imerg", "mursst"],
choices
)
parser.add_argument("--zoom",
=0,
defaulthelp="Zoom level for tile extent.",
)= parser.parse_args()
user_args = regrid(user_args.dataset, int(user_args.zoom)) da