diff --git a/changes/9620.feature.md b/changes/9620.feature.md new file mode 100644 index 00000000000..cd4ad19fe87 --- /dev/null +++ b/changes/9620.feature.md @@ -0,0 +1 @@ +Add `sessions` resolver to strawberry `AgentV2` schema diff --git a/docs/manager/graphql-reference/supergraph.graphql b/docs/manager/graphql-reference/supergraph.graphql index ff61daf3b17..8f47602107a 100644 --- a/docs/manager/graphql-reference/supergraph.graphql +++ b/docs/manager/graphql-reference/supergraph.graphql @@ -520,6 +520,11 @@ type AgentV2 implements Node Added in 26.2.0. List of kernels running on this agent with pagination support. """ kernels(filter: KernelV2Filter = null, orderBy: [KernelV2OrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): KernelV2Connection! + + """ + Added in 26.3.0. List of sessions running on this agent with pagination support. + """ + sessions(filter: SessionV2Filter = null, orderBy: [SessionV2OrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): SessionV2Connection! } """ diff --git a/docs/manager/graphql-reference/v2-schema.graphql b/docs/manager/graphql-reference/v2-schema.graphql index e98175038c1..28adb46d89c 100644 --- a/docs/manager/graphql-reference/v2-schema.graphql +++ b/docs/manager/graphql-reference/v2-schema.graphql @@ -316,6 +316,11 @@ type AgentV2 implements Node { Added in 26.2.0. List of kernels running on this agent with pagination support. """ kernels(filter: KernelV2Filter = null, orderBy: [KernelV2OrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): KernelV2Connection! + + """ + Added in 26.3.0. List of sessions running on this agent with pagination support. + """ + sessions(filter: SessionV2Filter = null, orderBy: [SessionV2OrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): SessionV2Connection! } """ diff --git a/src/ai/backend/manager/api/gql/agent/types.py b/src/ai/backend/manager/api/gql/agent/types.py index 5d7228dd450..4bb1566c721 100644 --- a/src/ai/backend/manager/api/gql/agent/types.py +++ b/src/ai/backend/manager/api/gql/agent/types.py @@ -20,6 +20,11 @@ KernelV2FilterGQL, KernelV2OrderByGQL, ) + from ai.backend.manager.api.gql.session.types import ( + SessionV2ConnectionGQL, + SessionV2FilterGQL, + SessionV2OrderByGQL, + ) from ai.backend.manager.api.gql.utils import dedent_strip from ai.backend.manager.data.agent.types import AgentDetailData, AgentStatus from ai.backend.manager.models.rbac.permission_defs import AgentPermission @@ -32,6 +37,7 @@ ) from ai.backend.manager.repositories.scheduler.options import ( KernelConditions, + SessionConditions, ) @@ -439,7 +445,48 @@ async def kernels( last=last, limit=limit, offset=offset, - base_conditions=[KernelConditions.by_agent_id(str(self._agent_id))], + base_conditions=[KernelConditions.by_agent_id(self._agent_id)], + ) + + @strawberry.field( # type: ignore[misc] + description="Added in 26.3.0. List of sessions running on this agent with pagination support." + ) + async def sessions( + self, + info: Info[StrawberryGQLContext], + filter: Annotated[ + SessionV2FilterGQL, strawberry.lazy("ai.backend.manager.api.gql.session.types") + ] + | None = None, + order_by: list[ + Annotated[ + SessionV2OrderByGQL, strawberry.lazy("ai.backend.manager.api.gql.session.types") + ] + ] + | None = None, + before: str | None = None, + after: str | None = None, + first: int | None = None, + last: int | None = None, + limit: int | None = None, + offset: int | None = None, + ) -> Annotated[ + SessionV2ConnectionGQL, strawberry.lazy("ai.backend.manager.api.gql.session.types") + ]: + """Fetch sessions associated with this agent.""" + from ai.backend.manager.api.gql.session.fetcher.session import fetch_sessions + + return await fetch_sessions( + info=info, + filter=filter, + order_by=order_by, + before=before, + after=after, + first=first, + last=last, + limit=limit, + offset=offset, + base_conditions=[SessionConditions.by_agent_id(self._agent_id)], ) @classmethod diff --git a/src/ai/backend/manager/repositories/scheduler/options.py b/src/ai/backend/manager/repositories/scheduler/options.py index a7d69bb3c60..ba0852637de 100644 --- a/src/ai/backend/manager/repositories/scheduler/options.py +++ b/src/ai/backend/manager/repositories/scheduler/options.py @@ -18,7 +18,7 @@ ) from ai.backend.manager.api.gql.kernel.types import KernelStatusInMatchSpec -from ai.backend.common.types import KernelId, SessionId +from ai.backend.common.types import AgentId, KernelId, SessionId from ai.backend.manager.data.kernel.types import KernelStatus from ai.backend.manager.data.session.types import KernelMatchType, SessionStatus from ai.backend.manager.models.image import ImageRow @@ -456,6 +456,15 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: return inner + @staticmethod + def by_agent_id(agent_id: AgentId) -> QueryCondition: + """Filter sessions that have kernels running on the given agent.""" + + def inner() -> sa.sql.expression.ColumnElement[bool]: + return sa.literal(agent_id) == sa.any_(SessionRow.agent_ids) + + return inner + @staticmethod def by_cursor_forward(cursor_id: str) -> QueryCondition: """Cursor condition for forward pagination (after cursor). @@ -619,7 +628,7 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: return inner @staticmethod - def by_agent_id(agent_id: str) -> QueryCondition: + def by_agent_id(agent_id: AgentId) -> QueryCondition: """Filter kernels by agent ID.""" def inner() -> sa.sql.expression.ColumnElement[bool]: