Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions app/service/auth_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def __init__(self):
self._login_handler = None
self._default_login_handler = None

@staticmethod
def _ensure_str(value) -> str:
"""Convert a value to str safely, decoding bytes instead of using str() repr."""
if isinstance(value, bytes):
return value.decode('utf-8', errors='replace')
return str(value)

@property
def default_login_handler(self):
return self._default_login_handler
Expand Down Expand Up @@ -170,9 +177,12 @@ async def get_permissions(self, request):
identity = await identity_policy.identify(request)
if identity in self.user_map:
return [self.Access[p.upper()] for p in self.user_map[identity].permissions]
elif request.headers.get(HEADER_API_KEY) == self.get_config(CONFIG_API_KEY_RED):
request_key = self._ensure_str(request.headers.get(HEADER_API_KEY) or '')
red_key = self._ensure_str(self.get_config(CONFIG_API_KEY_RED) or '')
blue_key = self._ensure_str(self.get_config(CONFIG_API_KEY_BLUE) or '')
if red_key and compare_digest(request_key, red_key):
return self.Access.RED, self.Access.APP
elif request.headers.get(HEADER_API_KEY) == self.get_config(CONFIG_API_KEY_BLUE):
elif blue_key and compare_digest(request_key, blue_key):
return self.Access.BLUE, self.Access.APP
return ()

Expand Down
85 changes: 85 additions & 0 deletions tests/test_compare_digest_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pytest
from unittest.mock import MagicMock, AsyncMock
from app.service.auth_svc import AuthService, HEADER_API_KEY
from app.utility.base_world import BaseWorld
Comment on lines +1 to +4


@pytest.fixture(autouse=True)
def setup_config():
BaseWorld.apply_config('main', {
'api_key_red': 'RED_KEY_123',
'api_key_blue': 'BLUE_KEY_456',
})
yield
BaseWorld.clear_config()


class MockHeaders:
def __init__(self, data):
self._data = data

def get(self, key, default=None):
return self._data.get(key, default)


class TestCompareDigestAuth:
def _make_request(self, api_key=None):
request = MagicMock()
headers_data = {HEADER_API_KEY: api_key} if api_key else {}
request.headers = MockHeaders(headers_data)
identity_policy = AsyncMock()
identity_policy.identify = AsyncMock(return_value=None)
request.config_dict = {'aiohttp_security_identity_policy': identity_policy}
return request

@pytest.mark.asyncio
async def test_red_key_returns_red_access(self):
svc = AuthService.__new__(AuthService)
svc.user_map = {}
request = self._make_request('RED_KEY_123')
result = await svc.get_permissions(request)
assert BaseWorld.Access.RED in result
assert BaseWorld.Access.APP in result

@pytest.mark.asyncio
async def test_blue_key_returns_blue_access(self):
svc = AuthService.__new__(AuthService)
svc.user_map = {}
request = self._make_request('BLUE_KEY_456')
result = await svc.get_permissions(request)
assert BaseWorld.Access.BLUE in result
assert BaseWorld.Access.APP in result

@pytest.mark.asyncio
async def test_wrong_key_returns_empty(self):
svc = AuthService.__new__(AuthService)
svc.user_map = {}
request = self._make_request('WRONG_KEY')
result = await svc.get_permissions(request)
assert result == ()

@pytest.mark.asyncio
async def test_no_key_returns_empty(self):
svc = AuthService.__new__(AuthService)
svc.user_map = {}
request = self._make_request(None)
result = await svc.get_permissions(request)
assert result == ()

def test_ensure_str_decodes_bytes(self):
"""bytes values should be decoded, not repr'd as 'b\"...\"'."""
assert AuthService._ensure_str(b'KEY123') == 'KEY123'
assert AuthService._ensure_str('KEY123') == 'KEY123'

@pytest.mark.asyncio
async def test_bytes_config_key_matches(self):
"""If config returns bytes, compare_digest should still match the str header."""
BaseWorld.apply_config('main', {
'api_key_red': b'RED_KEY_123',
'api_key_blue': b'BLUE_KEY_456',
})
svc = AuthService.__new__(AuthService)
svc.user_map = {}
request = self._make_request('RED_KEY_123')
result = await svc.get_permissions(request)
assert BaseWorld.Access.RED in result
Loading