Skip to content

STAC + Xarray

Goal: Create a custom STAC Reader supporting both COG and NetCDF/Zarr dataset

requirements:

  • titiler.core
  • titiler.xarray
  • fsspec
  • zarr
  • h5netcdf
  • aiohttp (optional)
  • s3fs (optional)

links:

1. Custom STACReader

First, we need to create a custom STACReader which will support both COG and NetCDF/Zarr dataset. The custom parts will be:

  • add netcdf and zarr as valid asset media types
  • introduce a new md:// prefixed asset form, so users can pass assets=md://{netcdf asset name}?variable={variable name} as we do for the GDAL vrt string connection support.
stac.py
from typing import Set, Type, Tuple, Dict, Optional

import attr
from urllib.parse import urlparse, parse_qsl
from rio_tiler.types import AssetInfo
from rio_tiler.io import BaseReader, Reader
from rio_tiler.io import stac

from titiler.xarray.io import Reader as XarrayReader

valid_types = {
    *stac.DEFAULT_VALID_TYPE,
    "application/x-netcdf",
    "application/vnd+zarr",
}


@attr.s
class STACReader(stac.STACReader):
    """Custom STACReader which adds support for `md://` prefixed assets.

    Example:
    >>> with STACReader("https://raw.githubusercontent.com/cogeotiff/rio-tiler/refs/heads/main/tests/fixtures/stac_netcdf.json") as src:
            print(src.assets)
            print(src._get_asset_info("md://netcdf?variable=dataset"))

        ['geotiff', 'netcdf']
        {'url': 'https://raw.githubusercontent.com/cogeotiff/rio-tiler/refs/heads/main/tests/fixtures/dataset_2d.nc', 'metadata': {}, 'reader_options': {'variable': 'dataset'}, 'media_type': 'application/x-netcdf'}

    """
    include_asset_types: Set[str] = attr.ib(default=valid_types)

    def _get_reader(self, asset_info: AssetInfo) -> Tuple[Type[BaseReader], Dict]:
        """Get Asset Reader."""
        asset_type = asset_info.get("media_type", None)
        if asset_type and asset_type in [
            "application/x-netcdf",
            "application/vnd+zarr",
            "application/x-hdf5",
            "application/x-hdf",
        ]:
            return XarrayReader, asset_info.get("reader_options", {})

        return Reader, asset_info.get("reader_options", {})

    def _parse_md_asset(self, asset: str) -> Tuple[str, Optional[Dict]]:
        """Parse md:// asset string and return both asset name and reader options"""
        if asset.startswith("md://") and asset not in self.assets:
            parsed = urlparse(asset)
            if not parsed.netloc or parsed.netloc not in self.assets:
                raise InvalidAssetName(
                    f"'{parsed.netloc}' is not valid, should be one of {self.assets}"
                )

            # NOTE: by using `parse_qsl` we assume the
            # reader_options are in form of `key=single_value`
            # reader_options for XarrayReader are:
            # - variable: str
            # - group: Optional[str]
            # - decode_times: bool = True
            # - datetime: Optional[str]
            # - drop_dim: Optional[str]
            return parsed.netloc, dict(parse_qsl(parsed.query))

        return asset, None

    def _get_asset_info(self, asset: str) -> AssetInfo:
        """Validate asset names and return asset's info.

        Args:
            asset (str): STAC asset name.

        Returns:
            AssetInfo: STAC asset info.

        """
        vrt_options = None
        reader_options = None
        if asset.startswith("vrt://"):
            asset, vrt_options = self._parse_vrt_asset(asset)

        # not part of the original STACReader
        elif asset.startswith("md://"):
            asset, reader_options = self._parse_md_asset(asset)

        if asset not in self.assets:
            raise InvalidAssetName(
                f"'{asset}' is not valid, should be one of {self.assets}"
            )

        asset_info = self.item.assets[asset]
        extras = asset_info.extra_fields

        info = AssetInfo(
            url=asset_info.get_absolute_href() or asset_info.href,
            metadata=extras if not vrt_options else None,
            reader_options=reader_options or {}
        )

        if stac.STAC_ALTERNATE_KEY and extras.get("alternate"):
            if alternate := extras["alternate"].get(stac.STAC_ALTERNATE_KEY):
                info["url"] = alternate["href"]

        if asset_info.media_type:
            info["media_type"] = asset_info.media_type

        # https://github.com/stac-extensions/file
        if head := extras.get("file:header_size"):
            info["env"] = {"GDAL_INGESTED_BYTES_AT_OPEN": head}

        # https://github.com/stac-extensions/raster
        if extras.get("raster:bands") and not vrt_options:
            bands = extras.get("raster:bands")
            stats = [
                (b["statistics"]["minimum"], b["statistics"]["maximum"])
                for b in bands
                if {"minimum", "maximum"}.issubset(b.get("statistics", {}))
            ]
            # check that stats data are all double and make warning if not
            if (
                stats
                and all(isinstance(v, (int, float)) for stat in stats for v in stat)
                and len(stats) == len(bands)
            ):
                info["dataset_statistics"] = stats
            else:
                warnings.warn(
                    "Some statistics data in STAC are invalid, they will be ignored."
                )

        if vrt_options:
            info["url"] = f"vrt://{info['url']}?{vrt_options}"

        return info

2. Endpoint Factory

Custom MultiBaseTilerFactory which removes some endpoints (/preview) and adapt dependencies to work with both COG and Xarray Datasets.

factory.py
"""Custom MultiBaseTilerFactory."""
from dataclasses import dataclass
from typing import Annotated, Type, Union, Optional, List
from attrs import define, field
from geojson_pydantic.features import Feature, FeatureCollection
from fastapi import Body, Depends, Query
from titiler.core import factory
from titiler.core.dependencies import (
    DefaultDependency,
    BidxParams,
    AssetsParams,
    AssetsBidxExprParamsOptional,
    CoordCRSParams,
    DstCRSParams,
)
from titiler.core.models.responses import MultiBaseStatisticsGeoJSON
from titiler.core.resources.responses import GeoJSONResponse
from rio_tiler.constants import WGS84_CRS
from rio_tiler.io import MultiBaseReader

from stac import STACReader


# Simple Asset dependency (1 asset, no expression)
@dataclass
class SingleAssetsParams(DefaultDependency):
    """Custom Assets parameters which only accept ONE asset and make it required."""

    assets: Annotated[
        str,
        Query(title="Asset names", description="Asset's name."),
    ]

    indexes: Annotated[
        Optional[List[int]],
        Query(
            title="Band indexes",
            alias="bidx",
            description="Dataset band indexes",
            openapi_examples={
                "user-provided": {"value": None},
                "one-band": {"value": [1]},
                "multi-bands": {"value": [1, 2, 3]},
            },
        ),
    ] = None


