XESMF with H5NetCDF and earthaccess

Requires the upcoming ESMF 8.7 release - https://github.com/pangeo-data/xESMF/issues/380

import itertools

import numpy as np
import pyproj
import rasterio.transform
import xarray as xr
import xesmf as xe
def configure_fs_auth():
    import earthaccess
    import s3fs

    auth = earthaccess.login()
    s3_credentials = auth.get_s3_credentials("PODAAC")
    fs = s3fs.S3FileSystem(
        anon=False,
        key=s3_credentials["accessKeyId"],
        secret=s3_credentials["secretAccessKey"],
        token=s3_credentials["sessionToken"],
    )
    return fs
def make_grid_ds() -> xr.Dataset:
    """
    Modified from ndpyramid - https://github.com/carbonplan/ndpyramid
    """
    dstSRS = "EPSG:3857"
    width = height = 256
    te = [
        -20037508.342789244,
        -20037508.342789244,
        20037508.342789244,
        20037508.342789244,
    ]

    transform = rasterio.transform.Affine.translation(
        te[0], te[3]
    ) * rasterio.transform.Affine.scale((te[2] * 2) / width, (te[1] * 2) / height)

    p = pyproj.Proj(dstSRS)

    grid_shape = (height, width)
    bounds_shape = (height + 1, width + 1)

    xs = np.empty(grid_shape)
    ys = np.empty(grid_shape)
    lat = np.empty(grid_shape)
    lon = np.empty(grid_shape)
    lat_b = np.zeros(bounds_shape)
    lon_b = np.zeros(bounds_shape)

    # calc grid cell center coordinates
    ii, jj = np.meshgrid(np.arange(height) + 0.5, np.arange(width) + 0.5)
    for i, j in itertools.product(range(grid_shape[0]), range(grid_shape[1])):
        locs = [ii[i, j], jj[i, j]]
        xs[i, j], ys[i, j] = transform * locs
        lon[i, j], lat[i, j] = p(xs[i, j], ys[i, j], inverse=True)

    # calc grid cell bounds
    iib, jjb = np.meshgrid(np.arange(height + 1), np.arange(width + 1))
    for i, j in itertools.product(range(bounds_shape[0]), range(bounds_shape[1])):
        locs = [iib[i, j], jjb[i, j]]
        x, y = transform * locs
        lon_b[i, j], lat_b[i, j] = p(x, y, inverse=True)

    return xr.Dataset(
        {
            "x": xr.DataArray(xs[0, :], dims=["x"]),
            "y": xr.DataArray(ys[:, 0], dims=["y"]),
            "lat": xr.DataArray(lat, dims=["y", "x"]),
            "lon": xr.DataArray(lon, dims=["y", "x"]),
            "lat_b": xr.DataArray(lat_b, dims=["y_b", "x_b"]),
            "lon_b": xr.DataArray(lon_b, dims=["y_b", "x_b"]),
        },
    )
def regrid():
    bucket = "podaac-ops-cumulus-protected"
    input_uri = "MUR-JPL-L4-GLOB-v4.1/20020601090000-JPL-L4_GHRSST-SSTfnd-MUR-GLOB-v02.0-fv04.1.nc"
    variable = "analysed_sst"
    src = f"s3://{bucket}/{input_uri}"
    fs = configure_fs_auth
    fsspec_caching = {
        "cache_type": "none",
    }
    target_grid = make_grid_ds()
    with fs.open(src, **fsspec_caching) as f:
        da = xr.open_dataset(f, engine="h5netcdf")[variable]
        regridder = xe.Regridder(
            da,
            target_grid,
            "nearest_s2d",
            periodic=True,
            extrap_method="nearest_s2d",
            ignore_degenerate=True,
        )
        return regridder(da)
if __name__ == "__main__":
    ds = regrid()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 2
      1 if __name__ == "__main__":
----> 2     ds = regrid()

Cell In[3], line 14, in regrid()
     12 with fs.open(src, **fsspec_caching) as f:
     13     da = xr.open_dataset(f, engine="h5netcdf")[variable]
---> 14     regridder = xe.Regridder(da, target_grid, "nearest_s2d", periodic= True, extrap_method="nearest_s2d", ignore_degenerate=True)
     15     return regridder(da)

File /opt/conda/lib/python3.11/site-packages/xesmf/frontend.py:919, in Regridder.__init__(self, ds_in, ds_out, method, locstream_in, locstream_out, periodic, parallel, **kwargs)
    917     grid_in, shape_in, input_dims = ds_to_ESMFlocstream(ds_in)
    918 else:
