Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions src/access_service/access_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from src.models.accesskey import AccessKey
from src.models.agent import Agent
from src.rag_service.dao.agent.base import AgentDAO
from src.utils.crypto_utils import hash_access_key, verify_access_key


class AccessService(AbstractAccessService):
Expand Down Expand Up @@ -38,27 +39,37 @@ def generate_accesskey(
agent = self.try_get_agent(agent_id)
key = secrets.token_bytes(32) # AES-256 32byte secret
key_str = base64.urlsafe_b64encode(key).decode("ascii")
access_key = AccessKey(
hashed_key = hash_access_key(key_str)

stored_access_key = AccessKey(
id=self.get_unique_access_key_id(agent),
key=key_str,
key=hashed_key.decode("utf-8"), # Store the hashed key as string
name=name,
expiry_date=expiry_date,
created=datetime.now(),
last_use=None,
)
agent.access_key.append(access_key)

agent.access_key.append(stored_access_key)
self.agent_database.add_agent(agent)
return access_key

# return the original key to user (not hashed)
return AccessKey(
id=stored_access_key.id,
key=key_str,
name=name,
expiry_date=expiry_date,
created=stored_access_key.created,
last_use=None,
)

def authenticate(self, agent_id: str, access_key: str) -> bool:
agent = self.try_get_agent(agent_id)
access_keys: list[AccessKey] = agent.access_key

for ak in access_keys:
if ak.key == access_key:
if ak.expiry_date is None:
return True
return datetime.now() < ak.expiry_date
if verify_access_key(access_key, ak.key.encode("utf-8")):
return ak.expiry_date is None or ak.expiry_date > datetime.now()
Comment thread
anettkva marked this conversation as resolved.

return False

Expand Down
47 changes: 33 additions & 14 deletions tests/test_access_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from src.access_service.factory import AccessServiceConfig, access_service_factory
from src.models.agent import Agent, Role
from src.utils.crypto_utils import verify_access_key
from tests.mocks.mock_agent_dao import MockAgentDAO


Expand Down Expand Up @@ -57,7 +58,14 @@ def test_create_access_key():
key = access_service.generate_accesskey("test", DEFAULT_DATETIME, agent.id)
assert key.name == "test"
assert key.expiry_date == DEFAULT_DATETIME
assert database.get_agent_by_id(agent.id).access_key[0] == key

stored_key = database.get_agent_by_id(agent.id).access_key[0]
assert stored_key.id == key.id
assert stored_key.name == key.name
assert stored_key.expiry_date == key.expiry_date
assert stored_key.key != key.key

assert verify_access_key(key.key, stored_key.key.encode("utf-8"))


@pytest.mark.unit
Expand All @@ -66,19 +74,24 @@ def test_create_access_key_no_expiery():
key = access_service.generate_accesskey("test", None, agent.id)
assert key.name == "test"
assert key.expiry_date is None
assert database.get_agent_by_id(agent.id).access_key[0] == key

stored_key = database.get_agent_by_id(agent.id).access_key[0]
assert stored_key.id == key.id
assert stored_key.key != key.key
assert verify_access_key(key.key, stored_key.key.encode("utf-8"))

assert access_service.authenticate(agent_id=agent.id, access_key=key.key)


@pytest.mark.unit
def test_revoke_access_key():
agent, access_service, database = get_access_service_and_agent()
key = access_service.generate_accesskey("test", DEFAULT_DATETIME, agent.id)
assert key.name == "test"
assert key.expiry_date == DEFAULT_DATETIME
assert database.get_agent_by_id(agent.id).access_key[0] == key
returned_key = access_service.generate_accesskey("test", DEFAULT_DATETIME, agent.id)

assert access_service.revoke_key(agent.id, key.id)
stored_key = database.get_agent_by_id(agent.id).access_key[0]
assert verify_access_key(returned_key.key, stored_key.key.encode("utf-8"))

assert access_service.revoke_key(agent.id, returned_key.id)
assert len(database.get_agent_by_id(agent.id).access_key) == 0


Expand All @@ -94,13 +107,19 @@ def test_bad_request():
@pytest.mark.unit
def test_authenticate_key():
agent, access_service, database = get_access_service_and_agent()
key = access_service.generate_accesskey("test", DEFAULT_DATETIME, agent.id)
assert key.name == "test"
assert key.expiry_date == DEFAULT_DATETIME
assert database.get_agent_by_id(agent.id).access_key[0] == key
returned_key = access_service.generate_accesskey("test", DEFAULT_DATETIME, agent.id)

assert access_service.authenticate(agent.id, key.key)
assert access_service.authenticate(agent.id, returned_key.key)

assert access_service.revoke_key(agent.id, key.id)
assert access_service.revoke_key(agent.id, returned_key.id)
assert not access_service.authenticate(agent.id, returned_key.key)


@pytest.mark.unit
def test_authenticate_rejects_wrong_key():
agent, access_service, _ = get_access_service_and_agent()
returned_key = access_service.generate_accesskey("test", DEFAULT_DATETIME, agent.id)

assert not access_service.authenticate(agent.id, key.key)
wrong = returned_key.key[:-2] + "AA"
assert wrong != returned_key.key
assert not access_service.authenticate(agent.id, wrong)