@define(kw_only=True)
class MultiBaseTilerFactory(factory.MultiBaseTilerFactory):

    reader: Type[MultiBaseReader] = STACReader

    # Assets/Indexes/Expression dependency
    layer_dependency: Type[DefaultDependency] = SingleAssetsParams

    # Assets dependency (for /info endpoints)
    assets_dependency: Type[DefaultDependency] = AssetsParams

    # remove preview endpoints
    img_preview_dependency: Type[DefaultDependency] = field(init=False)
    add_preview: bool = field(init=False, default=False)

    # Overwrite the `/statistics` endpoint to remove `full` dataset statistics (which could be unusable for NetCDF dataset)
    def statistics(self):  # noqa: C901
        """Register /statistics endpoint."""

        @self.router.post(
            "/statistics",
            response_model=MultiBaseStatisticsGeoJSON,
            response_model_exclude_none=True,
            response_class=GeoJSONResponse,
            responses={
                200: {
                    "content": {"application/geo+json": {}},
                    "description": "Return dataset's statistics from feature or featureCollection.",
                }
            },
        )
        def geojson_statistics(
            geojson: Annotated[
                Union[FeatureCollection, Feature],
                Body(description="GeoJSON Feature or FeatureCollection."),
            ],
            src_path=Depends(self.path_dependency),
            reader_params=Depends(self.reader_dependency),
            layer_params=Depends(AssetsBidxExprParamsOptional),
            dataset_params=Depends(self.dataset_dependency),
            coord_crs=Depends(CoordCRSParams),
            dst_crs=Depends(DstCRSParams),
            post_process=Depends(self.process_dependency),
            image_params=Depends(self.img_part_dependency),
            stats_params=Depends(self.stats_dependency),
            histogram_params=Depends(self.histogram_dependency),
            env=Depends(self.environment_dependency),
        ):
            """Get Statistics from a geojson feature or featureCollection."""
            fc = geojson
            if isinstance(fc, Feature):
                fc = FeatureCollection(type="FeatureCollection", features=[geojson])

            with rasterio.Env(**env):
                with self.reader(src_path, **reader_params.as_dict()) as src_dst:
                    # Default to all available assets
                    if not layer_params.assets and not layer_params.expression:
                        layer_params.assets = src_dst.assets

                    for feature in fc:
                        image = src_dst.feature(
                            feature.model_dump(exclude_none=True),
                            shape_crs=coord_crs or WGS84_CRS,
                            dst_crs=dst_crs,
                            align_bounds_with_dataset=True,
                            **layer_params.as_dict(),
                            **image_params.as_dict(),
                            **dataset_params.as_dict(),
                        )

                        if post_process:
                            image = post_process(image)

                        stats = image.statistics(
                            **stats_params.as_dict(),
                            hist_options=histogram_params.as_dict(),
                        )

                    feature.properties = feature.properties or {}
                    # NOTE: because we use `src_dst.feature` the statistics will be in form of
                    # `Dict[str, BandStatistics]` and not `Dict[str, Dict[str, BandStatistics]]`
                    feature.properties.update({"statistics": stats})

            return fc.features[0] if isinstance(geojson, Feature) else fc

3. Application

main.py
"""FastAPI application."""

from fastapi import FastAPI

from titiler.core.dependencies import DatasetPathParams
from titiler.core.errors import DEFAULT_STATUS_CODES, add_exception_handlers

from factory import MultiBaseTilerFactory

# STAC uses MultiBaseReader so we use MultiBaseTilerFactory to built the default endpoints
stac = MultiBaseTilerFactory(router_prefix="stac")

# Create FastAPI application
app = FastAPI()
app.include_router(stac.router, tags=["STAC"])
add_exception_handlers(app, DEFAULT_STATUS_CODES)
uvicorn app:app --port 8080 --reload

Screenshot 2024-11-07 at 4 42 21 PM

curl http://127.0.0.1:8080/assets\?url\=https%3A%2F%2Fraw.githubusercontent.com%2Fcogeotiff%2Frio-tiler%2Frefs%2Fheads%2Fmain%2Ftests%2Ffixtures%2Fstac_netcdf.json | jq

[
  "geotiff",
  "netcdf"
]
curl http://127.0.0.1:8080/info?url=https://raw.githubusercontent.com/cogeotiff/rio-tiler/refs/heads/main/tests/fixtures/stac_netcdf.json&assets=md://netcdf?variable=dataset  | jq
{
  "md://netcdf?variable=dataset": {
    "bounds": [
      -170.085,
      -80.08,
      169.914999999975,
      79.91999999999659
    ],
    "crs": "http://www.opengis.net/def/crs/EPSG/0/4326",
    "band_metadata": [
      [
        "b1",
        {}
      ]
    ],
    "band_descriptions": [
      [
        "b1",
        "value"
      ]
    ],
    "dtype": "float64",
    "nodata_type": "Nodata",
    "name": "dataset",
    "count": 1,
    "width": 2000,
    "height": 1000,
    "attrs": {
      "valid_min": 1.0,
      "valid_max": 1000.0,
      "fill_value": 0
    }
  }
}
curl http://127.0.0.1:8080/tiles/WebMercatorQuad/1/0/0?url=https://raw.githubusercontent.com/cogeotiff/rio-tiler/refs/heads/main/tests/fixtures/stac_netcdf.json&assets=md://netcdf?variable=dataset&rescale=0,1000