--> 919     grid_in, shape_in, input_dims = ds_to_ESMFgrid(
    920         ds_in, need_bounds=need_bounds, periodic=periodic
    921     )
    922 if locstream_out:
    923     grid_out, shape_out, output_dims = ds_to_ESMFlocstream(ds_out)

File /opt/conda/lib/python3.11/site-packages/xesmf/frontend.py:164, in ds_to_ESMFgrid(ds, need_bounds, periodic, append)
    162     grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=mask.T)
    163 else:
--> 164     grid = Grid.from_xarray(lon.T, lat.T, periodic=periodic, mask=None)
    166 if need_bounds:
    167     lon_b, lat_b = _get_lon_lat_bounds(ds)

File /opt/conda/lib/python3.11/site-packages/xesmf/backend.py:114, in Grid.from_xarray(cls, lon, lat, periodic, mask)
    108     num_peri_dims = None
    110 # ESMPy documentation claims that if staggerloc and coord_sys are None,
    111 # they will be set to default values (CENTER and SPH_DEG).
    112 # However, they actually need to be set explicitly,
    113 # otherwise grid._coord_sys and grid._staggerloc will still be None.
--> 114 grid = cls(
    115     np.array(lon.shape),
    116     staggerloc=staggerloc,
    117     coord_sys=ESMF.CoordSys.SPH_DEG,
    118     num_peri_dims=num_peri_dims,
    119 )
    121 # The grid object points to the underlying Fortran arrays in ESMF.
    122 # To modify lat/lon coordinates, need to get pointers to them
    123 lon_pointer = grid.get_coords(coord_dim=0, staggerloc=staggerloc)

File /opt/conda/lib/python3.11/site-packages/esmpy/util/decorators.py:59, in initialize.<locals>.new_func(*args, **kwargs)
     56 from esmpy.api import esmpymanager
     58 esmp = esmpymanager.Manager(debug = False)
---> 59 return func(*args, **kwargs)

File /opt/conda/lib/python3.11/site-packages/esmpy/api/grid.py:479, in Grid.__init__(self, max_index, num_peri_dims, periodic_dim, pole_dim, coord_sys, coord_typekind, staggerloc, pole_kind, filename, filetype, reg_decomp, decompflag, is_sphere, add_corner_stagger, add_user_area, add_mask, varname, coord_names, tilesize, regDecompPTile, name)
    454 # if self.decount == 1:
    455 # elif self.decount > 1:
    456 #     # lower_bounds[de][staggerLoc]
   (...)
    476 
    477 # Add coordinates if a staggerloc is specified
    478 if not isinstance(staggerloc, type(None)):
--> 479     self.add_coords(staggerloc=staggerloc, from_file=from_file)
    481 # Add items if they are specified, this is done after the
    482 # mask and area are initialized
    483 if add_user_area:

File /opt/conda/lib/python3.11/site-packages/esmpy/api/grid.py:837, in Grid.add_coords(self, staggerloc, coord_dim, from_file)
    834     ESMP_GridAddCoord(self, staggerloc=stagger)
    836 # and now for Python
--> 837 self._allocate_coords_(stagger, from_file=from_file)
    839 # set the staggerlocs to be done
    840 self.staggerloc[stagger] = True

File /opt/conda/lib/python3.11/site-packages/esmpy/api/grid.py:1062, in Grid._allocate_coords_(self, stagger, localde, from_file)
   1060 if (self.ndims == self.rank) or (self.ndims == 0):
   1061     for xyz in range(self.rank):
-> 1062         self._link_coord_buffer_(xyz, stagger, localde)
   1063 # and this way if we have 1d coordinates
   1064 elif self.ndims < self.rank:

File /opt/conda/lib/python3.11/site-packages/esmpy/api/grid.py:1112, in Grid._link_coord_buffer_(self, coord_dim, stagger, localde)
   1109 data = ESMP_GridGetCoordPtr(self, coord_dim, staggerloc=stagger, localde=localde)
   1110 lb, ub = ESMP_GridGetCoordBounds(self, staggerloc=stagger, localde=localde)
-> 1112 gridCoordP = ndarray_from_esmf(data, self.type, ub-lb)
   1114 # alias the coordinates to a grid property
   1115 self._coords[stagger][coord_dim] = gridCoordP

File /opt/conda/lib/python3.11/site-packages/esmpy/util/esmpyarray.py:38, in ndarray_from_esmf(data, dtype, shape)
     33 else:
     34     buffer = np.core.multiarray.int_asbuffer(
     35         ct.addressof(data.contents), size)
---> 38 esmfarray = np.ndarray(tuple(shape[:]), constants._ESMF2PythonType[dtype],
     39                        buffer, order="F")
     41 return esmfarray

TypeError: buffer is too small for requested array