Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions DEVIATION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Deviation Report: BA-4905

| Item | Type | Reason / Alternative |
|------|------|----------------------|
| Task 1: `LoginSecurityPolicy(BaseModel)` placed in `data/login_session/types.py` | Alternative applied | `data/` CLAUDE.md prohibits Pydantic imports. Pydantic models used with PydanticColumn follow the `models/{domain}/types.py` pattern (see `models/scaling_group/types.py`, `models/resource_slot/types.py`). `LoginSecurityPolicy` placed in `models/login_session/types.py` instead. |
1 change: 1 addition & 0 deletions changes/9720.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add LoginSecurityPolicy model, login_sessions table, LoginSessionRepository (Valkey Sorted Set + DB), and LoginSessionService for concurrent login session management.
1 change: 1 addition & 0 deletions src/ai/backend/common/data/permission/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class EntityType(enum.StrEnum):
RESOURCE_ALLOCATION = "resource_allocation"
RESOURCE_GROUP = "resource_group"
PROMETHEUS_QUERY_PRESET = "prometheus_query_preset"
LOGIN_SESSION = "login_session"
RESOURCE_PRESET = "resource_preset"
ROLE = "role"
DOTFILE = "dotfile"
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/common/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ class LayerType(enum.StrEnum):
PROMETHEUS_QUERY_PRESET_REPOSITORY = "prometheus_query_preset_repository"
PROJECT_RESOURCE_POLICY_REPOSITORY = "project_resource_policy_repository"
RESERVOIR_REGISTRY_REPOSITORY = "reservoir_registry_repository"
LOGIN_SESSION_REPOSITORY = "login_session_repository"
RESOURCE_PRESET_REPOSITORY = "resource_preset_repository"
SCALING_GROUP_REPOSITORY = "scaling_group_repository"
SCHEDULE_REPOSITORY = "schedule_repository"
Expand Down
Empty file.
23 changes: 23 additions & 0 deletions src/ai/backend/manager/data/login_session/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

import enum
from dataclasses import dataclass, field
from datetime import datetime
from uuid import UUID


class LoginSessionExpiryReason(enum.StrEnum):
LOGOUT = "logout"
EVICTED = "evicted"
EXPIRED = "expired"


@dataclass(frozen=True)
class LoginSessionData:
id: UUID
user_uuid: UUID
session_token: str
client_ip: str
created_at: datetime
expired_at: datetime | None = field(default=None)
reason: LoginSessionExpiryReason | None = field(default=None)
2 changes: 2 additions & 0 deletions src/ai/backend/manager/data/user/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class UserData:
container_uid: int | None = field(compare=False)
container_main_gid: int | None = field(compare=False)
container_gids: list[int] | None = field(compare=False)
login_security_policy: dict[str, Any] | None = field(default=None, compare=False)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using dict. We recommend handling data through PydanticModel or dataclass.


def scope_id(self) -> ScopeId:
return ScopeId(
Expand Down Expand Up @@ -134,6 +135,7 @@ def from_row(cls, row: Row[Any]) -> Self:
container_uid=row.container_uid,
container_main_gid=row.container_main_gid,
container_gids=row.container_gids,
login_security_policy=None,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Add login_sessions table and login_security_policy column to users

Revision ID: ba49050abc12
Revises: ffcf0ed13a26
Create Date: 2026-03-06 00:00:00.000000

"""

import sqlalchemy as sa
from alembic import op

from ai.backend.manager.models.base import GUID

# revision identifiers, used by Alembic.
revision = "ba49050abc12"
down_revision = "ffcf0ed13a26"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"login_sessions",
sa.Column(
"id",
GUID(),
server_default=sa.text("uuid_generate_v4()"),
nullable=False,
),
sa.Column("user_uuid", GUID(), nullable=False),
sa.Column("session_token", sa.String(length=512), nullable=False),
sa.Column("client_ip", sa.String(length=64), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("expired_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("reason", sa.String(length=64), nullable=True),
sa.ForeignKeyConstraint(
["user_uuid"],
["users.uuid"],
name=op.f("fk_login_sessions_user_uuid_users"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_login_sessions")),
sa.UniqueConstraint("session_token", name=op.f("uq_login_sessions_session_token")),
)
op.create_index(
op.f("ix_login_sessions_user_uuid"),
"login_sessions",
["user_uuid"],
unique=False,
)
op.add_column(
"users",
sa.Column(
"login_security_policy",
sa.dialects.postgresql.JSONB(none_as_null=True),
nullable=True,
),
)


def downgrade() -> None:
op.drop_column("users", "login_security_policy")
op.drop_index(op.f("ix_login_sessions_user_uuid"), table_name="login_sessions")
op.drop_table("login_sessions")
3 changes: 3 additions & 0 deletions src/ai/backend/manager/models/login_session/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ai.backend.manager.models.login_session.row import LoginSessionRow

__all__ = ("LoginSessionRow",)
72 changes: 72 additions & 0 deletions src/ai/backend/manager/models/login_session/row.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

import uuid
from collections.abc import Sequence
from datetime import datetime

import sqlalchemy as sa
from sqlalchemy.orm import Mapped, mapped_column

from ai.backend.manager.data.login_session.types import LoginSessionData, LoginSessionExpiryReason
from ai.backend.manager.models.base import (
GUID,
Base,
StrEnumType,
)

__all__: Sequence[str] = ("LoginSessionRow",)


class LoginSessionRow(Base): # type: ignore[misc]
__tablename__ = "login_sessions"
__table_args__ = (
sa.UniqueConstraint("session_token", name="uq_login_sessions_session_token"),
sa.Index("ix_login_sessions_user_uuid", "user_uuid"),
)

id: Mapped[uuid.UUID] = mapped_column(
"id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")
)
user_uuid: Mapped[uuid.UUID] = mapped_column(
"user_uuid",
GUID,
sa.ForeignKey("users.uuid", ondelete="CASCADE"),
nullable=False,
)
session_token: Mapped[str] = mapped_column(
"session_token",
sa.String(length=512),
nullable=False,
)
client_ip: Mapped[str] = mapped_column(
"client_ip",
sa.String(length=64),
nullable=False,
)
created_at: Mapped[datetime] = mapped_column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
)
expired_at: Mapped[datetime | None] = mapped_column(
"expired_at",
sa.DateTime(timezone=True),
nullable=True,
)
reason: Mapped[LoginSessionExpiryReason | None] = mapped_column(
"reason",
StrEnumType(LoginSessionExpiryReason),
nullable=True,
)

def to_dataclass(self) -> LoginSessionData:
return LoginSessionData(
id=self.id,
user_uuid=self.user_uuid,
session_token=self.session_token,
client_ip=self.client_ip,
created_at=self.created_at,
expired_at=self.expired_at,
reason=self.reason,
)
20 changes: 20 additions & 0 deletions src/ai/backend/manager/models/login_session/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from pydantic import BaseModel, ConfigDict

__all__ = ("LoginSecurityPolicy",)


class LoginSecurityPolicy(BaseModel):
"""Login security policy for controlling concurrent session limits.

Stored as JSONB in the users table via PydanticColumn.
"""

model_config = ConfigDict(frozen=True)

max_concurrent_logins: int | None = None
"""Maximum number of concurrent login sessions allowed.

None means unlimited.
"""
8 changes: 8 additions & 0 deletions src/ai/backend/manager/models/user/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
Base,
EnumValueType,
IPColumn,
PydanticColumn,
)
from ai.backend.manager.models.hasher import PasswordHasherFactory
from ai.backend.manager.models.hasher.types import HashInfo, PasswordColumn, PasswordInfo
from ai.backend.manager.models.login_session.types import LoginSecurityPolicy
from ai.backend.manager.models.types import (
QueryCondition,
QueryOption,
Expand Down Expand Up @@ -237,6 +239,9 @@ class UserRow(Base): # type: ignore[misc]
container_gids: Mapped[list[int] | None] = mapped_column(
"container_gids", sa.ARRAY(sa.Integer), nullable=True, server_default=sa.null()
)
login_security_policy: Mapped[LoginSecurityPolicy | None] = mapped_column(
"login_security_policy", PydanticColumn(LoginSecurityPolicy), nullable=True
)
Comment on lines +242 to +244
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the default value rather than leaving it optional.


# Relationships
sessions: Mapped[list[SessionRow]] = relationship(
Expand Down Expand Up @@ -431,6 +436,9 @@ def to_data(self) -> UserData:
container_uid=self.container_uid,
container_main_gid=self.container_main_gid,
container_gids=self.container_gids,
login_security_policy=self.login_security_policy.model_dump()
if self.login_security_policy is not None
else None,
)


Expand Down
5 changes: 5 additions & 0 deletions src/ai/backend/manager/repositories/login_session/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Login session repository module."""

from .repository import LoginSessionRepository

__all__ = ["LoginSessionRepository"]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Cache source for login session repository."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Cache source for login session repository operations using Valkey Sorted Sets."""

from __future__ import annotations

import logging
from uuid import UUID

from ai.backend.common.clients.valkey_client.valkey_stat.client import ValkeyStatClient
from ai.backend.logging.utils import BraceStyleAdapter

log = BraceStyleAdapter(logging.getLogger(__spec__.name))

_KEY_PREFIX = "login_session"


class LoginSessionCacheSource:
"""
Cache source for login session operations.
Uses Valkey Sorted Set keyed by `login_session:{user_uuid}`.
Score = UNIX timestamp of session creation.
Member = session_token.
"""

_valkey_stat: ValkeyStatClient

def __init__(self, valkey_stat: ValkeyStatClient) -> None:
self._valkey_stat = valkey_stat

def _key(self, user_uuid: UUID) -> str:
return f"{_KEY_PREFIX}:{user_uuid}"

async def add_session(self, user_uuid: UUID, session_token: str, score: float) -> None:
"""Register a session in the sorted set (ZADD)."""
await self._valkey_stat.execute_command([
"ZADD",
self._key(user_uuid),
str(score),
session_token,
])

async def session_score(self, user_uuid: UUID, session_token: str) -> float | None:
"""Check if session exists and return its score (ZSCORE). Returns None if not found."""
result = await self._valkey_stat.execute_command([
"ZSCORE",
self._key(user_uuid),
session_token,
])
if result is None:
return None
return float(result)

async def count_sessions(self, user_uuid: UUID) -> int:
"""Return number of active sessions for user (ZCARD)."""
result = await self._valkey_stat.execute_command(["ZCARD", self._key(user_uuid)])
return int(result) if result is not None else 0

async def pop_oldest_session(self, user_uuid: UUID) -> str | None:
"""Evict the oldest session (lowest score) and return its token (ZPOPMIN)."""
result = await self._valkey_stat.execute_command(["ZPOPMIN", self._key(user_uuid)])
if not result:
return None
# ZPOPMIN returns [member, score] interleaved; first element is the member
member = result[0]
if isinstance(member, bytes):
return member.decode()
return str(member)

async def remove_session(self, user_uuid: UUID, session_token: str) -> None:
"""Remove a session from the sorted set (ZREM)."""
await self._valkey_stat.execute_command(["ZREM", self._key(user_uuid), session_token])
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Database source for login session repository."""
Loading
Loading