Skip to content
Draft
1 change: 1 addition & 0 deletions changes/9624.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add RBAC validator infrastructure to Session actions following BEP-1048 patterns
8 changes: 7 additions & 1 deletion src/ai/backend/manager/api/gql/data_loader/kernel/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from collections.abc import Sequence

from ai.backend.common.contexts.user import current_user
from ai.backend.common.types import KernelId
from ai.backend.manager.data.kernel.types import KernelInfo
from ai.backend.manager.errors.user import UserNotFound
from ai.backend.manager.repositories.base import BatchQuerier, NoPagination
from ai.backend.manager.repositories.scheduler.options import KernelConditions
from ai.backend.manager.services.session.actions.search_kernel import SearchKernelsAction
Expand All @@ -26,13 +28,17 @@ async def load_kernels_by_ids(
if not kernel_ids:
return []

user = current_user()
if user is None:
raise UserNotFound("User not found in context")

querier = BatchQuerier(
pagination=NoPagination(),
conditions=[KernelConditions.by_ids(kernel_ids)],
)

action_result = await processor.search_kernels.wait_for_complete(
SearchKernelsAction(querier=querier)
SearchKernelsAction(querier=querier, user_id=user.user_id)
)

kernel_map: dict[KernelId, KernelInfo] = {kernel.id: kernel for kernel in action_result.data}
Expand Down
8 changes: 7 additions & 1 deletion src/ai/backend/manager/api/gql/data_loader/session/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from collections.abc import Sequence

from ai.backend.common.contexts.user import current_user
from ai.backend.common.types import SessionId
from ai.backend.manager.data.session.types import SessionData
from ai.backend.manager.errors.user import UserNotFound
from ai.backend.manager.repositories.base import BatchQuerier, NoPagination
from ai.backend.manager.repositories.scheduler.options import SessionConditions
from ai.backend.manager.services.session.actions.search import SearchSessionsAction
Expand All @@ -17,13 +19,17 @@ async def load_sessions_by_ids(
if not session_ids:
return []

user = current_user()
if user is None:
raise UserNotFound("User not found in context")

querier = BatchQuerier(
pagination=NoPagination(),
conditions=[SessionConditions.by_ids(session_ids)],
)

action_result = await processor.search_sessions.wait_for_complete(
SearchSessionsAction(querier=querier)
SearchSessionsAction(querier=querier, user_id=user.user_id)
)

session_map: dict[SessionId, SessionData] = {
Expand Down
8 changes: 7 additions & 1 deletion src/ai/backend/manager/api/gql/kernel/fetcher/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import strawberry
from strawberry import Info

from ai.backend.common.contexts.user import current_user
from ai.backend.common.types import KernelId
from ai.backend.manager.api.gql.adapter import PaginationOptions, PaginationSpec
from ai.backend.manager.api.gql.base import encode_cursor
Expand All @@ -16,6 +17,7 @@
KernelV2OrderByGQL,
)
from ai.backend.manager.api.gql.types import StrawberryGQLContext
from ai.backend.manager.errors.user import UserNotFound
from ai.backend.manager.models.kernel import KernelRow
from ai.backend.manager.repositories.base import QueryCondition
from ai.backend.manager.repositories.scheduler.options import KernelConditions
Expand Down Expand Up @@ -45,6 +47,10 @@ async def fetch_kernels(
offset: int | None = None,
base_conditions: list[QueryCondition] | None = None,
) -> KernelV2ConnectionGQL:
user = current_user()
if user is None:
raise UserNotFound("User not found in context")

querier = info.context.gql_adapter.build_querier(
PaginationOptions(
first=first,
Expand All @@ -61,7 +67,7 @@ async def fetch_kernels(
)

action_result = await info.context.processors.session.search_kernels.wait_for_complete(
SearchKernelsAction(querier=querier)
SearchKernelsAction(querier=querier, user_id=user.user_id)
)
nodes = [KernelV2GQL.from_kernel_info(kernel_info) for kernel_info in action_result.data]
edges = [KernelV2EdgeGQL(node=node, cursor=encode_cursor(node.id)) for node in nodes]
Expand Down
8 changes: 7 additions & 1 deletion src/ai/backend/manager/api/gql/session/fetcher/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import strawberry
from strawberry import Info

from ai.backend.common.contexts.user import current_user
from ai.backend.manager.api.gql.adapter import PaginationOptions, PaginationSpec
from ai.backend.manager.api.gql.base import encode_cursor
from ai.backend.manager.api.gql.session.types import (
Expand All @@ -15,6 +16,7 @@
SessionV2OrderByGQL,
)
from ai.backend.manager.api.gql.types import StrawberryGQLContext
from ai.backend.manager.errors.user import UserNotFound
from ai.backend.manager.models.session import SessionRow
from ai.backend.manager.repositories.base import QueryCondition
from ai.backend.manager.repositories.scheduler.options import SessionConditions, SessionOrders
Expand Down Expand Up @@ -44,6 +46,10 @@ async def fetch_sessions(
offset: int | None = None,
base_conditions: list[QueryCondition] | None = None,
) -> SessionV2ConnectionGQL:
user = current_user()
if user is None:
raise UserNotFound("User not found in context")

querier = info.context.gql_adapter.build_querier(
PaginationOptions(
first=first,
Expand All @@ -60,7 +66,7 @@ async def fetch_sessions(
)

action_result = await info.context.processors.session.search_sessions.wait_for_complete(
SearchSessionsAction(querier=querier)
SearchSessionsAction(querier=querier, user_id=user.user_id)
)

nodes = [SessionV2GQL.from_data(session_data) for session_data in action_result.data]
Expand Down
8 changes: 7 additions & 1 deletion src/ai/backend/manager/api/gql/session/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from strawberry import ID, Info
from strawberry.relay import Connection, Edge, Node, NodeID

from ai.backend.common.contexts.user import current_user
from ai.backend.common.types import SessionId
from ai.backend.manager.api.gql.base import OrderDirection, StringFilter, UUIDFilter, encode_cursor
from ai.backend.manager.api.gql.common.types import (
Expand Down Expand Up @@ -41,6 +42,7 @@
from ai.backend.manager.api.gql.types import GQLFilter, GQLOrderBy, StrawberryGQLContext
from ai.backend.manager.api.gql.user.types.node import UserV2GQL
from ai.backend.manager.data.session.types import SessionData, SessionStatus
from ai.backend.manager.errors.user import UserNotFound
from ai.backend.manager.repositories.base import (
BatchQuerier,
NoPagination,
Expand Down Expand Up @@ -437,13 +439,17 @@ async def images(self) -> ImageV2ConnectionGQL:
description="Added in 26.3.0. The kernels belonging to this session."
)
async def kernels(self, info: Info[StrawberryGQLContext]) -> KernelV2ConnectionGQL:
user = current_user()
if user is None:
raise UserNotFound("User not found in context")

session_id = SessionId(UUID(str(self.id)))
querier = BatchQuerier(
pagination=NoPagination(),
conditions=[KernelConditions.by_session_ids([session_id])],
)
action_result = await info.context.processors.session.search_kernels.wait_for_complete(
SearchKernelsAction(querier=querier)
SearchKernelsAction(querier=querier, user_id=user.user_id)
)
nodes = [KernelV2GQL.from_kernel_info(kernel) for kernel in action_result.data]
edges = [KernelV2EdgeGQL(node=node, cursor=encode_cursor(node.id)) for node in nodes]
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/manager/api/gql_legacy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ async def mutate_and_get_payload(

result = await graph_ctx.processors.session.modify_session.wait_for_complete(
ModifySessionAction(
session_id=session_id,
session_uuid=session_id,
updater=Updater(
spec=SessionUpdaterSpec(
name=OptionalState[str].from_graphql(name),
Expand Down
10 changes: 8 additions & 2 deletions src/ai/backend/manager/api/rest/compute_sessions/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Final

from ai.backend.common.api_handlers import APIResponse, BodyParam
from ai.backend.common.contexts.user import current_user
from ai.backend.common.dto.manager.compute_session import (
PaginationInfo,
SearchComputeSessionsRequest,
Expand All @@ -15,6 +16,7 @@
from ai.backend.common.types import SessionId
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.dto.context import UserContext
from ai.backend.manager.errors.user import UserNotFound
from ai.backend.manager.services.processors import Processors
from ai.backend.manager.services.session.actions.search import SearchSessionsAction
from ai.backend.manager.services.session.actions.search_kernel import SearchKernelsAction
Expand All @@ -39,10 +41,14 @@ async def search_sessions(
"""Search compute sessions with nested container data."""
log.info("SEARCH_SESSIONS (ak:{})", ctx.access_key)

user = current_user()
if user is None:
raise UserNotFound("User not found in context")

# Step 1: Search sessions
session_querier = self._adapter.build_session_querier(body.parsed)
session_result = await self._processors.session.search_sessions.wait_for_complete(
SearchSessionsAction(querier=session_querier)
SearchSessionsAction(querier=session_querier, user_id=user.user_id)
)

# Step 2: Fetch kernels for found sessions
Expand All @@ -51,7 +57,7 @@ async def search_sessions(
if session_ids:
kernel_querier = self._adapter.build_kernel_querier_for_sessions(session_ids)
kernel_result = await self._processors.session.search_kernels.wait_for_complete(
SearchKernelsAction(querier=kernel_querier)
SearchKernelsAction(querier=kernel_querier, user_id=user.user_id)
)
kernels_by_session = self._adapter.group_kernels_by_session(kernel_result.data)

Expand Down
7 changes: 7 additions & 0 deletions src/ai/backend/manager/api/rest/session/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pydantic import BaseModel

from ai.backend.common.api_handlers import APIResponse, BaseResponseModel, BodyParam, QueryParam
from ai.backend.common.contexts.user import current_user
from ai.backend.common.dto.manager.session.request import (
CommitSessionRequest,
CompleteRequest,
Expand Down Expand Up @@ -91,6 +92,7 @@
from ai.backend.manager.errors.api import InvalidAPIParameters
from ai.backend.manager.errors.auth import InsufficientPrivilege
from ai.backend.manager.errors.resource import NoCurrentTaskContext
from ai.backend.manager.errors.user import UserNotFound
from ai.backend.manager.models.user import UserRole
from ai.backend.manager.services.agent.actions.sync_agent_registry import (
SyncAgentRegistryAction,
Expand Down Expand Up @@ -491,6 +493,10 @@ async def match_sessions(
request = ctx.request
params = query.parsed
requester_access_key, owner_access_key = await get_access_key_scopes(request)
user = current_user()
if user is None:
raise UserNotFound("User not found in context")

log.info(
"MATCH_SESSIONS(ak:{0}/{1}, prefix:{2})",
requester_access_key,
Expand All @@ -501,6 +507,7 @@ async def match_sessions(
MatchSessionsAction(
id_or_name_prefix=params.id,
owner_access_key=owner_access_key,
user_id=user.user_id,
)
)
return APIResponse.build(HTTPStatus.OK, MatchSessionsResponse(matches=result.result))
Expand Down
6 changes: 5 additions & 1 deletion src/ai/backend/manager/services/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,11 @@ def create(cls, args: ProcessorArgs, action_monitors: list[ActionMonitor]) -> Se
vfolder_sharing_processors = VFolderSharingProcessors(
services.vfolder_sharing, action_monitors
)
session_processors = SessionProcessors(services.session, action_monitors)
session_processors = SessionProcessors(
services.session,
action_monitors,
args.service_args.repositories.permission_controller.repository,
)
keypair_resource_policy_processors = KeypairResourcePolicyProcessors(
services.keypair_resource_policy, action_monitors
)
Expand Down
27 changes: 25 additions & 2 deletions src/ai/backend/manager/services/session/actions/create_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,23 @@
from dataclasses import dataclass
from typing import Any, override

from ai.backend.common.data.permission.types import RBACElementType, ScopeType
from ai.backend.common.types import AccessKey, SessionTypes
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.permission.types import RBACElementRef
from ai.backend.manager.models.user import UserRole
from ai.backend.manager.services.session.base import SessionAction
from ai.backend.manager.services.session.base import SessionScopeAction


@dataclass
class CreateClusterAction(SessionAction):
class CreateClusterAction(SessionScopeAction):
"""Create a new cluster session.

RBAC validation checks if the user has CREATE permission in USER scope.
Scope is always USER scope with user_id.
"""

session_name: str
user_id: uuid.UUID
user_role: UserRole
Expand All @@ -37,6 +45,21 @@ def entity_id(self) -> str | None:
def operation_type(cls) -> ActionOperationType:
return ActionOperationType.CREATE

@override
def scope_type(self) -> ScopeType:
return ScopeType.USER

@override
def scope_id(self) -> str:
return str(self.user_id)

@override
def target_element(self) -> RBACElementRef:
return RBACElementRef(
element_type=RBACElementType.USER,
element_id=str(self.user_id),
)


@dataclass
class CreateClusterActionResult(BaseActionResult):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

import yarl

from ai.backend.common.data.permission.types import RBACElementType, ScopeType
from ai.backend.common.types import AccessKey, ClusterMode, SessionTypes
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.permission.types import RBACElementRef
from ai.backend.manager.models.user import UserRole
from ai.backend.manager.services.session.base import SessionAction
from ai.backend.manager.services.session.base import SessionScopeAction


# TODO: Idea: Refactor this type using pydantic and utilize as API model
Expand Down Expand Up @@ -40,7 +42,13 @@ class CreateFromParamsActionParams:


@dataclass
class CreateFromParamsAction(SessionAction):
class CreateFromParamsAction(SessionScopeAction):
"""Create a new session from parameters.

RBAC validation checks if the user has CREATE permission in USER scope.
Scope is always USER scope with user_id.
"""

params: CreateFromParamsActionParams
user_id: uuid.UUID
user_role: UserRole
Expand All @@ -57,6 +65,21 @@ def entity_id(self) -> str | None:
def operation_type(cls) -> ActionOperationType:
return ActionOperationType.CREATE

@override
def scope_type(self) -> ScopeType:
return ScopeType.USER

@override
def scope_id(self) -> str:
return str(self.user_id)

@override
def target_element(self) -> RBACElementRef:
return RBACElementRef(
element_type=RBACElementType.USER,
element_id=str(self.user_id),
)


@dataclass
class CreateFromParamsActionResult(BaseActionResult):
Expand Down
Loading
Loading