diff --git a/app/service/auth_svc.py b/app/service/auth_svc.py index 17b2c1aea..b61458657 100644 --- a/app/service/auth_svc.py +++ b/app/service/auth_svc.py @@ -62,6 +62,13 @@ def __init__(self): self._login_handler = None self._default_login_handler = None + @staticmethod + def _ensure_str(value) -> str: + """Convert a value to str safely, decoding bytes instead of using str() repr.""" + if isinstance(value, bytes): + return value.decode('utf-8', errors='replace') + return str(value) + @property def default_login_handler(self): return self._default_login_handler @@ -170,9 +177,12 @@ async def get_permissions(self, request): identity = await identity_policy.identify(request) if identity in self.user_map: return [self.Access[p.upper()] for p in self.user_map[identity].permissions] - elif request.headers.get(HEADER_API_KEY) == self.get_config(CONFIG_API_KEY_RED): + request_key = self._ensure_str(request.headers.get(HEADER_API_KEY) or '') + red_key = self._ensure_str(self.get_config(CONFIG_API_KEY_RED) or '') + blue_key = self._ensure_str(self.get_config(CONFIG_API_KEY_BLUE) or '') + if red_key and compare_digest(request_key, red_key): return self.Access.RED, self.Access.APP - elif request.headers.get(HEADER_API_KEY) == self.get_config(CONFIG_API_KEY_BLUE): + elif blue_key and compare_digest(request_key, blue_key): return self.Access.BLUE, self.Access.APP return () diff --git a/tests/test_compare_digest_auth.py b/tests/test_compare_digest_auth.py new file mode 100644 index 000000000..4cb50cda4 --- /dev/null +++ b/tests/test_compare_digest_auth.py @@ -0,0 +1,85 @@ +import pytest +from unittest.mock import MagicMock, AsyncMock +from app.service.auth_svc import AuthService, HEADER_API_KEY +from app.utility.base_world import BaseWorld + + +@pytest.fixture(autouse=True) +def setup_config(): + BaseWorld.apply_config('main', { + 'api_key_red': 'RED_KEY_123', + 'api_key_blue': 'BLUE_KEY_456', + }) + yield + BaseWorld.clear_config() + + +class MockHeaders: + def __init__(self, data): + self._data = data + + def get(self, key, default=None): + return self._data.get(key, default) + + +class TestCompareDigestAuth: + def _make_request(self, api_key=None): + request = MagicMock() + headers_data = {HEADER_API_KEY: api_key} if api_key else {} + request.headers = MockHeaders(headers_data) + identity_policy = AsyncMock() + identity_policy.identify = AsyncMock(return_value=None) + request.config_dict = {'aiohttp_security_identity_policy': identity_policy} + return request + + @pytest.mark.asyncio + async def test_red_key_returns_red_access(self): + svc = AuthService.__new__(AuthService) + svc.user_map = {} + request = self._make_request('RED_KEY_123') + result = await svc.get_permissions(request) + assert BaseWorld.Access.RED in result + assert BaseWorld.Access.APP in result + + @pytest.mark.asyncio + async def test_blue_key_returns_blue_access(self): + svc = AuthService.__new__(AuthService) + svc.user_map = {} + request = self._make_request('BLUE_KEY_456') + result = await svc.get_permissions(request) + assert BaseWorld.Access.BLUE in result + assert BaseWorld.Access.APP in result + + @pytest.mark.asyncio + async def test_wrong_key_returns_empty(self): + svc = AuthService.__new__(AuthService) + svc.user_map = {} + request = self._make_request('WRONG_KEY') + result = await svc.get_permissions(request) + assert result == () + + @pytest.mark.asyncio + async def test_no_key_returns_empty(self): + svc = AuthService.__new__(AuthService) + svc.user_map = {} + request = self._make_request(None) + result = await svc.get_permissions(request) + assert result == () + + def test_ensure_str_decodes_bytes(self): + """bytes values should be decoded, not repr'd as 'b\"...\"'.""" + assert AuthService._ensure_str(b'KEY123') == 'KEY123' + assert AuthService._ensure_str('KEY123') == 'KEY123' + + @pytest.mark.asyncio + async def test_bytes_config_key_matches(self): + """If config returns bytes, compare_digest should still match the str header.""" + BaseWorld.apply_config('main', { + 'api_key_red': b'RED_KEY_123', + 'api_key_blue': b'BLUE_KEY_456', + }) + svc = AuthService.__new__(AuthService) + svc.user_map = {} + request = self._make_request('RED_KEY_123') + result = await svc.get_permissions(request) + assert BaseWorld.Access.RED in result