Skip to content

Cql2ApplyFilterBodyMiddleware

Middleware to augment the request body with a CQL2 filter for search requests.

Cql2ApplyFilterBodyMiddleware dataclass

Middleware to augment the request body with a CQL2 filter for search requests.

Parameters:

Name Type Description Default
app Callable[list, Awaitable[None]]
required
state_key str
'cql2_filter'
Source code in src/stac_auth_proxy/middleware/Cql2ApplyFilterBodyMiddleware.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
@required_conformance(
    r"http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2",
    r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-text",
    r"http://www.opengis.net/spec/cql2/1.0/conf/cql2-json",
)
@dataclass(frozen=True)
class Cql2ApplyFilterBodyMiddleware:
    """Middleware to augment the request body with a CQL2 filter for search requests."""

    app: ASGIApp
    state_key: str = "cql2_filter"

    search_body_endpoints = [
        r"^/search$",
    ]

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        """Apply the CQL2 filter to the request body."""
        if scope["type"] != "http":
            return await self.app(scope, receive, send)

        request = Request(scope)
        cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
        if not cql2_filter:
            return await self.app(scope, receive, send)

        if request.method not in ["POST", "PUT", "PATCH"]:
            return await self.app(scope, receive, send)

        if not any(
            re.match(expr, request.url.path) for expr in self.search_body_endpoints
        ):
            return await self.app(scope, receive, send)

        body = b""
        more_body = True
        while more_body:
            message = await receive()
            if message["type"] == "http.request":
                body += message.get("body", b"")
                more_body = message.get("more_body", False)

        try:
            body_json = json.loads(body) if body else {}
        except json.JSONDecodeError:
            logger.warning("Failed to parse request body as JSON")
            from starlette.responses import JSONResponse

            response = JSONResponse(
                {
                    "code": "ParseError",
                    "description": "Request body must be valid JSON.",
                },
                status_code=400,
            )
            await response(scope, receive, send)
            return

        if not isinstance(body_json, dict):
            logger.warning("Request body must be a JSON object")
            from starlette.responses import JSONResponse

            response = JSONResponse(
                {
                    "code": "TypeError",
                    "description": "Request body must be a JSON object.",
                },
                status_code=400,
            )
            await response(scope, receive, send)
            return

        new_body = json.dumps(
            filters.append_body_filter(body_json, cql2_filter)
        ).encode("utf-8")

        # Patch content-length in the headers
        headers = dict(scope["headers"])
        headers[b"content-length"] = str(len(new_body)).encode("latin1")
        scope = dict(scope)
        scope["headers"] = list(headers.items())

        async def new_receive():
            return {
                "type": "http.request",
                "body": new_body,
                "more_body": False,
            }

        await self.app(scope, new_receive, send)

__call__(scope: Scope, receive: Receive, send: Send) -> None async

Apply the CQL2 filter to the request body.

Source code in src/stac_auth_proxy/middleware/Cql2ApplyFilterBodyMiddleware.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
    """Apply the CQL2 filter to the request body."""
    if scope["type"] != "http":
        return await self.app(scope, receive, send)

    request = Request(scope)
    cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
    if not cql2_filter:
        return await self.app(scope, receive, send)

    if request.method not in ["POST", "PUT", "PATCH"]:
        return await self.app(scope, receive, send)

    if not any(
        re.match(expr, request.url.path) for expr in self.search_body_endpoints
    ):
        return await self.app(scope, receive, send)

    body = b""
    more_body = True
    while more_body:
        message = await receive()
        if message["type"] == "http.request":
            body += message.get("body", b"")
            more_body = message.get("more_body", False)

    try:
        body_json = json.loads(body) if body else {}
    except json.JSONDecodeError:
        logger.warning("Failed to parse request body as JSON")
        from starlette.responses import JSONResponse

        response = JSONResponse(
            {
                "code": "ParseError",
                "description": "Request body must be valid JSON.",
            },
            status_code=400,
        )
        await response(scope, receive, send)
        return

    if not isinstance(body_json, dict):
        logger.warning("Request body must be a JSON object")
        from starlette.responses import JSONResponse

        response = JSONResponse(
            {
                "code": "TypeError",
                "description": "Request body must be a JSON object.",
            },
            status_code=400,
        )
        await response(scope, receive, send)
        return

    new_body = json.dumps(
        filters.append_body_filter(body_json, cql2_filter)
    ).encode("utf-8")

    # Patch content-length in the headers
    headers = dict(scope["headers"])
    headers[b"content-length"] = str(len(new_body)).encode("latin1")
    scope = dict(scope)
    scope["headers"] = list(headers.items())

    async def new_receive():
        return {
            "type": "http.request",
            "body": new_body,
            "more_body": False,
        }

    await self.app(scope, new_receive, send)