diff --git a/common/auth/tests/test_pep_middleware.py b/common/auth/tests/test_pep_middleware.py index cd5998a7..f086684e 100644 --- a/common/auth/tests/test_pep_middleware.py +++ b/common/auth/tests/test_pep_middleware.py @@ -13,7 +13,9 @@ def _request(path: str, method: str = "GET"): """Request mock for route matching""" req = MagicMock() - req.url.path = path.rstrip("/") or "/" + normalized = path.rstrip("/") or "/" + req.url.path = normalized + req.scope = {"path": normalized} req.method = method.upper() return req diff --git a/common/auth/tests/test_resource_extractors.py b/common/auth/tests/test_resource_extractors.py index a8422169..1f5d9f11 100644 --- a/common/auth/tests/test_resource_extractors.py +++ b/common/auth/tests/test_resource_extractors.py @@ -17,6 +17,14 @@ from fastapi import HTTPException, Request +def _request(path: str, method: str = "GET") -> MagicMock: + request = MagicMock(spec=Request) + request.url.path = path + request.scope = {"path": path} + request.method = method + return request + + class TestExtractTenantFromBody: """Tests for _extract_tenant_from_body function""" @@ -101,9 +109,7 @@ class TestExtractStacResourceId: @pytest.mark.asyncio async def test_get_collection_with_tenant(self): """Test extracting resource ID for GET collection with tenant""" - request = MagicMock(spec=Request) - request.url.path = "/collections/test-collection" - request.method = "GET" + request = _request("/collections/test-collection", "GET") request.state.tenant = "test-tenant" result = await extract_stac_resource_id(request) @@ -112,9 +118,7 @@ async def test_get_collection_with_tenant(self): @pytest.mark.asyncio async def test_get_collection_without_tenant(self): """Test extracting resource ID for GET collection without tenant (defaults to public)""" - request = MagicMock(spec=Request) - request.url.path = "/collections/test-collection" - request.method = "GET" + request = _request("/collections/test-collection", "GET") request.state = MagicMock() delattr(request.state, "tenant") @@ -127,9 +131,7 @@ async def test_put_collection_with_tenant_in_body(self): body_data = {"eic:tenant": "test-tenant", "id": "test-collection"} test_body = json.dumps(body_data).encode("utf-8") - request = MagicMock(spec=Request) - request.url.path = "/collections/test-collection" - request.method = "PUT" + request = _request("/collections/test-collection", "PUT") request.body = AsyncMock(return_value=test_body) result = await extract_stac_resource_id(request) @@ -141,9 +143,7 @@ async def test_post_collections_create_with_tenant_in_body(self): body_data = {"eic:tenant": "test-tenant", "id": "new-collection"} test_body = json.dumps(body_data).encode("utf-8") - request = MagicMock(spec=Request) - request.url.path = "/collections" - request.method = "POST" + request = _request("/collections", "POST") request.body = AsyncMock(return_value=test_body) result = await extract_stac_resource_id(request) @@ -155,9 +155,7 @@ async def test_post_collections_create_without_tenant_in_body(self): body_data = {"id": "new-collection", "type": "Collection"} test_body = json.dumps(body_data).encode("utf-8") - request = MagicMock(spec=Request) - request.url.path = "/collections" - request.method = "POST" + request = _request("/collections", "POST") request.body = AsyncMock(return_value=test_body) result = await extract_stac_resource_id(request) @@ -166,9 +164,7 @@ async def test_post_collections_create_without_tenant_in_body(self): @pytest.mark.asyncio async def test_get_item_with_tenant(self): """Test extracting resource ID for GET item with tenant""" - request = MagicMock(spec=Request) - request.url.path = "/collections/test-collection/items/test-item" - request.method = "GET" + request = _request("/collections/test-collection/items/test-item", "GET") request.state.tenant = "test-tenant" result = await extract_stac_resource_id(request) @@ -177,9 +173,7 @@ async def test_get_item_with_tenant(self): @pytest.mark.asyncio async def test_post_items_with_tenant(self): """Test extracting resource ID for POST items with tenant""" - request = MagicMock(spec=Request) - request.url.path = "/collections/test-collection/items" - request.method = "POST" + request = _request("/collections/test-collection/items", "POST") request.state.tenant = "test-tenant" result = await extract_stac_resource_id(request) @@ -188,9 +182,7 @@ async def test_post_items_with_tenant(self): @pytest.mark.asyncio async def test_post_bulk_items_with_tenant(self): """Test extracting resource ID for POST bulk_items with tenant""" - request = MagicMock(spec=Request) - request.url.path = "/collections/test-collection/bulk_items" - request.method = "POST" + request = _request("/collections/test-collection/bulk_items", "POST") request.state.tenant = "test-tenant" result = await extract_stac_resource_id(request) @@ -202,9 +194,7 @@ class TestExtractIngestResourceId: async def test_delete_collection_returns_collection_id(self): """DELETE /collections/{id} should return collection-specific resource ID""" - request = MagicMock(spec=Request) - request.url.path = "/collections/test-collection" - request.method = "DELETE" + request = _request("/collections/test-collection", "DELETE") request.state.tenant = "test-tenant" resource_id = await extract_ingest_resource_id(request) diff --git a/common/auth/veda_auth/pep_middleware.py b/common/auth/veda_auth/pep_middleware.py index 1ed01d9b..17c475fb 100644 --- a/common/auth/veda_auth/pep_middleware.py +++ b/common/auth/veda_auth/pep_middleware.py @@ -96,7 +96,7 @@ def _get_matching_scope_and_route( self, request: Request ) -> Optional[tuple[str, str]]: """Return (scope, method) for the route that matches, otherwise return None""" - path = request.url.path.rstrip("/") or "/" + path = (request.scope.get("path") or request.url.path).rstrip("/") or "/" method = request.method.upper() for pattern, route_method, scope in self._compiled: if route_method == method and pattern.search(path): @@ -112,12 +112,13 @@ def _get_bearer_token(self, request: Request) -> Optional[str]: async def dispatch(self, request: Request, call_next: Callable) -> Response: """Check UMA authorization for protected routes, pass through otherwise.""" + path = (request.scope.get("path") or request.url.path).rstrip("/") or "/" matched_request = self._get_matching_scope_and_route(request) if matched_request is None: logger.debug( "PEP: no protected route match for %s %s... continuing", request.method, - request.url.path, + path, ) return await call_next(request) @@ -125,7 +126,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: logger.info( "PEP: matched protected route %s %s and scope=%s", _method, - request.url.path, + path, scope, ) @@ -133,9 +134,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: token = self._get_bearer_token(request) if not token: - logger.warning( - "PEP: missing Bearer token for %s %s", _method, request.url.path - ) + logger.warning("PEP: missing Bearer token for %s %s", _method, path) return pep_error_response( 401, "Missing or invalid Authorization header (Bearer token required)", @@ -144,7 +143,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: resource_id = await self._extract_resource_id(request) if not resource_id: - logger.warning("PEP: no resource ID for %s %s", _method, request.url.path) + logger.warning("PEP: no resource ID for %s %s", _method, path) return pep_error_response( 403, "Could not determine resource for authorization" ) @@ -153,7 +152,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: "PEP: checking permission resource_id=%s, scope=%s, path=%s", resource_id, scope, - request.url.path, + path, ) try: @@ -163,9 +162,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: scope=scope, ) except TokenError as e: - logger.warning( - "PEP: token error for %s %s: %s", _method, request.url.path, e.detail - ) + logger.warning("PEP: token error for %s %s: %s", _method, path, e.detail) return pep_error_response( 401, e.detail, {"WWW-Authenticate": 'Bearer error="invalid_token"'} ) @@ -173,7 +170,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: logger.warning( "PEP: resource not found for %s %s: %s", _method, - request.url.path, + path, e.resource_id, ) return pep_error_response( @@ -185,7 +182,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: logger.warning( "PEP: denied %s %s resource_id=%s, scope=%s", _method, - request.url.path, + path, e.resource_id, e.scope, ) @@ -212,7 +209,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: "PEP: authorized for resource_id=%s, scope=%s, path=%s", resource_id, scope, - request.url.path, + path, ) return await call_next(request) diff --git a/common/auth/veda_auth/resource_extractors.py b/common/auth/veda_auth/resource_extractors.py index 9901f713..69580026 100644 --- a/common/auth/veda_auth/resource_extractors.py +++ b/common/auth/veda_auth/resource_extractors.py @@ -93,7 +93,7 @@ async def extract_stac_resource_id(request: Request) -> Optional[str]: - Collections: STAC_COLLECTION_TEMPLATE or STAC_COLLECTION_PUBLIC - Items: STAC_ITEM_TEMPLATE or STAC_ITEM_PUBLIC """ - path = request.url.path + path = request.scope.get("path") or request.url.path method = request.method if _COLLECTIONS_CREATE_PATH_PATTERN.match(path) and method == "POST": @@ -120,7 +120,7 @@ async def extract_stac_resource_id(request: Request) -> Optional[str]: async def extract_ingest_resource_id(request: Request) -> Optional[str]: """Extract resource ID for Ingest API requests""" - path = request.url.path + path = request.scope.get("path") or request.url.path method = request.method if path.endswith("/collections") and method == "POST": diff --git a/stac_api/runtime/setup.py b/stac_api/runtime/setup.py index e330b299..793d9362 100644 --- a/stac_api/runtime/setup.py +++ b/stac_api/runtime/setup.py @@ -10,6 +10,7 @@ inst_reqs = [ "boto3", "async-lru>=2.0.5", + "starlette==1.0.1", "stac-fastapi.api~=6.1", "stac-fastapi.types~=6.1", "stac-fastapi.extensions~=6.1", @@ -22,7 +23,7 @@ "aws_xray_sdk>=2.6.0,<3", "pystac[validation]>=1.14.0", "pydantic>2", - "stac-auth-proxy==0.11.1rc2", + "stac-auth-proxy==1.1.1", ] extra_reqs = {