Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion common/auth/tests/test_pep_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
def _request(path: str, method: str = "GET"):
"""Request mock for route matching"""
req = MagicMock()
req.url.path = path.rstrip("/") or "/"
normalized = path.rstrip("/") or "/"
req.url.path = normalized
req.scope = {"path": normalized}
req.method = method.upper()
return req

Expand Down
44 changes: 17 additions & 27 deletions common/auth/tests/test_resource_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@
from fastapi import HTTPException, Request


def _request(path: str, method: str = "GET") -> MagicMock:
request = MagicMock(spec=Request)
request.url.path = path
request.scope = {"path": path}
request.method = method
return request


class TestExtractTenantFromBody:
"""Tests for _extract_tenant_from_body function"""

Expand Down Expand Up @@ -101,9 +109,7 @@ class TestExtractStacResourceId:
@pytest.mark.asyncio
async def test_get_collection_with_tenant(self):
"""Test extracting resource ID for GET collection with tenant"""
request = MagicMock(spec=Request)
request.url.path = "/collections/test-collection"
request.method = "GET"
request = _request("/collections/test-collection", "GET")
request.state.tenant = "test-tenant"

result = await extract_stac_resource_id(request)
Expand All @@ -112,9 +118,7 @@ async def test_get_collection_with_tenant(self):
@pytest.mark.asyncio
async def test_get_collection_without_tenant(self):
"""Test extracting resource ID for GET collection without tenant (defaults to public)"""
request = MagicMock(spec=Request)
request.url.path = "/collections/test-collection"
request.method = "GET"
request = _request("/collections/test-collection", "GET")
request.state = MagicMock()
delattr(request.state, "tenant")

Expand All @@ -127,9 +131,7 @@ async def test_put_collection_with_tenant_in_body(self):
body_data = {"eic:tenant": "test-tenant", "id": "test-collection"}
test_body = json.dumps(body_data).encode("utf-8")

request = MagicMock(spec=Request)
request.url.path = "/collections/test-collection"
request.method = "PUT"
request = _request("/collections/test-collection", "PUT")
request.body = AsyncMock(return_value=test_body)

result = await extract_stac_resource_id(request)
Expand All @@ -141,9 +143,7 @@ async def test_post_collections_create_with_tenant_in_body(self):
body_data = {"eic:tenant": "test-tenant", "id": "new-collection"}
test_body = json.dumps(body_data).encode("utf-8")

request = MagicMock(spec=Request)
request.url.path = "/collections"
request.method = "POST"
request = _request("/collections", "POST")
request.body = AsyncMock(return_value=test_body)

result = await extract_stac_resource_id(request)
Expand All @@ -155,9 +155,7 @@ async def test_post_collections_create_without_tenant_in_body(self):
body_data = {"id": "new-collection", "type": "Collection"}
test_body = json.dumps(body_data).encode("utf-8")

request = MagicMock(spec=Request)
request.url.path = "/collections"
request.method = "POST"
request = _request("/collections", "POST")
request.body = AsyncMock(return_value=test_body)

result = await extract_stac_resource_id(request)
Expand All @@ -166,9 +164,7 @@ async def test_post_collections_create_without_tenant_in_body(self):
@pytest.mark.asyncio
async def test_get_item_with_tenant(self):
"""Test extracting resource ID for GET item with tenant"""
request = MagicMock(spec=Request)
request.url.path = "/collections/test-collection/items/test-item"
request.method = "GET"
request = _request("/collections/test-collection/items/test-item", "GET")
request.state.tenant = "test-tenant"

result = await extract_stac_resource_id(request)
Expand All @@ -177,9 +173,7 @@ async def test_get_item_with_tenant(self):
@pytest.mark.asyncio
async def test_post_items_with_tenant(self):
"""Test extracting resource ID for POST items with tenant"""
request = MagicMock(spec=Request)
request.url.path = "/collections/test-collection/items"
request.method = "POST"
request = _request("/collections/test-collection/items", "POST")
request.state.tenant = "test-tenant"

result = await extract_stac_resource_id(request)
Expand All @@ -188,9 +182,7 @@ async def test_post_items_with_tenant(self):
@pytest.mark.asyncio
async def test_post_bulk_items_with_tenant(self):
"""Test extracting resource ID for POST bulk_items with tenant"""
request = MagicMock(spec=Request)
request.url.path = "/collections/test-collection/bulk_items"
request.method = "POST"
request = _request("/collections/test-collection/bulk_items", "POST")
request.state.tenant = "test-tenant"

result = await extract_stac_resource_id(request)
Expand All @@ -202,9 +194,7 @@ class TestExtractIngestResourceId:

async def test_delete_collection_returns_collection_id(self):
"""DELETE /collections/{id} should return collection-specific resource ID"""
request = MagicMock(spec=Request)
request.url.path = "/collections/test-collection"
request.method = "DELETE"
request = _request("/collections/test-collection", "DELETE")
request.state.tenant = "test-tenant"

resource_id = await extract_ingest_resource_id(request)
Expand Down
25 changes: 11 additions & 14 deletions common/auth/veda_auth/pep_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _get_matching_scope_and_route(
self, request: Request
) -> Optional[tuple[str, str]]:
"""Return (scope, method) for the route that matches, otherwise return None"""
path = request.url.path.rstrip("/") or "/"
path = (request.scope.get("path") or request.url.path).rstrip("/") or "/"
method = request.method.upper()
for pattern, route_method, scope in self._compiled:
if route_method == method and pattern.search(path):
Expand All @@ -112,30 +112,29 @@ def _get_bearer_token(self, request: Request) -> Optional[str]:

async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""Check UMA authorization for protected routes, pass through otherwise."""
path = (request.scope.get("path") or request.url.path).rstrip("/") or "/"
matched_request = self._get_matching_scope_and_route(request)
if matched_request is None:
logger.debug(
"PEP: no protected route match for %s %s... continuing",
request.method,
request.url.path,
path,
)
return await call_next(request)

scope, _method = matched_request
logger.info(
"PEP: matched protected route %s %s and scope=%s",
_method,
request.url.path,
path,
scope,
)

pdp_client = self._get_pdp_client()

token = self._get_bearer_token(request)
if not token:
logger.warning(
"PEP: missing Bearer token for %s %s", _method, request.url.path
)
logger.warning("PEP: missing Bearer token for %s %s", _method, path)
return pep_error_response(
401,
"Missing or invalid Authorization header (Bearer token required)",
Expand All @@ -144,7 +143,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:

resource_id = await self._extract_resource_id(request)
if not resource_id:
logger.warning("PEP: no resource ID for %s %s", _method, request.url.path)
logger.warning("PEP: no resource ID for %s %s", _method, path)
return pep_error_response(
403, "Could not determine resource for authorization"
)
Expand All @@ -153,7 +152,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
"PEP: checking permission resource_id=%s, scope=%s, path=%s",
resource_id,
scope,
request.url.path,
path,
)

try:
Expand All @@ -163,17 +162,15 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
scope=scope,
)
except TokenError as e:
logger.warning(
"PEP: token error for %s %s: %s", _method, request.url.path, e.detail
)
logger.warning("PEP: token error for %s %s: %s", _method, path, e.detail)
return pep_error_response(
401, e.detail, {"WWW-Authenticate": 'Bearer error="invalid_token"'}
)
except ResourceNotFoundError as e:
logger.warning(
"PEP: resource not found for %s %s: %s",
_method,
request.url.path,
path,
e.resource_id,
)
return pep_error_response(
Expand All @@ -185,7 +182,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
logger.warning(
"PEP: denied %s %s resource_id=%s, scope=%s",
_method,
request.url.path,
path,
e.resource_id,
e.scope,
)
Expand All @@ -212,7 +209,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
"PEP: authorized for resource_id=%s, scope=%s, path=%s",
resource_id,
scope,
request.url.path,
path,
)

return await call_next(request)
4 changes: 2 additions & 2 deletions common/auth/veda_auth/resource_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def extract_stac_resource_id(request: Request) -> Optional[str]:
- Collections: STAC_COLLECTION_TEMPLATE or STAC_COLLECTION_PUBLIC
- Items: STAC_ITEM_TEMPLATE or STAC_ITEM_PUBLIC
"""
path = request.url.path
path = request.scope.get("path") or request.url.path
method = request.method

if _COLLECTIONS_CREATE_PATH_PATTERN.match(path) and method == "POST":
Expand All @@ -120,7 +120,7 @@ async def extract_stac_resource_id(request: Request) -> Optional[str]:

async def extract_ingest_resource_id(request: Request) -> Optional[str]:
"""Extract resource ID for Ingest API requests"""
path = request.url.path
path = request.scope.get("path") or request.url.path
method = request.method

if path.endswith("/collections") and method == "POST":
Expand Down
3 changes: 2 additions & 1 deletion stac_api/runtime/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
inst_reqs = [
"boto3",
"async-lru>=2.0.5",
"starlette==1.0.1",
"stac-fastapi.api~=6.1",
"stac-fastapi.types~=6.1",
"stac-fastapi.extensions~=6.1",
Expand All @@ -22,7 +23,7 @@
"aws_xray_sdk>=2.6.0,<3",
"pystac[validation]>=1.14.0",
"pydantic>2",
"stac-auth-proxy==0.11.1rc2",
"stac-auth-proxy==1.1.1",
]

extra_reqs = {
Expand Down
Loading