import argparse
import itertools
import fsspec
import numpy as np
import pyproj
import rasterio.transform
import xarray as xr
import xesmf as xe
Resampling with XESMF (local storage, NetCDF file, H5NetCDF driver)
def make_grid_ds(*, te, tilesize, dstSRS) -> xr.Dataset:
"""
Make a dataset representing a target grid
Returns
-------
xr.Dataset
Target grid dataset with the following variables:
- "x": X coordinate in Web Mercator projection (grid cell center)
- "y": Y coordinate in Web Mercator projection (grid cell center)
- "lat": latitude coordinate (grid cell center)
- "lon": longitude coordinate (grid cell center)
- "lat_b": latitude bounds for grid cell
- "lon_b": longitude bounds for grid cell
Notes
-----
Modified from ndpyramid - https://github.com/carbonplan/ndpyramid
"""
= rasterio.transform.Affine.translation(
transform 0], te[3]
te[* rasterio.transform.Affine.scale((te[2] * 2) / tilesize, (te[1] * 2) / tilesize)
)
= pyproj.Proj(dstSRS)
p
= (tilesize, tilesize)
grid_shape = (tilesize + 1, tilesize + 1)
bounds_shape
= np.empty(grid_shape)
xs = np.empty(grid_shape)
ys = np.empty(grid_shape)
lat = np.empty(grid_shape)
lon = np.zeros(bounds_shape)
lat_b = np.zeros(bounds_shape)
lon_b
# calc grid cell center coordinates
= np.meshgrid(np.arange(tilesize) + 0.5, np.arange(tilesize) + 0.5)
ii, jj for i, j in itertools.product(range(grid_shape[0]), range(grid_shape[1])):
= [ii[i, j], jj[i, j]]
locs = transform * locs
xs[i, j], ys[i, j] = p(xs[i, j], ys[i, j], inverse=True)
lon[i, j], lat[i, j]
# calc grid cell bounds
= np.meshgrid(np.arange(tilesize + 1), np.arange(tilesize + 1))
iib, jjb for i, j in itertools.product(range(bounds_shape[0]), range(bounds_shape[1])):
= [iib[i, j], jjb[i, j]]
locs = transform * locs
x, y = p(x, y, inverse=True)
lon_b[i, j], lat_b[i, j]
return xr.Dataset(
{"x": xr.DataArray(xs[0, :], dims=["x"]),
"y": xr.DataArray(ys[:, 0], dims=["y"]),
"lat": xr.DataArray(lat, dims=["y", "x"]),
"lon": xr.DataArray(lon, dims=["y", "x"]),
"lat_b": xr.DataArray(lat_b, dims=["y_b", "x_b"]),
"lon_b": xr.DataArray(lon_b, dims=["y_b", "x_b"]),
}, )
def regrid(dataset, zoom=0):
from common import earthaccess_args, target_extent
= target_extent[zoom]
te
# Define filepath, driver, and variable information args = earthaccess_args[dataset]
= earthaccess_args[dataset]
args = f'earthaccess_data/{args["filename"]}'
src # Create grid to hold result
= make_grid_ds(te=te, tilesize=256, dstSRS="EPSG:3857")
target_grid # Specify fsspec caching since default options don't work well for raster data
= {
fsspec_caching "cache_type": "none",
}= fsspec.filesystem("file")
fs with fs.open(src, **fsspec_caching) as f:
# Open dataset
= xr.open_dataset(f, engine="h5netcdf", chunks={}, mask_and_scale=True)[
da "variable"]
args[
]# Create XESMF regridder
= xe.Regridder(
regridder
da,
target_grid,"nearest_s2d",
=True,
periodic="nearest_s2d",
extrap_method=True,
ignore_degenerate
)# Regrid dataset
return regridder(da).load()
%%time
if __name__ == "__main__":
if "get_ipython" in dir():
# Just call warp_resample if running as a Jupyter Notebook
= regrid("gpm_imerg")
da else:
# Configure dataset via argpase if running via CLI
= argparse.ArgumentParser(description="Set environment for the script.")
parser
parser.add_argument("--dataset",
="mursst",
defaulthelp="Dataset to resample.",
=["gpm_imerg", "mursst"],
choices
)
parser.add_argument("--zoom",
=0,
defaulthelp="Zoom level for tile extent.",
)= parser.parse_args()
user_args = regrid(user_args.dataset, int(user_args.zoom)) da