From 7217e217c4ae59b3db45a74efbe864b4fe2a7ef4 Mon Sep 17 00:00:00 2001 From: jopemachine Date: Fri, 13 Feb 2026 04:52:32 +0000 Subject: [PATCH 01/10] feat: Add `artifact_storages` table - Add artifact_storages table and migration script - Move name column to artifact_storages - Add proper selectinload in db_source - Add ID type and update version - Append admin prefix - Fix rename types and error domain - Add news fragment --- changes/7057.feature.md | 1 + .../graphql-reference/supergraph.graphql | 54 +++++- .../graphql-reference/v2-schema.graphql | 46 ++++- .../backend/common/data/permission/types.py | 1 + src/ai/backend/common/exception.py | 1 + src/ai/backend/common/metrics/metric.py | 1 + src/ai/backend/common/types.py | 4 + .../manager/api/gql/artifact_storage.py | 131 +++++++++++++ .../backend/manager/api/gql/object_storage.py | 19 +- src/ai/backend/manager/api/gql/schema.py | 4 + src/ai/backend/manager/api/gql/vfs_storage.py | 19 +- .../data/artifact_storages/__init__.py | 7 + .../manager/data/artifact_storages/types.py | 62 +++++++ .../manager/errors/artifact_storage.py | 21 +++ ...0662_add_artifact_storages_common_table.py | 174 ++++++++++++++++++ .../manager/models/artifact_storages.py | 84 +++++++++ .../manager/models/object_storage/row.py | 21 ++- .../backend/manager/models/vfs_storage/row.py | 28 ++- .../repositories/artifact_storage/__init__.py | 7 + .../artifact_storage/db_source/__init__.py | 3 + .../artifact_storage/db_source/db_source.py | 72 ++++++++ .../artifact_storage/repositories.py | 19 ++ .../artifact_storage/repository.py | 59 ++++++ .../repositories/object_storage/creators.py | 2 - .../object_storage/db_source/db_source.py | 91 +++++++-- .../repositories/object_storage/repository.py | 17 +- .../repositories/object_storage/updaters.py | 2 - .../manager/repositories/repositories.py | 6 + .../repositories/vfs_storage/creators.py | 2 - .../vfs_storage/db_source/db_source.py | 80 ++++++-- .../repositories/vfs_storage/repository.py | 17 +- .../repositories/vfs_storage/updaters.py | 2 - .../services/artifact_storage/__init__.py | 7 + .../artifact_storage/actions/__init__.py | 8 + .../services/artifact_storage/actions/base.py | 11 ++ .../artifact_storage/actions/update.py | 34 ++++ .../services/artifact_storage/processors.py | 25 +++ .../services/artifact_storage/service.py | 32 ++++ .../services/object_storage/actions/create.py | 8 +- .../services/object_storage/actions/update.py | 2 + .../services/object_storage/service.py | 4 +- src/ai/backend/manager/services/processors.py | 13 ++ .../services/vfs_storage/actions/create.py | 8 +- .../services/vfs_storage/actions/update.py | 2 + .../manager/services/vfs_storage/service.py | 4 +- 45 files changed, 1146 insertions(+), 69 deletions(-) create mode 100644 changes/7057.feature.md create mode 100644 src/ai/backend/manager/api/gql/artifact_storage.py create mode 100644 src/ai/backend/manager/data/artifact_storages/__init__.py create mode 100644 src/ai/backend/manager/data/artifact_storages/types.py create mode 100644 src/ai/backend/manager/errors/artifact_storage.py create mode 100644 src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py create mode 100644 src/ai/backend/manager/models/artifact_storages.py create mode 100644 src/ai/backend/manager/repositories/artifact_storage/__init__.py create mode 100644 src/ai/backend/manager/repositories/artifact_storage/db_source/__init__.py create mode 100644 src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py create mode 100644 src/ai/backend/manager/repositories/artifact_storage/repositories.py create mode 100644 src/ai/backend/manager/repositories/artifact_storage/repository.py create mode 100644 src/ai/backend/manager/services/artifact_storage/__init__.py create mode 100644 src/ai/backend/manager/services/artifact_storage/actions/__init__.py create mode 100644 src/ai/backend/manager/services/artifact_storage/actions/base.py create mode 100644 src/ai/backend/manager/services/artifact_storage/actions/update.py create mode 100644 src/ai/backend/manager/services/artifact_storage/processors.py create mode 100644 src/ai/backend/manager/services/artifact_storage/service.py diff --git a/changes/7057.feature.md b/changes/7057.feature.md new file mode 100644 index 00000000000..9900e6a92ca --- /dev/null +++ b/changes/7057.feature.md @@ -0,0 +1 @@ +Add `artifact_storages` common table for storage metadata management across object storage and VFS storage backends, and add `adminUpdateArtifactStorage` GraphQL mutation for updating artifact storage metadata (e.g., name) \ No newline at end of file diff --git a/docs/manager/graphql-reference/supergraph.graphql b/docs/manager/graphql-reference/supergraph.graphql index c213b8c3005..41545bfca12 100644 --- a/docs/manager/graphql-reference/supergraph.graphql +++ b/docs/manager/graphql-reference/supergraph.graphql @@ -996,6 +996,34 @@ type ArtifactStatusChangedPayload artifactRevision: ArtifactRevision! } +"""Added in 26.3.0. Artifact storage metadata""" +type ArtifactStorage + @join__type(graph: STRAWBERRY) +{ + """The ID of the artifact storage""" + id: ID! + + """The name of the artifact storage""" + name: String! + + """The ID of the underlying storage (ObjectStorage or VFSStorage)""" + storageId: ID! + + """The type of the artifact storage""" + type: ArtifactStorageType! +} + +""" +Added in 26.3.0. The type of artifact storage backend. OBJECT_STORAGE: Object storage (e.g., S3-compatible). VFS_STORAGE: Virtual folder storage. GIT_LFS: Git LFS storage. +""" +enum ArtifactStorageType + @join__type(graph: STRAWBERRY) +{ + OBJECT_STORAGE @join__enumValue(graph: STRAWBERRY) + VFS_STORAGE @join__enumValue(graph: STRAWBERRY) + GIT_LFS @join__enumValue(graph: STRAWBERRY) +} + enum ArtifactType @join__type(graph: STRAWBERRY) { @@ -7298,6 +7326,11 @@ type Mutation """Added in 25.16.0. Delete a VFS storage""" deleteVFSStorage(input: DeleteVFSStorageInput!): DeleteVFSStoragePayload! @join__field(graph: STRAWBERRY) + """ + Added in 26.3.0. Update artifact storage metadata (common fields like name) + """ + adminUpdateArtifactStorage(input: UpdateArtifactStorageInput!): UpdateArtifactStoragePayload! @join__field(graph: STRAWBERRY) + """ Added in 25.15.0. @@ -11844,6 +11877,25 @@ type UpdateArtifactPayload artifact: Artifact! } +"""Added in 26.3.0. Input for updating artifact storage metadata""" +input UpdateArtifactStorageInput + @join__type(graph: STRAWBERRY) +{ + """The ID of the artifact storage""" + id: ID! + + """The new name for the artifact storage""" + name: String +} + +"""Added in 26.3.0. Payload for updating artifact storage metadata""" +type UpdateArtifactStoragePayload + @join__type(graph: STRAWBERRY) +{ + """The updated artifact storage""" + artifactStorage: ArtifactStorage! +} + input UpdateAutoScalingRuleInput @join__type(graph: STRAWBERRY) { @@ -11951,7 +12003,6 @@ input UpdateObjectStorageInput @join__type(graph: STRAWBERRY) { id: ID! - name: String host: String accessKey: String secretKey: String @@ -12167,7 +12218,6 @@ input UpdateVFSStorageInput @join__type(graph: STRAWBERRY) { id: ID! - name: String host: String basePath: String } diff --git a/docs/manager/graphql-reference/v2-schema.graphql b/docs/manager/graphql-reference/v2-schema.graphql index 2755b07cdd5..4c21a4ac08b 100644 --- a/docs/manager/graphql-reference/v2-schema.graphql +++ b/docs/manager/graphql-reference/v2-schema.graphql @@ -709,6 +709,30 @@ type ArtifactStatusChangedPayload { artifactRevision: ArtifactRevision! } +"""Added in 26.3.0. Artifact storage metadata""" +type ArtifactStorage { + """The ID of the artifact storage""" + id: ID! + + """The name of the artifact storage""" + name: String! + + """The ID of the underlying storage (ObjectStorage or VFSStorage)""" + storageId: ID! + + """The type of the artifact storage""" + type: ArtifactStorageType! +} + +""" +Added in 26.3.0. The type of artifact storage backend. OBJECT_STORAGE: Object storage (e.g., S3-compatible). VFS_STORAGE: Virtual folder storage. GIT_LFS: Git LFS storage. +""" +enum ArtifactStorageType { + OBJECT_STORAGE + VFS_STORAGE + GIT_LFS +} + enum ArtifactType { MODEL PACKAGE @@ -3731,6 +3755,11 @@ type Mutation { """Added in 25.16.0. Delete a VFS storage""" deleteVFSStorage(input: DeleteVFSStorageInput!): DeleteVFSStoragePayload! + """ + Added in 26.3.0. Update artifact storage metadata (common fields like name) + """ + adminUpdateArtifactStorage(input: UpdateArtifactStorageInput!): UpdateArtifactStoragePayload! + """ Added in 25.15.0. @@ -6942,6 +6971,21 @@ type UpdateArtifactPayload { artifact: Artifact! } +"""Added in 26.3.0. Input for updating artifact storage metadata""" +input UpdateArtifactStorageInput { + """The ID of the artifact storage""" + id: ID! + + """The new name for the artifact storage""" + name: String +} + +"""Added in 26.3.0. Payload for updating artifact storage metadata""" +type UpdateArtifactStoragePayload { + """The updated artifact storage""" + artifactStorage: ArtifactStorage! +} + input UpdateAutoScalingRuleInput { id: ID! metricSource: AutoScalingMetricSource @@ -7019,7 +7063,6 @@ type UpdateNotificationRulePayload { """Added in 25.14.0""" input UpdateObjectStorageInput { id: ID! - name: String host: String accessKey: String secretKey: String @@ -7209,7 +7252,6 @@ type UpdateUserV2Payload { """Added in 25.16.0. Input for updating VFS storage""" input UpdateVFSStorageInput { id: ID! - name: String host: String basePath: String } diff --git a/src/ai/backend/common/data/permission/types.py b/src/ai/backend/common/data/permission/types.py index 5a077d9f232..375f03bd5f7 100644 --- a/src/ai/backend/common/data/permission/types.py +++ b/src/ai/backend/common/data/permission/types.py @@ -103,6 +103,7 @@ class EntityType(enum.StrEnum): NETWORK = "network" NOTIFICATION = "notification" OBJECT_PERMISSION = "object_permission" + ARTIFACT_STORAGE = "artifact_storage" OBJECT_STORAGE = "object_storage" PERMISSION = "permission" AGENT_RESOURCE = "agent_resource" diff --git a/src/ai/backend/common/exception.py b/src/ai/backend/common/exception.py index 454e705da57..bcbf82c2d3b 100644 --- a/src/ai/backend/common/exception.py +++ b/src/ai/backend/common/exception.py @@ -137,6 +137,7 @@ class ErrorDomain(enum.StrEnum): ARTIFACT = "artifact" ARTIFACT_REGISTRY = "artifact-registry" ARTIFACT_ASSOCIATION = "artifact-association" + ARTIFACT_STORAGE = "artifact-storage" OBJECT_STORAGE = "object-storage" VFS_STORAGE = "vfs-storage" STORAGE_NAMESPACE = "storage-namespace" diff --git a/src/ai/backend/common/metrics/metric.py b/src/ai/backend/common/metrics/metric.py index 3eb3435e266..24fdf61e45e 100644 --- a/src/ai/backend/common/metrics/metric.py +++ b/src/ai/backend/common/metrics/metric.py @@ -409,6 +409,7 @@ class LayerType(enum.StrEnum): AUTH_REPOSITORY = "auth_repository" ARTIFACT_REPOSITORY = "artifact_repository" ARTIFACT_REGISTRY_REPOSITORY = "artifact_registry_repository" + ARTIFACT_STORAGE_REPOSITORY = "artifact_storage_repository" AUDIT_LOG_REPOSITORY = "audit_log_repository" CONTAINER_REGISTRY_REPOSITORY = "container_registry_repository" DEPLOYMENT_REPOSITORY = "deployment_repository" diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index bb81680d6c6..a8fb043bf79 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -306,6 +306,10 @@ def check_typed_tuple(value: tuple[Any, ...], types: tuple[type, ...]) -> tuple[ RuleId = NewType("RuleId", UUID) SessionId = NewType("SessionId", UUID) KernelId = NewType("KernelId", UUID) +# ID of the `artifact_storages` common table (storage metadata). +ArtifactStorageId = NewType("ArtifactStorageId", UUID) +# ID of the concrete storage table (`object_storages`, `vfs_storages`, etc.). +ConcreteStorageId = NewType("ConcreteStorageId", UUID) ImageAlias = NewType("ImageAlias", str) ArchName = NewType("ArchName", str) diff --git a/src/ai/backend/manager/api/gql/artifact_storage.py b/src/ai/backend/manager/api/gql/artifact_storage.py new file mode 100644 index 00000000000..b0209713fe3 --- /dev/null +++ b/src/ai/backend/manager/api/gql/artifact_storage.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import uuid +from enum import StrEnum +from typing import Self + +import strawberry +from strawberry import ID, UNSET, Info + +from ai.backend.common.data.storage.types import ArtifactStorageType +from ai.backend.common.types import ArtifactStorageId +from ai.backend.manager.api.gql.utils import check_admin_only +from ai.backend.manager.data.artifact_storages.types import ( + ArtifactStorageData, + ArtifactStorageUpdaterSpec, +) +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow +from ai.backend.manager.repositories.base.updater import Updater +from ai.backend.manager.services.artifact_storage.actions.update import ( + UpdateArtifactStorageAction, +) +from ai.backend.manager.types import OptionalState + +from .types import StrawberryGQLContext + + +@strawberry.enum( + name="ArtifactStorageType", + description=( + "Added in 26.3.0. The type of artifact storage backend. " + "OBJECT_STORAGE: Object storage (e.g., S3-compatible). " + "VFS_STORAGE: Virtual folder storage. " + "GIT_LFS: Git LFS storage." + ), +) +class ArtifactStorageTypeGQL(StrEnum): + """Artifact storage type enum.""" + + OBJECT_STORAGE = "object_storage" + VFS_STORAGE = "vfs_storage" + GIT_LFS = "git_lfs" + + @classmethod + def from_internal(cls, internal_type: ArtifactStorageType) -> ArtifactStorageTypeGQL: + """Convert internal ArtifactStorageType to GraphQL enum.""" + match internal_type: + case ArtifactStorageType.OBJECT_STORAGE: + return cls.OBJECT_STORAGE + case ArtifactStorageType.VFS_STORAGE: + return cls.VFS_STORAGE + case ArtifactStorageType.GIT_LFS: + return cls.GIT_LFS + + def to_internal(self) -> ArtifactStorageType: + """Convert GraphQL enum to internal ArtifactStorageType.""" + match self: + case ArtifactStorageTypeGQL.OBJECT_STORAGE: + return ArtifactStorageType.OBJECT_STORAGE + case ArtifactStorageTypeGQL.VFS_STORAGE: + return ArtifactStorageType.VFS_STORAGE + case ArtifactStorageTypeGQL.GIT_LFS: + return ArtifactStorageType.GIT_LFS + + +@strawberry.type(name="ArtifactStorage", description="Added in 26.3.0. Artifact storage metadata") +class ArtifactStorageGQL: + id: ID = strawberry.field(description="The ID of the artifact storage") + name: str = strawberry.field(description="The name of the artifact storage") + storage_id: ID = strawberry.field( + description="The ID of the underlying storage (ObjectStorage or VFSStorage)" + ) + type: ArtifactStorageTypeGQL = strawberry.field(description="The type of the artifact storage") + + @classmethod + def from_dataclass(cls, data: ArtifactStorageData) -> Self: + return cls( + id=ID(str(data.id)), + name=data.name, + storage_id=ID(str(data.storage_id)), + type=ArtifactStorageTypeGQL.from_internal(data.type), + ) + + +@strawberry.input( + name="UpdateArtifactStorageInput", + description="Added in 26.3.0. Input for updating artifact storage metadata", +) +class UpdateArtifactStorageInputGQL: + """Input for updating artifact storage metadata (common fields like name).""" + + id: ID = strawberry.field(description="The ID of the artifact storage") + name: str | None = strawberry.field( + default=UNSET, description="The new name for the artifact storage" + ) + + def to_updater(self) -> Updater[ArtifactStorageRow]: + spec = ArtifactStorageUpdaterSpec( + name=OptionalState[str].from_graphql(self.name), + ) + return Updater(spec=spec, pk_value=ArtifactStorageId(uuid.UUID(self.id))) + + +@strawberry.type( + name="UpdateArtifactStoragePayload", + description="Added in 26.3.0. Payload for updating artifact storage metadata", +) +class UpdateArtifactStoragePayloadGQL: + artifact_storage: ArtifactStorageGQL = strawberry.field( + description="The updated artifact storage" + ) + + +@strawberry.mutation( # type: ignore[misc] + name="adminUpdateArtifactStorage", + description="Added in 26.3.0. Update artifact storage metadata (common fields like name)", +) +async def update_artifact_storage( + input: UpdateArtifactStorageInputGQL, info: Info[StrawberryGQLContext] +) -> UpdateArtifactStoragePayloadGQL: + check_admin_only() + processors = info.context.processors + + action_result = await processors.artifact_storage.update.wait_for_complete( + UpdateArtifactStorageAction( + updater=input.to_updater(), + ) + ) + + return UpdateArtifactStoragePayloadGQL( + artifact_storage=ArtifactStorageGQL.from_dataclass(action_result.result), + ) diff --git a/src/ai/backend/manager/api/gql/object_storage.py b/src/ai/backend/manager/api/gql/object_storage.py index 0097bb5b45b..aad51c41a5a 100644 --- a/src/ai/backend/manager/api/gql/object_storage.py +++ b/src/ai/backend/manager/api/gql/object_storage.py @@ -3,13 +3,15 @@ import json import uuid from collections.abc import Iterable -from typing import Self +from typing import TYPE_CHECKING, Self import strawberry from strawberry import ID, UNSET, Info from strawberry.relay import Connection, Edge, Node, NodeID +from ai.backend.common.data.storage.types import ArtifactStorageType from ai.backend.manager.api.gql.base import encode_cursor +from ai.backend.manager.data.artifact_storages.types import ArtifactStorageCreatorSpec from ai.backend.manager.data.object_storage.types import ObjectStorageData from ai.backend.manager.models.object_storage import ObjectStorageRow from ai.backend.manager.repositories.base.creator import Creator @@ -35,6 +37,9 @@ from .storage_namespace import StorageNamespace, StorageNamespaceConnection, StorageNamespaceEdge from .types import StrawberryGQLContext +if TYPE_CHECKING: + from ai.backend.manager.models.artifact_storages import ArtifactStorageRow + @strawberry.type(description="Added in 25.14.0") class ObjectStorage(Node): @@ -165,7 +170,6 @@ class CreateObjectStorageInput: def to_creator(self) -> Creator[ObjectStorageRow]: return Creator( spec=ObjectStorageCreatorSpec( - name=self.name, host=self.host, access_key=self.access_key, secret_key=self.secret_key, @@ -174,11 +178,18 @@ def to_creator(self) -> Creator[ObjectStorageRow]: ) ) + def to_meta_creator(self) -> Creator[ArtifactStorageRow]: + return Creator( + spec=ArtifactStorageCreatorSpec( + name=self.name, + storage_type=ArtifactStorageType.OBJECT_STORAGE, + ) + ) + @strawberry.input(description="Added in 25.14.0") class UpdateObjectStorageInput: id: ID - name: str | None = UNSET host: str | None = UNSET access_key: str | None = UNSET secret_key: str | None = UNSET @@ -187,7 +198,6 @@ class UpdateObjectStorageInput: def to_updater(self) -> Updater[ObjectStorageRow]: spec = ObjectStorageUpdaterSpec( - name=OptionalState[str].from_graphql(self.name), host=OptionalState[str].from_graphql(self.host), access_key=OptionalState[str].from_graphql(self.access_key), secret_key=OptionalState[str].from_graphql(self.secret_key), @@ -250,6 +260,7 @@ async def create_object_storage( action_result = await processors.object_storage.create.wait_for_complete( CreateObjectStorageAction( creator=input.to_creator(), + meta_creator=input.to_meta_creator(), ) ) diff --git a/src/ai/backend/manager/api/gql/schema.py b/src/ai/backend/manager/api/gql/schema.py index d1e0703ab00..1634d79876f 100644 --- a/src/ai/backend/manager/api/gql/schema.py +++ b/src/ai/backend/manager/api/gql/schema.py @@ -39,6 +39,9 @@ update_artifact, ) from .artifact_registry import default_artifact_registry +from .artifact_storage import ( + update_artifact_storage, +) from .background_task import background_task_events from .deployment import ( # Revision @@ -427,6 +430,7 @@ class Mutation: create_vfs_storage = create_vfs_storage update_vfs_storage = update_vfs_storage delete_vfs_storage = delete_vfs_storage + update_artifact_storage = update_artifact_storage register_storage_namespace = register_storage_namespace unregister_storage_namespace = unregister_storage_namespace create_huggingface_registry = create_huggingface_registry diff --git a/src/ai/backend/manager/api/gql/vfs_storage.py b/src/ai/backend/manager/api/gql/vfs_storage.py index f450721546e..3b4fd53737f 100644 --- a/src/ai/backend/manager/api/gql/vfs_storage.py +++ b/src/ai/backend/manager/api/gql/vfs_storage.py @@ -2,13 +2,15 @@ import uuid from collections.abc import Iterable -from typing import Self +from typing import TYPE_CHECKING, Self import strawberry from strawberry import ID, UNSET, Info from strawberry.relay import Connection, Edge, Node, NodeID +from ai.backend.common.data.storage.types import ArtifactStorageType from ai.backend.manager.api.gql.base import encode_cursor +from ai.backend.manager.data.artifact_storages.types import ArtifactStorageCreatorSpec from ai.backend.manager.data.vfs_storage.types import VFSStorageData from ai.backend.manager.models.vfs_storage import VFSStorageRow from ai.backend.manager.repositories.base.creator import Creator @@ -24,6 +26,9 @@ from .types import StrawberryGQLContext +if TYPE_CHECKING: + from ai.backend.manager.models.artifact_storages import ArtifactStorageRow + @strawberry.type(description="Added in 25.16.0. VFS Storage configuration") class VFSStorage(Node): @@ -113,23 +118,28 @@ class CreateVFSStorageInput: def to_creator(self) -> Creator[VFSStorageRow]: return Creator( spec=VFSStorageCreatorSpec( - name=self.name, host=self.host, base_path=self.base_path, ) ) + def to_meta_creator(self) -> Creator[ArtifactStorageRow]: + return Creator( + spec=ArtifactStorageCreatorSpec( + name=self.name, + storage_type=ArtifactStorageType.VFS_STORAGE, + ) + ) + @strawberry.input(description="Added in 25.16.0. Input for updating VFS storage") class UpdateVFSStorageInput: id: ID - name: str | None = UNSET host: str | None = UNSET base_path: str | None = UNSET def to_updater(self) -> Updater[VFSStorageRow]: spec = VFSStorageUpdaterSpec( - name=OptionalState[str].from_graphql(self.name), host=OptionalState[str].from_graphql(self.host), base_path=OptionalState[str].from_graphql(self.base_path), ) @@ -167,6 +177,7 @@ async def create_vfs_storage( action_result = await processors.vfs_storage.create.wait_for_complete( CreateVFSStorageAction( creator=input.to_creator(), + meta_creator=input.to_meta_creator(), ) ) diff --git a/src/ai/backend/manager/data/artifact_storages/__init__.py b/src/ai/backend/manager/data/artifact_storages/__init__.py new file mode 100644 index 00000000000..d358ac3fdaf --- /dev/null +++ b/src/ai/backend/manager/data/artifact_storages/__init__.py @@ -0,0 +1,7 @@ +from .types import ArtifactStorageCreatorSpec, ArtifactStorageData, ArtifactStorageUpdaterSpec + +__all__ = ( + "ArtifactStorageCreatorSpec", + "ArtifactStorageData", + "ArtifactStorageUpdaterSpec", +) diff --git a/src/ai/backend/manager/data/artifact_storages/types.py b/src/ai/backend/manager/data/artifact_storages/types.py new file mode 100644 index 00000000000..d1e20c6f4fa --- /dev/null +++ b/src/ai/backend/manager/data/artifact_storages/types.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, override + +from ai.backend.common.data.storage.types import ArtifactStorageType +from ai.backend.common.types import ArtifactStorageId, ConcreteStorageId +from ai.backend.manager.errors.common import InternalServerError +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow +from ai.backend.manager.repositories.base.creator import CreatorSpec +from ai.backend.manager.repositories.base.updater import UpdaterSpec +from ai.backend.manager.types import OptionalState + + +@dataclass(frozen=True) +class ArtifactStorageData: + """Data class for ArtifactStorageRow.""" + + id: ArtifactStorageId + name: str + storage_id: ConcreteStorageId + type: ArtifactStorageType + + +class ArtifactStorageCreatorSpec(CreatorSpec[ArtifactStorageRow]): + """CreatorSpec for ArtifactStorageRow with deferred storage_id.""" + + def __init__(self, name: str, storage_type: ArtifactStorageType) -> None: + self._name = name + self._storage_type = storage_type + self._storage_id: ConcreteStorageId | None = None + + def set_storage_id(self, storage_id: ConcreteStorageId) -> None: + self._storage_id = storage_id + + @override + def build_row(self) -> ArtifactStorageRow: + if self._storage_id is None: + raise InternalServerError("storage_id must be set before building row") + return ArtifactStorageRow( + name=self._name, + storage_id=self._storage_id, + type=self._storage_type, + ) + + +@dataclass +class ArtifactStorageUpdaterSpec(UpdaterSpec[ArtifactStorageRow]): + """UpdaterSpec for ArtifactStorageRow.""" + + name: OptionalState[str] = field(default_factory=OptionalState.nop) + + @property + @override + def row_class(self) -> type[ArtifactStorageRow]: + return ArtifactStorageRow + + @override + def build_values(self) -> dict[str, Any]: + values: dict[str, Any] = {} + self.name.update_dict(values, "name") + return values diff --git a/src/ai/backend/manager/errors/artifact_storage.py b/src/ai/backend/manager/errors/artifact_storage.py new file mode 100644 index 00000000000..1dfb59d85f8 --- /dev/null +++ b/src/ai/backend/manager/errors/artifact_storage.py @@ -0,0 +1,21 @@ +from aiohttp import web + +from ai.backend.common.exception import ( + BackendAIError, + ErrorCode, + ErrorDetail, + ErrorDomain, + ErrorOperation, +) + + +class ArtifactStorageNotFoundError(BackendAIError, web.HTTPNotFound): + error_type = "https://api.backend.ai/probs/artifact-storage-not-found" + error_title = "Artifact Storage Not Found" + + def error_code(self) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.ARTIFACT_STORAGE, + operation=ErrorOperation.READ, + error_detail=ErrorDetail.NOT_FOUND, + ) diff --git a/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py b/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py new file mode 100644 index 00000000000..ef75d9db519 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py @@ -0,0 +1,174 @@ +"""Add artifact_storages common table + +Revision ID: 35dfab3b0662 +Revises: ccf8ae5c90fe +Create Date: 2025-12-02 09:24:21.050932 + +""" + +import sqlalchemy as sa +from alembic import op + +from ai.backend.manager.models.base import GUID + +# revision identifiers, used by Alembic. +revision = "35dfab3b0662" +down_revision = "ccf8ae5c90fe" +branch_labels = None +depends_on = None + + +def _migrate_object_storages_to_artifact_storages( + conn: sa.engine.Connection, +) -> None: + """Migrate existing object_storages records to artifact_storages.""" + result = conn.execute( + sa.text(""" + SELECT id, name FROM object_storages + WHERE name IS NOT NULL + """) + ) + + for row in result: + conn.execute( + sa.text(""" + INSERT INTO artifact_storages (name, storage_id, type) + VALUES (:name, :storage_id, :type) + """), + {"name": row.name, "storage_id": row.id, "type": "object_storage"}, + ) + + # Drop the name column and constraint + op.drop_index("ix_object_storages_name", table_name="object_storages") + op.drop_column("object_storages", "name") + + +def _migrate_vfs_storages_to_artifact_storages( + conn: sa.engine.Connection, +) -> None: + """Migrate existing vfs_storages records to artifact_storages.""" + result = conn.execute( + sa.text(""" + SELECT id, name FROM vfs_storages + WHERE name IS NOT NULL + """) + ) + + for row in result: + conn.execute( + sa.text(""" + INSERT INTO artifact_storages (name, storage_id, type) + VALUES (:name, :storage_id, :type) + """), + {"name": row.name, "storage_id": row.id, "type": "vfs_storage"}, + ) + + # Drop the name column and constraint + op.drop_index("ix_vfs_storages_name", table_name="vfs_storages") + op.drop_column("vfs_storages", "name") + + +def _migrate_artifact_storages_to_object_storages( + conn: sa.engine.Connection, +) -> None: + """Migrate data back from artifact_storages to object_storages.""" + result = conn.execute( + sa.text(""" + SELECT name, storage_id FROM artifact_storages + WHERE type = 'object_storage' + """) + ) + + for row in result: + conn.execute( + sa.text(""" + UPDATE object_storages + SET name = :name + WHERE id = :storage_id + """), + {"name": row.name, "storage_id": row.storage_id}, + ) + + +def _migrate_artifact_storages_to_vfs_storages( + conn: sa.engine.Connection, +) -> None: + """Migrate data back from artifact_storages to vfs_storages.""" + result = conn.execute( + sa.text(""" + SELECT name, storage_id FROM artifact_storages + WHERE type = 'vfs_storage' + """) + ) + + for row in result: + conn.execute( + sa.text(""" + UPDATE vfs_storages + SET name = :name + WHERE id = :storage_id + """), + {"name": row.name, "storage_id": row.storage_id}, + ) + + +def upgrade() -> None: + op.create_table( + "artifact_storages", + sa.Column("id", GUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("storage_id", GUID(), nullable=False), + sa.Column("type", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("pk_artifact_storages")), + sa.UniqueConstraint("name", name=op.f("uq_artifact_storages_name")), + sa.UniqueConstraint("storage_id", name=op.f("uq_artifact_storages_storage_id")), + ) + + conn = op.get_bind() + + _migrate_object_storages_to_artifact_storages(conn) + _migrate_vfs_storages_to_artifact_storages(conn) + + +def downgrade() -> None: + conn = op.get_bind() + + # Add name column back to object_storages + op.add_column( + "object_storages", + sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=True), + ) + + _migrate_artifact_storages_to_object_storages(conn) + + # Make name column NOT NULL + op.alter_column( + "object_storages", + "name", + existing_type=sa.VARCHAR(), + nullable=False, + ) + + # Recreate constraints + op.create_index("ix_object_storages_name", "object_storages", ["name"], unique=True) + + # Add name column back to vfs_storages + op.add_column( + "vfs_storages", + sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=True), + ) + + _migrate_artifact_storages_to_vfs_storages(conn) + + # Make name column NOT NULL + op.alter_column( + "vfs_storages", + "name", + existing_type=sa.VARCHAR(), + nullable=False, + ) + + # Recreate constraints + op.create_index("ix_vfs_storages_name", "vfs_storages", ["name"], unique=True) + + op.drop_table("artifact_storages") diff --git a/src/ai/backend/manager/models/artifact_storages.py b/src/ai/backend/manager/models/artifact_storages.py new file mode 100644 index 00000000000..ac3409f80a4 --- /dev/null +++ b/src/ai/backend/manager/models/artifact_storages.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import logging +import uuid +from typing import TYPE_CHECKING + +import sqlalchemy as sa +from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship + +from ai.backend.common.data.storage.types import ArtifactStorageType +from ai.backend.logging import BraceStyleAdapter + +from .base import ( + GUID, + Base, +) + +if TYPE_CHECKING: + from ai.backend.manager.data.artifact_storages.types import ArtifactStorageData + from ai.backend.manager.models.object_storage import ObjectStorageRow + from ai.backend.manager.models.vfs_storage import VFSStorageRow + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + +__all__ = ("ArtifactStorageRow",) + + +def _get_object_storage_join_condition() -> sa.ColumnElement[bool]: + from ai.backend.manager.models.object_storage import ObjectStorageRow + + return ObjectStorageRow.id == foreign(ArtifactStorageRow.storage_id) + + +def _get_vfs_storage_join_condition() -> sa.ColumnElement[bool]: + from ai.backend.manager.models.vfs_storage import VFSStorageRow + + return VFSStorageRow.id == foreign(ArtifactStorageRow.storage_id) + + +class ArtifactStorageRow(Base): # type: ignore[misc] + """ + Common information of all artifact storage records. + """ + + __tablename__ = "artifact_storages" + + id: Mapped[uuid.UUID] = mapped_column( + "id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()") + ) + name: Mapped[str] = mapped_column("name", sa.String, nullable=False, unique=True) + storage_id: Mapped[uuid.UUID] = mapped_column("storage_id", GUID, nullable=False, unique=True) + type: Mapped[str] = mapped_column("type", sa.String, nullable=False) + + object_storages: Mapped[ObjectStorageRow | None] = relationship( + "ObjectStorageRow", + back_populates="meta", + primaryjoin=_get_object_storage_join_condition, + uselist=False, + viewonly=True, + ) + vfs_storages: Mapped[VFSStorageRow | None] = relationship( + "VFSStorageRow", + back_populates="meta", + primaryjoin=_get_vfs_storage_join_condition, + uselist=False, + viewonly=True, + ) + + def __str__(self) -> str: + return f"ArtifactStorageRow(id={self.id}, storage_id={self.storage_id}, type={self.type}, name={self.name})" + + def __repr__(self) -> str: + return self.__str__() + + def to_dataclass(self) -> ArtifactStorageData: + from ai.backend.common.types import ArtifactStorageId, ConcreteStorageId + from ai.backend.manager.data.artifact_storages.types import ArtifactStorageData + + return ArtifactStorageData( + id=ArtifactStorageId(self.id), + name=self.name, + storage_id=ConcreteStorageId(self.storage_id), + type=ArtifactStorageType(self.type), + ) diff --git a/src/ai/backend/manager/models/object_storage/row.py b/src/ai/backend/manager/models/object_storage/row.py index 2a1d241b079..254f80c8170 100644 --- a/src/ai/backend/manager/models/object_storage/row.py +++ b/src/ai/backend/manager/models/object_storage/row.py @@ -7,6 +7,7 @@ import sqlalchemy as sa from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship +from ai.backend.common.exception import RelationNotLoadedError from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.object_storage.types import ObjectStorageData from ai.backend.manager.models.base import ( @@ -15,6 +16,7 @@ ) if TYPE_CHECKING: + from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.association_artifacts_storages import ( AssociationArtifactsStorageRow, ) @@ -39,6 +41,12 @@ def _get_object_storage_namespace_join_cond() -> sa.ColumnElement[bool]: return foreign(StorageNamespaceRow.storage_id) == ObjectStorageRow.id +def _get_object_storage_meta_join_cond() -> sa.ColumnElement[bool]: + from ai.backend.manager.models.artifact_storages import ArtifactStorageRow + + return ObjectStorageRow.id == foreign(ArtifactStorageRow.storage_id) + + class ObjectStorageRow(Base): # type: ignore[misc] """ Represents an object storage configuration. @@ -51,7 +59,6 @@ class ObjectStorageRow(Base): # type: ignore[misc] id: Mapped[uuid.UUID] = mapped_column( "id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()") ) - name: Mapped[str] = mapped_column("name", sa.String, index=True, unique=True, nullable=False) host: Mapped[str] = mapped_column("host", sa.String, index=True, nullable=False) access_key: Mapped[str] = mapped_column( "access_key", @@ -87,12 +94,18 @@ class ObjectStorageRow(Base): # type: ignore[misc] back_populates="object_storage_row", primaryjoin=_get_object_storage_namespace_join_cond, ) + meta: Mapped[ArtifactStorageRow | None] = relationship( + "ArtifactStorageRow", + back_populates="object_storages", + primaryjoin=_get_object_storage_meta_join_cond, + uselist=False, + viewonly=True, + ) def __str__(self) -> str: return ( f"ObjectStorageRow(" f"id={self.id}, " - f"name={self.name}, " f"host={self.host}, " f"access_key={self.access_key}, " f"secret_key={self.secret_key}, " @@ -104,9 +117,11 @@ def __repr__(self) -> str: return self.__str__() def to_dataclass(self) -> ObjectStorageData: + if self.meta is None: + raise RelationNotLoadedError() return ObjectStorageData( id=self.id, - name=self.name, + name=self.meta.name, host=self.host, access_key=self.access_key, secret_key=self.secret_key, diff --git a/src/ai/backend/manager/models/vfs_storage/row.py b/src/ai/backend/manager/models/vfs_storage/row.py index 11ca46f4041..5b92a40962c 100644 --- a/src/ai/backend/manager/models/vfs_storage/row.py +++ b/src/ai/backend/manager/models/vfs_storage/row.py @@ -8,6 +8,7 @@ import sqlalchemy as sa from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship +from ai.backend.common.exception import RelationNotLoadedError from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.vfs_storage.types import VFSStorageData from ai.backend.manager.models.base import ( @@ -16,6 +17,7 @@ ) if TYPE_CHECKING: + from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.association_artifacts_storages import ( AssociationArtifactsStorageRow, ) @@ -33,6 +35,12 @@ def _get_vfs_storage_association_artifact_join_cond() -> sa.ColumnElement[bool]: return VFSStorageRow.id == foreign(AssociationArtifactsStorageRow.storage_namespace_id) +def _get_vfs_storage_meta_join_cond() -> sa.ColumnElement[bool]: + from ai.backend.manager.models.artifact_storages import ArtifactStorageRow + + return VFSStorageRow.id == foreign(ArtifactStorageRow.storage_id) + + class VFSStorageRow(Base): # type: ignore[misc] """ Represents a VFS storage configuration. @@ -45,7 +53,6 @@ class VFSStorageRow(Base): # type: ignore[misc] id: Mapped[uuid.UUID] = mapped_column( "id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()") ) - name: Mapped[str] = mapped_column("name", sa.String, index=True, unique=True, nullable=False) host: Mapped[str] = mapped_column("host", sa.String, nullable=False) base_path: Mapped[str] = mapped_column("base_path", sa.String, nullable=False) @@ -57,23 +64,26 @@ class VFSStorageRow(Base): # type: ignore[misc] overlaps="association_artifacts_storages_rows,object_storage_row", ) ) + meta: Mapped[ArtifactStorageRow | None] = relationship( + "ArtifactStorageRow", + back_populates="vfs_storages", + primaryjoin=_get_vfs_storage_meta_join_cond, + uselist=False, + viewonly=True, + ) def __str__(self) -> str: - return ( - f"VFSStorageRow(" - f"id={self.id}, " - f"name={self.name}, " - f"host={self.host}, " - f"base_path={self.base_path})" - ) + return f"VFSStorageRow(id={self.id}, host={self.host}, base_path={self.base_path})" def __repr__(self) -> str: return self.__str__() def to_dataclass(self) -> VFSStorageData: + if self.meta is None: + raise RelationNotLoadedError() return VFSStorageData( id=self.id, - name=self.name, + name=self.meta.name, host=self.host, base_path=Path(self.base_path), ) diff --git a/src/ai/backend/manager/repositories/artifact_storage/__init__.py b/src/ai/backend/manager/repositories/artifact_storage/__init__.py new file mode 100644 index 00000000000..adc5b5889c3 --- /dev/null +++ b/src/ai/backend/manager/repositories/artifact_storage/__init__.py @@ -0,0 +1,7 @@ +from .repositories import ArtifactStorageRepositories +from .repository import ArtifactStorageRepository + +__all__ = ( + "ArtifactStorageRepositories", + "ArtifactStorageRepository", +) diff --git a/src/ai/backend/manager/repositories/artifact_storage/db_source/__init__.py b/src/ai/backend/manager/repositories/artifact_storage/db_source/__init__.py new file mode 100644 index 00000000000..9c06935ade8 --- /dev/null +++ b/src/ai/backend/manager/repositories/artifact_storage/db_source/__init__.py @@ -0,0 +1,3 @@ +from .db_source import ArtifactStorageDBSource + +__all__ = ("ArtifactStorageDBSource",) diff --git a/src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py b/src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py new file mode 100644 index 00000000000..cac2d7178c7 --- /dev/null +++ b/src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import uuid + +import sqlalchemy as sa + +from ai.backend.manager.data.artifact_storages.types import ArtifactStorageData +from ai.backend.manager.errors.artifact_storage import ArtifactStorageNotFoundError +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.repositories.base.updater import Updater, execute_updater + + +class ArtifactStorageDBSource: + """Database source for artifact storage operations.""" + + _db: ExtendedAsyncSAEngine + + def __init__(self, db: ExtendedAsyncSAEngine) -> None: + self._db = db + + async def get_by_id(self, storage_id: uuid.UUID) -> ArtifactStorageData: + """ + Get an existing artifact storage configuration from the database by ID. + """ + async with self._db.begin_session() as db_session: + query = sa.select(ArtifactStorageRow).where(ArtifactStorageRow.id == storage_id) + result = await db_session.execute(query) + row = result.scalar_one_or_none() + if row is None: + raise ArtifactStorageNotFoundError( + f"Artifact storage with ID {storage_id} not found." + ) + return row.to_dataclass() + + async def get_by_storage_id(self, storage_id: uuid.UUID) -> ArtifactStorageData: + """ + Get an existing artifact storage configuration from the database by storage_id. + """ + async with self._db.begin_session() as db_session: + query = sa.select(ArtifactStorageRow).where(ArtifactStorageRow.storage_id == storage_id) + result = await db_session.execute(query) + row = result.scalar_one_or_none() + if row is None: + raise ArtifactStorageNotFoundError( + f"Artifact storage with storage_id {storage_id} not found." + ) + return row.to_dataclass() + + async def update( + self, + updater: Updater[ArtifactStorageRow], + ) -> ArtifactStorageData: + """ + Update an existing artifact storage configuration in the database. + """ + async with self._db.begin_session() as db_session: + # Execute update (may return None if no values to update, which is fine) + await execute_updater(db_session, updater) + + artifact_storage_id = uuid.UUID(str(updater.pk_value)) + # Re-query to get the updated row + query = sa.select(ArtifactStorageRow).where( + ArtifactStorageRow.id == artifact_storage_id + ) + row_result = await db_session.execute(query) + row = row_result.scalar_one_or_none() + if row is None: + raise ArtifactStorageNotFoundError( + f"Artifact storage with ID {artifact_storage_id} not found." + ) + return row.to_dataclass() diff --git a/src/ai/backend/manager/repositories/artifact_storage/repositories.py b/src/ai/backend/manager/repositories/artifact_storage/repositories.py new file mode 100644 index 00000000000..316b4ef67de --- /dev/null +++ b/src/ai/backend/manager/repositories/artifact_storage/repositories.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from typing import Self + +from ai.backend.manager.repositories.types import RepositoryArgs + +from .repository import ArtifactStorageRepository + + +@dataclass +class ArtifactStorageRepositories: + repository: ArtifactStorageRepository + + @classmethod + def create(cls, args: RepositoryArgs) -> Self: + return cls( + repository=ArtifactStorageRepository( + db=args.db, + ), + ) diff --git a/src/ai/backend/manager/repositories/artifact_storage/repository.py b/src/ai/backend/manager/repositories/artifact_storage/repository.py new file mode 100644 index 00000000000..6c02311e011 --- /dev/null +++ b/src/ai/backend/manager/repositories/artifact_storage/repository.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import uuid + +from ai.backend.common.exception import BackendAIError +from ai.backend.common.metrics.metric import DomainType, LayerType +from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy +from ai.backend.common.resilience.policies.retry import BackoffStrategy, RetryArgs, RetryPolicy +from ai.backend.common.resilience.resilience import Resilience +from ai.backend.manager.data.artifact_storages.types import ArtifactStorageData +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.repositories.artifact_storage.db_source.db_source import ( + ArtifactStorageDBSource, +) +from ai.backend.manager.repositories.base.updater import Updater + +artifact_storage_repository_resilience = Resilience( + policies=[ + MetricPolicy( + MetricArgs( + domain=DomainType.REPOSITORY, + layer=LayerType.ARTIFACT_STORAGE_REPOSITORY, + ) + ), + RetryPolicy( + RetryArgs( + max_retries=10, + retry_delay=0.1, + backoff_strategy=BackoffStrategy.FIXED, + non_retryable_exceptions=(BackendAIError,), + ) + ), + ] +) + + +class ArtifactStorageRepository: + """Repository layer for artifact storage operations.""" + + _db_source: ArtifactStorageDBSource + + def __init__(self, db: ExtendedAsyncSAEngine) -> None: + self._db_source = ArtifactStorageDBSource(db) + + @artifact_storage_repository_resilience.apply() + async def get_by_id(self, storage_id: uuid.UUID) -> ArtifactStorageData: + return await self._db_source.get_by_id(storage_id) + + @artifact_storage_repository_resilience.apply() + async def get_by_storage_id(self, storage_id: uuid.UUID) -> ArtifactStorageData: + return await self._db_source.get_by_storage_id(storage_id) + + @artifact_storage_repository_resilience.apply() + async def update( + self, + updater: Updater[ArtifactStorageRow], + ) -> ArtifactStorageData: + return await self._db_source.update(updater) diff --git a/src/ai/backend/manager/repositories/object_storage/creators.py b/src/ai/backend/manager/repositories/object_storage/creators.py index e1a0b4c006f..190dce280ae 100644 --- a/src/ai/backend/manager/repositories/object_storage/creators.py +++ b/src/ai/backend/manager/repositories/object_storage/creators.py @@ -13,7 +13,6 @@ class ObjectStorageCreatorSpec(CreatorSpec[ObjectStorageRow]): """CreatorSpec for object storage creation.""" - name: str host: str access_key: str secret_key: str @@ -23,7 +22,6 @@ class ObjectStorageCreatorSpec(CreatorSpec[ObjectStorageRow]): @override def build_row(self) -> ObjectStorageRow: return ObjectStorageRow( - name=self.name, host=self.host, access_key=self.access_key, secret_key=self.secret_key, diff --git a/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py b/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py index 96695f9e806..c0e7632c879 100644 --- a/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py @@ -5,10 +5,14 @@ import sqlalchemy as sa from sqlalchemy.orm import selectinload +from ai.backend.common.types import ConcreteStorageId +from ai.backend.manager.data.artifact_storages.types import ArtifactStorageCreatorSpec from ai.backend.manager.data.object_storage.types import ObjectStorageData, ObjectStorageListResult +from ai.backend.manager.errors.common import InternalServerError from ai.backend.manager.errors.object_storage import ( ObjectStorageNotFoundError, ) +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.object_storage import ObjectStorageRow from ai.backend.manager.models.storage_namespace import StorageNamespaceRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine @@ -30,21 +34,37 @@ async def get_by_name(self, storage_name: str) -> ObjectStorageData: Get an existing object storage configuration from the database. """ async with self._db.begin_readonly_session_read_committed() as db_session: - query = sa.select(ObjectStorageRow).where(ObjectStorageRow.name == storage_name) + query = ( + sa.select(ArtifactStorageRow) + .where(ArtifactStorageRow.name == storage_name) + .options( + selectinload(ArtifactStorageRow.object_storages).selectinload( + ObjectStorageRow.meta + ) + ) + ) result = await db_session.execute(query) row = result.scalar_one_or_none() if row is None: raise ObjectStorageNotFoundError( f"Object storage with name {storage_name} not found." ) - return row.to_dataclass() + if row.object_storages is None: + raise ObjectStorageNotFoundError( + f"Object storage not found for name {storage_name}" + ) + return row.object_storages.to_dataclass() async def get_by_id(self, storage_id: uuid.UUID) -> ObjectStorageData: """ Get an existing object storage configuration from the database by ID. """ async with self._db.begin_readonly_session_read_committed() as db_session: - query = sa.select(ObjectStorageRow).where(ObjectStorageRow.id == storage_id) + query = ( + sa.select(ObjectStorageRow) + .where(ObjectStorageRow.id == storage_id) + .options(selectinload(ObjectStorageRow.meta)) + ) result = await db_session.execute(query) row = result.scalar_one_or_none() if row is None: @@ -59,7 +79,11 @@ async def get_by_namespace_id(self, storage_namespace_id: uuid.UUID) -> ObjectSt query = ( sa.select(StorageNamespaceRow) .where(StorageNamespaceRow.id == storage_namespace_id) - .options(selectinload(StorageNamespaceRow.object_storage_row)) + .options( + selectinload(StorageNamespaceRow.object_storage_row).selectinload( + ObjectStorageRow.meta + ) + ) ) result = await db_session.execute(query) row = result.scalar_one_or_none() @@ -73,25 +97,56 @@ async def get_by_namespace_id(self, storage_namespace_id: uuid.UUID) -> ObjectSt ) return row.object_storage_row.to_dataclass() - async def create(self, creator: Creator[ObjectStorageRow]) -> ObjectStorageData: + async def create( + self, creator: Creator[ObjectStorageRow], meta_creator: Creator[ArtifactStorageRow] + ) -> ObjectStorageData: """ Create a new object storage configuration in the database. """ async with self._db.begin_session() as db_session: creator_result = await execute_creator(db_session, creator) - return creator_result.row.to_dataclass() + new_row = creator_result.row + + # Set the storage_id on the meta creator spec and create ArtifactStorageRow + meta_spec = meta_creator.spec + if not isinstance(meta_spec, ArtifactStorageCreatorSpec): + raise InternalServerError("meta_creator.spec must be ArtifactStorageCreatorSpec") + meta_spec.set_storage_id(ConcreteStorageId(new_row.id)) + await execute_creator(db_session, meta_creator) - async def update(self, updater: Updater[ObjectStorageRow]) -> ObjectStorageData: + # Re-query to load the meta relationship + query = ( + sa.select(ObjectStorageRow) + .where(ObjectStorageRow.id == new_row.id) + .options(selectinload(ObjectStorageRow.meta)) + ) + row_result = await db_session.execute(query) + row = row_result.scalar_one() + return row.to_dataclass() + + async def update( + self, + updater: Updater[ObjectStorageRow], + ) -> ObjectStorageData: """ Update an existing object storage configuration in the database. """ async with self._db.begin_session() as db_session: - result = await execute_updater(db_session, updater) - if result is None: - raise ObjectStorageNotFoundError( - f"Object storage with ID {updater.pk_value} not found." - ) - return result.row.to_dataclass() + # Execute update (may return None if no values to update, which is fine) + await execute_updater(db_session, updater) + + storage_id = uuid.UUID(str(updater.pk_value)) + # Re-query to load the meta relationship + query = ( + sa.select(ObjectStorageRow) + .where(ObjectStorageRow.id == storage_id) + .options(selectinload(ObjectStorageRow.meta)) + ) + row_result = await db_session.execute(query) + row = row_result.scalar_one_or_none() + if row is None: + raise ObjectStorageNotFoundError(f"Object storage with ID {storage_id} not found.") + return row.to_dataclass() async def delete(self, storage_id: uuid.UUID) -> uuid.UUID: """ @@ -107,6 +162,12 @@ async def delete(self, storage_id: uuid.UUID) -> uuid.UUID: deleted_id = result.scalar() if deleted_id is None: raise ObjectStorageNotFoundError(f"Object storage with ID {storage_id} not found.") + + # Delete the ArtifactStorageRow + delete_meta_query = sa.delete(ArtifactStorageRow).where( + ArtifactStorageRow.storage_id == storage_id + ) + await db_session.execute(delete_meta_query) return deleted_id async def list_object_storages(self) -> list[ObjectStorageData]: @@ -114,7 +175,7 @@ async def list_object_storages(self) -> list[ObjectStorageData]: List all object storage configurations from the database. """ async with self._db.begin_readonly_session_read_committed() as db_session: - query = sa.select(ObjectStorageRow) + query = sa.select(ObjectStorageRow).options(selectinload(ObjectStorageRow.meta)) result = await db_session.execute(query) rows = result.scalars().all() return [row.to_dataclass() for row in rows] @@ -125,7 +186,7 @@ async def search( ) -> ObjectStorageListResult: """Searches Object storages with total count.""" async with self._db.begin_readonly_session() as db_sess: - query = sa.select(ObjectStorageRow) + query = sa.select(ObjectStorageRow).options(selectinload(ObjectStorageRow.meta)) result = await execute_batch_querier( db_sess, diff --git a/src/ai/backend/manager/repositories/object_storage/repository.py b/src/ai/backend/manager/repositories/object_storage/repository.py index 59eea2cd3bc..6d5e909f62e 100644 --- a/src/ai/backend/manager/repositories/object_storage/repository.py +++ b/src/ai/backend/manager/repositories/object_storage/repository.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import uuid +from typing import TYPE_CHECKING from ai.backend.common.exception import ( BackendAIError, @@ -15,6 +18,9 @@ from ai.backend.manager.repositories.base.updater import Updater from ai.backend.manager.repositories.object_storage.db_source.db_source import ObjectStorageDBSource +if TYPE_CHECKING: + from ai.backend.manager.models.artifact_storages import ArtifactStorageRow + object_storage_repository_resilience = Resilience( policies=[ MetricPolicy( @@ -53,11 +59,16 @@ async def get_by_namespace_id(self, storage_namespace_id: uuid.UUID) -> ObjectSt return await self._db_source.get_by_namespace_id(storage_namespace_id) @object_storage_repository_resilience.apply() - async def create(self, creator: Creator[ObjectStorageRow]) -> ObjectStorageData: - return await self._db_source.create(creator) + async def create( + self, creator: Creator[ObjectStorageRow], meta_creator: Creator[ArtifactStorageRow] + ) -> ObjectStorageData: + return await self._db_source.create(creator, meta_creator) @object_storage_repository_resilience.apply() - async def update(self, updater: Updater[ObjectStorageRow]) -> ObjectStorageData: + async def update( + self, + updater: Updater[ObjectStorageRow], + ) -> ObjectStorageData: return await self._db_source.update(updater) @object_storage_repository_resilience.apply() diff --git a/src/ai/backend/manager/repositories/object_storage/updaters.py b/src/ai/backend/manager/repositories/object_storage/updaters.py index e9dd0d08ba9..27059dd7ee7 100644 --- a/src/ai/backend/manager/repositories/object_storage/updaters.py +++ b/src/ai/backend/manager/repositories/object_storage/updaters.py @@ -14,7 +14,6 @@ class ObjectStorageUpdaterSpec(UpdaterSpec[ObjectStorageRow]): """UpdaterSpec for object storage updates.""" - name: OptionalState[str] = field(default_factory=OptionalState.nop) host: OptionalState[str] = field(default_factory=OptionalState.nop) access_key: OptionalState[str] = field(default_factory=OptionalState.nop) secret_key: OptionalState[str] = field(default_factory=OptionalState.nop) @@ -29,7 +28,6 @@ def row_class(self) -> type[ObjectStorageRow]: @override def build_values(self) -> dict[str, Any]: to_update: dict[str, Any] = {} - self.name.update_dict(to_update, "name") self.host.update_dict(to_update, "host") self.access_key.update_dict(to_update, "access_key") self.secret_key.update_dict(to_update, "secret_key") diff --git a/src/ai/backend/manager/repositories/repositories.py b/src/ai/backend/manager/repositories/repositories.py index faf2b29c45b..c56d9f9a169 100644 --- a/src/ai/backend/manager/repositories/repositories.py +++ b/src/ai/backend/manager/repositories/repositories.py @@ -7,6 +7,9 @@ from ai.backend.manager.repositories.artifact_registry.repositories import ( ArtifactRegistryRepositories, ) +from ai.backend.manager.repositories.artifact_storage.repositories import ( + ArtifactStorageRepositories, +) from ai.backend.manager.repositories.audit_log.repositories import AuditLogRepositories from ai.backend.manager.repositories.auth.repositories import AuthRepositories from ai.backend.manager.repositories.container_registry.repositories import ( @@ -104,6 +107,7 @@ class Repositories: huggingface_registry: HuggingFaceRegistryRepositories artifact: ArtifactRepositories artifact_registry: ArtifactRegistryRepositories + artifact_storage: ArtifactStorageRepositories storage_namespace: StorageNamespaceRepositories audit_log: AuditLogRepositories @@ -146,6 +150,7 @@ def create(cls, args: RepositoryArgs) -> Self: artifact_repositories = ArtifactRepositories.create(args) huggingface_registry_repositories = HuggingFaceRegistryRepositories.create(args) artifact_registries = ArtifactRegistryRepositories.create(args) + artifact_storage_repositories = ArtifactStorageRepositories.create(args) storage_namespace_repositories = StorageNamespaceRepositories.create(args) audit_log_repositories = AuditLogRepositories.create(args) @@ -187,6 +192,7 @@ def create(cls, args: RepositoryArgs) -> Self: huggingface_registry=huggingface_registry_repositories, artifact=artifact_repositories, artifact_registry=artifact_registries, + artifact_storage=artifact_storage_repositories, storage_namespace=storage_namespace_repositories, audit_log=audit_log_repositories, ) diff --git a/src/ai/backend/manager/repositories/vfs_storage/creators.py b/src/ai/backend/manager/repositories/vfs_storage/creators.py index aec0a785dc3..5b2af29793f 100644 --- a/src/ai/backend/manager/repositories/vfs_storage/creators.py +++ b/src/ai/backend/manager/repositories/vfs_storage/creators.py @@ -13,14 +13,12 @@ class VFSStorageCreatorSpec(CreatorSpec[VFSStorageRow]): """CreatorSpec for VFS storage creation.""" - name: str host: str base_path: str @override def build_row(self) -> VFSStorageRow: return VFSStorageRow( - name=self.name, host=self.host, base_path=self.base_path, ) diff --git a/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py b/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py index 50024e3a124..fae7bf76ee2 100644 --- a/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py @@ -3,11 +3,16 @@ import uuid import sqlalchemy as sa +from sqlalchemy.orm import selectinload +from ai.backend.common.types import ConcreteStorageId +from ai.backend.manager.data.artifact_storages.types import ArtifactStorageCreatorSpec from ai.backend.manager.data.vfs_storage.types import VFSStorageData, VFSStorageListResult +from ai.backend.manager.errors.common import InternalServerError from ai.backend.manager.errors.vfs_storage import ( VFSStorageNotFoundError, ) +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.models.vfs_storage import VFSStorageRow from ai.backend.manager.repositories.base import BatchQuerier, execute_batch_querier @@ -28,42 +33,87 @@ async def get_by_name(self, storage_name: str) -> VFSStorageData: Get an existing VFS storage configuration from the database. """ async with self._db.begin_readonly_session_read_committed() as db_session: - query = sa.select(VFSStorageRow).where(VFSStorageRow.name == storage_name) + query = ( + sa.select(ArtifactStorageRow) + .where(ArtifactStorageRow.name == storage_name) + .options( + selectinload(ArtifactStorageRow.vfs_storages).selectinload(VFSStorageRow.meta) + ) + ) result = await db_session.execute(query) row = result.scalar_one_or_none() if row is None: raise VFSStorageNotFoundError(f"VFS storage with name {storage_name} not found.") - return row.to_dataclass() + if row.vfs_storages is None: + raise VFSStorageNotFoundError(f"VFS storage not found for name {storage_name}") + return row.vfs_storages.to_dataclass() async def get_by_id(self, storage_id: uuid.UUID) -> VFSStorageData: """ Get an existing VFS storage configuration from the database by ID. """ async with self._db.begin_readonly_session_read_committed() as db_session: - query = sa.select(VFSStorageRow).where(VFSStorageRow.id == storage_id) + query = ( + sa.select(VFSStorageRow) + .where(VFSStorageRow.id == storage_id) + .options(selectinload(VFSStorageRow.meta)) + ) result = await db_session.execute(query) row = result.scalar_one_or_none() if row is None: raise VFSStorageNotFoundError(f"VFS storage with ID {storage_id} not found.") return row.to_dataclass() - async def create(self, creator: Creator[VFSStorageRow]) -> VFSStorageData: + async def create( + self, creator: Creator[VFSStorageRow], meta_creator: Creator[ArtifactStorageRow] + ) -> VFSStorageData: """ Create a new VFS storage configuration in the database. """ async with self._db.begin_session() as db_session: creator_result = await execute_creator(db_session, creator) - return creator_result.row.to_dataclass() + new_row = creator_result.row + + # Set the storage_id on the meta creator spec and create ArtifactStorageRow + meta_spec = meta_creator.spec + if not isinstance(meta_spec, ArtifactStorageCreatorSpec): + raise InternalServerError("meta_creator.spec must be ArtifactStorageCreatorSpec") + meta_spec.set_storage_id(ConcreteStorageId(new_row.id)) + await execute_creator(db_session, meta_creator) + + # Re-query to load the meta relationship + query = ( + sa.select(VFSStorageRow) + .where(VFSStorageRow.id == new_row.id) + .options(selectinload(VFSStorageRow.meta)) + ) + row_result = await db_session.execute(query) + row = row_result.scalar_one() + return row.to_dataclass() - async def update(self, updater: Updater[VFSStorageRow]) -> VFSStorageData: + async def update( + self, + updater: Updater[VFSStorageRow], + ) -> VFSStorageData: """ Update an existing VFS storage configuration in the database. """ async with self._db.begin_session() as db_session: - result = await execute_updater(db_session, updater) - if result is None: - raise VFSStorageNotFoundError(f"VFS storage with ID {updater.pk_value} not found.") - return result.row.to_dataclass() + # Execute update (may return None if no values to update, which is fine) + await execute_updater(db_session, updater) + + storage_id = uuid.UUID(str(updater.pk_value)) + # Re-query to load the meta relationship + query = ( + sa.select(VFSStorageRow) + .where(VFSStorageRow.id == storage_id) + .options(selectinload(VFSStorageRow.meta)) + ) + row_result = await db_session.execute(query) + row = row_result.scalar_one_or_none() + if row is None: + raise VFSStorageNotFoundError(f"VFS storage with ID {storage_id} not found.") + return row.to_dataclass() async def delete(self, storage_id: uuid.UUID) -> uuid.UUID: """ @@ -79,6 +129,12 @@ async def delete(self, storage_id: uuid.UUID) -> uuid.UUID: deleted_id = result.scalar() if deleted_id is None: raise VFSStorageNotFoundError(f"VFS storage with ID {storage_id} not found.") + + # Delete the ArtifactStorageRow + delete_meta_query = sa.delete(ArtifactStorageRow).where( + ArtifactStorageRow.storage_id == storage_id + ) + await db_session.execute(delete_meta_query) return deleted_id async def list_vfs_storages(self) -> list[VFSStorageData]: @@ -86,7 +142,7 @@ async def list_vfs_storages(self) -> list[VFSStorageData]: List all VFS storage configurations from the database. """ async with self._db.begin_readonly_session_read_committed() as db_session: - query = sa.select(VFSStorageRow) + query = sa.select(VFSStorageRow).options(selectinload(VFSStorageRow.meta)) result = await db_session.execute(query) rows = result.scalars().all() return [row.to_dataclass() for row in rows] @@ -97,7 +153,7 @@ async def search( ) -> VFSStorageListResult: """Searches VFS storages with total count.""" async with self._db.begin_readonly_session() as db_sess: - query = sa.select(VFSStorageRow) + query = sa.select(VFSStorageRow).options(selectinload(VFSStorageRow.meta)) result = await execute_batch_querier( db_sess, diff --git a/src/ai/backend/manager/repositories/vfs_storage/repository.py b/src/ai/backend/manager/repositories/vfs_storage/repository.py index 5ab611a6233..641978d4343 100644 --- a/src/ai/backend/manager/repositories/vfs_storage/repository.py +++ b/src/ai/backend/manager/repositories/vfs_storage/repository.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import uuid +from typing import TYPE_CHECKING from ai.backend.common.exception import BackendAIError from ai.backend.common.metrics.metric import DomainType, LayerType @@ -13,6 +16,9 @@ from ai.backend.manager.repositories.base.updater import Updater from ai.backend.manager.repositories.vfs_storage.db_source.db_source import VFSStorageDBSource +if TYPE_CHECKING: + from ai.backend.manager.models.artifact_storages import ArtifactStorageRow + vfs_storage_repository_resilience = Resilience( policies=[ MetricPolicy( @@ -47,11 +53,16 @@ async def get_by_id(self, storage_id: uuid.UUID) -> VFSStorageData: return await self._db_source.get_by_id(storage_id) @vfs_storage_repository_resilience.apply() - async def create(self, creator: Creator[VFSStorageRow]) -> VFSStorageData: - return await self._db_source.create(creator) + async def create( + self, creator: Creator[VFSStorageRow], meta_creator: Creator[ArtifactStorageRow] + ) -> VFSStorageData: + return await self._db_source.create(creator, meta_creator) @vfs_storage_repository_resilience.apply() - async def update(self, updater: Updater[VFSStorageRow]) -> VFSStorageData: + async def update( + self, + updater: Updater[VFSStorageRow], + ) -> VFSStorageData: return await self._db_source.update(updater) @vfs_storage_repository_resilience.apply() diff --git a/src/ai/backend/manager/repositories/vfs_storage/updaters.py b/src/ai/backend/manager/repositories/vfs_storage/updaters.py index 96e8f682c2c..b85b6a71f84 100644 --- a/src/ai/backend/manager/repositories/vfs_storage/updaters.py +++ b/src/ai/backend/manager/repositories/vfs_storage/updaters.py @@ -14,7 +14,6 @@ class VFSStorageUpdaterSpec(UpdaterSpec[VFSStorageRow]): """UpdaterSpec for VFS storage updates.""" - name: OptionalState[str] = field(default_factory=OptionalState.nop) host: OptionalState[str] = field(default_factory=OptionalState.nop) base_path: OptionalState[str] = field(default_factory=OptionalState.nop) @@ -26,7 +25,6 @@ def row_class(self) -> type[VFSStorageRow]: @override def build_values(self) -> dict[str, Any]: to_update: dict[str, Any] = {} - self.name.update_dict(to_update, "name") self.host.update_dict(to_update, "host") self.base_path.update_dict(to_update, "base_path") return to_update diff --git a/src/ai/backend/manager/services/artifact_storage/__init__.py b/src/ai/backend/manager/services/artifact_storage/__init__.py new file mode 100644 index 00000000000..42c1c6fa761 --- /dev/null +++ b/src/ai/backend/manager/services/artifact_storage/__init__.py @@ -0,0 +1,7 @@ +from .processors import ArtifactStorageProcessors +from .service import ArtifactStorageService + +__all__ = ( + "ArtifactStorageProcessors", + "ArtifactStorageService", +) diff --git a/src/ai/backend/manager/services/artifact_storage/actions/__init__.py b/src/ai/backend/manager/services/artifact_storage/actions/__init__.py new file mode 100644 index 00000000000..670110e2da4 --- /dev/null +++ b/src/ai/backend/manager/services/artifact_storage/actions/__init__.py @@ -0,0 +1,8 @@ +from .base import ArtifactStorageAction +from .update import UpdateArtifactStorageAction, UpdateArtifactStorageActionResult + +__all__ = ( + "ArtifactStorageAction", + "UpdateArtifactStorageAction", + "UpdateArtifactStorageActionResult", +) diff --git a/src/ai/backend/manager/services/artifact_storage/actions/base.py b/src/ai/backend/manager/services/artifact_storage/actions/base.py new file mode 100644 index 00000000000..9625aeaff71 --- /dev/null +++ b/src/ai/backend/manager/services/artifact_storage/actions/base.py @@ -0,0 +1,11 @@ +from typing import override + +from ai.backend.common.data.permission.types import EntityType +from ai.backend.manager.actions.action import BaseAction + + +class ArtifactStorageAction(BaseAction): + @override + @classmethod + def entity_type(cls) -> EntityType: + return EntityType.ARTIFACT_STORAGE diff --git a/src/ai/backend/manager/services/artifact_storage/actions/update.py b/src/ai/backend/manager/services/artifact_storage/actions/update.py new file mode 100644 index 00000000000..12d3980cf4a --- /dev/null +++ b/src/ai/backend/manager/services/artifact_storage/actions/update.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.data.artifact_storages.types import ArtifactStorageData +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow +from ai.backend.manager.repositories.base.updater import Updater +from ai.backend.manager.services.artifact_storage.actions.base import ArtifactStorageAction + + +@dataclass +class UpdateArtifactStorageAction(ArtifactStorageAction): + updater: Updater[ArtifactStorageRow] + + @override + def entity_id(self) -> str | None: + return str(self.updater.pk_value) + + @override + @classmethod + def operation_type(cls) -> ActionOperationType: + return ActionOperationType.UPDATE + + +@dataclass +class UpdateArtifactStorageActionResult(BaseActionResult): + result: ArtifactStorageData + + @override + def entity_id(self) -> str | None: + return str(self.result.id) diff --git a/src/ai/backend/manager/services/artifact_storage/processors.py b/src/ai/backend/manager/services/artifact_storage/processors.py new file mode 100644 index 00000000000..1ede3b08bef --- /dev/null +++ b/src/ai/backend/manager/services/artifact_storage/processors.py @@ -0,0 +1,25 @@ +from typing import override + +from ai.backend.manager.actions.monitors.monitor import ActionMonitor +from ai.backend.manager.actions.processor import ActionProcessor +from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec +from ai.backend.manager.services.artifact_storage.actions.update import ( + UpdateArtifactStorageAction, + UpdateArtifactStorageActionResult, +) +from ai.backend.manager.services.artifact_storage.service import ArtifactStorageService + + +class ArtifactStorageProcessors(AbstractProcessorPackage): + update: ActionProcessor[UpdateArtifactStorageAction, UpdateArtifactStorageActionResult] + + def __init__( + self, service: ArtifactStorageService, action_monitors: list[ActionMonitor] + ) -> None: + self.update = ActionProcessor(service.update, action_monitors) + + @override + def supported_actions(self) -> list[ActionSpec]: + return [ + UpdateArtifactStorageAction.spec(), + ] diff --git a/src/ai/backend/manager/services/artifact_storage/service.py b/src/ai/backend/manager/services/artifact_storage/service.py new file mode 100644 index 00000000000..ee71c1724c6 --- /dev/null +++ b/src/ai/backend/manager/services/artifact_storage/service.py @@ -0,0 +1,32 @@ +import logging + +from ai.backend.logging.utils import BraceStyleAdapter +from ai.backend.manager.repositories.artifact_storage.repository import ArtifactStorageRepository +from ai.backend.manager.services.artifact_storage.actions.update import ( + UpdateArtifactStorageAction, + UpdateArtifactStorageActionResult, +) + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + + +class ArtifactStorageService: + """Service layer for artifact storage operations.""" + + _artifact_storage_repository: ArtifactStorageRepository + + def __init__( + self, + artifact_storage_repository: ArtifactStorageRepository, + ) -> None: + self._artifact_storage_repository = artifact_storage_repository + + async def update( + self, action: UpdateArtifactStorageAction + ) -> UpdateArtifactStorageActionResult: + """ + Update an existing artifact storage. + """ + log.info("Updating artifact storage with id: {}", action.updater.pk_value) + storage_data = await self._artifact_storage_repository.update(action.updater) + return UpdateArtifactStorageActionResult(result=storage_data) diff --git a/src/ai/backend/manager/services/object_storage/actions/create.py b/src/ai/backend/manager/services/object_storage/actions/create.py index 6ac9ab7e804..aa5294caf3f 100644 --- a/src/ai/backend/manager/services/object_storage/actions/create.py +++ b/src/ai/backend/manager/services/object_storage/actions/create.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from dataclasses import dataclass -from typing import override +from typing import TYPE_CHECKING, override from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType @@ -8,10 +10,14 @@ from ai.backend.manager.repositories.base.creator import Creator from ai.backend.manager.services.object_storage.actions.base import ObjectStorageAction +if TYPE_CHECKING: + from ai.backend.manager.models.artifact_storages import ArtifactStorageRow + @dataclass class CreateObjectStorageAction(ObjectStorageAction): creator: Creator[ObjectStorageRow] + meta_creator: Creator[ArtifactStorageRow] @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/object_storage/actions/update.py b/src/ai/backend/manager/services/object_storage/actions/update.py index f51e8576bf3..3a84786d58e 100644 --- a/src/ai/backend/manager/services/object_storage/actions/update.py +++ b/src/ai/backend/manager/services/object_storage/actions/update.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from typing import override diff --git a/src/ai/backend/manager/services/object_storage/service.py b/src/ai/backend/manager/services/object_storage/service.py index 61d2322b7ee..7d7a139e7f1 100644 --- a/src/ai/backend/manager/services/object_storage/service.py +++ b/src/ai/backend/manager/services/object_storage/service.py @@ -77,7 +77,9 @@ async def create(self, action: CreateObjectStorageAction) -> CreateObjectStorage Create a new object storage. """ log.info("Creating object storage with data: {}", action.creator) - storage_data = await self._object_storage_repository.create(action.creator) + storage_data = await self._object_storage_repository.create( + action.creator, action.meta_creator + ) return CreateObjectStorageActionResult(result=storage_data) async def update(self, action: UpdateObjectStorageAction) -> UpdateObjectStorageActionResult: diff --git a/src/ai/backend/manager/services/processors.py b/src/ai/backend/manager/services/processors.py index e7985d56e15..e1077fd4fb2 100644 --- a/src/ai/backend/manager/services/processors.py +++ b/src/ai/backend/manager/services/processors.py @@ -38,6 +38,8 @@ from ai.backend.manager.services.artifact_registry.service import ArtifactRegistryService from ai.backend.manager.services.artifact_revision.processors import ArtifactRevisionProcessors from ai.backend.manager.services.artifact_revision.service import ArtifactRevisionService +from ai.backend.manager.services.artifact_storage.processors import ArtifactStorageProcessors +from ai.backend.manager.services.artifact_storage.service import ArtifactStorageService from ai.backend.manager.services.audit_log.processors import AuditLogProcessors from ai.backend.manager.services.audit_log.service import AuditLogService from ai.backend.manager.services.auth.processors import AuthProcessors @@ -202,6 +204,7 @@ class Services: artifact: ArtifactService artifact_revision: ArtifactRevisionService artifact_registry: ArtifactRegistryService + artifact_storage: ArtifactStorageService deployment: DeploymentService storage_namespace: StorageNamespaceService audit_log: AuditLogService @@ -406,6 +409,9 @@ def create(cls, args: ServiceArgs) -> Self: repositories.reservoir_registry.repository, repositories.artifact_registry.repository, ) + artifact_storage_service = ArtifactStorageService( + artifact_storage_repository=repositories.artifact_storage.repository, + ) deployment_service = DeploymentService( args.deployment_controller, args.deployment_controller._deployment_repository, @@ -459,6 +465,7 @@ def create(cls, args: ServiceArgs) -> Self: artifact=artifact_service, artifact_revision=artifact_revision_service, artifact_registry=artifact_registry_service, + artifact_storage=artifact_storage_service, deployment=deployment_service, storage_namespace=storage_namespace_service, audit_log=audit_log_service, @@ -511,6 +518,7 @@ class Processors(AbstractProcessorPackage): artifact: ArtifactProcessors artifact_registry: ArtifactRegistryProcessors artifact_revision: ArtifactRevisionProcessors + artifact_storage: ArtifactStorageProcessors deployment: DeploymentProcessors storage_namespace: StorageNamespaceProcessors audit_log: AuditLogProcessors @@ -585,6 +593,9 @@ def create(cls, args: ProcessorArgs, action_monitors: list[ActionMonitor]) -> Se artifact_revision_processors = ArtifactRevisionProcessors( services.artifact_revision, action_monitors ) + artifact_storage_processors = ArtifactStorageProcessors( + services.artifact_storage, action_monitors + ) deployment_processors = DeploymentProcessors(services.deployment, action_monitors) @@ -637,6 +648,7 @@ def create(cls, args: ProcessorArgs, action_monitors: list[ActionMonitor]) -> Se artifact=artifact_processors, artifact_registry=artifact_registry_processors, artifact_revision=artifact_revision_processors, + artifact_storage=artifact_storage_processors, deployment=deployment_processors, storage_namespace=storage_namespace_processors, audit_log=audit_log_processors, @@ -683,6 +695,7 @@ def supported_actions(self) -> list[ActionSpec]: *self.vfs_storage.supported_actions(), *self.artifact_registry.supported_actions(), *self.artifact_revision.supported_actions(), + *self.artifact_storage.supported_actions(), *self.artifact.supported_actions(), *(self.deployment.supported_actions() if self.deployment else []), *self.storage_namespace.supported_actions(), diff --git a/src/ai/backend/manager/services/vfs_storage/actions/create.py b/src/ai/backend/manager/services/vfs_storage/actions/create.py index b4a668eec77..a3dc0737b00 100644 --- a/src/ai/backend/manager/services/vfs_storage/actions/create.py +++ b/src/ai/backend/manager/services/vfs_storage/actions/create.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from dataclasses import dataclass -from typing import override +from typing import TYPE_CHECKING, override from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType @@ -8,10 +10,14 @@ from ai.backend.manager.repositories.base.creator import Creator from ai.backend.manager.services.vfs_storage.actions.base import VFSStorageAction +if TYPE_CHECKING: + from ai.backend.manager.models.artifact_storages import ArtifactStorageRow + @dataclass class CreateVFSStorageAction(VFSStorageAction): creator: Creator[VFSStorageRow] + meta_creator: Creator[ArtifactStorageRow] @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/vfs_storage/actions/update.py b/src/ai/backend/manager/services/vfs_storage/actions/update.py index e33ae6f9d41..5821455f1fc 100644 --- a/src/ai/backend/manager/services/vfs_storage/actions/update.py +++ b/src/ai/backend/manager/services/vfs_storage/actions/update.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from typing import override diff --git a/src/ai/backend/manager/services/vfs_storage/service.py b/src/ai/backend/manager/services/vfs_storage/service.py index 7736ffce346..685f745ff98 100644 --- a/src/ai/backend/manager/services/vfs_storage/service.py +++ b/src/ai/backend/manager/services/vfs_storage/service.py @@ -66,7 +66,9 @@ async def create(self, action: CreateVFSStorageAction) -> CreateVFSStorageAction Create a new VFS storage. """ log.info("Creating VFS storage with data: {}", action.creator) - storage_data = await self._vfs_storage_repository.create(action.creator) + storage_data = await self._vfs_storage_repository.create( + action.creator, action.meta_creator + ) return CreateVFSStorageActionResult(result=storage_data) async def update(self, action: UpdateVFSStorageAction) -> UpdateVFSStorageActionResult: From de9bec37a884d1b9bcfeb50b800538b11dbcaa40 Mon Sep 17 00:00:00 2001 From: Gyubong Date: Tue, 24 Feb 2026 15:18:16 +0900 Subject: [PATCH 02/10] fix: Reflect feedback --- .../data/storage/exceptions.py} | 0 src/ai/backend/common/data/storage/types.py | 12 ++++++++++++ src/ai/backend/common/types.py | 2 +- .../backend/manager/api/gql/artifact_storage.py | 7 ++----- .../manager/data/artifact_storages/__init__.py | 3 +-- .../manager/data/artifact_storages/types.py | 16 +++------------- .../backend/manager/models/artifact_storages.py | 9 +++------ .../artifact_storage/db_source/db_source.py | 4 ++-- .../repositories/artifact_storage/repository.py | 2 +- .../object_storage/db_source/db_source.py | 4 ++-- .../vfs_storage/db_source/db_source.py | 4 ++-- .../services/artifact_storage/actions/update.py | 2 +- 12 files changed, 30 insertions(+), 35 deletions(-) rename src/ai/backend/{manager/errors/artifact_storage.py => common/data/storage/exceptions.py} (100%) diff --git a/src/ai/backend/manager/errors/artifact_storage.py b/src/ai/backend/common/data/storage/exceptions.py similarity index 100% rename from src/ai/backend/manager/errors/artifact_storage.py rename to src/ai/backend/common/data/storage/exceptions.py diff --git a/src/ai/backend/common/data/storage/types.py b/src/ai/backend/common/data/storage/types.py index 971ce98c53b..60b06bb3731 100644 --- a/src/ai/backend/common/data/storage/types.py +++ b/src/ai/backend/common/data/storage/types.py @@ -1,10 +1,12 @@ from __future__ import annotations import enum +from dataclasses import dataclass from pydantic import BaseModel, ConfigDict from ai.backend.common.type_adapters import VFolderIDField +from ai.backend.common.types import ArtifactStorageId, ConcreteArtifactStorageId class VFolderStorageTarget(BaseModel): @@ -31,6 +33,16 @@ class ArtifactStorageType(enum.StrEnum): GIT_LFS = "git_lfs" +@dataclass(frozen=True) +class ArtifactStorageData: + """Data class for artifact storage metadata.""" + + id: ArtifactStorageId + name: str + storage_id: ConcreteArtifactStorageId + type: ArtifactStorageType + + class ArtifactStorageImportStep(enum.StrEnum): DOWNLOAD = "download" VERIFY = "verify" diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index a8fb043bf79..e3daaa2811e 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -309,7 +309,7 @@ def check_typed_tuple(value: tuple[Any, ...], types: tuple[type, ...]) -> tuple[ # ID of the `artifact_storages` common table (storage metadata). ArtifactStorageId = NewType("ArtifactStorageId", UUID) # ID of the concrete storage table (`object_storages`, `vfs_storages`, etc.). -ConcreteStorageId = NewType("ConcreteStorageId", UUID) +ConcreteArtifactStorageId = NewType("ConcreteArtifactStorageId", UUID) ImageAlias = NewType("ImageAlias", str) ArchName = NewType("ArchName", str) diff --git a/src/ai/backend/manager/api/gql/artifact_storage.py b/src/ai/backend/manager/api/gql/artifact_storage.py index b0209713fe3..1b4857e03fa 100644 --- a/src/ai/backend/manager/api/gql/artifact_storage.py +++ b/src/ai/backend/manager/api/gql/artifact_storage.py @@ -7,13 +7,10 @@ import strawberry from strawberry import ID, UNSET, Info -from ai.backend.common.data.storage.types import ArtifactStorageType +from ai.backend.common.data.storage.types import ArtifactStorageData, ArtifactStorageType from ai.backend.common.types import ArtifactStorageId from ai.backend.manager.api.gql.utils import check_admin_only -from ai.backend.manager.data.artifact_storages.types import ( - ArtifactStorageData, - ArtifactStorageUpdaterSpec, -) +from ai.backend.manager.data.artifact_storages.types import ArtifactStorageUpdaterSpec from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.repositories.base.updater import Updater from ai.backend.manager.services.artifact_storage.actions.update import ( diff --git a/src/ai/backend/manager/data/artifact_storages/__init__.py b/src/ai/backend/manager/data/artifact_storages/__init__.py index d358ac3fdaf..bb7896b3b7b 100644 --- a/src/ai/backend/manager/data/artifact_storages/__init__.py +++ b/src/ai/backend/manager/data/artifact_storages/__init__.py @@ -1,7 +1,6 @@ -from .types import ArtifactStorageCreatorSpec, ArtifactStorageData, ArtifactStorageUpdaterSpec +from .types import ArtifactStorageCreatorSpec, ArtifactStorageUpdaterSpec __all__ = ( "ArtifactStorageCreatorSpec", - "ArtifactStorageData", "ArtifactStorageUpdaterSpec", ) diff --git a/src/ai/backend/manager/data/artifact_storages/types.py b/src/ai/backend/manager/data/artifact_storages/types.py index d1e20c6f4fa..c070599d7d2 100644 --- a/src/ai/backend/manager/data/artifact_storages/types.py +++ b/src/ai/backend/manager/data/artifact_storages/types.py @@ -4,7 +4,7 @@ from typing import Any, override from ai.backend.common.data.storage.types import ArtifactStorageType -from ai.backend.common.types import ArtifactStorageId, ConcreteStorageId +from ai.backend.common.types import ConcreteArtifactStorageId from ai.backend.manager.errors.common import InternalServerError from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.repositories.base.creator import CreatorSpec @@ -12,25 +12,15 @@ from ai.backend.manager.types import OptionalState -@dataclass(frozen=True) -class ArtifactStorageData: - """Data class for ArtifactStorageRow.""" - - id: ArtifactStorageId - name: str - storage_id: ConcreteStorageId - type: ArtifactStorageType - - class ArtifactStorageCreatorSpec(CreatorSpec[ArtifactStorageRow]): """CreatorSpec for ArtifactStorageRow with deferred storage_id.""" def __init__(self, name: str, storage_type: ArtifactStorageType) -> None: self._name = name self._storage_type = storage_type - self._storage_id: ConcreteStorageId | None = None + self._storage_id: ConcreteArtifactStorageId | None = None - def set_storage_id(self, storage_id: ConcreteStorageId) -> None: + def set_storage_id(self, storage_id: ConcreteArtifactStorageId) -> None: self._storage_id = storage_id @override diff --git a/src/ai/backend/manager/models/artifact_storages.py b/src/ai/backend/manager/models/artifact_storages.py index ac3409f80a4..040d07e7053 100644 --- a/src/ai/backend/manager/models/artifact_storages.py +++ b/src/ai/backend/manager/models/artifact_storages.py @@ -7,7 +7,8 @@ import sqlalchemy as sa from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship -from ai.backend.common.data.storage.types import ArtifactStorageType +from ai.backend.common.data.storage.types import ArtifactStorageData, ArtifactStorageType +from ai.backend.common.types import ArtifactStorageId, ConcreteArtifactStorageId from ai.backend.logging import BraceStyleAdapter from .base import ( @@ -16,7 +17,6 @@ ) if TYPE_CHECKING: - from ai.backend.manager.data.artifact_storages.types import ArtifactStorageData from ai.backend.manager.models.object_storage import ObjectStorageRow from ai.backend.manager.models.vfs_storage import VFSStorageRow @@ -73,12 +73,9 @@ def __repr__(self) -> str: return self.__str__() def to_dataclass(self) -> ArtifactStorageData: - from ai.backend.common.types import ArtifactStorageId, ConcreteStorageId - from ai.backend.manager.data.artifact_storages.types import ArtifactStorageData - return ArtifactStorageData( id=ArtifactStorageId(self.id), name=self.name, - storage_id=ConcreteStorageId(self.storage_id), + storage_id=ConcreteArtifactStorageId(self.storage_id), type=ArtifactStorageType(self.type), ) diff --git a/src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py b/src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py index cac2d7178c7..e77f041bddd 100644 --- a/src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py @@ -4,8 +4,8 @@ import sqlalchemy as sa -from ai.backend.manager.data.artifact_storages.types import ArtifactStorageData -from ai.backend.manager.errors.artifact_storage import ArtifactStorageNotFoundError +from ai.backend.common.data.storage.exceptions import ArtifactStorageNotFoundError +from ai.backend.common.data.storage.types import ArtifactStorageData from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.repositories.base.updater import Updater, execute_updater diff --git a/src/ai/backend/manager/repositories/artifact_storage/repository.py b/src/ai/backend/manager/repositories/artifact_storage/repository.py index 6c02311e011..5edaee0274b 100644 --- a/src/ai/backend/manager/repositories/artifact_storage/repository.py +++ b/src/ai/backend/manager/repositories/artifact_storage/repository.py @@ -2,12 +2,12 @@ import uuid +from ai.backend.common.data.storage.types import ArtifactStorageData from ai.backend.common.exception import BackendAIError from ai.backend.common.metrics.metric import DomainType, LayerType from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy from ai.backend.common.resilience.policies.retry import BackoffStrategy, RetryArgs, RetryPolicy from ai.backend.common.resilience.resilience import Resilience -from ai.backend.manager.data.artifact_storages.types import ArtifactStorageData from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.repositories.artifact_storage.db_source.db_source import ( diff --git a/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py b/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py index c0e7632c879..7de454e0888 100644 --- a/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py @@ -5,7 +5,7 @@ import sqlalchemy as sa from sqlalchemy.orm import selectinload -from ai.backend.common.types import ConcreteStorageId +from ai.backend.common.types import ConcreteArtifactStorageId from ai.backend.manager.data.artifact_storages.types import ArtifactStorageCreatorSpec from ai.backend.manager.data.object_storage.types import ObjectStorageData, ObjectStorageListResult from ai.backend.manager.errors.common import InternalServerError @@ -111,7 +111,7 @@ async def create( meta_spec = meta_creator.spec if not isinstance(meta_spec, ArtifactStorageCreatorSpec): raise InternalServerError("meta_creator.spec must be ArtifactStorageCreatorSpec") - meta_spec.set_storage_id(ConcreteStorageId(new_row.id)) + meta_spec.set_storage_id(ConcreteArtifactStorageId(new_row.id)) await execute_creator(db_session, meta_creator) # Re-query to load the meta relationship diff --git a/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py b/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py index fae7bf76ee2..3ca8b8805d8 100644 --- a/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py @@ -5,7 +5,7 @@ import sqlalchemy as sa from sqlalchemy.orm import selectinload -from ai.backend.common.types import ConcreteStorageId +from ai.backend.common.types import ConcreteArtifactStorageId from ai.backend.manager.data.artifact_storages.types import ArtifactStorageCreatorSpec from ai.backend.manager.data.vfs_storage.types import VFSStorageData, VFSStorageListResult from ai.backend.manager.errors.common import InternalServerError @@ -78,7 +78,7 @@ async def create( meta_spec = meta_creator.spec if not isinstance(meta_spec, ArtifactStorageCreatorSpec): raise InternalServerError("meta_creator.spec must be ArtifactStorageCreatorSpec") - meta_spec.set_storage_id(ConcreteStorageId(new_row.id)) + meta_spec.set_storage_id(ConcreteArtifactStorageId(new_row.id)) await execute_creator(db_session, meta_creator) # Re-query to load the meta relationship diff --git a/src/ai/backend/manager/services/artifact_storage/actions/update.py b/src/ai/backend/manager/services/artifact_storage/actions/update.py index 12d3980cf4a..c0acb5055ec 100644 --- a/src/ai/backend/manager/services/artifact_storage/actions/update.py +++ b/src/ai/backend/manager/services/artifact_storage/actions/update.py @@ -3,9 +3,9 @@ from dataclasses import dataclass from typing import override +from ai.backend.common.data.storage.types import ArtifactStorageData from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.data.artifact_storages.types import ArtifactStorageData from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.repositories.base.updater import Updater from ai.backend.manager.services.artifact_storage.actions.base import ArtifactStorageAction From 984ec0ffb2034d4c6a9f4db8c8a31336f1f7b32b Mon Sep 17 00:00:00 2001 From: Gyubong Date: Tue, 24 Feb 2026 15:38:09 +0900 Subject: [PATCH 03/10] fix: alembic conflcit --- .../35dfab3b0662_add_artifact_storages_common_table.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py b/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py index ef75d9db519..95410586e02 100644 --- a/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py +++ b/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py @@ -1,7 +1,7 @@ """Add artifact_storages common table Revision ID: 35dfab3b0662 -Revises: ccf8ae5c90fe +Revises: 03ff6767b2e4 Create Date: 2025-12-02 09:24:21.050932 """ @@ -13,7 +13,7 @@ # revision identifiers, used by Alembic. revision = "35dfab3b0662" -down_revision = "ccf8ae5c90fe" +down_revision = "03ff6767b2e4" branch_labels = None depends_on = None From ecdb3b2aa9903a043e197f03d2ea25fbc050b240 Mon Sep 17 00:00:00 2001 From: Gyubong Date: Thu, 26 Feb 2026 12:47:46 +0900 Subject: [PATCH 04/10] fix: Apply SQLAlchemy Joined Table Inheritance to artifact_storages (experimetal) --- src/ai/backend/common/data/storage/types.py | 3 +- src/ai/backend/common/types.py | 2 - .../manager/api/gql/artifact_storage.py | 4 - .../backend/manager/api/gql/object_storage.py | 17 +--- src/ai/backend/manager/api/gql/vfs_storage.py | 17 +--- .../data/artifact_storages/__init__.py | 7 +- .../manager/data/artifact_storages/types.py | 26 ------ .../manager/data/object_storage/types.py | 8 +- .../backend/manager/data/vfs_storage/types.py | 9 +- ...64643926_apply_jti_to_artifact_storages.py | 81 +++++++++++++++++ .../manager/models/artifact_storages.py | 44 ++------- .../association_artifacts_storages/row.py | 2 + .../manager/models/object_storage/row.py | 35 +++----- .../manager/models/storage_namespace/row.py | 1 + .../backend/manager/models/vfs_storage/row.py | 36 +++----- .../artifact_storage/db_source/db_source.py | 16 ---- .../artifact_storage/repository.py | 4 - .../repositories/object_storage/creators.py | 2 + .../object_storage/db_source/db_source.py | 90 +++++-------------- .../repositories/object_storage/repository.py | 10 +-- .../repositories/vfs_storage/creators.py | 2 + .../vfs_storage/db_source/db_source.py | 79 +++++----------- .../repositories/vfs_storage/repository.py | 10 +-- .../services/object_storage/actions/create.py | 6 +- .../services/object_storage/service.py | 4 +- .../services/vfs_storage/actions/create.py | 6 +- .../manager/services/vfs_storage/service.py | 4 +- .../test_object_storage_service.py | 2 + .../vfs_storage/test_vfs_storage_service.py | 2 + 29 files changed, 192 insertions(+), 337 deletions(-) create mode 100644 src/ai/backend/manager/models/alembic/versions/7b5764643926_apply_jti_to_artifact_storages.py diff --git a/src/ai/backend/common/data/storage/types.py b/src/ai/backend/common/data/storage/types.py index 60b06bb3731..a3e9023424b 100644 --- a/src/ai/backend/common/data/storage/types.py +++ b/src/ai/backend/common/data/storage/types.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict from ai.backend.common.type_adapters import VFolderIDField -from ai.backend.common.types import ArtifactStorageId, ConcreteArtifactStorageId +from ai.backend.common.types import ArtifactStorageId class VFolderStorageTarget(BaseModel): @@ -39,7 +39,6 @@ class ArtifactStorageData: id: ArtifactStorageId name: str - storage_id: ConcreteArtifactStorageId type: ArtifactStorageType diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index e3daaa2811e..15abf3f6f55 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -308,8 +308,6 @@ def check_typed_tuple(value: tuple[Any, ...], types: tuple[type, ...]) -> tuple[ KernelId = NewType("KernelId", UUID) # ID of the `artifact_storages` common table (storage metadata). ArtifactStorageId = NewType("ArtifactStorageId", UUID) -# ID of the concrete storage table (`object_storages`, `vfs_storages`, etc.). -ConcreteArtifactStorageId = NewType("ConcreteArtifactStorageId", UUID) ImageAlias = NewType("ImageAlias", str) ArchName = NewType("ArchName", str) diff --git a/src/ai/backend/manager/api/gql/artifact_storage.py b/src/ai/backend/manager/api/gql/artifact_storage.py index 1b4857e03fa..0d5b7df9bcb 100644 --- a/src/ai/backend/manager/api/gql/artifact_storage.py +++ b/src/ai/backend/manager/api/gql/artifact_storage.py @@ -63,9 +63,6 @@ def to_internal(self) -> ArtifactStorageType: class ArtifactStorageGQL: id: ID = strawberry.field(description="The ID of the artifact storage") name: str = strawberry.field(description="The name of the artifact storage") - storage_id: ID = strawberry.field( - description="The ID of the underlying storage (ObjectStorage or VFSStorage)" - ) type: ArtifactStorageTypeGQL = strawberry.field(description="The type of the artifact storage") @classmethod @@ -73,7 +70,6 @@ def from_dataclass(cls, data: ArtifactStorageData) -> Self: return cls( id=ID(str(data.id)), name=data.name, - storage_id=ID(str(data.storage_id)), type=ArtifactStorageTypeGQL.from_internal(data.type), ) diff --git a/src/ai/backend/manager/api/gql/object_storage.py b/src/ai/backend/manager/api/gql/object_storage.py index aad51c41a5a..f11afda6a7c 100644 --- a/src/ai/backend/manager/api/gql/object_storage.py +++ b/src/ai/backend/manager/api/gql/object_storage.py @@ -3,15 +3,13 @@ import json import uuid from collections.abc import Iterable -from typing import TYPE_CHECKING, Self +from typing import Self import strawberry from strawberry import ID, UNSET, Info from strawberry.relay import Connection, Edge, Node, NodeID -from ai.backend.common.data.storage.types import ArtifactStorageType from ai.backend.manager.api.gql.base import encode_cursor -from ai.backend.manager.data.artifact_storages.types import ArtifactStorageCreatorSpec from ai.backend.manager.data.object_storage.types import ObjectStorageData from ai.backend.manager.models.object_storage import ObjectStorageRow from ai.backend.manager.repositories.base.creator import Creator @@ -37,9 +35,6 @@ from .storage_namespace import StorageNamespace, StorageNamespaceConnection, StorageNamespaceEdge from .types import StrawberryGQLContext -if TYPE_CHECKING: - from ai.backend.manager.models.artifact_storages import ArtifactStorageRow - @strawberry.type(description="Added in 25.14.0") class ObjectStorage(Node): @@ -170,6 +165,7 @@ class CreateObjectStorageInput: def to_creator(self) -> Creator[ObjectStorageRow]: return Creator( spec=ObjectStorageCreatorSpec( + name=self.name, host=self.host, access_key=self.access_key, secret_key=self.secret_key, @@ -178,14 +174,6 @@ def to_creator(self) -> Creator[ObjectStorageRow]: ) ) - def to_meta_creator(self) -> Creator[ArtifactStorageRow]: - return Creator( - spec=ArtifactStorageCreatorSpec( - name=self.name, - storage_type=ArtifactStorageType.OBJECT_STORAGE, - ) - ) - @strawberry.input(description="Added in 25.14.0") class UpdateObjectStorageInput: @@ -260,7 +248,6 @@ async def create_object_storage( action_result = await processors.object_storage.create.wait_for_complete( CreateObjectStorageAction( creator=input.to_creator(), - meta_creator=input.to_meta_creator(), ) ) diff --git a/src/ai/backend/manager/api/gql/vfs_storage.py b/src/ai/backend/manager/api/gql/vfs_storage.py index 3b4fd53737f..db5f6ae07bb 100644 --- a/src/ai/backend/manager/api/gql/vfs_storage.py +++ b/src/ai/backend/manager/api/gql/vfs_storage.py @@ -2,15 +2,13 @@ import uuid from collections.abc import Iterable -from typing import TYPE_CHECKING, Self +from typing import Self import strawberry from strawberry import ID, UNSET, Info from strawberry.relay import Connection, Edge, Node, NodeID -from ai.backend.common.data.storage.types import ArtifactStorageType from ai.backend.manager.api.gql.base import encode_cursor -from ai.backend.manager.data.artifact_storages.types import ArtifactStorageCreatorSpec from ai.backend.manager.data.vfs_storage.types import VFSStorageData from ai.backend.manager.models.vfs_storage import VFSStorageRow from ai.backend.manager.repositories.base.creator import Creator @@ -26,9 +24,6 @@ from .types import StrawberryGQLContext -if TYPE_CHECKING: - from ai.backend.manager.models.artifact_storages import ArtifactStorageRow - @strawberry.type(description="Added in 25.16.0. VFS Storage configuration") class VFSStorage(Node): @@ -118,19 +113,12 @@ class CreateVFSStorageInput: def to_creator(self) -> Creator[VFSStorageRow]: return Creator( spec=VFSStorageCreatorSpec( + name=self.name, host=self.host, base_path=self.base_path, ) ) - def to_meta_creator(self) -> Creator[ArtifactStorageRow]: - return Creator( - spec=ArtifactStorageCreatorSpec( - name=self.name, - storage_type=ArtifactStorageType.VFS_STORAGE, - ) - ) - @strawberry.input(description="Added in 25.16.0. Input for updating VFS storage") class UpdateVFSStorageInput: @@ -177,7 +165,6 @@ async def create_vfs_storage( action_result = await processors.vfs_storage.create.wait_for_complete( CreateVFSStorageAction( creator=input.to_creator(), - meta_creator=input.to_meta_creator(), ) ) diff --git a/src/ai/backend/manager/data/artifact_storages/__init__.py b/src/ai/backend/manager/data/artifact_storages/__init__.py index bb7896b3b7b..854ceaea72b 100644 --- a/src/ai/backend/manager/data/artifact_storages/__init__.py +++ b/src/ai/backend/manager/data/artifact_storages/__init__.py @@ -1,6 +1,3 @@ -from .types import ArtifactStorageCreatorSpec, ArtifactStorageUpdaterSpec +from .types import ArtifactStorageUpdaterSpec -__all__ = ( - "ArtifactStorageCreatorSpec", - "ArtifactStorageUpdaterSpec", -) +__all__ = ("ArtifactStorageUpdaterSpec",) diff --git a/src/ai/backend/manager/data/artifact_storages/types.py b/src/ai/backend/manager/data/artifact_storages/types.py index c070599d7d2..39566293801 100644 --- a/src/ai/backend/manager/data/artifact_storages/types.py +++ b/src/ai/backend/manager/data/artifact_storages/types.py @@ -3,37 +3,11 @@ from dataclasses import dataclass, field from typing import Any, override -from ai.backend.common.data.storage.types import ArtifactStorageType -from ai.backend.common.types import ConcreteArtifactStorageId -from ai.backend.manager.errors.common import InternalServerError from ai.backend.manager.models.artifact_storages import ArtifactStorageRow -from ai.backend.manager.repositories.base.creator import CreatorSpec from ai.backend.manager.repositories.base.updater import UpdaterSpec from ai.backend.manager.types import OptionalState -class ArtifactStorageCreatorSpec(CreatorSpec[ArtifactStorageRow]): - """CreatorSpec for ArtifactStorageRow with deferred storage_id.""" - - def __init__(self, name: str, storage_type: ArtifactStorageType) -> None: - self._name = name - self._storage_type = storage_type - self._storage_id: ConcreteArtifactStorageId | None = None - - def set_storage_id(self, storage_id: ConcreteArtifactStorageId) -> None: - self._storage_id = storage_id - - @override - def build_row(self) -> ArtifactStorageRow: - if self._storage_id is None: - raise InternalServerError("storage_id must be set before building row") - return ArtifactStorageRow( - name=self._name, - storage_id=self._storage_id, - type=self._storage_type, - ) - - @dataclass class ArtifactStorageUpdaterSpec(UpdaterSpec[ArtifactStorageRow]): """UpdaterSpec for ArtifactStorageRow.""" diff --git a/src/ai/backend/manager/data/object_storage/types.py b/src/ai/backend/manager/data/object_storage/types.py index 777e3bf5292..30dfce517e6 100644 --- a/src/ai/backend/manager/data/object_storage/types.py +++ b/src/ai/backend/manager/data/object_storage/types.py @@ -1,8 +1,8 @@ from __future__ import annotations -import uuid from dataclasses import dataclass +from ai.backend.common.data.storage.types import ArtifactStorageData from ai.backend.common.dto.manager.response import ObjectStorageResponse @@ -16,10 +16,8 @@ class ObjectStorageListResult: has_previous_page: bool -@dataclass -class ObjectStorageData: - id: uuid.UUID - name: str +@dataclass(frozen=True) +class ObjectStorageData(ArtifactStorageData): host: str access_key: str secret_key: str diff --git a/src/ai/backend/manager/data/vfs_storage/types.py b/src/ai/backend/manager/data/vfs_storage/types.py index d2b4ba9c380..eedc90eda81 100644 --- a/src/ai/backend/manager/data/vfs_storage/types.py +++ b/src/ai/backend/manager/data/vfs_storage/types.py @@ -1,9 +1,10 @@ from __future__ import annotations -import uuid from dataclasses import dataclass from pathlib import Path +from ai.backend.common.data.storage.types import ArtifactStorageData + @dataclass class VFSStorageListResult: @@ -15,9 +16,7 @@ class VFSStorageListResult: has_previous_page: bool -@dataclass -class VFSStorageData: - id: uuid.UUID - name: str +@dataclass(frozen=True) +class VFSStorageData(ArtifactStorageData): host: str base_path: Path diff --git a/src/ai/backend/manager/models/alembic/versions/7b5764643926_apply_jti_to_artifact_storages.py b/src/ai/backend/manager/models/alembic/versions/7b5764643926_apply_jti_to_artifact_storages.py new file mode 100644 index 00000000000..0cefcf29f54 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/7b5764643926_apply_jti_to_artifact_storages.py @@ -0,0 +1,81 @@ +"""Apply joined table inheritance to artifact_storages + +Revision ID: 7b5764643926 +Revises: 35dfab3b0662 +Create Date: 2026-02-26 00:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +from ai.backend.manager.models.base import GUID + +# revision identifiers, used by Alembic. +revision = "7b5764643926" +down_revision = "35dfab3b0662" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + conn = op.get_bind() + + # 1. For each row in artifact_storages, set id = storage_id (align PKs) + # We need to update artifact_storages.id to match the child table id (storage_id). + # Since id is a PK, we need to handle this carefully. + conn.execute( + sa.text(""" + UPDATE artifact_storages SET id = storage_id + """) + ) + + # 2. Drop the storage_id unique constraint and column + op.drop_constraint("uq_artifact_storages_storage_id", "artifact_storages", type_="unique") + op.drop_column("artifact_storages", "storage_id") + + # 3. Add FK constraints: child.id -> artifact_storages.id + op.create_foreign_key( + "fk_object_storages_id_artifact_storages", + "object_storages", + "artifact_storages", + ["id"], + ["id"], + ondelete="CASCADE", + ) + op.create_foreign_key( + "fk_vfs_storages_id_artifact_storages", + "vfs_storages", + "artifact_storages", + ["id"], + ["id"], + ondelete="CASCADE", + ) + + +def downgrade() -> None: + conn = op.get_bind() + + # 1. Drop FK constraints + op.drop_constraint("fk_vfs_storages_id_artifact_storages", "vfs_storages", type_="foreignkey") + op.drop_constraint( + "fk_object_storages_id_artifact_storages", "object_storages", type_="foreignkey" + ) + + # 2. Re-add storage_id column (copy id into it, since they were aligned) + op.add_column( + "artifact_storages", + sa.Column("storage_id", GUID(), nullable=True), + ) + + conn.execute( + sa.text(""" + UPDATE artifact_storages SET storage_id = id + """) + ) + + op.alter_column("artifact_storages", "storage_id", nullable=False) + + op.create_unique_constraint( + "uq_artifact_storages_storage_id", "artifact_storages", ["storage_id"] + ) diff --git a/src/ai/backend/manager/models/artifact_storages.py b/src/ai/backend/manager/models/artifact_storages.py index 040d07e7053..f48c5d90c45 100644 --- a/src/ai/backend/manager/models/artifact_storages.py +++ b/src/ai/backend/manager/models/artifact_storages.py @@ -2,13 +2,12 @@ import logging import uuid -from typing import TYPE_CHECKING import sqlalchemy as sa -from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column from ai.backend.common.data.storage.types import ArtifactStorageData, ArtifactStorageType -from ai.backend.common.types import ArtifactStorageId, ConcreteArtifactStorageId +from ai.backend.common.types import ArtifactStorageId from ai.backend.logging import BraceStyleAdapter from .base import ( @@ -16,30 +15,15 @@ Base, ) -if TYPE_CHECKING: - from ai.backend.manager.models.object_storage import ObjectStorageRow - from ai.backend.manager.models.vfs_storage import VFSStorageRow - log = BraceStyleAdapter(logging.getLogger(__spec__.name)) __all__ = ("ArtifactStorageRow",) -def _get_object_storage_join_condition() -> sa.ColumnElement[bool]: - from ai.backend.manager.models.object_storage import ObjectStorageRow - - return ObjectStorageRow.id == foreign(ArtifactStorageRow.storage_id) - - -def _get_vfs_storage_join_condition() -> sa.ColumnElement[bool]: - from ai.backend.manager.models.vfs_storage import VFSStorageRow - - return VFSStorageRow.id == foreign(ArtifactStorageRow.storage_id) - - class ArtifactStorageRow(Base): # type: ignore[misc] """ Common information of all artifact storage records. + Uses SQLAlchemy Joined Table Inheritance as the base class. """ __tablename__ = "artifact_storages" @@ -48,26 +32,15 @@ class ArtifactStorageRow(Base): # type: ignore[misc] "id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()") ) name: Mapped[str] = mapped_column("name", sa.String, nullable=False, unique=True) - storage_id: Mapped[uuid.UUID] = mapped_column("storage_id", GUID, nullable=False, unique=True) type: Mapped[str] = mapped_column("type", sa.String, nullable=False) - object_storages: Mapped[ObjectStorageRow | None] = relationship( - "ObjectStorageRow", - back_populates="meta", - primaryjoin=_get_object_storage_join_condition, - uselist=False, - viewonly=True, - ) - vfs_storages: Mapped[VFSStorageRow | None] = relationship( - "VFSStorageRow", - back_populates="meta", - primaryjoin=_get_vfs_storage_join_condition, - uselist=False, - viewonly=True, - ) + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "base", + } def __str__(self) -> str: - return f"ArtifactStorageRow(id={self.id}, storage_id={self.storage_id}, type={self.type}, name={self.name})" + return f"ArtifactStorageRow(id={self.id}, type={self.type}, name={self.name})" def __repr__(self) -> str: return self.__str__() @@ -76,6 +49,5 @@ def to_dataclass(self) -> ArtifactStorageData: return ArtifactStorageData( id=ArtifactStorageId(self.id), name=self.name, - storage_id=ConcreteArtifactStorageId(self.storage_id), type=ArtifactStorageType(self.type), ) diff --git a/src/ai/backend/manager/models/association_artifacts_storages/row.py b/src/ai/backend/manager/models/association_artifacts_storages/row.py index 4627caf3724..49d79fc1426 100644 --- a/src/ai/backend/manager/models/association_artifacts_storages/row.py +++ b/src/ai/backend/manager/models/association_artifacts_storages/row.py @@ -77,6 +77,7 @@ class AssociationArtifactsStorageRow(Base): # type: ignore[misc] back_populates="association_artifacts_storages_rows", primaryjoin=_get_association_object_storage_join_cond, overlaps="vfs_storage_row", + viewonly=True, ) # only valid when storage_type is "vfs" @@ -85,4 +86,5 @@ class AssociationArtifactsStorageRow(Base): # type: ignore[misc] back_populates="association_artifacts_storages_rows", primaryjoin=_get_association_vfs_storage_join_cond, overlaps="object_storage_row", + viewonly=True, ) diff --git a/src/ai/backend/manager/models/object_storage/row.py b/src/ai/backend/manager/models/object_storage/row.py index 254f80c8170..2d2f451e6ab 100644 --- a/src/ai/backend/manager/models/object_storage/row.py +++ b/src/ai/backend/manager/models/object_storage/row.py @@ -7,16 +7,16 @@ import sqlalchemy as sa from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship -from ai.backend.common.exception import RelationNotLoadedError +from ai.backend.common.data.storage.types import ArtifactStorageType +from ai.backend.common.types import ArtifactStorageId from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.object_storage.types import ObjectStorageData +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.base import ( GUID, - Base, ) if TYPE_CHECKING: - from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.association_artifacts_storages import ( AssociationArtifactsStorageRow, ) @@ -41,23 +41,18 @@ def _get_object_storage_namespace_join_cond() -> sa.ColumnElement[bool]: return foreign(StorageNamespaceRow.storage_id) == ObjectStorageRow.id -def _get_object_storage_meta_join_cond() -> sa.ColumnElement[bool]: - from ai.backend.manager.models.artifact_storages import ArtifactStorageRow - - return ObjectStorageRow.id == foreign(ArtifactStorageRow.storage_id) - - -class ObjectStorageRow(Base): # type: ignore[misc] +class ObjectStorageRow(ArtifactStorageRow): """ Represents an object storage configuration. This model is used to store the details of object storage services such as access keys, endpoints. + Uses SQLAlchemy Joined Table Inheritance (child of ArtifactStorageRow). """ __tablename__ = "object_storages" id: Mapped[uuid.UUID] = mapped_column( - "id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()") + "id", GUID, sa.ForeignKey("artifact_storages.id", ondelete="CASCADE"), primary_key=True ) host: Mapped[str] = mapped_column("host", sa.String, index=True, nullable=False) access_key: Mapped[str] = mapped_column( @@ -87,21 +82,20 @@ class ObjectStorageRow(Base): # type: ignore[misc] back_populates="object_storage_row", primaryjoin=_get_object_storage_association_artifact_join_cond, overlaps="vfs_storage_row", + viewonly=True, ) ) namespace_rows: Mapped[list[StorageNamespaceRow]] = relationship( "StorageNamespaceRow", back_populates="object_storage_row", primaryjoin=_get_object_storage_namespace_join_cond, - ) - meta: Mapped[ArtifactStorageRow | None] = relationship( - "ArtifactStorageRow", - back_populates="object_storages", - primaryjoin=_get_object_storage_meta_join_cond, - uselist=False, viewonly=True, ) + __mapper_args__ = { + "polymorphic_identity": "object_storage", + } + def __str__(self) -> str: return ( f"ObjectStorageRow(" @@ -117,11 +111,10 @@ def __repr__(self) -> str: return self.__str__() def to_dataclass(self) -> ObjectStorageData: - if self.meta is None: - raise RelationNotLoadedError() return ObjectStorageData( - id=self.id, - name=self.meta.name, + id=ArtifactStorageId(self.id), + name=self.name, + type=ArtifactStorageType(self.type), host=self.host, access_key=self.access_key, secret_key=self.secret_key, diff --git a/src/ai/backend/manager/models/storage_namespace/row.py b/src/ai/backend/manager/models/storage_namespace/row.py index 4e7bea62c8b..2c4f71c79e8 100644 --- a/src/ai/backend/manager/models/storage_namespace/row.py +++ b/src/ai/backend/manager/models/storage_namespace/row.py @@ -49,6 +49,7 @@ class StorageNamespaceRow(Base): # type: ignore[misc] "ObjectStorageRow", back_populates="namespace_rows", primaryjoin=_get_storage_namespace_join_cond, + viewonly=True, ) def to_dataclass(self) -> StorageNamespaceData: diff --git a/src/ai/backend/manager/models/vfs_storage/row.py b/src/ai/backend/manager/models/vfs_storage/row.py index 5b92a40962c..b2485d85754 100644 --- a/src/ai/backend/manager/models/vfs_storage/row.py +++ b/src/ai/backend/manager/models/vfs_storage/row.py @@ -8,16 +8,16 @@ import sqlalchemy as sa from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship -from ai.backend.common.exception import RelationNotLoadedError +from ai.backend.common.data.storage.types import ArtifactStorageType +from ai.backend.common.types import ArtifactStorageId from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.vfs_storage.types import VFSStorageData +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.base import ( GUID, - Base, ) if TYPE_CHECKING: - from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.association_artifacts_storages import ( AssociationArtifactsStorageRow, ) @@ -35,23 +35,18 @@ def _get_vfs_storage_association_artifact_join_cond() -> sa.ColumnElement[bool]: return VFSStorageRow.id == foreign(AssociationArtifactsStorageRow.storage_namespace_id) -def _get_vfs_storage_meta_join_cond() -> sa.ColumnElement[bool]: - from ai.backend.manager.models.artifact_storages import ArtifactStorageRow - - return VFSStorageRow.id == foreign(ArtifactStorageRow.storage_id) - - -class VFSStorageRow(Base): # type: ignore[misc] +class VFSStorageRow(ArtifactStorageRow): """ Represents a VFS storage configuration. This model is used to store the details of VFS storage backends such as base paths, subpaths, and chunk sizes. + Uses SQLAlchemy Joined Table Inheritance (child of ArtifactStorageRow). """ __tablename__ = "vfs_storages" id: Mapped[uuid.UUID] = mapped_column( - "id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()") + "id", GUID, sa.ForeignKey("artifact_storages.id", ondelete="CASCADE"), primary_key=True ) host: Mapped[str] = mapped_column("host", sa.String, nullable=False) base_path: Mapped[str] = mapped_column("base_path", sa.String, nullable=False) @@ -62,15 +57,13 @@ class VFSStorageRow(Base): # type: ignore[misc] back_populates="vfs_storage_row", primaryjoin=_get_vfs_storage_association_artifact_join_cond, overlaps="association_artifacts_storages_rows,object_storage_row", + viewonly=True, ) ) - meta: Mapped[ArtifactStorageRow | None] = relationship( - "ArtifactStorageRow", - back_populates="vfs_storages", - primaryjoin=_get_vfs_storage_meta_join_cond, - uselist=False, - viewonly=True, - ) + + __mapper_args__ = { + "polymorphic_identity": "vfs_storage", + } def __str__(self) -> str: return f"VFSStorageRow(id={self.id}, host={self.host}, base_path={self.base_path})" @@ -79,11 +72,10 @@ def __repr__(self) -> str: return self.__str__() def to_dataclass(self) -> VFSStorageData: - if self.meta is None: - raise RelationNotLoadedError() return VFSStorageData( - id=self.id, - name=self.meta.name, + id=ArtifactStorageId(self.id), + name=self.name, + type=ArtifactStorageType(self.type), host=self.host, base_path=Path(self.base_path), ) diff --git a/src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py b/src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py index e77f041bddd..fa3bf256173 100644 --- a/src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py @@ -33,20 +33,6 @@ async def get_by_id(self, storage_id: uuid.UUID) -> ArtifactStorageData: ) return row.to_dataclass() - async def get_by_storage_id(self, storage_id: uuid.UUID) -> ArtifactStorageData: - """ - Get an existing artifact storage configuration from the database by storage_id. - """ - async with self._db.begin_session() as db_session: - query = sa.select(ArtifactStorageRow).where(ArtifactStorageRow.storage_id == storage_id) - result = await db_session.execute(query) - row = result.scalar_one_or_none() - if row is None: - raise ArtifactStorageNotFoundError( - f"Artifact storage with storage_id {storage_id} not found." - ) - return row.to_dataclass() - async def update( self, updater: Updater[ArtifactStorageRow], @@ -55,11 +41,9 @@ async def update( Update an existing artifact storage configuration in the database. """ async with self._db.begin_session() as db_session: - # Execute update (may return None if no values to update, which is fine) await execute_updater(db_session, updater) artifact_storage_id = uuid.UUID(str(updater.pk_value)) - # Re-query to get the updated row query = sa.select(ArtifactStorageRow).where( ArtifactStorageRow.id == artifact_storage_id ) diff --git a/src/ai/backend/manager/repositories/artifact_storage/repository.py b/src/ai/backend/manager/repositories/artifact_storage/repository.py index 5edaee0274b..f2acae550f3 100644 --- a/src/ai/backend/manager/repositories/artifact_storage/repository.py +++ b/src/ai/backend/manager/repositories/artifact_storage/repository.py @@ -47,10 +47,6 @@ def __init__(self, db: ExtendedAsyncSAEngine) -> None: async def get_by_id(self, storage_id: uuid.UUID) -> ArtifactStorageData: return await self._db_source.get_by_id(storage_id) - @artifact_storage_repository_resilience.apply() - async def get_by_storage_id(self, storage_id: uuid.UUID) -> ArtifactStorageData: - return await self._db_source.get_by_storage_id(storage_id) - @artifact_storage_repository_resilience.apply() async def update( self, diff --git a/src/ai/backend/manager/repositories/object_storage/creators.py b/src/ai/backend/manager/repositories/object_storage/creators.py index 190dce280ae..e1a0b4c006f 100644 --- a/src/ai/backend/manager/repositories/object_storage/creators.py +++ b/src/ai/backend/manager/repositories/object_storage/creators.py @@ -13,6 +13,7 @@ class ObjectStorageCreatorSpec(CreatorSpec[ObjectStorageRow]): """CreatorSpec for object storage creation.""" + name: str host: str access_key: str secret_key: str @@ -22,6 +23,7 @@ class ObjectStorageCreatorSpec(CreatorSpec[ObjectStorageRow]): @override def build_row(self) -> ObjectStorageRow: return ObjectStorageRow( + name=self.name, host=self.host, access_key=self.access_key, secret_key=self.secret_key, diff --git a/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py b/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py index 7de454e0888..0545c17241f 100644 --- a/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py @@ -5,20 +5,16 @@ import sqlalchemy as sa from sqlalchemy.orm import selectinload -from ai.backend.common.types import ConcreteArtifactStorageId -from ai.backend.manager.data.artifact_storages.types import ArtifactStorageCreatorSpec from ai.backend.manager.data.object_storage.types import ObjectStorageData, ObjectStorageListResult -from ai.backend.manager.errors.common import InternalServerError from ai.backend.manager.errors.object_storage import ( ObjectStorageNotFoundError, ) -from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.object_storage import ObjectStorageRow from ai.backend.manager.models.storage_namespace import StorageNamespaceRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.repositories.base import BatchQuerier, execute_batch_querier from ai.backend.manager.repositories.base.creator import Creator, execute_creator -from ai.backend.manager.repositories.base.updater import Updater, execute_updater +from ai.backend.manager.repositories.base.updater import Updater class ObjectStorageDBSource: @@ -34,37 +30,21 @@ async def get_by_name(self, storage_name: str) -> ObjectStorageData: Get an existing object storage configuration from the database. """ async with self._db.begin_readonly_session_read_committed() as db_session: - query = ( - sa.select(ArtifactStorageRow) - .where(ArtifactStorageRow.name == storage_name) - .options( - selectinload(ArtifactStorageRow.object_storages).selectinload( - ObjectStorageRow.meta - ) - ) - ) + query = sa.select(ObjectStorageRow).where(ObjectStorageRow.name == storage_name) result = await db_session.execute(query) row = result.scalar_one_or_none() if row is None: raise ObjectStorageNotFoundError( f"Object storage with name {storage_name} not found." ) - if row.object_storages is None: - raise ObjectStorageNotFoundError( - f"Object storage not found for name {storage_name}" - ) - return row.object_storages.to_dataclass() + return row.to_dataclass() async def get_by_id(self, storage_id: uuid.UUID) -> ObjectStorageData: """ Get an existing object storage configuration from the database by ID. """ async with self._db.begin_readonly_session_read_committed() as db_session: - query = ( - sa.select(ObjectStorageRow) - .where(ObjectStorageRow.id == storage_id) - .options(selectinload(ObjectStorageRow.meta)) - ) + query = sa.select(ObjectStorageRow).where(ObjectStorageRow.id == storage_id) result = await db_session.execute(query) row = result.scalar_one_or_none() if row is None: @@ -73,17 +53,13 @@ async def get_by_id(self, storage_id: uuid.UUID) -> ObjectStorageData: async def get_by_namespace_id(self, storage_namespace_id: uuid.UUID) -> ObjectStorageData: """ - Get an existing object storage configuration from the database by ID. + Get an existing object storage configuration from the database by namespace ID. """ async with self._db.begin_readonly_session_read_committed() as db_session: query = ( sa.select(StorageNamespaceRow) .where(StorageNamespaceRow.id == storage_namespace_id) - .options( - selectinload(StorageNamespaceRow.object_storage_row).selectinload( - ObjectStorageRow.meta - ) - ) + .options(selectinload(StorageNamespaceRow.object_storage_row)) ) result = await db_session.execute(query) row = result.scalar_one_or_none() @@ -97,32 +73,14 @@ async def get_by_namespace_id(self, storage_namespace_id: uuid.UUID) -> ObjectSt ) return row.object_storage_row.to_dataclass() - async def create( - self, creator: Creator[ObjectStorageRow], meta_creator: Creator[ArtifactStorageRow] - ) -> ObjectStorageData: + async def create(self, creator: Creator[ObjectStorageRow]) -> ObjectStorageData: """ Create a new object storage configuration in the database. + JTI handles inserting into both artifact_storages and object_storages. """ async with self._db.begin_session() as db_session: creator_result = await execute_creator(db_session, creator) - new_row = creator_result.row - - # Set the storage_id on the meta creator spec and create ArtifactStorageRow - meta_spec = meta_creator.spec - if not isinstance(meta_spec, ArtifactStorageCreatorSpec): - raise InternalServerError("meta_creator.spec must be ArtifactStorageCreatorSpec") - meta_spec.set_storage_id(ConcreteArtifactStorageId(new_row.id)) - await execute_creator(db_session, meta_creator) - - # Re-query to load the meta relationship - query = ( - sa.select(ObjectStorageRow) - .where(ObjectStorageRow.id == new_row.id) - .options(selectinload(ObjectStorageRow.meta)) - ) - row_result = await db_session.execute(query) - row = row_result.scalar_one() - return row.to_dataclass() + return creator_result.row.to_dataclass() async def update( self, @@ -130,18 +88,19 @@ async def update( ) -> ObjectStorageData: """ Update an existing object storage configuration in the database. + Uses plain UPDATE + SELECT instead of execute_updater's RETURNING+from_statement, + which is incompatible with JTI (RETURNING only includes child table columns). """ async with self._db.begin_session() as db_session: - # Execute update (may return None if no values to update, which is fine) - await execute_updater(db_session, updater) - storage_id = uuid.UUID(str(updater.pk_value)) - # Re-query to load the meta relationship - query = ( - sa.select(ObjectStorageRow) - .where(ObjectStorageRow.id == storage_id) - .options(selectinload(ObjectStorageRow.meta)) - ) + values = updater.spec.build_values() + if values: + table = ObjectStorageRow.__table__ + pk_col = list(table.primary_key.columns)[0] + update_stmt = sa.update(table).values(values).where(pk_col == storage_id) + await db_session.execute(update_stmt) + + query = sa.select(ObjectStorageRow).where(ObjectStorageRow.id == storage_id) row_result = await db_session.execute(query) row = row_result.scalar_one_or_none() if row is None: @@ -151,6 +110,7 @@ async def update( async def delete(self, storage_id: uuid.UUID) -> uuid.UUID: """ Delete an existing object storage configuration from the database. + FK cascade handles deleting the ArtifactStorageRow. """ async with self._db.begin_session() as db_session: delete_query = ( @@ -162,12 +122,6 @@ async def delete(self, storage_id: uuid.UUID) -> uuid.UUID: deleted_id = result.scalar() if deleted_id is None: raise ObjectStorageNotFoundError(f"Object storage with ID {storage_id} not found.") - - # Delete the ArtifactStorageRow - delete_meta_query = sa.delete(ArtifactStorageRow).where( - ArtifactStorageRow.storage_id == storage_id - ) - await db_session.execute(delete_meta_query) return deleted_id async def list_object_storages(self) -> list[ObjectStorageData]: @@ -175,7 +129,7 @@ async def list_object_storages(self) -> list[ObjectStorageData]: List all object storage configurations from the database. """ async with self._db.begin_readonly_session_read_committed() as db_session: - query = sa.select(ObjectStorageRow).options(selectinload(ObjectStorageRow.meta)) + query = sa.select(ObjectStorageRow) result = await db_session.execute(query) rows = result.scalars().all() return [row.to_dataclass() for row in rows] @@ -186,7 +140,7 @@ async def search( ) -> ObjectStorageListResult: """Searches Object storages with total count.""" async with self._db.begin_readonly_session() as db_sess: - query = sa.select(ObjectStorageRow).options(selectinload(ObjectStorageRow.meta)) + query = sa.select(ObjectStorageRow) result = await execute_batch_querier( db_sess, diff --git a/src/ai/backend/manager/repositories/object_storage/repository.py b/src/ai/backend/manager/repositories/object_storage/repository.py index 6d5e909f62e..5b896093298 100644 --- a/src/ai/backend/manager/repositories/object_storage/repository.py +++ b/src/ai/backend/manager/repositories/object_storage/repository.py @@ -1,7 +1,6 @@ from __future__ import annotations import uuid -from typing import TYPE_CHECKING from ai.backend.common.exception import ( BackendAIError, @@ -18,9 +17,6 @@ from ai.backend.manager.repositories.base.updater import Updater from ai.backend.manager.repositories.object_storage.db_source.db_source import ObjectStorageDBSource -if TYPE_CHECKING: - from ai.backend.manager.models.artifact_storages import ArtifactStorageRow - object_storage_repository_resilience = Resilience( policies=[ MetricPolicy( @@ -59,10 +55,8 @@ async def get_by_namespace_id(self, storage_namespace_id: uuid.UUID) -> ObjectSt return await self._db_source.get_by_namespace_id(storage_namespace_id) @object_storage_repository_resilience.apply() - async def create( - self, creator: Creator[ObjectStorageRow], meta_creator: Creator[ArtifactStorageRow] - ) -> ObjectStorageData: - return await self._db_source.create(creator, meta_creator) + async def create(self, creator: Creator[ObjectStorageRow]) -> ObjectStorageData: + return await self._db_source.create(creator) @object_storage_repository_resilience.apply() async def update( diff --git a/src/ai/backend/manager/repositories/vfs_storage/creators.py b/src/ai/backend/manager/repositories/vfs_storage/creators.py index 5b2af29793f..aec0a785dc3 100644 --- a/src/ai/backend/manager/repositories/vfs_storage/creators.py +++ b/src/ai/backend/manager/repositories/vfs_storage/creators.py @@ -13,12 +13,14 @@ class VFSStorageCreatorSpec(CreatorSpec[VFSStorageRow]): """CreatorSpec for VFS storage creation.""" + name: str host: str base_path: str @override def build_row(self) -> VFSStorageRow: return VFSStorageRow( + name=self.name, host=self.host, base_path=self.base_path, ) diff --git a/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py b/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py index 3ca8b8805d8..ebc9a7a142b 100644 --- a/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py @@ -3,21 +3,16 @@ import uuid import sqlalchemy as sa -from sqlalchemy.orm import selectinload -from ai.backend.common.types import ConcreteArtifactStorageId -from ai.backend.manager.data.artifact_storages.types import ArtifactStorageCreatorSpec from ai.backend.manager.data.vfs_storage.types import VFSStorageData, VFSStorageListResult -from ai.backend.manager.errors.common import InternalServerError from ai.backend.manager.errors.vfs_storage import ( VFSStorageNotFoundError, ) -from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.models.vfs_storage import VFSStorageRow from ai.backend.manager.repositories.base import BatchQuerier, execute_batch_querier from ai.backend.manager.repositories.base.creator import Creator, execute_creator -from ai.backend.manager.repositories.base.updater import Updater, execute_updater +from ai.backend.manager.repositories.base.updater import Updater class VFSStorageDBSource: @@ -33,63 +28,33 @@ async def get_by_name(self, storage_name: str) -> VFSStorageData: Get an existing VFS storage configuration from the database. """ async with self._db.begin_readonly_session_read_committed() as db_session: - query = ( - sa.select(ArtifactStorageRow) - .where(ArtifactStorageRow.name == storage_name) - .options( - selectinload(ArtifactStorageRow.vfs_storages).selectinload(VFSStorageRow.meta) - ) - ) + query = sa.select(VFSStorageRow).where(VFSStorageRow.name == storage_name) result = await db_session.execute(query) row = result.scalar_one_or_none() if row is None: raise VFSStorageNotFoundError(f"VFS storage with name {storage_name} not found.") - if row.vfs_storages is None: - raise VFSStorageNotFoundError(f"VFS storage not found for name {storage_name}") - return row.vfs_storages.to_dataclass() + return row.to_dataclass() async def get_by_id(self, storage_id: uuid.UUID) -> VFSStorageData: """ Get an existing VFS storage configuration from the database by ID. """ async with self._db.begin_readonly_session_read_committed() as db_session: - query = ( - sa.select(VFSStorageRow) - .where(VFSStorageRow.id == storage_id) - .options(selectinload(VFSStorageRow.meta)) - ) + query = sa.select(VFSStorageRow).where(VFSStorageRow.id == storage_id) result = await db_session.execute(query) row = result.scalar_one_or_none() if row is None: raise VFSStorageNotFoundError(f"VFS storage with ID {storage_id} not found.") return row.to_dataclass() - async def create( - self, creator: Creator[VFSStorageRow], meta_creator: Creator[ArtifactStorageRow] - ) -> VFSStorageData: + async def create(self, creator: Creator[VFSStorageRow]) -> VFSStorageData: """ Create a new VFS storage configuration in the database. + JTI handles inserting into both artifact_storages and vfs_storages. """ async with self._db.begin_session() as db_session: creator_result = await execute_creator(db_session, creator) - new_row = creator_result.row - - # Set the storage_id on the meta creator spec and create ArtifactStorageRow - meta_spec = meta_creator.spec - if not isinstance(meta_spec, ArtifactStorageCreatorSpec): - raise InternalServerError("meta_creator.spec must be ArtifactStorageCreatorSpec") - meta_spec.set_storage_id(ConcreteArtifactStorageId(new_row.id)) - await execute_creator(db_session, meta_creator) - - # Re-query to load the meta relationship - query = ( - sa.select(VFSStorageRow) - .where(VFSStorageRow.id == new_row.id) - .options(selectinload(VFSStorageRow.meta)) - ) - row_result = await db_session.execute(query) - row = row_result.scalar_one() - return row.to_dataclass() + return creator_result.row.to_dataclass() async def update( self, @@ -97,18 +62,19 @@ async def update( ) -> VFSStorageData: """ Update an existing VFS storage configuration in the database. + Uses plain UPDATE + SELECT instead of execute_updater's RETURNING+from_statement, + which is incompatible with JTI (RETURNING only includes child table columns). """ async with self._db.begin_session() as db_session: - # Execute update (may return None if no values to update, which is fine) - await execute_updater(db_session, updater) - storage_id = uuid.UUID(str(updater.pk_value)) - # Re-query to load the meta relationship - query = ( - sa.select(VFSStorageRow) - .where(VFSStorageRow.id == storage_id) - .options(selectinload(VFSStorageRow.meta)) - ) + values = updater.spec.build_values() + if values: + table = VFSStorageRow.__table__ + pk_col = list(table.primary_key.columns)[0] + update_stmt = sa.update(table).values(values).where(pk_col == storage_id) + await db_session.execute(update_stmt) + + query = sa.select(VFSStorageRow).where(VFSStorageRow.id == storage_id) row_result = await db_session.execute(query) row = row_result.scalar_one_or_none() if row is None: @@ -118,6 +84,7 @@ async def update( async def delete(self, storage_id: uuid.UUID) -> uuid.UUID: """ Delete an existing VFS storage configuration from the database. + FK cascade handles deleting the ArtifactStorageRow. """ async with self._db.begin_session() as db_session: delete_query = ( @@ -129,12 +96,6 @@ async def delete(self, storage_id: uuid.UUID) -> uuid.UUID: deleted_id = result.scalar() if deleted_id is None: raise VFSStorageNotFoundError(f"VFS storage with ID {storage_id} not found.") - - # Delete the ArtifactStorageRow - delete_meta_query = sa.delete(ArtifactStorageRow).where( - ArtifactStorageRow.storage_id == storage_id - ) - await db_session.execute(delete_meta_query) return deleted_id async def list_vfs_storages(self) -> list[VFSStorageData]: @@ -142,7 +103,7 @@ async def list_vfs_storages(self) -> list[VFSStorageData]: List all VFS storage configurations from the database. """ async with self._db.begin_readonly_session_read_committed() as db_session: - query = sa.select(VFSStorageRow).options(selectinload(VFSStorageRow.meta)) + query = sa.select(VFSStorageRow) result = await db_session.execute(query) rows = result.scalars().all() return [row.to_dataclass() for row in rows] @@ -153,7 +114,7 @@ async def search( ) -> VFSStorageListResult: """Searches VFS storages with total count.""" async with self._db.begin_readonly_session() as db_sess: - query = sa.select(VFSStorageRow).options(selectinload(VFSStorageRow.meta)) + query = sa.select(VFSStorageRow) result = await execute_batch_querier( db_sess, diff --git a/src/ai/backend/manager/repositories/vfs_storage/repository.py b/src/ai/backend/manager/repositories/vfs_storage/repository.py index 641978d4343..57852b2da53 100644 --- a/src/ai/backend/manager/repositories/vfs_storage/repository.py +++ b/src/ai/backend/manager/repositories/vfs_storage/repository.py @@ -1,7 +1,6 @@ from __future__ import annotations import uuid -from typing import TYPE_CHECKING from ai.backend.common.exception import BackendAIError from ai.backend.common.metrics.metric import DomainType, LayerType @@ -16,9 +15,6 @@ from ai.backend.manager.repositories.base.updater import Updater from ai.backend.manager.repositories.vfs_storage.db_source.db_source import VFSStorageDBSource -if TYPE_CHECKING: - from ai.backend.manager.models.artifact_storages import ArtifactStorageRow - vfs_storage_repository_resilience = Resilience( policies=[ MetricPolicy( @@ -53,10 +49,8 @@ async def get_by_id(self, storage_id: uuid.UUID) -> VFSStorageData: return await self._db_source.get_by_id(storage_id) @vfs_storage_repository_resilience.apply() - async def create( - self, creator: Creator[VFSStorageRow], meta_creator: Creator[ArtifactStorageRow] - ) -> VFSStorageData: - return await self._db_source.create(creator, meta_creator) + async def create(self, creator: Creator[VFSStorageRow]) -> VFSStorageData: + return await self._db_source.create(creator) @vfs_storage_repository_resilience.apply() async def update( diff --git a/src/ai/backend/manager/services/object_storage/actions/create.py b/src/ai/backend/manager/services/object_storage/actions/create.py index aa5294caf3f..b135206bb9b 100644 --- a/src/ai/backend/manager/services/object_storage/actions/create.py +++ b/src/ai/backend/manager/services/object_storage/actions/create.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, override +from typing import override from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType @@ -10,14 +10,10 @@ from ai.backend.manager.repositories.base.creator import Creator from ai.backend.manager.services.object_storage.actions.base import ObjectStorageAction -if TYPE_CHECKING: - from ai.backend.manager.models.artifact_storages import ArtifactStorageRow - @dataclass class CreateObjectStorageAction(ObjectStorageAction): creator: Creator[ObjectStorageRow] - meta_creator: Creator[ArtifactStorageRow] @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/object_storage/service.py b/src/ai/backend/manager/services/object_storage/service.py index 7d7a139e7f1..61d2322b7ee 100644 --- a/src/ai/backend/manager/services/object_storage/service.py +++ b/src/ai/backend/manager/services/object_storage/service.py @@ -77,9 +77,7 @@ async def create(self, action: CreateObjectStorageAction) -> CreateObjectStorage Create a new object storage. """ log.info("Creating object storage with data: {}", action.creator) - storage_data = await self._object_storage_repository.create( - action.creator, action.meta_creator - ) + storage_data = await self._object_storage_repository.create(action.creator) return CreateObjectStorageActionResult(result=storage_data) async def update(self, action: UpdateObjectStorageAction) -> UpdateObjectStorageActionResult: diff --git a/src/ai/backend/manager/services/vfs_storage/actions/create.py b/src/ai/backend/manager/services/vfs_storage/actions/create.py index a3dc0737b00..d1de8b4ad2d 100644 --- a/src/ai/backend/manager/services/vfs_storage/actions/create.py +++ b/src/ai/backend/manager/services/vfs_storage/actions/create.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, override +from typing import override from ai.backend.manager.actions.action import BaseActionResult from ai.backend.manager.actions.types import ActionOperationType @@ -10,14 +10,10 @@ from ai.backend.manager.repositories.base.creator import Creator from ai.backend.manager.services.vfs_storage.actions.base import VFSStorageAction -if TYPE_CHECKING: - from ai.backend.manager.models.artifact_storages import ArtifactStorageRow - @dataclass class CreateVFSStorageAction(VFSStorageAction): creator: Creator[VFSStorageRow] - meta_creator: Creator[ArtifactStorageRow] @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/vfs_storage/service.py b/src/ai/backend/manager/services/vfs_storage/service.py index 685f745ff98..7736ffce346 100644 --- a/src/ai/backend/manager/services/vfs_storage/service.py +++ b/src/ai/backend/manager/services/vfs_storage/service.py @@ -66,9 +66,7 @@ async def create(self, action: CreateVFSStorageAction) -> CreateVFSStorageAction Create a new VFS storage. """ log.info("Creating VFS storage with data: {}", action.creator) - storage_data = await self._vfs_storage_repository.create( - action.creator, action.meta_creator - ) + storage_data = await self._vfs_storage_repository.create(action.creator) return CreateVFSStorageActionResult(result=storage_data) async def update(self, action: UpdateVFSStorageAction) -> UpdateVFSStorageActionResult: diff --git a/tests/unit/manager/services/object_storage/test_object_storage_service.py b/tests/unit/manager/services/object_storage/test_object_storage_service.py index b4e53eaa5f8..5aa493b8491 100644 --- a/tests/unit/manager/services/object_storage/test_object_storage_service.py +++ b/tests/unit/manager/services/object_storage/test_object_storage_service.py @@ -10,6 +10,7 @@ import pytest +from ai.backend.common.data.storage.types import ArtifactStorageType from ai.backend.manager.data.object_storage.types import ( ObjectStorageData, ObjectStorageListResult, @@ -76,6 +77,7 @@ def sample_object_storage_data(self) -> ObjectStorageData: return ObjectStorageData( id=uuid4(), name="test-object-storage", + type=ArtifactStorageType.OBJECT_STORAGE, host="storage-proxy-1", access_key="test-access-key", secret_key="test-secret-key", diff --git a/tests/unit/manager/services/vfs_storage/test_vfs_storage_service.py b/tests/unit/manager/services/vfs_storage/test_vfs_storage_service.py index 3ef710b5ab6..66919e04726 100644 --- a/tests/unit/manager/services/vfs_storage/test_vfs_storage_service.py +++ b/tests/unit/manager/services/vfs_storage/test_vfs_storage_service.py @@ -11,6 +11,7 @@ import pytest +from ai.backend.common.data.storage.types import ArtifactStorageType from ai.backend.manager.data.vfs_storage.types import VFSStorageData, VFSStorageListResult from ai.backend.manager.repositories.base import BatchQuerier, OffsetPagination from ai.backend.manager.repositories.vfs_storage.repository import VFSStorageRepository @@ -42,6 +43,7 @@ def sample_vfs_storage_data(self) -> VFSStorageData: return VFSStorageData( id=uuid4(), name="test-vfs-storage", + type=ArtifactStorageType.VFS_STORAGE, host="localhost", base_path=Path("/mnt/vfs/test"), ) From 9c1d73b67544aded2c56b4a51a7acb4396ef0ee0 Mon Sep 17 00:00:00 2001 From: Gyubong Date: Thu, 26 Feb 2026 03:51:09 +0000 Subject: [PATCH 05/10] chore: update api schema dump Co-authored-by: octodog --- docs/manager/graphql-reference/supergraph.graphql | 3 --- docs/manager/graphql-reference/v2-schema.graphql | 3 --- 2 files changed, 6 deletions(-) diff --git a/docs/manager/graphql-reference/supergraph.graphql b/docs/manager/graphql-reference/supergraph.graphql index 41545bfca12..b4135af28a4 100644 --- a/docs/manager/graphql-reference/supergraph.graphql +++ b/docs/manager/graphql-reference/supergraph.graphql @@ -1006,9 +1006,6 @@ type ArtifactStorage """The name of the artifact storage""" name: String! - """The ID of the underlying storage (ObjectStorage or VFSStorage)""" - storageId: ID! - """The type of the artifact storage""" type: ArtifactStorageType! } diff --git a/docs/manager/graphql-reference/v2-schema.graphql b/docs/manager/graphql-reference/v2-schema.graphql index 4c21a4ac08b..d3e15184b32 100644 --- a/docs/manager/graphql-reference/v2-schema.graphql +++ b/docs/manager/graphql-reference/v2-schema.graphql @@ -717,9 +717,6 @@ type ArtifactStorage { """The name of the artifact storage""" name: String! - """The ID of the underlying storage (ObjectStorage or VFSStorage)""" - storageId: ID! - """The type of the artifact storage""" type: ArtifactStorageType! } From 1e0a7ad8054d313e12eae265283dd9b0ddec1244 Mon Sep 17 00:00:00 2001 From: Gyubong Date: Thu, 26 Feb 2026 13:02:09 +0900 Subject: [PATCH 06/10] fix: Migration script --- ...0662_add_artifact_storages_common_table.py | 102 ++++++++---------- ...64643926_apply_jti_to_artifact_storages.py | 81 -------------- 2 files changed, 47 insertions(+), 136 deletions(-) delete mode 100644 src/ai/backend/manager/models/alembic/versions/7b5764643926_apply_jti_to_artifact_storages.py diff --git a/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py b/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py index 95410586e02..b1e948a4904 100644 --- a/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py +++ b/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py @@ -1,7 +1,7 @@ -"""Add artifact_storages common table +"""Add artifact_storages common table with JTI Revision ID: 35dfab3b0662 -Revises: 03ff6767b2e4 +Revises: ffcf0ed13a26 Create Date: 2025-12-02 09:24:21.050932 """ @@ -13,7 +13,7 @@ # revision identifiers, used by Alembic. revision = "35dfab3b0662" -down_revision = "03ff6767b2e4" +down_revision = "ffcf0ed13a26" branch_labels = None depends_on = None @@ -21,23 +21,16 @@ def _migrate_object_storages_to_artifact_storages( conn: sa.engine.Connection, ) -> None: - """Migrate existing object_storages records to artifact_storages.""" - result = conn.execute( + """Migrate existing object_storages records to artifact_storages (JTI: id = child id).""" + conn.execute( sa.text(""" - SELECT id, name FROM object_storages + INSERT INTO artifact_storages (id, name, type) + SELECT id, name, 'object_storage' + FROM object_storages WHERE name IS NOT NULL """) ) - for row in result: - conn.execute( - sa.text(""" - INSERT INTO artifact_storages (name, storage_id, type) - VALUES (:name, :storage_id, :type) - """), - {"name": row.name, "storage_id": row.id, "type": "object_storage"}, - ) - # Drop the name column and constraint op.drop_index("ix_object_storages_name", table_name="object_storages") op.drop_column("object_storages", "name") @@ -46,23 +39,16 @@ def _migrate_object_storages_to_artifact_storages( def _migrate_vfs_storages_to_artifact_storages( conn: sa.engine.Connection, ) -> None: - """Migrate existing vfs_storages records to artifact_storages.""" - result = conn.execute( + """Migrate existing vfs_storages records to artifact_storages (JTI: id = child id).""" + conn.execute( sa.text(""" - SELECT id, name FROM vfs_storages + INSERT INTO artifact_storages (id, name, type) + SELECT id, name, 'vfs_storage' + FROM vfs_storages WHERE name IS NOT NULL """) ) - for row in result: - conn.execute( - sa.text(""" - INSERT INTO artifact_storages (name, storage_id, type) - VALUES (:name, :storage_id, :type) - """), - {"name": row.name, "storage_id": row.id, "type": "vfs_storage"}, - ) - # Drop the name column and constraint op.drop_index("ix_vfs_storages_name", table_name="vfs_storages") op.drop_column("vfs_storages", "name") @@ -72,56 +58,38 @@ def _migrate_artifact_storages_to_object_storages( conn: sa.engine.Connection, ) -> None: """Migrate data back from artifact_storages to object_storages.""" - result = conn.execute( + conn.execute( sa.text(""" - SELECT name, storage_id FROM artifact_storages - WHERE type = 'object_storage' + UPDATE object_storages o + SET name = a.name + FROM artifact_storages a + WHERE o.id = a.id AND a.type = 'object_storage' """) ) - for row in result: - conn.execute( - sa.text(""" - UPDATE object_storages - SET name = :name - WHERE id = :storage_id - """), - {"name": row.name, "storage_id": row.storage_id}, - ) - def _migrate_artifact_storages_to_vfs_storages( conn: sa.engine.Connection, ) -> None: """Migrate data back from artifact_storages to vfs_storages.""" - result = conn.execute( + conn.execute( sa.text(""" - SELECT name, storage_id FROM artifact_storages - WHERE type = 'vfs_storage' + UPDATE vfs_storages v + SET name = a.name + FROM artifact_storages a + WHERE v.id = a.id AND a.type = 'vfs_storage' """) ) - for row in result: - conn.execute( - sa.text(""" - UPDATE vfs_storages - SET name = :name - WHERE id = :storage_id - """), - {"name": row.name, "storage_id": row.storage_id}, - ) - def upgrade() -> None: op.create_table( "artifact_storages", sa.Column("id", GUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False), sa.Column("name", sa.String(), nullable=False), - sa.Column("storage_id", GUID(), nullable=False), sa.Column("type", sa.String(), nullable=False), sa.PrimaryKeyConstraint("id", name=op.f("pk_artifact_storages")), sa.UniqueConstraint("name", name=op.f("uq_artifact_storages_name")), - sa.UniqueConstraint("storage_id", name=op.f("uq_artifact_storages_storage_id")), ) conn = op.get_bind() @@ -129,10 +97,34 @@ def upgrade() -> None: _migrate_object_storages_to_artifact_storages(conn) _migrate_vfs_storages_to_artifact_storages(conn) + # Add FK constraints: child.id → artifact_storages.id (JTI) + op.create_foreign_key( + "fk_object_storages_id_artifact_storages", + "object_storages", + "artifact_storages", + ["id"], + ["id"], + ondelete="CASCADE", + ) + op.create_foreign_key( + "fk_vfs_storages_id_artifact_storages", + "vfs_storages", + "artifact_storages", + ["id"], + ["id"], + ondelete="CASCADE", + ) + def downgrade() -> None: conn = op.get_bind() + # Drop FK constraints + op.drop_constraint( + "fk_object_storages_id_artifact_storages", "object_storages", type_="foreignkey" + ) + op.drop_constraint("fk_vfs_storages_id_artifact_storages", "vfs_storages", type_="foreignkey") + # Add name column back to object_storages op.add_column( "object_storages", diff --git a/src/ai/backend/manager/models/alembic/versions/7b5764643926_apply_jti_to_artifact_storages.py b/src/ai/backend/manager/models/alembic/versions/7b5764643926_apply_jti_to_artifact_storages.py deleted file mode 100644 index 0cefcf29f54..00000000000 --- a/src/ai/backend/manager/models/alembic/versions/7b5764643926_apply_jti_to_artifact_storages.py +++ /dev/null @@ -1,81 +0,0 @@ -"""Apply joined table inheritance to artifact_storages - -Revision ID: 7b5764643926 -Revises: 35dfab3b0662 -Create Date: 2026-02-26 00:00:00.000000 - -""" - -import sqlalchemy as sa -from alembic import op - -from ai.backend.manager.models.base import GUID - -# revision identifiers, used by Alembic. -revision = "7b5764643926" -down_revision = "35dfab3b0662" -branch_labels = None -depends_on = None - - -def upgrade() -> None: - conn = op.get_bind() - - # 1. For each row in artifact_storages, set id = storage_id (align PKs) - # We need to update artifact_storages.id to match the child table id (storage_id). - # Since id is a PK, we need to handle this carefully. - conn.execute( - sa.text(""" - UPDATE artifact_storages SET id = storage_id - """) - ) - - # 2. Drop the storage_id unique constraint and column - op.drop_constraint("uq_artifact_storages_storage_id", "artifact_storages", type_="unique") - op.drop_column("artifact_storages", "storage_id") - - # 3. Add FK constraints: child.id -> artifact_storages.id - op.create_foreign_key( - "fk_object_storages_id_artifact_storages", - "object_storages", - "artifact_storages", - ["id"], - ["id"], - ondelete="CASCADE", - ) - op.create_foreign_key( - "fk_vfs_storages_id_artifact_storages", - "vfs_storages", - "artifact_storages", - ["id"], - ["id"], - ondelete="CASCADE", - ) - - -def downgrade() -> None: - conn = op.get_bind() - - # 1. Drop FK constraints - op.drop_constraint("fk_vfs_storages_id_artifact_storages", "vfs_storages", type_="foreignkey") - op.drop_constraint( - "fk_object_storages_id_artifact_storages", "object_storages", type_="foreignkey" - ) - - # 2. Re-add storage_id column (copy id into it, since they were aligned) - op.add_column( - "artifact_storages", - sa.Column("storage_id", GUID(), nullable=True), - ) - - conn.execute( - sa.text(""" - UPDATE artifact_storages SET storage_id = id - """) - ) - - op.alter_column("artifact_storages", "storage_id", nullable=False) - - op.create_unique_constraint( - "uq_artifact_storages_storage_id", "artifact_storages", ["storage_id"] - ) From 032aab382fb45bd11f44cf1ddddf8e1e77932aea Mon Sep 17 00:00:00 2001 From: jopemachine Date: Wed, 4 Mar 2026 03:57:46 +0000 Subject: [PATCH 07/10] wip --- .../api/gql/data_loader/object_storage/loader.py | 4 ++-- .../api/gql/data_loader/vfs_storage/loader.py | 4 ++-- .../manager/api/object_storage/test_dataloader.py | 13 +++++++++---- .../unit/manager/api/vfs_storage/test_dataloader.py | 13 +++++++++---- .../object_storage/test_object_storage_service.py | 3 ++- .../vfs_storage/test_vfs_storage_service.py | 3 ++- 6 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/ai/backend/manager/api/gql/data_loader/object_storage/loader.py b/src/ai/backend/manager/api/gql/data_loader/object_storage/loader.py index 3ba443d5b85..8c30b4c73f9 100644 --- a/src/ai/backend/manager/api/gql/data_loader/object_storage/loader.py +++ b/src/ai/backend/manager/api/gql/data_loader/object_storage/loader.py @@ -1,8 +1,8 @@ from __future__ import annotations -import uuid from collections.abc import Sequence +from ai.backend.common.types import ArtifactStorageId from ai.backend.manager.data.object_storage.types import ObjectStorageData from ai.backend.manager.repositories.base import BatchQuerier, OffsetPagination from ai.backend.manager.repositories.object_storage.options import ObjectStorageConditions @@ -12,7 +12,7 @@ async def load_object_storages_by_ids( processor: ObjectStorageProcessors, - storage_ids: Sequence[uuid.UUID], + storage_ids: Sequence[ArtifactStorageId], ) -> list[ObjectStorageData | None]: """Batch load object storages by their IDs. diff --git a/src/ai/backend/manager/api/gql/data_loader/vfs_storage/loader.py b/src/ai/backend/manager/api/gql/data_loader/vfs_storage/loader.py index 4fa1b707cf6..19c94a5d1b6 100644 --- a/src/ai/backend/manager/api/gql/data_loader/vfs_storage/loader.py +++ b/src/ai/backend/manager/api/gql/data_loader/vfs_storage/loader.py @@ -1,8 +1,8 @@ from __future__ import annotations -import uuid from collections.abc import Sequence +from ai.backend.common.types import ArtifactStorageId from ai.backend.manager.data.vfs_storage.types import VFSStorageData from ai.backend.manager.repositories.base import BatchQuerier, OffsetPagination from ai.backend.manager.repositories.vfs_storage.options import VFSStorageConditions @@ -12,7 +12,7 @@ async def load_vfs_storages_by_ids( processor: VFSStorageProcessors, - storage_ids: Sequence[uuid.UUID], + storage_ids: Sequence[ArtifactStorageId], ) -> list[VFSStorageData | None]: """Batch load VFS storages by their IDs. diff --git a/tests/unit/manager/api/object_storage/test_dataloader.py b/tests/unit/manager/api/object_storage/test_dataloader.py index 2ae5afc075c..71dabc4c7b4 100644 --- a/tests/unit/manager/api/object_storage/test_dataloader.py +++ b/tests/unit/manager/api/object_storage/test_dataloader.py @@ -5,6 +5,7 @@ import uuid from unittest.mock import AsyncMock, MagicMock +from ai.backend.common.types import ArtifactStorageId from ai.backend.manager.api.gql.data_loader.object_storage.loader import ( load_object_storages_by_ids, ) @@ -15,7 +16,7 @@ class TestLoadObjectStoragesByIds: """Tests for load_object_storages_by_ids function.""" @staticmethod - def create_mock_storage(storage_id: uuid.UUID) -> MagicMock: + def create_mock_storage(storage_id: ArtifactStorageId) -> MagicMock: mock = MagicMock(spec=ObjectStorageData) mock.id = storage_id return mock @@ -43,7 +44,11 @@ async def test_empty_ids_returns_empty_list(self) -> None: async def test_returns_storages_in_request_order(self) -> None: # Given - id1, id2, id3 = uuid.uuid4(), uuid.uuid4(), uuid.uuid4() + id1, id2, id3 = ( + ArtifactStorageId(uuid.uuid4()), + ArtifactStorageId(uuid.uuid4()), + ArtifactStorageId(uuid.uuid4()), + ) storage1 = self.create_mock_storage(id1) storage2 = self.create_mock_storage(id2) storage3 = self.create_mock_storage(id3) @@ -59,8 +64,8 @@ async def test_returns_storages_in_request_order(self) -> None: async def test_returns_none_for_missing_ids(self) -> None: # Given - existing_id = uuid.uuid4() - missing_id = uuid.uuid4() + existing_id = ArtifactStorageId(uuid.uuid4()) + missing_id = ArtifactStorageId(uuid.uuid4()) existing_storage = self.create_mock_storage(existing_id) mock_processor = self.create_mock_processor([existing_storage]) diff --git a/tests/unit/manager/api/vfs_storage/test_dataloader.py b/tests/unit/manager/api/vfs_storage/test_dataloader.py index 17ce34c96d8..6ad8c222e0f 100644 --- a/tests/unit/manager/api/vfs_storage/test_dataloader.py +++ b/tests/unit/manager/api/vfs_storage/test_dataloader.py @@ -5,6 +5,7 @@ import uuid from unittest.mock import AsyncMock, MagicMock +from ai.backend.common.types import ArtifactStorageId from ai.backend.manager.api.gql.data_loader.vfs_storage.loader import load_vfs_storages_by_ids from ai.backend.manager.data.vfs_storage.types import VFSStorageData @@ -13,7 +14,7 @@ class TestLoadVFSStoragesByIds: """Tests for load_vfs_storages_by_ids function.""" @staticmethod - def create_mock_storage(storage_id: uuid.UUID) -> MagicMock: + def create_mock_storage(storage_id: ArtifactStorageId) -> MagicMock: return MagicMock(spec=VFSStorageData, id=storage_id) @staticmethod @@ -39,7 +40,11 @@ async def test_empty_ids_returns_empty_list(self) -> None: async def test_returns_storages_in_request_order(self) -> None: # Given - id1, id2, id3 = uuid.uuid4(), uuid.uuid4(), uuid.uuid4() + id1, id2, id3 = ( + ArtifactStorageId(uuid.uuid4()), + ArtifactStorageId(uuid.uuid4()), + ArtifactStorageId(uuid.uuid4()), + ) storage1 = self.create_mock_storage(id1) storage2 = self.create_mock_storage(id2) storage3 = self.create_mock_storage(id3) @@ -55,8 +60,8 @@ async def test_returns_storages_in_request_order(self) -> None: async def test_returns_none_for_missing_ids(self) -> None: # Given - existing_id = uuid.uuid4() - missing_id = uuid.uuid4() + existing_id = ArtifactStorageId(uuid.uuid4()) + missing_id = ArtifactStorageId(uuid.uuid4()) existing_storage = self.create_mock_storage(existing_id) mock_processor = self.create_mock_processor([existing_storage]) diff --git a/tests/unit/manager/services/object_storage/test_object_storage_service.py b/tests/unit/manager/services/object_storage/test_object_storage_service.py index 5aa493b8491..cb894826e6c 100644 --- a/tests/unit/manager/services/object_storage/test_object_storage_service.py +++ b/tests/unit/manager/services/object_storage/test_object_storage_service.py @@ -11,6 +11,7 @@ import pytest from ai.backend.common.data.storage.types import ArtifactStorageType +from ai.backend.common.types import ArtifactStorageId from ai.backend.manager.data.object_storage.types import ( ObjectStorageData, ObjectStorageListResult, @@ -75,7 +76,7 @@ def object_storage_service( def sample_object_storage_data(self) -> ObjectStorageData: """Create sample Object storage data""" return ObjectStorageData( - id=uuid4(), + id=ArtifactStorageId(uuid4()), name="test-object-storage", type=ArtifactStorageType.OBJECT_STORAGE, host="storage-proxy-1", diff --git a/tests/unit/manager/services/vfs_storage/test_vfs_storage_service.py b/tests/unit/manager/services/vfs_storage/test_vfs_storage_service.py index 66919e04726..1fac5cdf065 100644 --- a/tests/unit/manager/services/vfs_storage/test_vfs_storage_service.py +++ b/tests/unit/manager/services/vfs_storage/test_vfs_storage_service.py @@ -12,6 +12,7 @@ import pytest from ai.backend.common.data.storage.types import ArtifactStorageType +from ai.backend.common.types import ArtifactStorageId from ai.backend.manager.data.vfs_storage.types import VFSStorageData, VFSStorageListResult from ai.backend.manager.repositories.base import BatchQuerier, OffsetPagination from ai.backend.manager.repositories.vfs_storage.repository import VFSStorageRepository @@ -41,7 +42,7 @@ def vfs_storage_service( def sample_vfs_storage_data(self) -> VFSStorageData: """Create sample VFS storage data""" return VFSStorageData( - id=uuid4(), + id=ArtifactStorageId(uuid4()), name="test-vfs-storage", type=ArtifactStorageType.VFS_STORAGE, host="localhost", From 5371c73bf8824df524b54831b21083bd77869e9d Mon Sep 17 00:00:00 2001 From: jopemachine Date: Wed, 4 Mar 2026 03:59:36 +0000 Subject: [PATCH 08/10] wip --- .../35dfab3b0662_add_artifact_storages_common_table.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py b/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py index b1e948a4904..3966812ba79 100644 --- a/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py +++ b/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py @@ -1,7 +1,7 @@ """Add artifact_storages common table with JTI Revision ID: 35dfab3b0662 -Revises: ffcf0ed13a26 +Revises: 3f5c20f7bb07 Create Date: 2025-12-02 09:24:21.050932 """ @@ -13,7 +13,7 @@ # revision identifiers, used by Alembic. revision = "35dfab3b0662" -down_revision = "ffcf0ed13a26" +down_revision = "3f5c20f7bb07" branch_labels = None depends_on = None From 8be3451783822271fbeb9f9409fd3190b174d737 Mon Sep 17 00:00:00 2001 From: jopemachine Date: Wed, 4 Mar 2026 04:37:29 +0000 Subject: [PATCH 09/10] test: Fix Broken tests --- tests/component/object_storage/conftest.py | 22 ++++++++++++++++++- .../test_object_storage_repository.py | 2 ++ .../test_vfs_storage_repository.py | 3 ++- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/tests/component/object_storage/conftest.py b/tests/component/object_storage/conftest.py index 30c8421db64..feafda73790 100644 --- a/tests/component/object_storage/conftest.py +++ b/tests/component/object_storage/conftest.py @@ -19,6 +19,7 @@ from ai.backend.manager.api.rest.object_storage.registry import register_object_storage_routes from ai.backend.manager.api.rest.types import ModuleRegistrar from ai.backend.manager.api.types import CleanupContext +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.object_storage import ObjectStorageRow from ai.backend.manager.models.storage_namespace.row import StorageNamespaceRow from ai.backend.manager.repositories.repositories import Repositories @@ -143,7 +144,21 @@ async def _create(**overrides: Any) -> ObjectStorageFixtureData: } defaults.update(overrides) async with db_engine.begin() as conn: - await conn.execute(sa.insert(ObjectStorageRow.__table__).values(**defaults)) + # Insert parent row first (Joined Table Inheritance) + await conn.execute( + sa.insert(ArtifactStorageRow.__table__).values( + id=defaults["id"], + name=defaults["name"], + type="object_storage", + ) + ) + # Insert child row with remaining columns + child_cols = { + k: v + for k, v in defaults.items() + if k in ("id", "host", "access_key", "secret_key", "endpoint", "region") + } + await conn.execute(sa.insert(ObjectStorageRow.__table__).values(**child_cols)) created_ids.append(defaults["id"]) return defaults @@ -159,6 +174,11 @@ async def _create(**overrides: Any) -> ObjectStorageFixtureData: await conn.execute( sa.delete(ObjectStorageRow.__table__).where(ObjectStorageRow.__table__.c.id == sid) ) + await conn.execute( + sa.delete(ArtifactStorageRow.__table__).where( + ArtifactStorageRow.__table__.c.id == sid + ) + ) @pytest.fixture() diff --git a/tests/unit/manager/repositories/object_storage/test_object_storage_repository.py b/tests/unit/manager/repositories/object_storage/test_object_storage_repository.py index 3be3bb1844d..d665e96c58e 100644 --- a/tests/unit/manager/repositories/object_storage/test_object_storage_repository.py +++ b/tests/unit/manager/repositories/object_storage/test_object_storage_repository.py @@ -11,6 +11,7 @@ import pytest from ai.backend.manager.errors.object_storage import ObjectStorageNotFoundError +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.object_storage import ObjectStorageRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.repositories.base import BatchQuerier, OffsetPagination @@ -34,6 +35,7 @@ async def db_with_cleanup( async with with_tables( database_connection, [ + ArtifactStorageRow, ObjectStorageRow, ], ): diff --git a/tests/unit/manager/repositories/vfs_storage/test_vfs_storage_repository.py b/tests/unit/manager/repositories/vfs_storage/test_vfs_storage_repository.py index 7b139770028..4d40d39c685 100644 --- a/tests/unit/manager/repositories/vfs_storage/test_vfs_storage_repository.py +++ b/tests/unit/manager/repositories/vfs_storage/test_vfs_storage_repository.py @@ -10,6 +10,7 @@ import pytest +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.models.vfs_storage import VFSStorageRow from ai.backend.manager.repositories.base import BatchQuerier, OffsetPagination @@ -32,7 +33,7 @@ async def db_with_cleanup( """Database connection with tables created. TRUNCATE CASCADE handles cleanup.""" async with with_tables( database_connection, - [VFSStorageRow], + [ArtifactStorageRow, VFSStorageRow], ): yield database_connection From 18253bc0e680f473e2dc41df2c6e5b24d78f6cc3 Mon Sep 17 00:00:00 2001 From: jopemachine Date: Wed, 4 Mar 2026 06:41:36 +0000 Subject: [PATCH 10/10] test: Add --- .../test_object_storage_repository.py | 129 ++++++++++++ .../test_vfs_storage_repository.py | 184 ++++++++++++++++++ 2 files changed, 313 insertions(+) diff --git a/tests/unit/manager/repositories/object_storage/test_object_storage_repository.py b/tests/unit/manager/repositories/object_storage/test_object_storage_repository.py index d665e96c58e..95afca4fbbf 100644 --- a/tests/unit/manager/repositories/object_storage/test_object_storage_repository.py +++ b/tests/unit/manager/repositories/object_storage/test_object_storage_repository.py @@ -15,7 +15,12 @@ from ai.backend.manager.models.object_storage import ObjectStorageRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.repositories.base import BatchQuerier, OffsetPagination +from ai.backend.manager.repositories.base.creator import Creator +from ai.backend.manager.repositories.base.updater import Updater +from ai.backend.manager.repositories.object_storage.creators import ObjectStorageCreatorSpec from ai.backend.manager.repositories.object_storage.repository import ObjectStorageRepository +from ai.backend.manager.repositories.object_storage.updaters import ObjectStorageUpdaterSpec +from ai.backend.manager.types import OptionalState from ai.backend.testutils.db import with_tables @@ -128,6 +133,26 @@ async def object_storage_repository( repo = ObjectStorageRepository(db=db_with_cleanup) yield repo + @pytest.fixture + def object_storage_creator_spec(self) -> ObjectStorageCreatorSpec: + """Spec for creating a new object storage""" + return ObjectStorageCreatorSpec( + name="new-object-storage", + host="storage-proxy-2", + access_key="new-access-key", + secret_key="new-secret-key", + endpoint="https://s3.new.example.com", + region="ap-northeast-2", + ) + + @pytest.fixture + def object_storage_updater_spec(self) -> ObjectStorageUpdaterSpec: + """Spec for updating object storage fields""" + return ObjectStorageUpdaterSpec( + host=OptionalState.update("updated-host"), + endpoint=OptionalState.update("https://s3.updated.example.com"), + ) + # ========================================================================= # Tests - Get by ID # ========================================================================= @@ -341,3 +366,107 @@ async def test_search_with_pagination_filter_and_order( # Verify ordering is ascending result_names = [storage.name for storage in result.items] assert result_names == sorted(result_names) + + # ========================================================================= + # Tests - Create + # ========================================================================= + + async def test_create( + self, + object_storage_repository: ObjectStorageRepository, + object_storage_creator_spec: ObjectStorageCreatorSpec, + ) -> None: + """Test creating a new object storage via Creator""" + result = await object_storage_repository.create(Creator(spec=object_storage_creator_spec)) + + assert result.name == "new-object-storage" + assert result.host == "storage-proxy-2" + assert result.access_key == "new-access-key" + assert result.secret_key == "new-secret-key" + assert result.endpoint == "https://s3.new.example.com" + assert result.region == "ap-northeast-2" + assert result.id is not None + + # Verify persisted in DB + fetched = await object_storage_repository.get_by_id(result.id) + assert fetched.name == "new-object-storage" + + async def test_create_duplicate_name_raises_error( + self, + object_storage_repository: ObjectStorageRepository, + sample_storage_id: uuid.UUID, + object_storage_creator_spec: ObjectStorageCreatorSpec, + ) -> None: + """Test creating object storage with duplicate name raises error""" + duplicate_spec = ObjectStorageCreatorSpec( + name="test-object-storage", + host=object_storage_creator_spec.host, + access_key=object_storage_creator_spec.access_key, + secret_key=object_storage_creator_spec.secret_key, + endpoint=object_storage_creator_spec.endpoint, + region=object_storage_creator_spec.region, + ) + with pytest.raises(Exception): + await object_storage_repository.create(Creator(spec=duplicate_spec)) + + # ========================================================================= + # Tests - Update + # ========================================================================= + + async def test_update( + self, + object_storage_repository: ObjectStorageRepository, + sample_storage_id: uuid.UUID, + object_storage_updater_spec: ObjectStorageUpdaterSpec, + ) -> None: + """Test updating an existing object storage via Updater""" + updater = Updater(spec=object_storage_updater_spec, pk_value=sample_storage_id) + result = await object_storage_repository.update(updater) + + assert result.id == sample_storage_id + assert result.host == "updated-host" + assert result.endpoint == "https://s3.updated.example.com" + # Unchanged fields remain the same + assert result.name == "test-object-storage" + assert result.access_key == "test-access-key" + + async def test_update_not_found( + self, + object_storage_repository: ObjectStorageRepository, + ) -> None: + """Test updating non-existent object storage raises error""" + not_found_updater = Updater( + spec=ObjectStorageUpdaterSpec( + host=OptionalState.update("updated-host"), + ), + pk_value=uuid.uuid4(), + ) + + with pytest.raises(ObjectStorageNotFoundError): + await object_storage_repository.update(not_found_updater) + + # ========================================================================= + # Tests - Delete + # ========================================================================= + + async def test_delete( + self, + object_storage_repository: ObjectStorageRepository, + sample_storage_id: uuid.UUID, + ) -> None: + """Test deleting an existing object storage""" + deleted_id = await object_storage_repository.delete(sample_storage_id) + + assert deleted_id == sample_storage_id + + # Verify it no longer exists + with pytest.raises(ObjectStorageNotFoundError): + await object_storage_repository.get_by_id(sample_storage_id) + + async def test_delete_not_found( + self, + object_storage_repository: ObjectStorageRepository, + ) -> None: + """Test deleting non-existent object storage raises error""" + with pytest.raises(ObjectStorageNotFoundError): + await object_storage_repository.delete(uuid.uuid4()) diff --git a/tests/unit/manager/repositories/vfs_storage/test_vfs_storage_repository.py b/tests/unit/manager/repositories/vfs_storage/test_vfs_storage_repository.py index 4d40d39c685..2ed6b310b2f 100644 --- a/tests/unit/manager/repositories/vfs_storage/test_vfs_storage_repository.py +++ b/tests/unit/manager/repositories/vfs_storage/test_vfs_storage_repository.py @@ -10,11 +10,17 @@ import pytest +from ai.backend.manager.errors.vfs_storage import VFSStorageNotFoundError from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.models.vfs_storage import VFSStorageRow from ai.backend.manager.repositories.base import BatchQuerier, OffsetPagination +from ai.backend.manager.repositories.base.creator import Creator +from ai.backend.manager.repositories.base.updater import Updater +from ai.backend.manager.repositories.vfs_storage.creators import VFSStorageCreatorSpec from ai.backend.manager.repositories.vfs_storage.repository import VFSStorageRepository +from ai.backend.manager.repositories.vfs_storage.updaters import VFSStorageUpdaterSpec +from ai.backend.manager.types import OptionalState from ai.backend.testutils.db import with_tables @@ -65,6 +71,23 @@ async def sample_vfs_storage_id( yield storage_id + @pytest.fixture + def vfs_storage_creator_spec(self) -> VFSStorageCreatorSpec: + """Spec for creating a new VFS storage""" + return VFSStorageCreatorSpec( + name="new-vfs-storage", + host="storage-host-1", + base_path="/mnt/nfs/new", + ) + + @pytest.fixture + def vfs_storage_updater_spec(self) -> VFSStorageUpdaterSpec: + """Spec for updating VFS storage fields""" + return VFSStorageUpdaterSpec( + host=OptionalState.update("updated-host"), + base_path=OptionalState.update("/mnt/vfs/updated"), + ) + @pytest.fixture async def sample_vfs_storages_for_filtering( self, @@ -310,3 +333,164 @@ async def test_search_vfs_storages_with_pagination_filter_and_order( # Verify ordering is ascending result_names = [storage.name for storage in result.items] assert result_names == sorted(result_names) + + # ========================================================================= + # Tests - Get by ID + # ========================================================================= + + async def test_get_by_id( + self, + vfs_storage_repository: VFSStorageRepository, + sample_vfs_storage_id: uuid.UUID, + ) -> None: + """Test retrieving VFS storage by ID""" + result = await vfs_storage_repository.get_by_id(sample_vfs_storage_id) + + assert result.id == sample_vfs_storage_id + assert result.name == "test-vfs-storage" + assert result.host == "localhost" + assert str(result.base_path) == "/mnt/vfs/test" + + async def test_get_by_id_not_found( + self, + vfs_storage_repository: VFSStorageRepository, + ) -> None: + """Test retrieving non-existent VFS storage raises error""" + with pytest.raises(VFSStorageNotFoundError): + await vfs_storage_repository.get_by_id(uuid.uuid4()) + + # ========================================================================= + # Tests - Get by Name + # ========================================================================= + + async def test_get_by_name( + self, + vfs_storage_repository: VFSStorageRepository, + sample_vfs_storage_id: uuid.UUID, + ) -> None: + """Test retrieving VFS storage by name""" + result = await vfs_storage_repository.get_by_name("test-vfs-storage") + + assert result.id == sample_vfs_storage_id + assert result.name == "test-vfs-storage" + + async def test_get_by_name_not_found( + self, + vfs_storage_repository: VFSStorageRepository, + ) -> None: + """Test retrieving non-existent VFS storage by name raises error""" + with pytest.raises(VFSStorageNotFoundError): + await vfs_storage_repository.get_by_name("non-existent-storage") + + # ========================================================================= + # Tests - Create + # ========================================================================= + + async def test_create( + self, + vfs_storage_repository: VFSStorageRepository, + vfs_storage_creator_spec: VFSStorageCreatorSpec, + ) -> None: + """Test creating a new VFS storage via Creator""" + result = await vfs_storage_repository.create(Creator(spec=vfs_storage_creator_spec)) + + assert result.name == "new-vfs-storage" + assert result.host == "storage-host-1" + assert str(result.base_path) == "/mnt/nfs/new" + assert result.id is not None + + # Verify persisted in DB + fetched = await vfs_storage_repository.get_by_id(result.id) + assert fetched.name == "new-vfs-storage" + + async def test_create_duplicate_name_raises_error( + self, + vfs_storage_repository: VFSStorageRepository, + sample_vfs_storage_id: uuid.UUID, + vfs_storage_creator_spec: VFSStorageCreatorSpec, + ) -> None: + """Test creating VFS storage with duplicate name raises error""" + duplicate_spec = VFSStorageCreatorSpec( + name="test-vfs-storage", + host=vfs_storage_creator_spec.host, + base_path=vfs_storage_creator_spec.base_path, + ) + with pytest.raises(Exception): + await vfs_storage_repository.create(Creator(spec=duplicate_spec)) + + # ========================================================================= + # Tests - Update + # ========================================================================= + + async def test_update( + self, + vfs_storage_repository: VFSStorageRepository, + sample_vfs_storage_id: uuid.UUID, + vfs_storage_updater_spec: VFSStorageUpdaterSpec, + ) -> None: + """Test updating an existing VFS storage via Updater""" + updater = Updater(spec=vfs_storage_updater_spec, pk_value=sample_vfs_storage_id) + result = await vfs_storage_repository.update(updater) + + assert result.id == sample_vfs_storage_id + assert result.host == "updated-host" + assert str(result.base_path) == "/mnt/vfs/updated" + # Unchanged fields remain the same + assert result.name == "test-vfs-storage" + + async def test_update_partial( + self, + vfs_storage_repository: VFSStorageRepository, + sample_vfs_storage_id: uuid.UUID, + ) -> None: + """Test partial update only changes specified fields""" + partial_spec = VFSStorageUpdaterSpec( + host=OptionalState.update("partial-updated-host"), + ) + updater = Updater(spec=partial_spec, pk_value=sample_vfs_storage_id) + result = await vfs_storage_repository.update(updater) + + assert result.host == "partial-updated-host" + # base_path should remain unchanged + assert str(result.base_path) == "/mnt/vfs/test" + + async def test_update_not_found( + self, + vfs_storage_repository: VFSStorageRepository, + ) -> None: + """Test updating non-existent VFS storage raises error""" + not_found_updater = Updater( + spec=VFSStorageUpdaterSpec( + host=OptionalState.update("updated-host"), + ), + pk_value=uuid.uuid4(), + ) + + with pytest.raises(VFSStorageNotFoundError): + await vfs_storage_repository.update(not_found_updater) + + # ========================================================================= + # Tests - Delete + # ========================================================================= + + async def test_delete( + self, + vfs_storage_repository: VFSStorageRepository, + sample_vfs_storage_id: uuid.UUID, + ) -> None: + """Test deleting an existing VFS storage""" + deleted_id = await vfs_storage_repository.delete(sample_vfs_storage_id) + + assert deleted_id == sample_vfs_storage_id + + # Verify it no longer exists + with pytest.raises(VFSStorageNotFoundError): + await vfs_storage_repository.get_by_id(sample_vfs_storage_id) + + async def test_delete_not_found( + self, + vfs_storage_repository: VFSStorageRepository, + ) -> None: + """Test deleting non-existent VFS storage raises error""" + with pytest.raises(VFSStorageNotFoundError): + await vfs_storage_repository.delete(uuid.uuid4())