diff --git a/changes/9624.feature.md b/changes/9624.feature.md new file mode 100644 index 00000000000..10a261771f3 --- /dev/null +++ b/changes/9624.feature.md @@ -0,0 +1 @@ +Add RBAC validator infrastructure to Session actions following BEP-1048 patterns diff --git a/src/ai/backend/manager/api/gql/data_loader/kernel/loader.py b/src/ai/backend/manager/api/gql/data_loader/kernel/loader.py index f2adf405c54..5ea816895fb 100644 --- a/src/ai/backend/manager/api/gql/data_loader/kernel/loader.py +++ b/src/ai/backend/manager/api/gql/data_loader/kernel/loader.py @@ -2,8 +2,10 @@ from collections.abc import Sequence +from ai.backend.common.contexts.user import current_user from ai.backend.common.types import KernelId from ai.backend.manager.data.kernel.types import KernelInfo +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.repositories.base import BatchQuerier, NoPagination from ai.backend.manager.repositories.scheduler.options import KernelConditions from ai.backend.manager.services.session.actions.search_kernel import SearchKernelsAction @@ -26,13 +28,17 @@ async def load_kernels_by_ids( if not kernel_ids: return [] + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + querier = BatchQuerier( pagination=NoPagination(), conditions=[KernelConditions.by_ids(kernel_ids)], ) action_result = await processor.search_kernels.wait_for_complete( - SearchKernelsAction(querier=querier) + SearchKernelsAction(querier=querier, user_id=user.user_id) ) kernel_map: dict[KernelId, KernelInfo] = {kernel.id: kernel for kernel in action_result.data} diff --git a/src/ai/backend/manager/api/gql/data_loader/session/loader.py b/src/ai/backend/manager/api/gql/data_loader/session/loader.py index 59e0f2a176e..8cd5708b59f 100644 --- a/src/ai/backend/manager/api/gql/data_loader/session/loader.py +++ b/src/ai/backend/manager/api/gql/data_loader/session/loader.py @@ -2,8 +2,10 @@ from collections.abc import Sequence +from ai.backend.common.contexts.user import current_user from ai.backend.common.types import SessionId from ai.backend.manager.data.session.types import SessionData +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.repositories.base import BatchQuerier, NoPagination from ai.backend.manager.repositories.scheduler.options import SessionConditions from ai.backend.manager.services.session.actions.search import SearchSessionsAction @@ -17,13 +19,17 @@ async def load_sessions_by_ids( if not session_ids: return [] + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + querier = BatchQuerier( pagination=NoPagination(), conditions=[SessionConditions.by_ids(session_ids)], ) action_result = await processor.search_sessions.wait_for_complete( - SearchSessionsAction(querier=querier) + SearchSessionsAction(querier=querier, user_id=user.user_id) ) session_map: dict[SessionId, SessionData] = { diff --git a/src/ai/backend/manager/api/gql/kernel/fetcher/kernel.py b/src/ai/backend/manager/api/gql/kernel/fetcher/kernel.py index 50bda150070..30a803e60aa 100644 --- a/src/ai/backend/manager/api/gql/kernel/fetcher/kernel.py +++ b/src/ai/backend/manager/api/gql/kernel/fetcher/kernel.py @@ -5,6 +5,7 @@ import strawberry from strawberry import Info +from ai.backend.common.contexts.user import current_user from ai.backend.common.types import KernelId from ai.backend.manager.api.gql.adapter import PaginationOptions, PaginationSpec from ai.backend.manager.api.gql.base import encode_cursor @@ -16,6 +17,7 @@ KernelV2OrderByGQL, ) from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.models.kernel import KernelRow from ai.backend.manager.repositories.base import QueryCondition from ai.backend.manager.repositories.scheduler.options import KernelConditions @@ -45,6 +47,10 @@ async def fetch_kernels( offset: int | None = None, base_conditions: list[QueryCondition] | None = None, ) -> KernelV2ConnectionGQL: + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + querier = info.context.gql_adapter.build_querier( PaginationOptions( first=first, @@ -61,7 +67,7 @@ async def fetch_kernels( ) action_result = await info.context.processors.session.search_kernels.wait_for_complete( - SearchKernelsAction(querier=querier) + SearchKernelsAction(querier=querier, user_id=user.user_id) ) nodes = [KernelV2GQL.from_kernel_info(kernel_info) for kernel_info in action_result.data] edges = [KernelV2EdgeGQL(node=node, cursor=encode_cursor(node.id)) for node in nodes] diff --git a/src/ai/backend/manager/api/gql/session/fetcher/session.py b/src/ai/backend/manager/api/gql/session/fetcher/session.py index 07b88f0e069..603405e00e2 100644 --- a/src/ai/backend/manager/api/gql/session/fetcher/session.py +++ b/src/ai/backend/manager/api/gql/session/fetcher/session.py @@ -5,6 +5,7 @@ import strawberry from strawberry import Info +from ai.backend.common.contexts.user import current_user from ai.backend.manager.api.gql.adapter import PaginationOptions, PaginationSpec from ai.backend.manager.api.gql.base import encode_cursor from ai.backend.manager.api.gql.session.types import ( @@ -15,6 +16,7 @@ SessionV2OrderByGQL, ) from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.models.session import SessionRow from ai.backend.manager.repositories.base import QueryCondition from ai.backend.manager.repositories.scheduler.options import SessionConditions, SessionOrders @@ -44,6 +46,10 @@ async def fetch_sessions( offset: int | None = None, base_conditions: list[QueryCondition] | None = None, ) -> SessionV2ConnectionGQL: + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + querier = info.context.gql_adapter.build_querier( PaginationOptions( first=first, @@ -60,7 +66,7 @@ async def fetch_sessions( ) action_result = await info.context.processors.session.search_sessions.wait_for_complete( - SearchSessionsAction(querier=querier) + SearchSessionsAction(querier=querier, user_id=user.user_id) ) nodes = [SessionV2GQL.from_data(session_data) for session_data in action_result.data] diff --git a/src/ai/backend/manager/api/gql/session/types.py b/src/ai/backend/manager/api/gql/session/types.py index 0e384f69b78..c96b629e726 100644 --- a/src/ai/backend/manager/api/gql/session/types.py +++ b/src/ai/backend/manager/api/gql/session/types.py @@ -12,6 +12,7 @@ from strawberry import ID, Info from strawberry.relay import Connection, Edge, Node, NodeID +from ai.backend.common.contexts.user import current_user from ai.backend.common.types import SessionId from ai.backend.manager.api.gql.base import OrderDirection, StringFilter, UUIDFilter, encode_cursor from ai.backend.manager.api.gql.common.types import ( @@ -41,6 +42,7 @@ from ai.backend.manager.api.gql.types import GQLFilter, GQLOrderBy, StrawberryGQLContext from ai.backend.manager.api.gql.user.types.node import UserV2GQL from ai.backend.manager.data.session.types import SessionData, SessionStatus +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.repositories.base import ( BatchQuerier, NoPagination, @@ -437,13 +439,17 @@ async def images(self) -> ImageV2ConnectionGQL: description="Added in 26.3.0. The kernels belonging to this session." ) async def kernels(self, info: Info[StrawberryGQLContext]) -> KernelV2ConnectionGQL: + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + session_id = SessionId(UUID(str(self.id))) querier = BatchQuerier( pagination=NoPagination(), conditions=[KernelConditions.by_session_ids([session_id])], ) action_result = await info.context.processors.session.search_kernels.wait_for_complete( - SearchKernelsAction(querier=querier) + SearchKernelsAction(querier=querier, user_id=user.user_id) ) nodes = [KernelV2GQL.from_kernel_info(kernel) for kernel in action_result.data] edges = [KernelV2EdgeGQL(node=node, cursor=encode_cursor(node.id)) for node in nodes] diff --git a/src/ai/backend/manager/api/gql_legacy/session.py b/src/ai/backend/manager/api/gql_legacy/session.py index 1bea8f9d124..6dcb7efb7e9 100644 --- a/src/ai/backend/manager/api/gql_legacy/session.py +++ b/src/ai/backend/manager/api/gql_legacy/session.py @@ -861,7 +861,7 @@ async def mutate_and_get_payload( result = await graph_ctx.processors.session.modify_session.wait_for_complete( ModifySessionAction( - session_id=session_id, + session_uuid=session_id, updater=Updater( spec=SessionUpdaterSpec( name=OptionalState[str].from_graphql(name), diff --git a/src/ai/backend/manager/api/rest/compute_sessions/handler.py b/src/ai/backend/manager/api/rest/compute_sessions/handler.py index dd4a37834ba..703202a0391 100644 --- a/src/ai/backend/manager/api/rest/compute_sessions/handler.py +++ b/src/ai/backend/manager/api/rest/compute_sessions/handler.py @@ -7,6 +7,7 @@ from typing import Final from ai.backend.common.api_handlers import APIResponse, BodyParam +from ai.backend.common.contexts.user import current_user from ai.backend.common.dto.manager.compute_session import ( PaginationInfo, SearchComputeSessionsRequest, @@ -15,6 +16,7 @@ from ai.backend.common.types import SessionId from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.dto.context import UserContext +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.services.processors import Processors from ai.backend.manager.services.session.actions.search import SearchSessionsAction from ai.backend.manager.services.session.actions.search_kernel import SearchKernelsAction @@ -39,10 +41,14 @@ async def search_sessions( """Search compute sessions with nested container data.""" log.info("SEARCH_SESSIONS (ak:{})", ctx.access_key) + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + # Step 1: Search sessions session_querier = self._adapter.build_session_querier(body.parsed) session_result = await self._processors.session.search_sessions.wait_for_complete( - SearchSessionsAction(querier=session_querier) + SearchSessionsAction(querier=session_querier, user_id=user.user_id) ) # Step 2: Fetch kernels for found sessions @@ -51,7 +57,7 @@ async def search_sessions( if session_ids: kernel_querier = self._adapter.build_kernel_querier_for_sessions(session_ids) kernel_result = await self._processors.session.search_kernels.wait_for_complete( - SearchKernelsAction(querier=kernel_querier) + SearchKernelsAction(querier=kernel_querier, user_id=user.user_id) ) kernels_by_session = self._adapter.group_kernels_by_session(kernel_result.data) diff --git a/src/ai/backend/manager/api/rest/session/handler.py b/src/ai/backend/manager/api/rest/session/handler.py index 98730a49a5b..e37e07439e8 100644 --- a/src/ai/backend/manager/api/rest/session/handler.py +++ b/src/ai/backend/manager/api/rest/session/handler.py @@ -20,6 +20,7 @@ from pydantic import BaseModel from ai.backend.common.api_handlers import APIResponse, BaseResponseModel, BodyParam, QueryParam +from ai.backend.common.contexts.user import current_user from ai.backend.common.dto.manager.session.request import ( CommitSessionRequest, CompleteRequest, @@ -91,6 +92,7 @@ from ai.backend.manager.errors.api import InvalidAPIParameters from ai.backend.manager.errors.auth import InsufficientPrivilege from ai.backend.manager.errors.resource import NoCurrentTaskContext +from ai.backend.manager.errors.user import UserNotFound from ai.backend.manager.models.user import UserRole from ai.backend.manager.services.agent.actions.sync_agent_registry import ( SyncAgentRegistryAction, @@ -491,6 +493,10 @@ async def match_sessions( request = ctx.request params = query.parsed requester_access_key, owner_access_key = await get_access_key_scopes(request) + user = current_user() + if user is None: + raise UserNotFound("User not found in context") + log.info( "MATCH_SESSIONS(ak:{0}/{1}, prefix:{2})", requester_access_key, @@ -501,6 +507,7 @@ async def match_sessions( MatchSessionsAction( id_or_name_prefix=params.id, owner_access_key=owner_access_key, + user_id=user.user_id, ) ) return APIResponse.build(HTTPStatus.OK, MatchSessionsResponse(matches=result.result)) diff --git a/src/ai/backend/manager/services/processors.py b/src/ai/backend/manager/services/processors.py index e7985d56e15..5abe73207b8 100644 --- a/src/ai/backend/manager/services/processors.py +++ b/src/ai/backend/manager/services/processors.py @@ -543,7 +543,11 @@ def create(cls, args: ProcessorArgs, action_monitors: list[ActionMonitor]) -> Se vfolder_sharing_processors = VFolderSharingProcessors( services.vfolder_sharing, action_monitors ) - session_processors = SessionProcessors(services.session, action_monitors) + session_processors = SessionProcessors( + services.session, + action_monitors, + args.service_args.repositories.permission_controller.repository, + ) keypair_resource_policy_processors = KeypairResourcePolicyProcessors( services.keypair_resource_policy, action_monitors ) diff --git a/src/ai/backend/manager/services/session/actions/create_cluster.py b/src/ai/backend/manager/services/session/actions/create_cluster.py index 7a2570e738f..10e843b1df7 100644 --- a/src/ai/backend/manager/services/session/actions/create_cluster.py +++ b/src/ai/backend/manager/services/session/actions/create_cluster.py @@ -3,15 +3,23 @@ from dataclasses import dataclass from typing import Any, override +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction @dataclass -class CreateClusterAction(SessionAction): +class CreateClusterAction(SessionScopeAction): + """Create a new cluster session. + + RBAC validation checks if the user has CREATE permission in USER scope. + Scope is always USER scope with user_id. + """ + session_name: str user_id: uuid.UUID user_role: UserRole @@ -37,6 +45,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class CreateClusterActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/create_from_params.py b/src/ai/backend/manager/services/session/actions/create_from_params.py index fe4d83a4be0..b92996af1e3 100644 --- a/src/ai/backend/manager/services/session/actions/create_from_params.py +++ b/src/ai/backend/manager/services/session/actions/create_from_params.py @@ -6,11 +6,13 @@ import yarl +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, ClusterMode, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction # TODO: Idea: Refactor this type using pydantic and utilize as API model @@ -40,7 +42,13 @@ class CreateFromParamsActionParams: @dataclass -class CreateFromParamsAction(SessionAction): +class CreateFromParamsAction(SessionScopeAction): + """Create a new session from parameters. + + RBAC validation checks if the user has CREATE permission in USER scope. + Scope is always USER scope with user_id. + """ + params: CreateFromParamsActionParams user_id: uuid.UUID user_role: UserRole @@ -57,6 +65,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class CreateFromParamsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/create_from_template.py b/src/ai/backend/manager/services/session/actions/create_from_template.py index decfd4c8208..9479a9feeb7 100644 --- a/src/ai/backend/manager/services/session/actions/create_from_template.py +++ b/src/ai/backend/manager/services/session/actions/create_from_template.py @@ -6,12 +6,14 @@ import yarl +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey, ClusterMode, SessionTypes from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.api.utils import Undefined +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.user import UserRole -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction # TODO: Idea: Refactor this type using pydantic and utilize as API model @@ -43,7 +45,13 @@ class CreateFromTemplateActionParams: @dataclass -class CreateFromTemplateAction(SessionAction): +class CreateFromTemplateAction(SessionScopeAction): + """Create a new session from template. + + RBAC validation checks if the user has CREATE permission in USER scope. + Scope is always USER scope with user_id. + """ + params: CreateFromTemplateActionParams user_id: uuid.UUID user_role: UserRole @@ -60,6 +68,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class CreateFromTemplateActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/match_sessions.py b/src/ai/backend/manager/services/session/actions/match_sessions.py index 93bb8b25e2b..ca91ae8bab5 100644 --- a/src/ai/backend/manager/services/session/actions/match_sessions.py +++ b/src/ai/backend/manager/services/session/actions/match_sessions.py @@ -1,17 +1,27 @@ +import uuid from dataclasses import dataclass from typing import Any, override +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.common.types import AccessKey from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.data.permission.types import RBACElementRef +from ai.backend.manager.services.session.base import SessionScopeAction # TODO: Make this BatchAction @dataclass -class MatchSessionsAction(SessionAction): +class MatchSessionsAction(SessionScopeAction): + """Match sessions by ID or name prefix. + + RBAC validation checks if the user has READ permission in USER scope. + Scope is always USER scope with user_id. + """ + id_or_name_prefix: str owner_access_key: AccessKey + user_id: uuid.UUID @override def entity_id(self) -> str | None: @@ -22,6 +32,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class MatchSessionsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/search.py b/src/ai/backend/manager/services/session/actions/search.py index 83b2e226b2e..8d0b177490d 100644 --- a/src/ai/backend/manager/services/session/actions/search.py +++ b/src/ai/backend/manager/services/session/actions/search.py @@ -1,18 +1,28 @@ from __future__ import annotations +import uuid from dataclasses import dataclass from typing import override +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.data.session.types import SessionData from ai.backend.manager.repositories.base import BatchQuerier -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction @dataclass -class SearchSessionsAction(SessionAction): +class SearchSessionsAction(SessionScopeAction): + """Search sessions within a scope. + + RBAC validation checks if the user has READ permission in USER scope. + Scope is always USER scope with user_id. + """ + querier: BatchQuerier + user_id: uuid.UUID @override def entity_id(self) -> str | None: @@ -23,6 +33,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class SearchSessionsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/actions/search_kernel.py b/src/ai/backend/manager/services/session/actions/search_kernel.py index 5864e534b8d..0a3ace75f58 100644 --- a/src/ai/backend/manager/services/session/actions/search_kernel.py +++ b/src/ai/backend/manager/services/session/actions/search_kernel.py @@ -1,19 +1,28 @@ from __future__ import annotations +import uuid from dataclasses import dataclass from typing import override -from ai.backend.common.data.permission.types import EntityType +from ai.backend.common.data.permission.types import EntityType, RBACElementType, ScopeType from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.kernel.types import KernelInfo +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.repositories.base import BatchQuerier -from ai.backend.manager.services.session.base import SessionAction +from ai.backend.manager.services.session.base import SessionScopeAction @dataclass -class SearchKernelsAction(SessionAction): +class SearchKernelsAction(SessionScopeAction): + """Search kernels within a scope. + + RBAC validation checks if the user has READ permission in USER scope. + Scope is always USER scope with user_id. + """ + querier: BatchQuerier + user_id: uuid.UUID @override @classmethod @@ -29,6 +38,21 @@ def entity_id(self) -> str | None: def operation_type(cls) -> ActionOperationType: return ActionOperationType.SEARCH + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef( + element_type=RBACElementType.USER, + element_id=str(self.user_id), + ) + @dataclass class SearchKernelsActionResult(BaseActionResult): diff --git a/src/ai/backend/manager/services/session/base.py b/src/ai/backend/manager/services/session/base.py index bcc90407892..0693e4c77cd 100644 --- a/src/ai/backend/manager/services/session/base.py +++ b/src/ai/backend/manager/services/session/base.py @@ -3,6 +3,7 @@ from ai.backend.common.data.permission.types import EntityType from ai.backend.manager.actions.action import BaseAction, BaseBatchAction +from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult @dataclass @@ -19,3 +20,24 @@ class SessionBatchAction(BaseBatchAction): @classmethod def entity_type(cls) -> EntityType: return EntityType.SESSION + + +@dataclass +class SessionScopeAction(BaseScopeAction): + """Base class for session actions that operate within a scope (domain/project). + + Used for operations like creating or searching sessions within a specific scope. + Subclasses must implement scope_type(), scope_id(), and target_element() methods. + + Note: Scope should typically be USER scope (user_id), not GLOBAL. + """ + + @override + @classmethod + def entity_type(cls) -> EntityType: + return EntityType.SESSION + + +@dataclass +class SessionScopeActionResult(BaseScopeActionResult): + pass diff --git a/src/ai/backend/manager/services/session/processors.py b/src/ai/backend/manager/services/session/processors.py index be32577a4f5..bded4274e29 100644 --- a/src/ai/backend/manager/services/session/processors.py +++ b/src/ai/backend/manager/services/session/processors.py @@ -1,8 +1,14 @@ -from typing import override +from typing import cast, override from ai.backend.manager.actions.monitors.monitor import ActionMonitor from ai.backend.manager.actions.processor import ActionProcessor from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec +from ai.backend.manager.actions.validator.base import ActionValidator +from ai.backend.manager.actions.validators.rbac.scope import ScopeActionRBACValidator +from ai.backend.manager.actions.validators.rbac.single_entity import SingleEntityActionRBACValidator +from ai.backend.manager.repositories.permission_controller.repository import ( + PermissionControllerRepository, +) from ai.backend.manager.services.session.actions.check_and_transit_status import ( CheckAndTransitStatusAction, CheckAndTransitStatusActionResult, @@ -165,19 +171,24 @@ class SessionProcessors(AbstractProcessorPackage): CheckAndTransitStatusAction, CheckAndTransitStatusActionResult ] - def __init__(self, service: SessionService, action_monitors: list[ActionMonitor]) -> None: + def __init__( + self, + service: SessionService, + action_monitors: list[ActionMonitor], + permission_repository: PermissionControllerRepository, + ) -> None: + # Create RBAC validators + scope_validator = ScopeActionRBACValidator(permission_repository) + single_entity_validator = SingleEntityActionRBACValidator(permission_repository) + + # Actions without RBAC validation (internal/legacy) self.commit_session = ActionProcessor(service.commit_session, action_monitors) self.complete = ActionProcessor(service.complete, action_monitors) self.convert_session_to_image = ActionProcessor( service.convert_session_to_image, action_monitors ) - self.create_cluster = ActionProcessor(service.create_cluster, action_monitors) - self.create_from_params = ActionProcessor(service.create_from_params, action_monitors) - self.create_from_template = ActionProcessor(service.create_from_template, action_monitors) - self.destroy_session = ActionProcessor(service.destroy_session, action_monitors) self.download_file = ActionProcessor(service.download_file, action_monitors) self.download_files = ActionProcessor(service.download_files, action_monitors) - self.execute_session = ActionProcessor(service.execute_session, action_monitors) self.get_abusing_report = ActionProcessor(service.get_abusing_report, action_monitors) self.get_commit_status = ActionProcessor(service.get_commit_status, action_monitors) self.get_container_logs = ActionProcessor(service.get_container_logs, action_monitors) @@ -185,23 +196,70 @@ def __init__(self, service: SessionService, action_monitors: list[ActionMonitor] self.get_direct_access_info = ActionProcessor( service.get_direct_access_info, action_monitors ) - self.get_session_info = ActionProcessor(service.get_session_info, action_monitors) self.get_status_history = ActionProcessor(service.get_status_history, action_monitors) self.interrupt = ActionProcessor(service.interrupt, action_monitors) self.list_files = ActionProcessor(service.list_files, action_monitors) - self.match_sessions = ActionProcessor(service.match_sessions, action_monitors) self.rename_session = ActionProcessor(service.rename_session, action_monitors) self.restart_session = ActionProcessor(service.restart_session, action_monitors) - self.search_kernels = ActionProcessor(service.search_kernels, action_monitors) - self.search_sessions = ActionProcessor(service.search, action_monitors) self.shutdown_service = ActionProcessor(service.shutdown_service, action_monitors) self.start_service = ActionProcessor(service.start_service, action_monitors) self.upload_files = ActionProcessor(service.upload_files, action_monitors) - self.modify_session = ActionProcessor(service.modify_session, action_monitors) self.check_and_transit_status = ActionProcessor( service.check_and_transit_status, action_monitors ) + # Scope actions with RBAC validation + self.create_cluster = ActionProcessor( + service.create_cluster, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.create_from_params = ActionProcessor( + service.create_from_params, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.create_from_template = ActionProcessor( + service.create_from_template, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.match_sessions = ActionProcessor( + service.match_sessions, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.search_kernels = ActionProcessor( + service.search_kernels, + action_monitors, + validators=[cast(ActionValidator, scope_validator)], + ) + self.search_sessions = ActionProcessor( + service.search, action_monitors, validators=[cast(ActionValidator, scope_validator)] + ) + + # Single entity actions with RBAC validation + self.destroy_session = ActionProcessor( + service.destroy_session, + action_monitors, + validators=[cast(ActionValidator, single_entity_validator)], + ) + self.execute_session = ActionProcessor( + service.execute_session, + action_monitors, + validators=[cast(ActionValidator, single_entity_validator)], + ) + self.get_session_info = ActionProcessor( + service.get_session_info, + action_monitors, + validators=[cast(ActionValidator, single_entity_validator)], + ) + self.modify_session = ActionProcessor( + service.modify_session, + action_monitors, + validators=[cast(ActionValidator, single_entity_validator)], + ) + @override def supported_actions(self) -> list[ActionSpec]: return [