diff --git a/src/access_service/access_service.py b/src/access_service/access_service.py index 36a205f..9db955e 100644 --- a/src/access_service/access_service.py +++ b/src/access_service/access_service.py @@ -7,6 +7,7 @@ from src.models.accesskey import AccessKey from src.models.agent import Agent from src.rag_service.dao.agent.base import AgentDAO +from src.utils.crypto_utils import hash_access_key, verify_access_key class AccessService(AbstractAccessService): @@ -38,27 +39,37 @@ def generate_accesskey( agent = self.try_get_agent(agent_id) key = secrets.token_bytes(32) # AES-256 32byte secret key_str = base64.urlsafe_b64encode(key).decode("ascii") - access_key = AccessKey( + hashed_key = hash_access_key(key_str) + + stored_access_key = AccessKey( id=self.get_unique_access_key_id(agent), - key=key_str, + key=hashed_key.decode("utf-8"), # Store the hashed key as string name=name, expiry_date=expiry_date, created=datetime.now(), last_use=None, ) - agent.access_key.append(access_key) + + agent.access_key.append(stored_access_key) self.agent_database.add_agent(agent) - return access_key + + # return the original key to user (not hashed) + return AccessKey( + id=stored_access_key.id, + key=key_str, + name=name, + expiry_date=expiry_date, + created=stored_access_key.created, + last_use=None, + ) def authenticate(self, agent_id: str, access_key: str) -> bool: agent = self.try_get_agent(agent_id) access_keys: list[AccessKey] = agent.access_key for ak in access_keys: - if ak.key == access_key: - if ak.expiry_date is None: - return True - return datetime.now() < ak.expiry_date + if verify_access_key(access_key, ak.key.encode("utf-8")): + return ak.expiry_date is None or ak.expiry_date > datetime.now() return False diff --git a/tests/test_access_keys.py b/tests/test_access_keys.py index 65cba36..980d01c 100644 --- a/tests/test_access_keys.py +++ b/tests/test_access_keys.py @@ -4,6 +4,7 @@ from src.access_service.factory import AccessServiceConfig, access_service_factory from src.models.agent import Agent, Role +from src.utils.crypto_utils import verify_access_key from tests.mocks.mock_agent_dao import MockAgentDAO @@ -57,7 +58,14 @@ def test_create_access_key(): key = access_service.generate_accesskey("test", DEFAULT_DATETIME, agent.id) assert key.name == "test" assert key.expiry_date == DEFAULT_DATETIME - assert database.get_agent_by_id(agent.id).access_key[0] == key + + stored_key = database.get_agent_by_id(agent.id).access_key[0] + assert stored_key.id == key.id + assert stored_key.name == key.name + assert stored_key.expiry_date == key.expiry_date + assert stored_key.key != key.key + + assert verify_access_key(key.key, stored_key.key.encode("utf-8")) @pytest.mark.unit @@ -66,19 +74,24 @@ def test_create_access_key_no_expiery(): key = access_service.generate_accesskey("test", None, agent.id) assert key.name == "test" assert key.expiry_date is None - assert database.get_agent_by_id(agent.id).access_key[0] == key + + stored_key = database.get_agent_by_id(agent.id).access_key[0] + assert stored_key.id == key.id + assert stored_key.key != key.key + assert verify_access_key(key.key, stored_key.key.encode("utf-8")) + assert access_service.authenticate(agent_id=agent.id, access_key=key.key) @pytest.mark.unit def test_revoke_access_key(): agent, access_service, database = get_access_service_and_agent() - key = access_service.generate_accesskey("test", DEFAULT_DATETIME, agent.id) - assert key.name == "test" - assert key.expiry_date == DEFAULT_DATETIME - assert database.get_agent_by_id(agent.id).access_key[0] == key + returned_key = access_service.generate_accesskey("test", DEFAULT_DATETIME, agent.id) - assert access_service.revoke_key(agent.id, key.id) + stored_key = database.get_agent_by_id(agent.id).access_key[0] + assert verify_access_key(returned_key.key, stored_key.key.encode("utf-8")) + + assert access_service.revoke_key(agent.id, returned_key.id) assert len(database.get_agent_by_id(agent.id).access_key) == 0 @@ -94,13 +107,19 @@ def test_bad_request(): @pytest.mark.unit def test_authenticate_key(): agent, access_service, database = get_access_service_and_agent() - key = access_service.generate_accesskey("test", DEFAULT_DATETIME, agent.id) - assert key.name == "test" - assert key.expiry_date == DEFAULT_DATETIME - assert database.get_agent_by_id(agent.id).access_key[0] == key + returned_key = access_service.generate_accesskey("test", DEFAULT_DATETIME, agent.id) - assert access_service.authenticate(agent.id, key.key) + assert access_service.authenticate(agent.id, returned_key.key) - assert access_service.revoke_key(agent.id, key.id) + assert access_service.revoke_key(agent.id, returned_key.id) + assert not access_service.authenticate(agent.id, returned_key.key) + + +@pytest.mark.unit +def test_authenticate_rejects_wrong_key(): + agent, access_service, _ = get_access_service_and_agent() + returned_key = access_service.generate_accesskey("test", DEFAULT_DATETIME, agent.id) - assert not access_service.authenticate(agent.id, key.key) + wrong = returned_key.key[:-2] + "AA" + assert wrong != returned_key.key + assert not access_service.authenticate(agent.id, wrong)