diff --git a/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py b/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py index 1909bfd..dff36fb 100644 --- a/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py +++ b/src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py @@ -26,12 +26,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: return await self.app(scope, receive, send) request = Request(scope) - original_filter = request.query_params.get("filter") cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None) if cql2_filter is None: # No filter set, just pass through return await self.app(scope, receive, send) + user_filter, receive = await self._extract_user_filter(request, receive) + # Intercept the response response_start = None body_chunks = [] @@ -46,19 +47,74 @@ async def send_wrapper(message: Message): more_body = message.get("more_body", False) if not more_body: await self._process_and_send_response( - response_start, body_chunks, send, original_filter + response_start, body_chunks, send, user_filter ) else: await send(message) await self.app(scope, receive, send_wrapper) + async def _extract_user_filter( + self, request: Request, receive: Receive + ) -> tuple[Optional[Expr], Receive]: + """ + Recover the user's original filter from either the query string or JSON body. + + For methods that may carry a JSON body (POST/PUT/PATCH), the body is buffered + and a replacement ``receive`` is returned so downstream consumers still see it. + """ + query_filter = request.query_params.get("filter") + if query_filter: + try: + return Expr(query_filter), receive + except Exception: + logger.warning("Failed to parse user filter from query string") + return None, receive + + if request.method not in ("POST", "PUT", "PATCH"): + return None, receive + + body = b"" + more_body = True + while more_body: + message = await receive() + if message["type"] == "http.request": + body += message.get("body", b"") + more_body = message.get("more_body", False) + else: + # e.g. http.disconnect - stop reading; downstream will get the same. + break + + async def replay_receive() -> Message: + return {"type": "http.request", "body": body, "more_body": False} + + if not body: + return None, replay_receive + + try: + body_json = json.loads(body) + except json.JSONDecodeError: + return None, replay_receive + + if not isinstance(body_json, dict): + return None, replay_receive + + body_filter = body_json.get("filter") + if body_filter is None: + return None, replay_receive + + try: + return Expr(body_filter), replay_receive + except Exception: + logger.warning("Failed to parse user filter from request body") + return None, replay_receive + async def _process_and_send_response( self, response_start: Message, body_chunks: list[bytes], send: Send, - original_filter: Optional[str], + user_filter: Optional[Expr], ): body = b"".join(body_chunks) try: @@ -68,7 +124,6 @@ async def _process_and_send_response( await send({"type": "http.response.body", "body": body, "more_body": False}) return - cql2_filter = Expr(original_filter) if original_filter else None links = data.get("links") if isinstance(links, list): for link in links: @@ -77,19 +132,24 @@ async def _process_and_send_response( url = urlparse(link["href"]) qs = parse_qs(url.query) if "filter" in qs: - if cql2_filter: - qs["filter"] = [cql2_filter.to_text()] + if user_filter is not None: + qs["filter"] = [user_filter.to_text()] else: qs.pop("filter", None) qs.pop("filter-lang", None) new_query = urlencode(qs, doseq=True) link["href"] = urlunparse(url._replace(query=new_query)) - # Handle filter in body (for POST links) + # Handle filter in body (for POST links). The spec only + # requires cql2-json for POST bodies, but if the link advertises + # cql2-text we preserve that lang on the way out. if "body" in link and isinstance(link["body"], dict): if "filter" in link["body"]: - if cql2_filter: - link["body"]["filter"] = cql2_filter.to_json() + if user_filter is not None: + if link["body"].get("filter-lang") == "cql2-text": + link["body"]["filter"] = user_filter.to_text() + else: + link["body"]["filter"] = user_filter.to_json() else: link["body"].pop("filter", None) link["body"].pop("filter-lang", None) diff --git a/tests/test_cql2_rewrite_links_filter_middleware.py b/tests/test_cql2_rewrite_links_filter_middleware.py index d5b07cd..4dabda5 100644 --- a/tests/test_cql2_rewrite_links_filter_middleware.py +++ b/tests/test_cql2_rewrite_links_filter_middleware.py @@ -1,5 +1,6 @@ """Test Cql2RewriteLinksFilterMiddleware.""" +import json import re import pytest @@ -12,6 +13,24 @@ ) +def _install_middlewares(app: FastAPI, system_filter: str) -> None: + """Attach the rewrite middleware behind a mock build-filter middleware.""" + + class MockBuildFilterMiddleware: + def __init__(self, app, state_key="cql2_filter"): + self.app = app + self.state_key = state_key + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + request = Request(scope) + setattr(request.state, self.state_key, Expr(system_filter)) + await self.app(scope, receive, send) + + app.add_middleware(Cql2RewriteLinksFilterMiddleware) + app.add_middleware(MockBuildFilterMiddleware) + + def test_non_json_response(): """Test middleware behavior with non-JSON responses.""" app = FastAPI() @@ -335,3 +354,123 @@ async def test_endpoint(request: Request): # Other data should be preserved assert body["other_data"] == "preserved" + + +class TestPostBodyClientFilterPreservation: + """ + Client filters sent in a POST search body must be preserved in next-link bodies. + + Regression: the middleware previously read the user's filter only from the + query string, silently dropping filters supplied via POST body. + """ + + @pytest.mark.parametrize( + "client_filter,client_filter_lang", + [ + ( + {"op": "<", "args": [{"property": "cloud_coverage"}, 50]}, + "cql2-json", + ), + ("cloud_coverage < 30", "cql2-text"), + (None, None), + ], + ) + def test_preserves_client_filter_from_post_body( + self, client_filter, client_filter_lang + ): + """Filter supplied in the POST body must be preserved in the next link.""" + app = FastAPI() + _install_middlewares(app, system_filter="private = false") + + @app.post("/search") + async def search_endpoint(request: Request): + body_json = await request.json() + system_expr = getattr(request.state, "cql2_filter", None) + user_filter = body_json.get("filter") + user_filter_lang = body_json.get("filter-lang") + + combined = None + if system_expr is not None and user_filter is not None: + combined = system_expr + Expr(user_filter) + elif system_expr is not None: + combined = system_expr + elif user_filter is not None: + combined = Expr(user_filter) + + next_body = {"token": "next-token"} + if combined is not None: + lang = user_filter_lang or "cql2-json" + next_body["filter-lang"] = lang + next_body["filter"] = ( + combined.to_text() if lang == "cql2-text" else combined.to_json() + ) + + return { + "links": [ + { + "rel": "next", + "method": "POST", + "href": "http://example.com/search", + "body": next_body, + } + ], + } + + request_body = {} + if client_filter is not None: + request_body["filter"] = client_filter + request_body["filter-lang"] = client_filter_lang + + response = TestClient(app).post("/search", json=request_body) + assert response.status_code == 200, response.text + body = response.json()["links"][0]["body"] + + assert body["token"] == "next-token" + + if client_filter is None: + # No client filter → system filter must not leak into next link. + assert "filter" not in body + assert "filter-lang" not in body + else: + # Compare semantically: cql2-python may re-emit equivalent text + # with different formatting (e.g. added parens). + assert Expr(body["filter"]).to_json() == Expr(client_filter).to_json() + assert body["filter-lang"] == client_filter_lang + + def test_request_body_is_intact_for_inner_app(self): + """Body capture must replay the exact original bytes to the inner app.""" + app = FastAPI() + _install_middlewares(app, system_filter="private = false") + + @app.post("/search") + async def search_endpoint(request: Request): + return {"echo": json.loads(await request.body())} + + request_body = { + "collections": ["a", "b"], + "filter": {"op": "=", "args": [{"property": "x"}, 1]}, + "filter-lang": "cql2-json", + } + response = TestClient(app).post("/search", json=request_body) + assert response.status_code == 200, response.text + assert response.json()["echo"] == request_body + + def test_malformed_json_body_does_not_break_middleware(self): + """An unparseable body must pass through without the middleware crashing.""" + app = FastAPI() + _install_middlewares(app, system_filter="private = false") + + @app.post("/search") + async def search_endpoint(request: Request): + return Response( + content=await request.body(), + media_type="application/octet-stream", + ) + + response = TestClient(app).post( + "/search", + content=b"not json", + headers={"content-type": "application/json"}, + ) + assert response.status_code == 200 + assert response.content == b"not json"