diff --git a/changes/9372.fix.md b/changes/9372.fix.md new file mode 100644 index 00000000000..5935f402b8a --- /dev/null +++ b/changes/9372.fix.md @@ -0,0 +1 @@ +Migrate groups table's `container_registry` to `container_registry_id` based schema \ No newline at end of file diff --git a/docs/manager/graphql-reference/schema.graphql b/docs/manager/graphql-reference/schema.graphql index 7a72139c66d..bd26d41a63e 100644 --- a/docs/manager/graphql-reference/schema.graphql +++ b/docs/manager/graphql-reference/schema.graphql @@ -941,11 +941,13 @@ type GroupNode implements Node @key(fields: "id") { integration_id: String resource_policy: String + """ + Added in 24.03.0. The default container registry resolved from container_registry_id. Kept for backward compatibility. + """ + container_registry: JSONString + """Added in 24.03.7. One of ['GENERAL', 'MODEL_STORE'].""" type: String - - """Added in 24.03.7.""" - container_registry: JSONString scaling_groups: [String] """Added in 25.3.0.""" @@ -1068,11 +1070,13 @@ type Group { integration_id: String resource_policy: String - """Added in 24.03.0.""" - type: String + """ + Added in 24.03.0. The default container registry resolved from container_registry_id. Kept for backward compatibility. + """ + container_registry: JSONString """Added in 24.03.0.""" - container_registry: JSONString + type: String scaling_groups: [String] } @@ -2709,8 +2713,10 @@ input GroupInput { integration_id: String = "" resource_policy: String = "default" - """Added in 24.03.0""" - container_registry: JSONString = "{}" + """ + Added in 25.3.0. The default container registry used as the target for session commits when no container registry is explicitly specified. + """ + container_registry_id: UUID } type ModifyGroup { @@ -2731,8 +2737,10 @@ input ModifyGroupInput { integration_id: String resource_policy: String - """Added in 24.03.0""" - container_registry: JSONString = "{}" + """ + Added in 25.3.0. The default container registry used as the target for session commits when no container registry is explicitly specified. + """ + container_registry_id: UUID } """Instead of deleting the group, just mark it as inactive.""" diff --git a/docs/manager/graphql-reference/supergraph.graphql b/docs/manager/graphql-reference/supergraph.graphql index f6ad4f1cf68..1ff5c77b644 100644 --- a/docs/manager/graphql-reference/supergraph.graphql +++ b/docs/manager/graphql-reference/supergraph.graphql @@ -4607,11 +4607,13 @@ type Group integration_id: String resource_policy: String - """Added in 24.03.0.""" - type: String + """ + Added in 24.03.0. The default container registry resolved from container_registry_id. Kept for backward compatibility. + """ + container_registry: JSONString """Added in 24.03.0.""" - container_registry: JSONString + type: String scaling_groups: [String] } @@ -4653,8 +4655,10 @@ input GroupInput integration_id: String = "" resource_policy: String = "default" - """Added in 24.03.0""" - container_registry: JSONString = "{}" + """ + Added in 25.3.0. The default container registry used as the target for session commits when no container registry is explicitly specified. + """ + container_registry_id: UUID } type GroupNode implements Node @@ -4678,11 +4682,13 @@ type GroupNode implements Node integration_id: String @join__field(graph: GRAPHENE) resource_policy: String @join__field(graph: GRAPHENE) + """ + Added in 24.03.0. The default container registry resolved from container_registry_id. Kept for backward compatibility. + """ + container_registry: JSONString @join__field(graph: GRAPHENE) + """Added in 24.03.7. One of ['GENERAL', 'MODEL_STORE'].""" type: String @join__field(graph: GRAPHENE) - - """Added in 24.03.7.""" - container_registry: JSONString @join__field(graph: GRAPHENE) scaling_groups: [String] @join__field(graph: GRAPHENE) """Added in 25.3.0.""" @@ -6421,8 +6427,10 @@ input ModifyGroupInput integration_id: String resource_policy: String - """Added in 24.03.0""" - container_registry: JSONString = "{}" + """ + Added in 25.3.0. The default container registry used as the target for session commits when no container registry is explicitly specified. + """ + container_registry_id: UUID } type ModifyImage diff --git a/src/ai/backend/common/dto/manager/group/request.py b/src/ai/backend/common/dto/manager/group/request.py index 551cb828172..939b20bc039 100644 --- a/src/ai/backend/common/dto/manager/group/request.py +++ b/src/ai/backend/common/dto/manager/group/request.py @@ -61,6 +61,13 @@ class CreateGroupRequest(BaseRequestModel): ) integration_id: str | None = Field(default=None, description="External integration ID") resource_policy: str | None = Field(default=None, description="Resource policy name") + container_registry_id: UUID | None = Field( + default=None, + description=( + "The default container registry used as the target for session commits" + " when no container registry is explicitly specified." + ), + ) class UpdateGroupRequest(BaseRequestModel): @@ -77,6 +84,13 @@ class UpdateGroupRequest(BaseRequestModel): ) integration_id: str | None = Field(default=None, description="Updated external integration ID") resource_policy: str | None = Field(default=None, description="Updated resource policy name") + container_registry_id: UUID | None = Field( + default=None, + description=( + "The default container registry used as the target for session commits" + " when no container registry is explicitly specified." + ), + ) class AddGroupMembersRequest(BaseRequestModel): diff --git a/src/ai/backend/common/dto/manager/group/response.py b/src/ai/backend/common/dto/manager/group/response.py index f3b757ddf0c..7fdaa772e97 100644 --- a/src/ai/backend/common/dto/manager/group/response.py +++ b/src/ai/backend/common/dto/manager/group/response.py @@ -50,8 +50,12 @@ class GroupDTO(BaseModel): default=None, description="Allowed vfolder host permissions" ) resource_policy: str | None = Field(default=None, description="Resource policy name") - container_registry: dict[str, Any] | None = Field( - default=None, description="Container registry configuration" + container_registry_id: UUID | None = Field( + default=None, + description=( + "The default container registry used as the target for session commits" + " when no container registry is explicitly specified." + ), ) diff --git a/src/ai/backend/manager/api/gql_legacy/group.py b/src/ai/backend/manager/api/gql_legacy/group.py index eb38556e890..0feb261f0eb 100644 --- a/src/ai/backend/manager/api/gql_legacy/group.py +++ b/src/ai/backend/manager/api/gql_legacy/group.py @@ -38,10 +38,15 @@ from ai.backend.manager.models.rbac.context import ClientContext from ai.backend.manager.models.rbac.permission_defs import ProjectPermission from ai.backend.manager.models.user import UserRole +from ai.backend.manager.repositories.base import BatchQuerier, NoPagination from ai.backend.manager.repositories.base.creator import Creator from ai.backend.manager.repositories.base.updater import Updater +from ai.backend.manager.repositories.container_registry.options import ContainerRegistryConditions from ai.backend.manager.repositories.group.creators import GroupCreatorSpec from ai.backend.manager.repositories.group.updaters import GroupUpdaterSpec +from ai.backend.manager.services.container_registry.actions.search_container_registries import ( + SearchContainerRegistriesAction, +) from ai.backend.manager.services.group.actions.create_group import CreateGroupAction from ai.backend.manager.services.group.actions.delete_group import ( DeleteGroupAction, @@ -95,6 +100,8 @@ class GroupNode(graphene.ObjectType): # type: ignore[misc] class Meta: interfaces = (AsyncNode,) + _container_registry_id: uuid.UUID | None = None + row_id = graphene.UUID(description="Added in 24.03.7. The undecoded id value stored in DB.") name = graphene.String() description = graphene.String() @@ -106,8 +113,10 @@ class Meta: allowed_vfolder_hosts = graphene.JSONString() integration_id = graphene.String() resource_policy = graphene.String() + container_registry = graphene.JSONString( + description="Added in 24.03.0. The default container registry resolved from container_registry_id. Kept for backward compatibility.", + ) type = graphene.String(description=f"Added in 24.03.7. One of {[t.name for t in ProjectType]}.") - container_registry = graphene.JSONString(description="Added in 24.03.7.") scaling_groups = graphene.List( lambda: graphene.String, ) @@ -144,7 +153,7 @@ def from_row( graph_ctx: GraphQueryContext, row: GroupRow, ) -> Self: - return cls( + obj = cls( id=row.id, row_id=row.id, name=row.name, @@ -158,8 +167,9 @@ def from_row( integration_id=row.integration_id, resource_policy=row.resource_policy, type=row.type.name, - container_registry=row.container_registry, ) + obj._container_registry_id = row.container_registry_id + return obj async def resolve_scaling_groups(self, info: graphene.ResolveInfo) -> Sequence[ScalingGroup]: graph_ctx: GraphQueryContext = info.context @@ -170,6 +180,30 @@ async def resolve_scaling_groups(self, info: graphene.ResolveInfo) -> Sequence[S sgroups = await loader.load(self.id) return [sg.name for sg in sgroups] + async def resolve_container_registry( + self, info: graphene.ResolveInfo + ) -> dict[str, str | None] | None: + if self._container_registry_id is None: + return None + graph_ctx: GraphQueryContext = info.context + registry_id = uuid.UUID(str(self._container_registry_id)) + action = SearchContainerRegistriesAction( + querier=BatchQuerier( + pagination=NoPagination(), + conditions=[ContainerRegistryConditions.by_ids([registry_id])], + ), + ) + result = await graph_ctx.processors.container_registry.search_container_registries.wait_for_complete( + action + ) + if not result.data: + return None + registry = result.data[0] + return { + "registry": registry.registry_name, + "project": registry.project, + } + async def resolve_user_nodes( self, info: graphene.ResolveInfo, @@ -343,6 +377,8 @@ def parse_value(value: str) -> ProjectPermission: class Group(graphene.ObjectType): # type: ignore[misc] + _container_registry_id: uuid.UUID | None = None + id = graphene.UUID() name = graphene.String() description = graphene.String() @@ -354,8 +390,10 @@ class Group(graphene.ObjectType): # type: ignore[misc] allowed_vfolder_hosts = graphene.JSONString() integration_id = graphene.String() resource_policy = graphene.String() + container_registry = graphene.JSONString( + description="Added in 24.03.0. The default container registry resolved from container_registry_id. Kept for backward compatibility.", + ) type = graphene.String(description="Added in 24.03.0.") - container_registry = graphene.JSONString(description="Added in 24.03.0.") scaling_groups = graphene.List(lambda: graphene.String) @@ -363,7 +401,7 @@ class Group(graphene.ObjectType): # type: ignore[misc] def from_row(cls, graph_ctx: GraphQueryContext, row: Row[Any] | None) -> Group | None: if row is None: return None - return cls( + obj = cls( id=row.id, name=row.name, description=row.description, @@ -378,14 +416,15 @@ def from_row(cls, graph_ctx: GraphQueryContext, row: Row[Any] | None) -> Group | integration_id=row.integration_id, resource_policy=row.resource_policy, type=row.type.name, - container_registry=row.container_registry, ) + obj._container_registry_id = row.container_registry_id + return obj @classmethod def from_dto(cls, dto: GroupData | None) -> Self | None: if dto is None: return None - return cls( + obj = cls( id=dto.id, name=dto.name, description=dto.description, @@ -400,8 +439,33 @@ def from_dto(cls, dto: GroupData | None) -> Self | None: integration_id=dto.integration_id, resource_policy=dto.resource_policy, type=dto.type.name, - container_registry=dto.container_registry, ) + obj._container_registry_id = dto.container_registry_id + return obj + + async def resolve_container_registry( + self, info: graphene.ResolveInfo + ) -> dict[str, str | None] | None: + if self._container_registry_id is None: + return None + graph_ctx: GraphQueryContext = info.context + registry_id = uuid.UUID(str(self._container_registry_id)) + action = SearchContainerRegistriesAction( + querier=BatchQuerier( + pagination=NoPagination(), + conditions=[ContainerRegistryConditions.by_ids([registry_id])], + ), + ) + result = await graph_ctx.processors.container_registry.search_container_registries.wait_for_complete( + action + ) + if not result.data: + return None + registry = result.data[0] + return { + "registry": registry.registry_name, + "project": registry.project, + } async def resolve_scaling_groups(self, info: graphene.ResolveInfo) -> Sequence[ScalingGroup]: graph_ctx: GraphQueryContext = info.context @@ -555,8 +619,9 @@ class GroupInput(graphene.InputObjectType): # type: ignore[misc] allowed_vfolder_hosts = graphene.JSONString(required=False, default_value={}) integration_id = graphene.String(required=False, default_value="") resource_policy = graphene.String(required=False, default_value="default") - container_registry = graphene.JSONString( - required=False, default_value={}, description="Added in 24.03.0" + container_registry_id = graphene.UUID( + required=False, + description="Added in 25.3.0. The default container registry used as the target for session commits when no container registry is explicitly specified.", ) def to_action(self, name: str) -> CreateGroupAction: @@ -578,7 +643,7 @@ def value_or_none(value: Any) -> Any: ) integration_id_val = value_or_none(self.integration_id) resource_policy_val = value_or_none(self.resource_policy) - container_registry_val = value_or_none(self.container_registry) + container_registry_id_val = value_or_none(self.container_registry_id) return CreateGroupAction( creator=Creator( @@ -592,7 +657,7 @@ def value_or_none(value: Any) -> Any: allowed_vfolder_hosts=allowed_vfolder_hosts_val, integration_id=integration_id_val, resource_policy=resource_policy_val, - container_registry=container_registry_val, + container_registry_id=container_registry_id_val, ) ), ) @@ -609,8 +674,9 @@ class ModifyGroupInput(graphene.InputObjectType): # type: ignore[misc] allowed_vfolder_hosts = graphene.JSONString(required=False) integration_id = graphene.String(required=False) resource_policy = graphene.String(required=False) - container_registry = graphene.JSONString( - required=False, default_value={}, description="Added in 24.03.0" + container_registry_id = graphene.UUID( + required=False, + description="Added in 25.3.0. The default container registry used as the target for session commits when no container registry is explicitly specified.", ) def to_action(self, group_id: uuid.UUID) -> ModifyGroupAction: @@ -641,8 +707,8 @@ def to_action(self, group_id: uuid.UUID) -> ModifyGroupAction: resource_policy=OptionalState[str].from_graphql( self.resource_policy, ), - container_registry=TriState[dict[str, str]].from_graphql( - self.container_registry, + container_registry_id=TriState[uuid.UUID].from_graphql( + self.container_registry_id, ), ) return ModifyGroupAction( diff --git a/src/ai/backend/manager/data/group/types.py b/src/ai/backend/manager/data/group/types.py index b7bf807b344..2de84cd5a37 100644 --- a/src/ai/backend/manager/data/group/types.py +++ b/src/ai/backend/manager/data/group/types.py @@ -52,7 +52,7 @@ class GroupData: dotfiles: bytes resource_policy: str type: ProjectType - container_registry: dict[str, str] | None + container_registry_id: uuid.UUID | None = None def scope_id(self) -> ScopeId: return ScopeId( @@ -88,8 +88,8 @@ class GroupModifier(PartialModifier): ) integration_id: OptionalState[str] = field(default_factory=OptionalState[str].nop) resource_policy: OptionalState[str] = field(default_factory=OptionalState[str].nop) - container_registry: TriState[dict[str, str]] = field( - default_factory=TriState[dict[str, str]].nop + container_registry_id: TriState[uuid.UUID] = field( + default_factory=TriState[uuid.UUID].nop, ) @override @@ -103,5 +103,5 @@ def fields_to_update(self) -> dict[str, Any]: self.allowed_vfolder_hosts.update_dict(to_update, "allowed_vfolder_hosts") self.integration_id.update_dict(to_update, "integration_id") self.resource_policy.update_dict(to_update, "resource_policy") - self.container_registry.update_dict(to_update, "container_registry") + self.container_registry_id.update_dict(to_update, "container_registry_id") return to_update diff --git a/src/ai/backend/manager/models/alembic/versions/78a9a25d7af3_remove_groups_container_registry_column.py b/src/ai/backend/manager/models/alembic/versions/78a9a25d7af3_remove_groups_container_registry_column.py new file mode 100644 index 00000000000..508bb5b0d5b --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/78a9a25d7af3_remove_groups_container_registry_column.py @@ -0,0 +1,84 @@ +"""remove groups.container_registry jsonb column + +Revision ID: 78a9a25d7af3 +Revises: ffcf0ed13a26 +Create Date: 2026-02-26 00:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "78a9a25d7af3" +down_revision = "ffcf0ed13a26" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + conn = op.get_bind() + + # Migrate existing container_registry JSONB data into the association table. + # The JSONB column stores {"registry": "", "project": ""}. + # We match against the container_registries table and insert into the + # junction table (association_container_registries_groups) if not already present. + conn.execute( + sa.text(""" + INSERT INTO association_container_registries_groups (registry_id, group_id) + SELECT cr.id, g.id + FROM groups g + JOIN container_registries cr + ON cr.registry_name = g.container_registry->>'registry' + AND cr.project = g.container_registry->>'project' + WHERE g.container_registry IS NOT NULL + AND g.container_registry != '{}'::jsonb + ON CONFLICT DO NOTHING + """) + ) + + op.drop_column("groups", "container_registry") + + # Add a new container_registry_id UUID column (no FK constraint). + op.add_column( + "groups", + sa.Column("container_registry_id", postgresql.UUID(as_uuid=True), nullable=True), + ) + + # Populate container_registry_id from the association table (pick one per group). + conn.execute( + sa.text(""" + UPDATE groups g + SET container_registry_id = ( + SELECT registry_id + FROM association_container_registries_groups + WHERE group_id = g.id + LIMIT 1 + ) + """) + ) + + +def downgrade() -> None: + op.drop_column("groups", "container_registry_id") + + op.add_column( + "groups", + sa.Column("container_registry", postgresql.JSONB(), nullable=True), + ) + + # Best-effort: backfill one registry per group from the association table. + conn = op.get_bind() + conn.execute( + sa.text(""" + UPDATE groups g + SET container_registry = jsonb_build_object( + 'registry', cr.registry_name, + 'project', cr.project + ) + FROM association_container_registries_groups acr + JOIN container_registries cr ON cr.id = acr.registry_id + WHERE acr.group_id = g.id + """) + ) diff --git a/src/ai/backend/manager/models/group/row.py b/src/ai/backend/manager/models/group/row.py index 89be2aa0f7d..431904ea687 100644 --- a/src/ai/backend/manager/models/group/row.py +++ b/src/ai/backend/manager/models/group/row.py @@ -15,7 +15,6 @@ ) import sqlalchemy as sa -import trafaret as t from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection from sqlalchemy.ext.asyncio import AsyncSession @@ -45,7 +44,6 @@ EnumValueType, ResourceSlotColumn, SlugType, - StructuredJSONColumn, VFolderHostPermissionColumn, ) from ai.backend.manager.models.rbac import ( @@ -115,12 +113,6 @@ def _get_association_container_registries_groups_join_condition() -> sa.ColumnEl MAXIMUM_DOTFILE_SIZE = 64 * 1024 # 61 KiB -container_registry_iv = t.Dict({}) | t.Dict({ - t.Key("registry"): t.String(), - t.Key("project"): t.String(), -}) - - class AssocGroupUserRow(Base): # type: ignore[misc] __tablename__ = "association_groups_users" __table_args__ = ( @@ -210,13 +202,9 @@ class GroupRow(Base): # type: ignore[misc] nullable=False, default=ProjectType.GENERAL, ) - container_registry: Mapped[dict[str, Any] | None] = mapped_column( - "container_registry", - StructuredJSONColumn(container_registry_iv), - nullable=True, - default=None, + container_registry_id: Mapped[uuid.UUID | None] = mapped_column( + "container_registry_id", GUID, nullable=True ) - # Relationships (defined with deferred join conditions to avoid circular imports) sessions: Mapped[list[SessionRow]] = relationship("SessionRow", back_populates="group") domain: Mapped[DomainRow] = relationship("DomainRow", back_populates="groups") @@ -263,7 +251,7 @@ def to_data(self) -> GroupData: dotfiles=self.dotfiles, resource_policy=self.resource_policy, type=self.type, - container_registry=self.container_registry, + container_registry_id=self.container_registry_id, ) @classmethod @@ -378,7 +366,6 @@ class ProjectModel(RBACModel[ProjectPermission]): _allowed_vfolder_hosts: VFolderHostPermissionMap _dotfiles: bytes _resource_policy: str - _container_registry: dict[str, str] | None _permissions: frozenset[ProjectPermission] = field(default_factory=frozenset) @@ -411,11 +398,6 @@ def dotfiles(self) -> bytes: def resource_policy(self) -> str: return self._resource_policy - @property - @required_permission(ProjectPermission.READ_SENSITIVE_ATTRIBUTE) - def container_registry(self) -> dict[str, Any] | None: - return self._container_registry - @classmethod def from_row(cls, row: GroupRow, permissions: Iterable[ProjectPermission]) -> Self: return cls( @@ -432,7 +414,6 @@ def from_row(cls, row: GroupRow, permissions: Iterable[ProjectPermission]) -> Se _allowed_vfolder_hosts=row.allowed_vfolder_hosts, _dotfiles=row.dotfiles, _resource_policy=row.resource_policy, - _container_registry=row.container_registry, _permissions=frozenset(permissions), ) diff --git a/src/ai/backend/manager/repositories/container_registry_quota/db_source/db_source.py b/src/ai/backend/manager/repositories/container_registry_quota/db_source/db_source.py index 0d12f1c3535..8b02a34e544 100644 --- a/src/ai/backend/manager/repositories/container_registry_quota/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/container_registry_quota/db_source/db_source.py @@ -2,7 +2,6 @@ from __future__ import annotations -from typing import Any from uuid import UUID import sqlalchemy as sa @@ -28,31 +27,20 @@ async def fetch_container_registry_row( ) -> PerProjectContainerRegistryInfo: async with self._db.begin_readonly_session() as db_sess: project_id = scope_id.project_id - project_row = await self._fetch_project_row(db_sess, project_id) + container_registry_id = await self._fetch_project_container_registry_id( + db_sess, project_id + ) - if project_row is None: + if container_registry_id is None: raise ContainerRegistryNotFound( f"Container registry info does not exist or is invalid in the project. (project: {project_id})" ) - container_registry: dict[str, Any] | None = project_row.container_registry - if ( - not container_registry - or "registry" not in container_registry - or "project" not in container_registry - ): - raise ContainerRegistryNotFound( - f"Container registry info does not exist or is invalid in the project. (project: {project_id})" - ) - registry_name, project = ( - container_registry["registry"], - container_registry["project"], - ) - registry_row = await self._fetch_registry_row(db_sess, registry_name, project) + registry_row = await self._fetch_registry_row_by_id(db_sess, container_registry_id) if registry_row is None: raise ContainerRegistryNotFound( - f"Container registry row not found. (registry: {registry_name}, project: {project})" + f"Container registry row not found. (id: {container_registry_id})" ) return PerProjectContainerRegistryInfo( @@ -68,28 +56,29 @@ async def fetch_container_registry_row( extra=registry_row.extra or {}, ) - async def _fetch_project_row( + async def _fetch_project_container_registry_id( self, db_sess: SASession, project_id: UUID, - ) -> GroupRow | None: + ) -> UUID | None: project_query = ( sa.select(GroupRow) .where(GroupRow.id == project_id) - .options(load_only(GroupRow.container_registry)) + .options(load_only(GroupRow.container_registry_id)) ) result = await db_sess.execute(project_query) - return result.scalar_one_or_none() + project_row = result.scalar_one_or_none() + if project_row is None: + return None + return project_row.container_registry_id - async def _fetch_registry_row( + async def _fetch_registry_row_by_id( self, db_sess: SASession, - registry_name: str, - project: str, + registry_id: UUID, ) -> ContainerRegistryRow | None: registry_query = sa.select(ContainerRegistryRow).where( - (ContainerRegistryRow.registry_name == registry_name) - & (ContainerRegistryRow.project == project) + ContainerRegistryRow.id == registry_id ) result = await db_sess.execute(registry_query) return result.scalars().one_or_none() diff --git a/src/ai/backend/manager/repositories/group/creators.py b/src/ai/backend/manager/repositories/group/creators.py index 6be9552f17c..f9494cc0823 100644 --- a/src/ai/backend/manager/repositories/group/creators.py +++ b/src/ai/backend/manager/repositories/group/creators.py @@ -2,6 +2,7 @@ from __future__ import annotations +import uuid from dataclasses import dataclass from typing import override @@ -23,8 +24,8 @@ class GroupCreatorSpec(CreatorSpec[GroupRow]): allowed_vfolder_hosts: VFolderHostPermissionMap | None = None integration_id: str | None = None resource_policy: str | None = None - container_registry: dict[str, str] | None = None dotfiles: bytes | None = None + container_registry_id: uuid.UUID | None = None @override def build_row(self) -> GroupRow: @@ -39,5 +40,5 @@ def build_row(self) -> GroupRow: integration_id=self.integration_id, resource_policy=self.resource_policy, dotfiles=self.dotfiles, - container_registry=self.container_registry, + container_registry_id=self.container_registry_id, ) diff --git a/src/ai/backend/manager/repositories/group/updaters.py b/src/ai/backend/manager/repositories/group/updaters.py index e64a6b667a2..6684ca1e74b 100644 --- a/src/ai/backend/manager/repositories/group/updaters.py +++ b/src/ai/backend/manager/repositories/group/updaters.py @@ -1,5 +1,6 @@ from __future__ import annotations +import uuid from dataclasses import dataclass, field from typing import Any, override @@ -25,8 +26,8 @@ class GroupUpdaterSpec(UpdaterSpec[GroupRow]): ) integration_id: OptionalState[str] = field(default_factory=OptionalState[str].nop) resource_policy: OptionalState[str] = field(default_factory=OptionalState[str].nop) - container_registry: TriState[dict[str, str]] = field( - default_factory=TriState[dict[str, str]].nop + container_registry_id: TriState[uuid.UUID] = field( + default_factory=TriState[uuid.UUID].nop, ) @property @@ -45,5 +46,5 @@ def build_values(self) -> dict[str, Any]: self.allowed_vfolder_hosts.update_dict(to_update, "allowed_vfolder_hosts") self.integration_id.update_dict(to_update, "integration_id") self.resource_policy.update_dict(to_update, "resource_policy") - self.container_registry.update_dict(to_update, "container_registry") + self.container_registry_id.update_dict(to_update, "container_registry_id") return to_update diff --git a/src/ai/backend/testutils/extra_fixtures.py b/src/ai/backend/testutils/extra_fixtures.py index a7e18108779..b62993a87c5 100644 --- a/src/ai/backend/testutils/extra_fixtures.py +++ b/src/ai/backend/testutils/extra_fixtures.py @@ -23,10 +23,7 @@ "resource_policy": "default", "total_resource_slots": {}, "allowed_vfolder_hosts": {}, - "container_registry": { - "registry": "mock_registry", - "project": "mock_project", - }, + "container_registry_id": "00000000-0000-0000-0000-000000000000", "type": "general", } ], diff --git a/tests/component/manager/models/gql_models/test_group.py b/tests/component/manager/models/gql_models/test_group.py index b85070964ef..910649f5565 100644 --- a/tests/component/manager/models/gql_models/test_group.py +++ b/tests/component/manager/models/gql_models/test_group.py @@ -137,7 +137,6 @@ async def test_default_value_types_correctly_processed( assert result.integration_id is None assert result.total_resource_slots == ResourceSlot.from_user_input({}, None) assert result.type == ProjectType.GENERAL - assert result.container_registry == {} assert result.dotfiles == b"\x90" @@ -163,10 +162,6 @@ async def test_db_data_insertion( "dotfiles": b"test_dotfiles", "resource_policy": "default", "type": ProjectType.MODEL_STORE, - "container_registry": { - "registry": "example_registry", - "project": "example_project", - }, } async with database_engine.begin_session() as session: @@ -183,5 +178,4 @@ async def test_db_data_insertion( assert result.allowed_vfolder_hosts == {"local:volume1": {VFolderHostPermission.CREATE}} assert result.resource_policy == data["resource_policy"] assert result.type == data["type"] - assert result.container_registry == data["container_registry"] assert result.dotfiles == data["dotfiles"] diff --git a/tests/unit/client_v2/test_group.py b/tests/unit/client_v2/test_group.py index 86e665513b6..552b588ac22 100644 --- a/tests/unit/client_v2/test_group.py +++ b/tests/unit/client_v2/test_group.py @@ -93,7 +93,6 @@ def _last_request_call(mock_session: MagicMock) -> tuple[str, str, dict[str, Any "total_resource_slots": None, "allowed_vfolder_hosts": None, "resource_policy": None, - "container_registry": None, } _SAMPLE_MEMBER_DTO_1: dict[str, Any] = { diff --git a/tests/unit/manager/api/gql/test_project_v2_types.py b/tests/unit/manager/api/gql/test_project_v2_types.py index 5b97adb1a59..3560c48f6fd 100644 --- a/tests/unit/manager/api/gql/test_project_v2_types.py +++ b/tests/unit/manager/api/gql/test_project_v2_types.py @@ -34,7 +34,6 @@ def test_from_data_basic_conversion(self) -> None: dotfiles=b"", resource_policy="default-policy", type=ProjectType.GENERAL, - container_registry=None, ) # Convert to GraphQL type @@ -75,7 +74,6 @@ def test_from_data_project_type_conversion(self) -> None: dotfiles=b"", resource_policy="default", type=project_type, - container_registry=None, ) project_gql = ProjectV2GQL.from_data(data) @@ -98,7 +96,6 @@ def test_from_data_vfolder_hosts_conversion(self) -> None: dotfiles=b"", resource_policy="default", type=ProjectType.GENERAL, - container_registry=None, ) project_gql_empty = ProjectV2GQL.from_data(data_empty) @@ -130,7 +127,6 @@ def test_from_data_vfolder_hosts_conversion(self) -> None: dotfiles=b"", resource_policy="default", type=ProjectType.GENERAL, - container_registry=None, ) project_gql_with_hosts = ProjectV2GQL.from_data(data_with_hosts) @@ -171,7 +167,6 @@ def test_from_data_optional_fields(self) -> None: dotfiles=b"", resource_policy="default", type=ProjectType.GENERAL, - container_registry=None, ) project_gql = ProjectV2GQL.from_data(data) diff --git a/tests/unit/manager/api/group/test_group_node.py b/tests/unit/manager/api/group/test_group_node.py index df1af6251e0..e9b705b4776 100644 --- a/tests/unit/manager/api/group/test_group_node.py +++ b/tests/unit/manager/api/group/test_group_node.py @@ -63,7 +63,7 @@ def group_data_response(self) -> GroupData: resource_policy="default", type=ProjectType.GENERAL, integration_id=None, - container_registry={}, + container_registry_id=None, ) @pytest.fixture @@ -204,7 +204,6 @@ def mock_group_row(self) -> MagicMock: row.integration_id = None row.resource_policy = "default" row.type = ProjectType.GENERAL - row.container_registry = {} return row @pytest.fixture diff --git a/tests/unit/manager/repositories/auth/test_auth_repository.py b/tests/unit/manager/repositories/auth/test_auth_repository.py index b6c4e33f571..69b5264d43a 100644 --- a/tests/unit/manager/repositories/auth/test_auth_repository.py +++ b/tests/unit/manager/repositories/auth/test_auth_repository.py @@ -312,7 +312,7 @@ async def sample_group_data( dotfiles=group.dotfiles, resource_policy=group.resource_policy, type=group.type, - container_registry=group.container_registry, + container_registry_id=group.container_registry_id, ) yield group_data diff --git a/tests/unit/manager/services/test_group.py b/tests/unit/manager/services/test_group.py index 0b30b7dc87f..18b7602dac3 100644 --- a/tests/unit/manager/services/test_group.py +++ b/tests/unit/manager/services/test_group.py @@ -78,7 +78,7 @@ def sample_group_data(self) -> GroupData: dotfiles=b"\x90", resource_policy="default", type=ProjectType.GENERAL, - container_registry={}, + container_registry_id=None, ) async def test_create_with_valid_data_returns_group( @@ -202,7 +202,7 @@ def modified_group_data(self) -> GroupData: dotfiles=b"\x90", resource_policy="default", type=ProjectType.GENERAL, - container_registry={}, + container_registry_id=None, ) async def test_modify_with_valid_data_returns_updated_group(