diff --git a/DEVIATION.md b/DEVIATION.md new file mode 100644 index 00000000000..b22bb99744d --- /dev/null +++ b/DEVIATION.md @@ -0,0 +1,5 @@ +# Deviation Report: BA-4905 + +| Item | Type | Reason / Alternative | +|------|------|----------------------| +| Task 1: `LoginSecurityPolicy(BaseModel)` placed in `data/login_session/types.py` | Alternative applied | `data/` CLAUDE.md prohibits Pydantic imports. Pydantic models used with PydanticColumn follow the `models/{domain}/types.py` pattern (see `models/scaling_group/types.py`, `models/resource_slot/types.py`). `LoginSecurityPolicy` placed in `models/login_session/types.py` instead. | diff --git a/changes/9720.feature.md b/changes/9720.feature.md new file mode 100644 index 00000000000..493408cecef --- /dev/null +++ b/changes/9720.feature.md @@ -0,0 +1 @@ +Add LoginSecurityPolicy model, login_sessions table, LoginSessionRepository (Valkey Sorted Set + DB), and LoginSessionService for concurrent login session management. \ No newline at end of file diff --git a/src/ai/backend/common/data/permission/types.py b/src/ai/backend/common/data/permission/types.py index 8b1c0330ea8..fdbe06f7946 100644 --- a/src/ai/backend/common/data/permission/types.py +++ b/src/ai/backend/common/data/permission/types.py @@ -109,6 +109,7 @@ class EntityType(enum.StrEnum): RESOURCE_ALLOCATION = "resource_allocation" RESOURCE_GROUP = "resource_group" PROMETHEUS_QUERY_PRESET = "prometheus_query_preset" + LOGIN_SESSION = "login_session" RESOURCE_PRESET = "resource_preset" ROLE = "role" DOTFILE = "dotfile" diff --git a/src/ai/backend/common/metrics/metric.py b/src/ai/backend/common/metrics/metric.py index 3eb3435e266..78b0798b717 100644 --- a/src/ai/backend/common/metrics/metric.py +++ b/src/ai/backend/common/metrics/metric.py @@ -430,6 +430,7 @@ class LayerType(enum.StrEnum): PROMETHEUS_QUERY_PRESET_REPOSITORY = "prometheus_query_preset_repository" PROJECT_RESOURCE_POLICY_REPOSITORY = "project_resource_policy_repository" RESERVOIR_REGISTRY_REPOSITORY = "reservoir_registry_repository" + LOGIN_SESSION_REPOSITORY = "login_session_repository" RESOURCE_PRESET_REPOSITORY = "resource_preset_repository" SCALING_GROUP_REPOSITORY = "scaling_group_repository" SCHEDULE_REPOSITORY = "schedule_repository" diff --git a/src/ai/backend/manager/data/login_session/__init__.py b/src/ai/backend/manager/data/login_session/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/ai/backend/manager/data/login_session/types.py b/src/ai/backend/manager/data/login_session/types.py new file mode 100644 index 00000000000..8c5eb75fe14 --- /dev/null +++ b/src/ai/backend/manager/data/login_session/types.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import enum +from dataclasses import dataclass, field +from datetime import datetime +from uuid import UUID + + +class LoginSessionExpiryReason(enum.StrEnum): + LOGOUT = "logout" + EVICTED = "evicted" + EXPIRED = "expired" + + +@dataclass(frozen=True) +class LoginSessionData: + id: UUID + user_uuid: UUID + session_token: str + client_ip: str + created_at: datetime + expired_at: datetime | None = field(default=None) + reason: LoginSessionExpiryReason | None = field(default=None) diff --git a/src/ai/backend/manager/data/user/types.py b/src/ai/backend/manager/data/user/types.py index af262ec9c18..da69ac8563d 100644 --- a/src/ai/backend/manager/data/user/types.py +++ b/src/ai/backend/manager/data/user/types.py @@ -87,6 +87,7 @@ class UserData: container_uid: int | None = field(compare=False) container_main_gid: int | None = field(compare=False) container_gids: list[int] | None = field(compare=False) + login_security_policy: dict[str, Any] | None = field(default=None, compare=False) def scope_id(self) -> ScopeId: return ScopeId( @@ -134,6 +135,7 @@ def from_row(cls, row: Row[Any]) -> Self: container_uid=row.container_uid, container_main_gid=row.container_main_gid, container_gids=row.container_gids, + login_security_policy=None, ) diff --git a/src/ai/backend/manager/models/alembic/versions/ba49050abc12_add_login_sessions_table_and_login_security_policy.py b/src/ai/backend/manager/models/alembic/versions/ba49050abc12_add_login_sessions_table_and_login_security_policy.py new file mode 100644 index 00000000000..10c0fa03bfe --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/ba49050abc12_add_login_sessions_table_and_login_security_policy.py @@ -0,0 +1,69 @@ +"""Add login_sessions table and login_security_policy column to users + +Revision ID: ba49050abc12 +Revises: ffcf0ed13a26 +Create Date: 2026-03-06 00:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +from ai.backend.manager.models.base import GUID + +# revision identifiers, used by Alembic. +revision = "ba49050abc12" +down_revision = "ffcf0ed13a26" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "login_sessions", + sa.Column( + "id", + GUID(), + server_default=sa.text("uuid_generate_v4()"), + nullable=False, + ), + sa.Column("user_uuid", GUID(), nullable=False), + sa.Column("session_token", sa.String(length=512), nullable=False), + sa.Column("client_ip", sa.String(length=64), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column("expired_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("reason", sa.String(length=64), nullable=True), + sa.ForeignKeyConstraint( + ["user_uuid"], + ["users.uuid"], + name=op.f("fk_login_sessions_user_uuid_users"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_login_sessions")), + sa.UniqueConstraint("session_token", name=op.f("uq_login_sessions_session_token")), + ) + op.create_index( + op.f("ix_login_sessions_user_uuid"), + "login_sessions", + ["user_uuid"], + unique=False, + ) + op.add_column( + "users", + sa.Column( + "login_security_policy", + sa.dialects.postgresql.JSONB(none_as_null=True), + nullable=True, + ), + ) + + +def downgrade() -> None: + op.drop_column("users", "login_security_policy") + op.drop_index(op.f("ix_login_sessions_user_uuid"), table_name="login_sessions") + op.drop_table("login_sessions") diff --git a/src/ai/backend/manager/models/login_session/__init__.py b/src/ai/backend/manager/models/login_session/__init__.py new file mode 100644 index 00000000000..ee89dd1346c --- /dev/null +++ b/src/ai/backend/manager/models/login_session/__init__.py @@ -0,0 +1,3 @@ +from ai.backend.manager.models.login_session.row import LoginSessionRow + +__all__ = ("LoginSessionRow",) diff --git a/src/ai/backend/manager/models/login_session/row.py b/src/ai/backend/manager/models/login_session/row.py new file mode 100644 index 00000000000..ae195decc87 --- /dev/null +++ b/src/ai/backend/manager/models/login_session/row.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import uuid +from collections.abc import Sequence +from datetime import datetime + +import sqlalchemy as sa +from sqlalchemy.orm import Mapped, mapped_column + +from ai.backend.manager.data.login_session.types import LoginSessionData, LoginSessionExpiryReason +from ai.backend.manager.models.base import ( + GUID, + Base, + StrEnumType, +) + +__all__: Sequence[str] = ("LoginSessionRow",) + + +class LoginSessionRow(Base): # type: ignore[misc] + __tablename__ = "login_sessions" + __table_args__ = ( + sa.UniqueConstraint("session_token", name="uq_login_sessions_session_token"), + sa.Index("ix_login_sessions_user_uuid", "user_uuid"), + ) + + id: Mapped[uuid.UUID] = mapped_column( + "id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()") + ) + user_uuid: Mapped[uuid.UUID] = mapped_column( + "user_uuid", + GUID, + sa.ForeignKey("users.uuid", ondelete="CASCADE"), + nullable=False, + ) + session_token: Mapped[str] = mapped_column( + "session_token", + sa.String(length=512), + nullable=False, + ) + client_ip: Mapped[str] = mapped_column( + "client_ip", + sa.String(length=64), + nullable=False, + ) + created_at: Mapped[datetime] = mapped_column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ) + expired_at: Mapped[datetime | None] = mapped_column( + "expired_at", + sa.DateTime(timezone=True), + nullable=True, + ) + reason: Mapped[LoginSessionExpiryReason | None] = mapped_column( + "reason", + StrEnumType(LoginSessionExpiryReason), + nullable=True, + ) + + def to_dataclass(self) -> LoginSessionData: + return LoginSessionData( + id=self.id, + user_uuid=self.user_uuid, + session_token=self.session_token, + client_ip=self.client_ip, + created_at=self.created_at, + expired_at=self.expired_at, + reason=self.reason, + ) diff --git a/src/ai/backend/manager/models/login_session/types.py b/src/ai/backend/manager/models/login_session/types.py new file mode 100644 index 00000000000..ba03a68e1a1 --- /dev/null +++ b/src/ai/backend/manager/models/login_session/types.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict + +__all__ = ("LoginSecurityPolicy",) + + +class LoginSecurityPolicy(BaseModel): + """Login security policy for controlling concurrent session limits. + + Stored as JSONB in the users table via PydanticColumn. + """ + + model_config = ConfigDict(frozen=True) + + max_concurrent_logins: int | None = None + """Maximum number of concurrent login sessions allowed. + + None means unlimited. + """ diff --git a/src/ai/backend/manager/models/user/row.py b/src/ai/backend/manager/models/user/row.py index f20fb0ba1f9..0eb9c358bd3 100644 --- a/src/ai/backend/manager/models/user/row.py +++ b/src/ai/backend/manager/models/user/row.py @@ -30,9 +30,11 @@ Base, EnumValueType, IPColumn, + PydanticColumn, ) from ai.backend.manager.models.hasher import PasswordHasherFactory from ai.backend.manager.models.hasher.types import HashInfo, PasswordColumn, PasswordInfo +from ai.backend.manager.models.login_session.types import LoginSecurityPolicy from ai.backend.manager.models.types import ( QueryCondition, QueryOption, @@ -237,6 +239,9 @@ class UserRow(Base): # type: ignore[misc] container_gids: Mapped[list[int] | None] = mapped_column( "container_gids", sa.ARRAY(sa.Integer), nullable=True, server_default=sa.null() ) + login_security_policy: Mapped[LoginSecurityPolicy | None] = mapped_column( + "login_security_policy", PydanticColumn(LoginSecurityPolicy), nullable=True + ) # Relationships sessions: Mapped[list[SessionRow]] = relationship( @@ -431,6 +436,9 @@ def to_data(self) -> UserData: container_uid=self.container_uid, container_main_gid=self.container_main_gid, container_gids=self.container_gids, + login_security_policy=self.login_security_policy.model_dump() + if self.login_security_policy is not None + else None, ) diff --git a/src/ai/backend/manager/repositories/login_session/__init__.py b/src/ai/backend/manager/repositories/login_session/__init__.py new file mode 100644 index 00000000000..9ddbae8da93 --- /dev/null +++ b/src/ai/backend/manager/repositories/login_session/__init__.py @@ -0,0 +1,5 @@ +"""Login session repository module.""" + +from .repository import LoginSessionRepository + +__all__ = ["LoginSessionRepository"] diff --git a/src/ai/backend/manager/repositories/login_session/cache_source/__init__.py b/src/ai/backend/manager/repositories/login_session/cache_source/__init__.py new file mode 100644 index 00000000000..0ccec55e713 --- /dev/null +++ b/src/ai/backend/manager/repositories/login_session/cache_source/__init__.py @@ -0,0 +1 @@ +"""Cache source for login session repository.""" diff --git a/src/ai/backend/manager/repositories/login_session/cache_source/cache_source.py b/src/ai/backend/manager/repositories/login_session/cache_source/cache_source.py new file mode 100644 index 00000000000..8d0ea21951b --- /dev/null +++ b/src/ai/backend/manager/repositories/login_session/cache_source/cache_source.py @@ -0,0 +1,70 @@ +"""Cache source for login session repository operations using Valkey Sorted Sets.""" + +from __future__ import annotations + +import logging +from uuid import UUID + +from ai.backend.common.clients.valkey_client.valkey_stat.client import ValkeyStatClient +from ai.backend.logging.utils import BraceStyleAdapter + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + +_KEY_PREFIX = "login_session" + + +class LoginSessionCacheSource: + """ + Cache source for login session operations. + Uses Valkey Sorted Set keyed by `login_session:{user_uuid}`. + Score = UNIX timestamp of session creation. + Member = session_token. + """ + + _valkey_stat: ValkeyStatClient + + def __init__(self, valkey_stat: ValkeyStatClient) -> None: + self._valkey_stat = valkey_stat + + def _key(self, user_uuid: UUID) -> str: + return f"{_KEY_PREFIX}:{user_uuid}" + + async def add_session(self, user_uuid: UUID, session_token: str, score: float) -> None: + """Register a session in the sorted set (ZADD).""" + await self._valkey_stat.execute_command([ + "ZADD", + self._key(user_uuid), + str(score), + session_token, + ]) + + async def session_score(self, user_uuid: UUID, session_token: str) -> float | None: + """Check if session exists and return its score (ZSCORE). Returns None if not found.""" + result = await self._valkey_stat.execute_command([ + "ZSCORE", + self._key(user_uuid), + session_token, + ]) + if result is None: + return None + return float(result) + + async def count_sessions(self, user_uuid: UUID) -> int: + """Return number of active sessions for user (ZCARD).""" + result = await self._valkey_stat.execute_command(["ZCARD", self._key(user_uuid)]) + return int(result) if result is not None else 0 + + async def pop_oldest_session(self, user_uuid: UUID) -> str | None: + """Evict the oldest session (lowest score) and return its token (ZPOPMIN).""" + result = await self._valkey_stat.execute_command(["ZPOPMIN", self._key(user_uuid)]) + if not result: + return None + # ZPOPMIN returns [member, score] interleaved; first element is the member + member = result[0] + if isinstance(member, bytes): + return member.decode() + return str(member) + + async def remove_session(self, user_uuid: UUID, session_token: str) -> None: + """Remove a session from the sorted set (ZREM).""" + await self._valkey_stat.execute_command(["ZREM", self._key(user_uuid), session_token]) diff --git a/src/ai/backend/manager/repositories/login_session/db_source/__init__.py b/src/ai/backend/manager/repositories/login_session/db_source/__init__.py new file mode 100644 index 00000000000..2d1609d3dcd --- /dev/null +++ b/src/ai/backend/manager/repositories/login_session/db_source/__init__.py @@ -0,0 +1 @@ +"""Database source for login session repository.""" diff --git a/src/ai/backend/manager/repositories/login_session/db_source/db_source.py b/src/ai/backend/manager/repositories/login_session/db_source/db_source.py new file mode 100644 index 00000000000..f3ec327982a --- /dev/null +++ b/src/ai/backend/manager/repositories/login_session/db_source/db_source.py @@ -0,0 +1,102 @@ +"""Database source for login session repository operations.""" + +from __future__ import annotations + +import logging +from datetime import UTC, datetime +from typing import cast +from uuid import UUID + +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncSession as SASession + +from ai.backend.logging.utils import BraceStyleAdapter +from ai.backend.manager.data.login_session.types import LoginSessionData, LoginSessionExpiryReason +from ai.backend.manager.models.login_session.row import LoginSessionRow +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + + +class LoginSessionDBSource: + """Database source for login session CRUD operations.""" + + _db: ExtendedAsyncSAEngine + + def __init__(self, db: ExtendedAsyncSAEngine) -> None: + self._db = db + + async def create_session( + self, + user_uuid: UUID, + session_token: str, + client_ip: str, + ) -> LoginSessionData: + """Insert a new login session row.""" + async with self._db.begin_session() as session: + row = LoginSessionRow( + user_uuid=user_uuid, + session_token=session_token, + client_ip=client_ip, + ) + session.add(row) + await session.flush() + await session.refresh(row) + return row.to_dataclass() + + async def expire_session( + self, + session_token: str, + reason: LoginSessionExpiryReason, + ) -> LoginSessionData | None: + """Mark a session as expired by setting expired_at and reason.""" + async with self._db.begin_session() as db_sess: + row = await self._get_active_session(db_sess, session_token) + if row is None: + return None + row.expired_at = datetime.now(tz=UTC) + row.reason = reason + return row.to_dataclass() + + async def list_active_sessions(self, user_uuid: UUID) -> list[LoginSessionData]: + """List all active (non-expired) sessions for a user.""" + async with self._db.begin_readonly_session_read_committed() as db_sess: + query = sa.select(LoginSessionRow).where( + sa.and_( + LoginSessionRow.user_uuid == user_uuid, + LoginSessionRow.expired_at.is_(None), + ) + ) + rows = await db_sess.scalars(query) + return [row.to_dataclass() for row in rows] + + async def count_active_sessions(self, user_uuid: UUID) -> int: + """Count active (non-expired) sessions for a user.""" + async with self._db.begin_readonly_session_read_committed() as db_sess: + query = ( + sa.select(sa.func.count()) + .where( + sa.and_( + LoginSessionRow.user_uuid == user_uuid, + LoginSessionRow.expired_at.is_(None), + ) + ) + .select_from(LoginSessionRow) + ) + result = await db_sess.scalar(query) + return int(result) if result is not None else 0 + + async def _get_active_session( + self, db_sess: SASession, session_token: str + ) -> LoginSessionRow | None: + return cast( + LoginSessionRow | None, + await db_sess.scalar( + sa.select(LoginSessionRow).where( + sa.and_( + LoginSessionRow.session_token == session_token, + LoginSessionRow.expired_at.is_(None), + ) + ) + ), + ) diff --git a/src/ai/backend/manager/repositories/login_session/repository.py b/src/ai/backend/manager/repositories/login_session/repository.py new file mode 100644 index 00000000000..f789b40d600 --- /dev/null +++ b/src/ai/backend/manager/repositories/login_session/repository.py @@ -0,0 +1,125 @@ +"""LoginSession repository orchestrating DB and cache sources.""" + +from __future__ import annotations + +import logging +import time +from uuid import UUID + +from ai.backend.common.clients.valkey_client.valkey_stat.client import ValkeyStatClient +from ai.backend.common.exception import BackendAIError +from ai.backend.common.metrics.metric import DomainType, LayerType +from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy +from ai.backend.common.resilience.policies.retry import BackoffStrategy, RetryArgs, RetryPolicy +from ai.backend.common.resilience.resilience import Resilience +from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.data.login_session.types import LoginSessionData, LoginSessionExpiryReason +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine + +from .cache_source.cache_source import LoginSessionCacheSource +from .db_source.db_source import LoginSessionDBSource +from .utils import suppress_with_log + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + + +login_session_repository_resilience = Resilience( + policies=[ + MetricPolicy( + MetricArgs(domain=DomainType.REPOSITORY, layer=LayerType.LOGIN_SESSION_REPOSITORY) + ), + RetryPolicy( + RetryArgs( + max_retries=3, + retry_delay=0.1, + backoff_strategy=BackoffStrategy.FIXED, + non_retryable_exceptions=(BackendAIError,), + ) + ), + ] +) + + +class LoginSessionRepository: + """Repository that orchestrates between DB and cache sources for login session operations.""" + + _db_source: LoginSessionDBSource + _cache_source: LoginSessionCacheSource + + def __init__( + self, + db: ExtendedAsyncSAEngine, + valkey_stat: ValkeyStatClient, + ) -> None: + self._db_source = LoginSessionDBSource(db) + self._cache_source = LoginSessionCacheSource(valkey_stat) + + @login_session_repository_resilience.apply() + async def create_session( + self, + user_uuid: UUID, + session_token: str, + client_ip: str, + ) -> LoginSessionData: + """Create a new login session. Writes to DB first, then updates cache.""" + data = await self._db_source.create_session(user_uuid, session_token, client_ip) + score = data.created_at.timestamp() if data.created_at else time.time() + with suppress_with_log( + [Exception], message="Failed to add session to cache after creation" + ): + await self._cache_source.add_session(user_uuid, session_token, score) + return data + + @login_session_repository_resilience.apply() + async def expire_session( + self, + user_uuid: UUID, + session_token: str, + reason: LoginSessionExpiryReason, + ) -> LoginSessionData | None: + """Expire a session. Writes to DB first, then removes from cache.""" + data = await self._db_source.expire_session(session_token, reason) + with suppress_with_log( + [Exception], message="Failed to remove session from cache after expiry" + ): + await self._cache_source.remove_session(user_uuid, session_token) + return data + + @login_session_repository_resilience.apply() + async def evict_oldest_session(self, user_uuid: UUID) -> str | None: + """ + Evict the oldest session for a user. + Pops from cache first; if cache hit, expires in DB. Falls back to DB-only list. + Returns the evicted session_token or None if no active sessions. + """ + # Try cache-first + with suppress_with_log([Exception], message="Failed to pop oldest session from cache"): + token = await self._cache_source.pop_oldest_session(user_uuid) + if token is not None: + await self._db_source.expire_session(token, LoginSessionExpiryReason.EVICTED) + return token + + # Fallback: list from DB and expire the oldest + sessions = await self._db_source.list_active_sessions(user_uuid) + if not sessions: + return None + oldest = min(sessions, key=lambda s: s.created_at) + await self._db_source.expire_session(oldest.session_token, LoginSessionExpiryReason.EVICTED) + return oldest.session_token + + @login_session_repository_resilience.apply() + async def count_active_sessions(self, user_uuid: UUID) -> int: + """ + Count active sessions for a user. + Cache-first with fallback to DB. + """ + try: + return await self._cache_source.count_sessions(user_uuid) + except Exception as e: + log.warning("Failed to count sessions from cache: {}", e) + return await self._db_source.count_active_sessions(user_uuid) + + @login_session_repository_resilience.apply() + async def list_active_sessions(self, user_uuid: UUID) -> list[LoginSessionData]: + """List active sessions for a user from DB.""" + return await self._db_source.list_active_sessions(user_uuid) diff --git a/src/ai/backend/manager/repositories/login_session/utils.py b/src/ai/backend/manager/repositories/login_session/utils.py new file mode 100644 index 00000000000..fe3275cf32f --- /dev/null +++ b/src/ai/backend/manager/repositories/login_session/utils.py @@ -0,0 +1,29 @@ +"""Utility functions for login session repository.""" + +from __future__ import annotations + +import logging +from collections.abc import Generator +from contextlib import contextmanager + +from ai.backend.logging.utils import BraceStyleAdapter + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + + +@contextmanager +def suppress_with_log( + exceptions: list[type[BaseException]], + message: str | None = None, + log_level: int = logging.WARNING, +) -> Generator[None, None, None]: + """ + Context manager that suppresses specified exceptions and logs them. + """ + try: + yield + except tuple(exceptions) as e: + if message: + log.log(log_level, "{}: {}", message, e) + else: + log.log(log_level, "Suppressed exception: {}", e) diff --git a/src/ai/backend/manager/services/login_session/__init__.py b/src/ai/backend/manager/services/login_session/__init__.py new file mode 100644 index 00000000000..85233b610ff --- /dev/null +++ b/src/ai/backend/manager/services/login_session/__init__.py @@ -0,0 +1,5 @@ +"""Login session service module.""" + +from .service import LoginSessionService + +__all__ = ["LoginSessionService"] diff --git a/src/ai/backend/manager/services/login_session/actions/__init__.py b/src/ai/backend/manager/services/login_session/actions/__init__.py new file mode 100644 index 00000000000..2e5972ba623 --- /dev/null +++ b/src/ai/backend/manager/services/login_session/actions/__init__.py @@ -0,0 +1 @@ +"""Login session service actions.""" diff --git a/src/ai/backend/manager/services/login_session/actions/base.py b/src/ai/backend/manager/services/login_session/actions/base.py new file mode 100644 index 00000000000..eab9c2ca5f9 --- /dev/null +++ b/src/ai/backend/manager/services/login_session/actions/base.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from typing import override + +from ai.backend.common.data.permission.types import EntityType +from ai.backend.manager.actions.action import BaseAction + + +@dataclass +class LoginSessionAction(BaseAction): + @override + @classmethod + def entity_type(cls) -> EntityType: + return EntityType.LOGIN_SESSION diff --git a/src/ai/backend/manager/services/login_session/actions/check_concurrency_limit.py b/src/ai/backend/manager/services/login_session/actions/check_concurrency_limit.py new file mode 100644 index 00000000000..394af8af630 --- /dev/null +++ b/src/ai/backend/manager/services/login_session/actions/check_concurrency_limit.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass +from typing import override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.services.login_session.actions.base import LoginSessionAction + + +@dataclass +class CheckConcurrencyLimitAction(LoginSessionAction): + user_uuid: UUID + max_concurrent_logins: int | None + + @override + @classmethod + def operation_type(cls) -> ActionOperationType: + return ActionOperationType.GET + + +@dataclass +class CheckConcurrencyLimitActionResult(BaseActionResult): + active_sessions: int + limit_exceeded: bool + + @override + def entity_id(self) -> str | None: + return None diff --git a/src/ai/backend/manager/services/login_session/actions/create_session.py b/src/ai/backend/manager/services/login_session/actions/create_session.py new file mode 100644 index 00000000000..82f934163de --- /dev/null +++ b/src/ai/backend/manager/services/login_session/actions/create_session.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass +from typing import override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.login_session.types import LoginSessionData +from ai.backend.manager.services.login_session.actions.base import LoginSessionAction + + +@dataclass +class CreateLoginSessionAction(LoginSessionAction): + user_uuid: UUID + session_token: str + client_ip: str + + @override + @classmethod + def operation_type(cls) -> ActionOperationType: + return ActionOperationType.CREATE + + +@dataclass +class CreateLoginSessionActionResult(BaseActionResult): + session: LoginSessionData + + @override + def entity_id(self) -> str | None: + return str(self.session.id) diff --git a/src/ai/backend/manager/services/login_session/actions/evict_oldest_session.py b/src/ai/backend/manager/services/login_session/actions/evict_oldest_session.py new file mode 100644 index 00000000000..6bc1f030852 --- /dev/null +++ b/src/ai/backend/manager/services/login_session/actions/evict_oldest_session.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from typing import override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.services.login_session.actions.base import LoginSessionAction + + +@dataclass +class EvictOldestSessionAction(LoginSessionAction): + user_uuid: UUID + + @override + @classmethod + def operation_type(cls) -> ActionOperationType: + return ActionOperationType.DELETE + + +@dataclass +class EvictOldestSessionActionResult(BaseActionResult): + evicted_session_token: str | None + + @override + def entity_id(self) -> str | None: + return self.evicted_session_token diff --git a/src/ai/backend/manager/services/login_session/actions/expire_session.py b/src/ai/backend/manager/services/login_session/actions/expire_session.py new file mode 100644 index 00000000000..26dbb70f2eb --- /dev/null +++ b/src/ai/backend/manager/services/login_session/actions/expire_session.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass +from typing import override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.login_session.types import LoginSessionData, LoginSessionExpiryReason +from ai.backend.manager.services.login_session.actions.base import LoginSessionAction + + +@dataclass +class ExpireLoginSessionAction(LoginSessionAction): + user_uuid: UUID + session_token: str + reason: LoginSessionExpiryReason + + @override + @classmethod + def operation_type(cls) -> ActionOperationType: + return ActionOperationType.UPDATE + + +@dataclass +class ExpireLoginSessionActionResult(BaseActionResult): + session: LoginSessionData | None + + @override + def entity_id(self) -> str | None: + return str(self.session.id) if self.session else None diff --git a/src/ai/backend/manager/services/login_session/service.py b/src/ai/backend/manager/services/login_session/service.py new file mode 100644 index 00000000000..901d9fd3d1f --- /dev/null +++ b/src/ai/backend/manager/services/login_session/service.py @@ -0,0 +1,72 @@ +"""LoginSession service.""" + +from __future__ import annotations + +import logging + +from ai.backend.logging.utils import BraceStyleAdapter +from ai.backend.manager.repositories.login_session import LoginSessionRepository +from ai.backend.manager.services.login_session.actions.check_concurrency_limit import ( + CheckConcurrencyLimitAction, + CheckConcurrencyLimitActionResult, +) +from ai.backend.manager.services.login_session.actions.create_session import ( + CreateLoginSessionAction, + CreateLoginSessionActionResult, +) +from ai.backend.manager.services.login_session.actions.evict_oldest_session import ( + EvictOldestSessionAction, + EvictOldestSessionActionResult, +) +from ai.backend.manager.services.login_session.actions.expire_session import ( + ExpireLoginSessionAction, + ExpireLoginSessionActionResult, +) + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + + +class LoginSessionService: + _repository: LoginSessionRepository + + def __init__(self, repository: LoginSessionRepository) -> None: + self._repository = repository + + async def create_session( + self, action: CreateLoginSessionAction + ) -> CreateLoginSessionActionResult: + session = await self._repository.create_session( + user_uuid=action.user_uuid, + session_token=action.session_token, + client_ip=action.client_ip, + ) + return CreateLoginSessionActionResult(session=session) + + async def expire_session( + self, action: ExpireLoginSessionAction + ) -> ExpireLoginSessionActionResult: + session = await self._repository.expire_session( + user_uuid=action.user_uuid, + session_token=action.session_token, + reason=action.reason, + ) + return ExpireLoginSessionActionResult(session=session) + + async def evict_oldest_session( + self, action: EvictOldestSessionAction + ) -> EvictOldestSessionActionResult: + evicted_token = await self._repository.evict_oldest_session(action.user_uuid) + return EvictOldestSessionActionResult(evicted_session_token=evicted_token) + + async def check_concurrency_limit( + self, action: CheckConcurrencyLimitAction + ) -> CheckConcurrencyLimitActionResult: + active_sessions = await self._repository.count_active_sessions(action.user_uuid) + limit_exceeded = ( + action.max_concurrent_logins is not None + and active_sessions >= action.max_concurrent_logins + ) + return CheckConcurrencyLimitActionResult( + active_sessions=active_sessions, + limit_exceeded=limit_exceeded, + )