import argparse
import warnings
import numpy as np
import xarray as xr
from common import earthaccess_args
from icechunk import IcechunkStore, StorageConfigResampling with sparse (S3 storage, Zarr V3 store, Zarr reader with icechunk)¶
def _reconstruct_xesmf_weights(ds_w):
"""
Reconstruct weights into format that xESMF understands
Notes
-----
From ndpyramid - https://github.com/carbonplan/ndpyramid
"""
import sparse
import xarray as xr
col = ds_w["col"].values - 1
row = ds_w["row"].values - 1
s = ds_w["S"].values
n_out, n_in = ds_w.attrs["n_out"], ds_w.attrs["n_in"]
crds = np.stack([row, col])
return xr.DataArray(
sparse.COO(crds, s, (n_out, n_in)), dims=("out_dim", "in_dim"), name="weights"
)def xr_regridder(
ds: xr.Dataset,
grid: xr.Dataset,
weights: xr.DataArray,
) -> xr.Dataset:
"""
Xarray-aware regridding function that uses weights from xESMF but performs the regridding using sparse matrix multiplication.
Parameters
----------
ds
weights
out_grid_shape
Returns
-------
regridded_ds
Notes
-----
Modified from https://github.com/carbonplan/ndpyramid/pull/130
"""
latlon_dims = ["lat", "lon"]
shape_in = (ds.sizes["lat"], ds.sizes["lon"])
shape_out = (grid.sizes["y"], grid.sizes["x"])
regridded_ds = xr.apply_ufunc(
esmf_apply_weights,
weights,
ds,
input_core_dims=[["out_dim", "in_dim"], latlon_dims],
output_core_dims=[latlon_dims],
exclude_dims=set(latlon_dims),
kwargs={"shape_in": shape_in, "shape_out": shape_out},
keep_attrs=True,
)
return regridded_ds
def esmf_apply_weights(weights, indata, shape_in, shape_out):
"""
Apply regridding weights to data.
Parameters
----------
A : scipy sparse COO matrix
indata : numpy array of shape ``(..., n_lat, n_lon)`` or ``(..., n_y, n_x)``.
Should be C-ordered. Will be then tranposed to F-ordered.
shape_in, shape_out : tuple of two integers
Input/output data shape for unflatten operation.
For rectilinear grid, it is just ``(n_lat, n_lon)``.
Returns
-------
outdata : numpy array of shape ``(..., shape_out[0], shape_out[1])``.
Extra dimensions are the same as `indata`.
If input data is C-ordered, output will also be C-ordered.
Notes
-----
From https://github.com/carbonplan/ndpyramid/pull/130
"""
# COO matrix is fast with F-ordered array but slow with C-array, so we
# take in a C-ordered and then transpose)
# (CSR or CRS matrix is fast with C-ordered array but slow with F-array)
if not indata.flags["C_CONTIGUOUS"]:
warnings.warn("Input array is not C_CONTIGUOUS. " "Will affect performance.")
# get input shape information
shape_horiz = indata.shape[-2:]
extra_shape = indata.shape[0:-2]
assert shape_horiz == shape_in, (
"The horizontal shape of input data is {}, different from that of"
"the regridder {}!".format(shape_horiz, shape_in)
)
assert (
shape_in[0] * shape_in[1] == weights.shape[1]
), "ny_in * nx_in should equal to weights.shape[1]"
assert (
shape_out[0] * shape_out[1] == weights.shape[0]
), "ny_out * nx_out should equal to weights.shape[0]"
# use flattened array for dot operation
indata_flat = indata.reshape(-1, shape_in[0] * shape_in[1])
outdata_flat = weights.dot(indata_flat.T).T
# unflattened output array
outdata = outdata_flat.reshape([*extra_shape, shape_out[0], shape_out[1]])
return outdatadef regrid(dataset, zoom=0):
args = earthaccess_args[dataset]
# Load pre-generated weights and target dataset
weights_storage = StorageConfig.s3_from_env(
bucket="nasa-veda-scratch",
prefix=f"resampling/test-weight-caching/{dataset}-weights-{zoom}",
region="us-west-2",
)
target_storage = StorageConfig.s3_from_env(
bucket="nasa-veda-scratch",
prefix=f"resampling/test-weight-caching/{dataset}-target-{zoom}",
region="us-west-2",
)
weights_store = IcechunkStore.open_existing(storage=weights_storage, mode="r")
target_store = IcechunkStore.open_existing(storage=target_storage, mode="r")
weights = _reconstruct_xesmf_weights(
xr.open_zarr(weights_store, zarr_format=3, consolidated=False)
)
grid = xr.open_zarr(target_store, zarr_format=3, consolidated=False).load()
# Open dataset
storage = StorageConfig.s3_from_env(
bucket="nasa-veda-scratch",
prefix=f"resampling/icechunk/{dataset}",
region="us-west-2",
)
store = IcechunkStore.open_existing(storage=storage, mode="r")
da = xr.open_zarr(store, zarr_format=3, consolidated=False)[args["variable"]].load()
return xr_regridder(da, grid, weights)if __name__ == "__main__":
if "get_ipython" in dir():
# Just call warp_resample if running as a Jupyter Notebook
da = regrid("gpm_imerg")
else:
# Configure dataset via argpase if running via CLI
parser = argparse.ArgumentParser(description="Set environment for the script.")
parser.add_argument(
"--dataset",
default="gpm_imerg",
help="Dataset to resample.",
choices=["gpm_imerg", "mursst"],
)
parser.add_argument(
"--zoom",
default=0,
help="Zoom level for tile extent.",
)
user_args = parser.parse_args()
da = regrid(user_args.dataset, int(user_args.zoom))