diff --git a/changes/7057.feature.md b/changes/7057.feature.md new file mode 100644 index 00000000000..9900e6a92ca --- /dev/null +++ b/changes/7057.feature.md @@ -0,0 +1 @@ +Add `artifact_storages` common table for storage metadata management across object storage and VFS storage backends, and add `adminUpdateArtifactStorage` GraphQL mutation for updating artifact storage metadata (e.g., name) \ No newline at end of file diff --git a/docs/manager/graphql-reference/supergraph.graphql b/docs/manager/graphql-reference/supergraph.graphql index c213b8c3005..b4135af28a4 100644 --- a/docs/manager/graphql-reference/supergraph.graphql +++ b/docs/manager/graphql-reference/supergraph.graphql @@ -996,6 +996,31 @@ type ArtifactStatusChangedPayload artifactRevision: ArtifactRevision! } +"""Added in 26.3.0. Artifact storage metadata""" +type ArtifactStorage + @join__type(graph: STRAWBERRY) +{ + """The ID of the artifact storage""" + id: ID! + + """The name of the artifact storage""" + name: String! + + """The type of the artifact storage""" + type: ArtifactStorageType! +} + +""" +Added in 26.3.0. The type of artifact storage backend. OBJECT_STORAGE: Object storage (e.g., S3-compatible). VFS_STORAGE: Virtual folder storage. GIT_LFS: Git LFS storage. +""" +enum ArtifactStorageType + @join__type(graph: STRAWBERRY) +{ + OBJECT_STORAGE @join__enumValue(graph: STRAWBERRY) + VFS_STORAGE @join__enumValue(graph: STRAWBERRY) + GIT_LFS @join__enumValue(graph: STRAWBERRY) +} + enum ArtifactType @join__type(graph: STRAWBERRY) { @@ -7298,6 +7323,11 @@ type Mutation """Added in 25.16.0. Delete a VFS storage""" deleteVFSStorage(input: DeleteVFSStorageInput!): DeleteVFSStoragePayload! @join__field(graph: STRAWBERRY) + """ + Added in 26.3.0. Update artifact storage metadata (common fields like name) + """ + adminUpdateArtifactStorage(input: UpdateArtifactStorageInput!): UpdateArtifactStoragePayload! @join__field(graph: STRAWBERRY) + """ Added in 25.15.0. @@ -11844,6 +11874,25 @@ type UpdateArtifactPayload artifact: Artifact! } +"""Added in 26.3.0. Input for updating artifact storage metadata""" +input UpdateArtifactStorageInput + @join__type(graph: STRAWBERRY) +{ + """The ID of the artifact storage""" + id: ID! + + """The new name for the artifact storage""" + name: String +} + +"""Added in 26.3.0. Payload for updating artifact storage metadata""" +type UpdateArtifactStoragePayload + @join__type(graph: STRAWBERRY) +{ + """The updated artifact storage""" + artifactStorage: ArtifactStorage! +} + input UpdateAutoScalingRuleInput @join__type(graph: STRAWBERRY) { @@ -11951,7 +12000,6 @@ input UpdateObjectStorageInput @join__type(graph: STRAWBERRY) { id: ID! - name: String host: String accessKey: String secretKey: String @@ -12167,7 +12215,6 @@ input UpdateVFSStorageInput @join__type(graph: STRAWBERRY) { id: ID! - name: String host: String basePath: String } diff --git a/docs/manager/graphql-reference/v2-schema.graphql b/docs/manager/graphql-reference/v2-schema.graphql index 2755b07cdd5..d3e15184b32 100644 --- a/docs/manager/graphql-reference/v2-schema.graphql +++ b/docs/manager/graphql-reference/v2-schema.graphql @@ -709,6 +709,27 @@ type ArtifactStatusChangedPayload { artifactRevision: ArtifactRevision! } +"""Added in 26.3.0. Artifact storage metadata""" +type ArtifactStorage { + """The ID of the artifact storage""" + id: ID! + + """The name of the artifact storage""" + name: String! + + """The type of the artifact storage""" + type: ArtifactStorageType! +} + +""" +Added in 26.3.0. The type of artifact storage backend. OBJECT_STORAGE: Object storage (e.g., S3-compatible). VFS_STORAGE: Virtual folder storage. GIT_LFS: Git LFS storage. +""" +enum ArtifactStorageType { + OBJECT_STORAGE + VFS_STORAGE + GIT_LFS +} + enum ArtifactType { MODEL PACKAGE @@ -3731,6 +3752,11 @@ type Mutation { """Added in 25.16.0. Delete a VFS storage""" deleteVFSStorage(input: DeleteVFSStorageInput!): DeleteVFSStoragePayload! + """ + Added in 26.3.0. Update artifact storage metadata (common fields like name) + """ + adminUpdateArtifactStorage(input: UpdateArtifactStorageInput!): UpdateArtifactStoragePayload! + """ Added in 25.15.0. @@ -6942,6 +6968,21 @@ type UpdateArtifactPayload { artifact: Artifact! } +"""Added in 26.3.0. Input for updating artifact storage metadata""" +input UpdateArtifactStorageInput { + """The ID of the artifact storage""" + id: ID! + + """The new name for the artifact storage""" + name: String +} + +"""Added in 26.3.0. Payload for updating artifact storage metadata""" +type UpdateArtifactStoragePayload { + """The updated artifact storage""" + artifactStorage: ArtifactStorage! +} + input UpdateAutoScalingRuleInput { id: ID! metricSource: AutoScalingMetricSource @@ -7019,7 +7060,6 @@ type UpdateNotificationRulePayload { """Added in 25.14.0""" input UpdateObjectStorageInput { id: ID! - name: String host: String accessKey: String secretKey: String @@ -7209,7 +7249,6 @@ type UpdateUserV2Payload { """Added in 25.16.0. Input for updating VFS storage""" input UpdateVFSStorageInput { id: ID! - name: String host: String basePath: String } diff --git a/src/ai/backend/common/data/permission/types.py b/src/ai/backend/common/data/permission/types.py index 5a077d9f232..375f03bd5f7 100644 --- a/src/ai/backend/common/data/permission/types.py +++ b/src/ai/backend/common/data/permission/types.py @@ -103,6 +103,7 @@ class EntityType(enum.StrEnum): NETWORK = "network" NOTIFICATION = "notification" OBJECT_PERMISSION = "object_permission" + ARTIFACT_STORAGE = "artifact_storage" OBJECT_STORAGE = "object_storage" PERMISSION = "permission" AGENT_RESOURCE = "agent_resource" diff --git a/src/ai/backend/common/data/storage/exceptions.py b/src/ai/backend/common/data/storage/exceptions.py new file mode 100644 index 00000000000..1dfb59d85f8 --- /dev/null +++ b/src/ai/backend/common/data/storage/exceptions.py @@ -0,0 +1,21 @@ +from aiohttp import web + +from ai.backend.common.exception import ( + BackendAIError, + ErrorCode, + ErrorDetail, + ErrorDomain, + ErrorOperation, +) + + +class ArtifactStorageNotFoundError(BackendAIError, web.HTTPNotFound): + error_type = "https://api.backend.ai/probs/artifact-storage-not-found" + error_title = "Artifact Storage Not Found" + + def error_code(self) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.ARTIFACT_STORAGE, + operation=ErrorOperation.READ, + error_detail=ErrorDetail.NOT_FOUND, + ) diff --git a/src/ai/backend/common/data/storage/types.py b/src/ai/backend/common/data/storage/types.py index 971ce98c53b..a3e9023424b 100644 --- a/src/ai/backend/common/data/storage/types.py +++ b/src/ai/backend/common/data/storage/types.py @@ -1,10 +1,12 @@ from __future__ import annotations import enum +from dataclasses import dataclass from pydantic import BaseModel, ConfigDict from ai.backend.common.type_adapters import VFolderIDField +from ai.backend.common.types import ArtifactStorageId class VFolderStorageTarget(BaseModel): @@ -31,6 +33,15 @@ class ArtifactStorageType(enum.StrEnum): GIT_LFS = "git_lfs" +@dataclass(frozen=True) +class ArtifactStorageData: + """Data class for artifact storage metadata.""" + + id: ArtifactStorageId + name: str + type: ArtifactStorageType + + class ArtifactStorageImportStep(enum.StrEnum): DOWNLOAD = "download" VERIFY = "verify" diff --git a/src/ai/backend/common/exception.py b/src/ai/backend/common/exception.py index 454e705da57..bcbf82c2d3b 100644 --- a/src/ai/backend/common/exception.py +++ b/src/ai/backend/common/exception.py @@ -137,6 +137,7 @@ class ErrorDomain(enum.StrEnum): ARTIFACT = "artifact" ARTIFACT_REGISTRY = "artifact-registry" ARTIFACT_ASSOCIATION = "artifact-association" + ARTIFACT_STORAGE = "artifact-storage" OBJECT_STORAGE = "object-storage" VFS_STORAGE = "vfs-storage" STORAGE_NAMESPACE = "storage-namespace" diff --git a/src/ai/backend/common/metrics/metric.py b/src/ai/backend/common/metrics/metric.py index 3eb3435e266..24fdf61e45e 100644 --- a/src/ai/backend/common/metrics/metric.py +++ b/src/ai/backend/common/metrics/metric.py @@ -409,6 +409,7 @@ class LayerType(enum.StrEnum): AUTH_REPOSITORY = "auth_repository" ARTIFACT_REPOSITORY = "artifact_repository" ARTIFACT_REGISTRY_REPOSITORY = "artifact_registry_repository" + ARTIFACT_STORAGE_REPOSITORY = "artifact_storage_repository" AUDIT_LOG_REPOSITORY = "audit_log_repository" CONTAINER_REGISTRY_REPOSITORY = "container_registry_repository" DEPLOYMENT_REPOSITORY = "deployment_repository" diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index bb81680d6c6..15abf3f6f55 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -306,6 +306,8 @@ def check_typed_tuple(value: tuple[Any, ...], types: tuple[type, ...]) -> tuple[ RuleId = NewType("RuleId", UUID) SessionId = NewType("SessionId", UUID) KernelId = NewType("KernelId", UUID) +# ID of the `artifact_storages` common table (storage metadata). +ArtifactStorageId = NewType("ArtifactStorageId", UUID) ImageAlias = NewType("ImageAlias", str) ArchName = NewType("ArchName", str) diff --git a/src/ai/backend/manager/api/gql/artifact_storage.py b/src/ai/backend/manager/api/gql/artifact_storage.py new file mode 100644 index 00000000000..0d5b7df9bcb --- /dev/null +++ b/src/ai/backend/manager/api/gql/artifact_storage.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import uuid +from enum import StrEnum +from typing import Self + +import strawberry +from strawberry import ID, UNSET, Info + +from ai.backend.common.data.storage.types import ArtifactStorageData, ArtifactStorageType +from ai.backend.common.types import ArtifactStorageId +from ai.backend.manager.api.gql.utils import check_admin_only +from ai.backend.manager.data.artifact_storages.types import ArtifactStorageUpdaterSpec +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow +from ai.backend.manager.repositories.base.updater import Updater +from ai.backend.manager.services.artifact_storage.actions.update import ( + UpdateArtifactStorageAction, +) +from ai.backend.manager.types import OptionalState + +from .types import StrawberryGQLContext + + +@strawberry.enum( + name="ArtifactStorageType", + description=( + "Added in 26.3.0. The type of artifact storage backend. " + "OBJECT_STORAGE: Object storage (e.g., S3-compatible). " + "VFS_STORAGE: Virtual folder storage. " + "GIT_LFS: Git LFS storage." + ), +) +class ArtifactStorageTypeGQL(StrEnum): + """Artifact storage type enum.""" + + OBJECT_STORAGE = "object_storage" + VFS_STORAGE = "vfs_storage" + GIT_LFS = "git_lfs" + + @classmethod + def from_internal(cls, internal_type: ArtifactStorageType) -> ArtifactStorageTypeGQL: + """Convert internal ArtifactStorageType to GraphQL enum.""" + match internal_type: + case ArtifactStorageType.OBJECT_STORAGE: + return cls.OBJECT_STORAGE + case ArtifactStorageType.VFS_STORAGE: + return cls.VFS_STORAGE + case ArtifactStorageType.GIT_LFS: + return cls.GIT_LFS + + def to_internal(self) -> ArtifactStorageType: + """Convert GraphQL enum to internal ArtifactStorageType.""" + match self: + case ArtifactStorageTypeGQL.OBJECT_STORAGE: + return ArtifactStorageType.OBJECT_STORAGE + case ArtifactStorageTypeGQL.VFS_STORAGE: + return ArtifactStorageType.VFS_STORAGE + case ArtifactStorageTypeGQL.GIT_LFS: + return ArtifactStorageType.GIT_LFS + + +@strawberry.type(name="ArtifactStorage", description="Added in 26.3.0. Artifact storage metadata") +class ArtifactStorageGQL: + id: ID = strawberry.field(description="The ID of the artifact storage") + name: str = strawberry.field(description="The name of the artifact storage") + type: ArtifactStorageTypeGQL = strawberry.field(description="The type of the artifact storage") + + @classmethod + def from_dataclass(cls, data: ArtifactStorageData) -> Self: + return cls( + id=ID(str(data.id)), + name=data.name, + type=ArtifactStorageTypeGQL.from_internal(data.type), + ) + + +@strawberry.input( + name="UpdateArtifactStorageInput", + description="Added in 26.3.0. Input for updating artifact storage metadata", +) +class UpdateArtifactStorageInputGQL: + """Input for updating artifact storage metadata (common fields like name).""" + + id: ID = strawberry.field(description="The ID of the artifact storage") + name: str | None = strawberry.field( + default=UNSET, description="The new name for the artifact storage" + ) + + def to_updater(self) -> Updater[ArtifactStorageRow]: + spec = ArtifactStorageUpdaterSpec( + name=OptionalState[str].from_graphql(self.name), + ) + return Updater(spec=spec, pk_value=ArtifactStorageId(uuid.UUID(self.id))) + + +@strawberry.type( + name="UpdateArtifactStoragePayload", + description="Added in 26.3.0. Payload for updating artifact storage metadata", +) +class UpdateArtifactStoragePayloadGQL: + artifact_storage: ArtifactStorageGQL = strawberry.field( + description="The updated artifact storage" + ) + + +@strawberry.mutation( # type: ignore[misc] + name="adminUpdateArtifactStorage", + description="Added in 26.3.0. Update artifact storage metadata (common fields like name)", +) +async def update_artifact_storage( + input: UpdateArtifactStorageInputGQL, info: Info[StrawberryGQLContext] +) -> UpdateArtifactStoragePayloadGQL: + check_admin_only() + processors = info.context.processors + + action_result = await processors.artifact_storage.update.wait_for_complete( + UpdateArtifactStorageAction( + updater=input.to_updater(), + ) + ) + + return UpdateArtifactStoragePayloadGQL( + artifact_storage=ArtifactStorageGQL.from_dataclass(action_result.result), + ) diff --git a/src/ai/backend/manager/api/gql/data_loader/object_storage/loader.py b/src/ai/backend/manager/api/gql/data_loader/object_storage/loader.py index 3ba443d5b85..8c30b4c73f9 100644 --- a/src/ai/backend/manager/api/gql/data_loader/object_storage/loader.py +++ b/src/ai/backend/manager/api/gql/data_loader/object_storage/loader.py @@ -1,8 +1,8 @@ from __future__ import annotations -import uuid from collections.abc import Sequence +from ai.backend.common.types import ArtifactStorageId from ai.backend.manager.data.object_storage.types import ObjectStorageData from ai.backend.manager.repositories.base import BatchQuerier, OffsetPagination from ai.backend.manager.repositories.object_storage.options import ObjectStorageConditions @@ -12,7 +12,7 @@ async def load_object_storages_by_ids( processor: ObjectStorageProcessors, - storage_ids: Sequence[uuid.UUID], + storage_ids: Sequence[ArtifactStorageId], ) -> list[ObjectStorageData | None]: """Batch load object storages by their IDs. diff --git a/src/ai/backend/manager/api/gql/data_loader/vfs_storage/loader.py b/src/ai/backend/manager/api/gql/data_loader/vfs_storage/loader.py index 4fa1b707cf6..19c94a5d1b6 100644 --- a/src/ai/backend/manager/api/gql/data_loader/vfs_storage/loader.py +++ b/src/ai/backend/manager/api/gql/data_loader/vfs_storage/loader.py @@ -1,8 +1,8 @@ from __future__ import annotations -import uuid from collections.abc import Sequence +from ai.backend.common.types import ArtifactStorageId from ai.backend.manager.data.vfs_storage.types import VFSStorageData from ai.backend.manager.repositories.base import BatchQuerier, OffsetPagination from ai.backend.manager.repositories.vfs_storage.options import VFSStorageConditions @@ -12,7 +12,7 @@ async def load_vfs_storages_by_ids( processor: VFSStorageProcessors, - storage_ids: Sequence[uuid.UUID], + storage_ids: Sequence[ArtifactStorageId], ) -> list[VFSStorageData | None]: """Batch load VFS storages by their IDs. diff --git a/src/ai/backend/manager/api/gql/object_storage.py b/src/ai/backend/manager/api/gql/object_storage.py index 0097bb5b45b..f11afda6a7c 100644 --- a/src/ai/backend/manager/api/gql/object_storage.py +++ b/src/ai/backend/manager/api/gql/object_storage.py @@ -178,7 +178,6 @@ def to_creator(self) -> Creator[ObjectStorageRow]: @strawberry.input(description="Added in 25.14.0") class UpdateObjectStorageInput: id: ID - name: str | None = UNSET host: str | None = UNSET access_key: str | None = UNSET secret_key: str | None = UNSET @@ -187,7 +186,6 @@ class UpdateObjectStorageInput: def to_updater(self) -> Updater[ObjectStorageRow]: spec = ObjectStorageUpdaterSpec( - name=OptionalState[str].from_graphql(self.name), host=OptionalState[str].from_graphql(self.host), access_key=OptionalState[str].from_graphql(self.access_key), secret_key=OptionalState[str].from_graphql(self.secret_key), diff --git a/src/ai/backend/manager/api/gql/schema.py b/src/ai/backend/manager/api/gql/schema.py index d1e0703ab00..1634d79876f 100644 --- a/src/ai/backend/manager/api/gql/schema.py +++ b/src/ai/backend/manager/api/gql/schema.py @@ -39,6 +39,9 @@ update_artifact, ) from .artifact_registry import default_artifact_registry +from .artifact_storage import ( + update_artifact_storage, +) from .background_task import background_task_events from .deployment import ( # Revision @@ -427,6 +430,7 @@ class Mutation: create_vfs_storage = create_vfs_storage update_vfs_storage = update_vfs_storage delete_vfs_storage = delete_vfs_storage + update_artifact_storage = update_artifact_storage register_storage_namespace = register_storage_namespace unregister_storage_namespace = unregister_storage_namespace create_huggingface_registry = create_huggingface_registry diff --git a/src/ai/backend/manager/api/gql/vfs_storage.py b/src/ai/backend/manager/api/gql/vfs_storage.py index f450721546e..db5f6ae07bb 100644 --- a/src/ai/backend/manager/api/gql/vfs_storage.py +++ b/src/ai/backend/manager/api/gql/vfs_storage.py @@ -123,13 +123,11 @@ def to_creator(self) -> Creator[VFSStorageRow]: @strawberry.input(description="Added in 25.16.0. Input for updating VFS storage") class UpdateVFSStorageInput: id: ID - name: str | None = UNSET host: str | None = UNSET base_path: str | None = UNSET def to_updater(self) -> Updater[VFSStorageRow]: spec = VFSStorageUpdaterSpec( - name=OptionalState[str].from_graphql(self.name), host=OptionalState[str].from_graphql(self.host), base_path=OptionalState[str].from_graphql(self.base_path), ) diff --git a/src/ai/backend/manager/data/artifact_storages/__init__.py b/src/ai/backend/manager/data/artifact_storages/__init__.py new file mode 100644 index 00000000000..854ceaea72b --- /dev/null +++ b/src/ai/backend/manager/data/artifact_storages/__init__.py @@ -0,0 +1,3 @@ +from .types import ArtifactStorageUpdaterSpec + +__all__ = ("ArtifactStorageUpdaterSpec",) diff --git a/src/ai/backend/manager/data/artifact_storages/types.py b/src/ai/backend/manager/data/artifact_storages/types.py new file mode 100644 index 00000000000..39566293801 --- /dev/null +++ b/src/ai/backend/manager/data/artifact_storages/types.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, override + +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow +from ai.backend.manager.repositories.base.updater import UpdaterSpec +from ai.backend.manager.types import OptionalState + + +@dataclass +class ArtifactStorageUpdaterSpec(UpdaterSpec[ArtifactStorageRow]): + """UpdaterSpec for ArtifactStorageRow.""" + + name: OptionalState[str] = field(default_factory=OptionalState.nop) + + @property + @override + def row_class(self) -> type[ArtifactStorageRow]: + return ArtifactStorageRow + + @override + def build_values(self) -> dict[str, Any]: + values: dict[str, Any] = {} + self.name.update_dict(values, "name") + return values diff --git a/src/ai/backend/manager/data/object_storage/types.py b/src/ai/backend/manager/data/object_storage/types.py index 777e3bf5292..30dfce517e6 100644 --- a/src/ai/backend/manager/data/object_storage/types.py +++ b/src/ai/backend/manager/data/object_storage/types.py @@ -1,8 +1,8 @@ from __future__ import annotations -import uuid from dataclasses import dataclass +from ai.backend.common.data.storage.types import ArtifactStorageData from ai.backend.common.dto.manager.response import ObjectStorageResponse @@ -16,10 +16,8 @@ class ObjectStorageListResult: has_previous_page: bool -@dataclass -class ObjectStorageData: - id: uuid.UUID - name: str +@dataclass(frozen=True) +class ObjectStorageData(ArtifactStorageData): host: str access_key: str secret_key: str diff --git a/src/ai/backend/manager/data/vfs_storage/types.py b/src/ai/backend/manager/data/vfs_storage/types.py index d2b4ba9c380..eedc90eda81 100644 --- a/src/ai/backend/manager/data/vfs_storage/types.py +++ b/src/ai/backend/manager/data/vfs_storage/types.py @@ -1,9 +1,10 @@ from __future__ import annotations -import uuid from dataclasses import dataclass from pathlib import Path +from ai.backend.common.data.storage.types import ArtifactStorageData + @dataclass class VFSStorageListResult: @@ -15,9 +16,7 @@ class VFSStorageListResult: has_previous_page: bool -@dataclass -class VFSStorageData: - id: uuid.UUID - name: str +@dataclass(frozen=True) +class VFSStorageData(ArtifactStorageData): host: str base_path: Path diff --git a/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py b/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py new file mode 100644 index 00000000000..3966812ba79 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/35dfab3b0662_add_artifact_storages_common_table.py @@ -0,0 +1,166 @@ +"""Add artifact_storages common table with JTI + +Revision ID: 35dfab3b0662 +Revises: 3f5c20f7bb07 +Create Date: 2025-12-02 09:24:21.050932 + +""" + +import sqlalchemy as sa +from alembic import op + +from ai.backend.manager.models.base import GUID + +# revision identifiers, used by Alembic. +revision = "35dfab3b0662" +down_revision = "3f5c20f7bb07" +branch_labels = None +depends_on = None + + +def _migrate_object_storages_to_artifact_storages( + conn: sa.engine.Connection, +) -> None: + """Migrate existing object_storages records to artifact_storages (JTI: id = child id).""" + conn.execute( + sa.text(""" + INSERT INTO artifact_storages (id, name, type) + SELECT id, name, 'object_storage' + FROM object_storages + WHERE name IS NOT NULL + """) + ) + + # Drop the name column and constraint + op.drop_index("ix_object_storages_name", table_name="object_storages") + op.drop_column("object_storages", "name") + + +def _migrate_vfs_storages_to_artifact_storages( + conn: sa.engine.Connection, +) -> None: + """Migrate existing vfs_storages records to artifact_storages (JTI: id = child id).""" + conn.execute( + sa.text(""" + INSERT INTO artifact_storages (id, name, type) + SELECT id, name, 'vfs_storage' + FROM vfs_storages + WHERE name IS NOT NULL + """) + ) + + # Drop the name column and constraint + op.drop_index("ix_vfs_storages_name", table_name="vfs_storages") + op.drop_column("vfs_storages", "name") + + +def _migrate_artifact_storages_to_object_storages( + conn: sa.engine.Connection, +) -> None: + """Migrate data back from artifact_storages to object_storages.""" + conn.execute( + sa.text(""" + UPDATE object_storages o + SET name = a.name + FROM artifact_storages a + WHERE o.id = a.id AND a.type = 'object_storage' + """) + ) + + +def _migrate_artifact_storages_to_vfs_storages( + conn: sa.engine.Connection, +) -> None: + """Migrate data back from artifact_storages to vfs_storages.""" + conn.execute( + sa.text(""" + UPDATE vfs_storages v + SET name = a.name + FROM artifact_storages a + WHERE v.id = a.id AND a.type = 'vfs_storage' + """) + ) + + +def upgrade() -> None: + op.create_table( + "artifact_storages", + sa.Column("id", GUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("type", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("pk_artifact_storages")), + sa.UniqueConstraint("name", name=op.f("uq_artifact_storages_name")), + ) + + conn = op.get_bind() + + _migrate_object_storages_to_artifact_storages(conn) + _migrate_vfs_storages_to_artifact_storages(conn) + + # Add FK constraints: child.id → artifact_storages.id (JTI) + op.create_foreign_key( + "fk_object_storages_id_artifact_storages", + "object_storages", + "artifact_storages", + ["id"], + ["id"], + ondelete="CASCADE", + ) + op.create_foreign_key( + "fk_vfs_storages_id_artifact_storages", + "vfs_storages", + "artifact_storages", + ["id"], + ["id"], + ondelete="CASCADE", + ) + + +def downgrade() -> None: + conn = op.get_bind() + + # Drop FK constraints + op.drop_constraint( + "fk_object_storages_id_artifact_storages", "object_storages", type_="foreignkey" + ) + op.drop_constraint("fk_vfs_storages_id_artifact_storages", "vfs_storages", type_="foreignkey") + + # Add name column back to object_storages + op.add_column( + "object_storages", + sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=True), + ) + + _migrate_artifact_storages_to_object_storages(conn) + + # Make name column NOT NULL + op.alter_column( + "object_storages", + "name", + existing_type=sa.VARCHAR(), + nullable=False, + ) + + # Recreate constraints + op.create_index("ix_object_storages_name", "object_storages", ["name"], unique=True) + + # Add name column back to vfs_storages + op.add_column( + "vfs_storages", + sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=True), + ) + + _migrate_artifact_storages_to_vfs_storages(conn) + + # Make name column NOT NULL + op.alter_column( + "vfs_storages", + "name", + existing_type=sa.VARCHAR(), + nullable=False, + ) + + # Recreate constraints + op.create_index("ix_vfs_storages_name", "vfs_storages", ["name"], unique=True) + + op.drop_table("artifact_storages") diff --git a/src/ai/backend/manager/models/artifact_storages.py b/src/ai/backend/manager/models/artifact_storages.py new file mode 100644 index 00000000000..f48c5d90c45 --- /dev/null +++ b/src/ai/backend/manager/models/artifact_storages.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import logging +import uuid + +import sqlalchemy as sa +from sqlalchemy.orm import Mapped, mapped_column + +from ai.backend.common.data.storage.types import ArtifactStorageData, ArtifactStorageType +from ai.backend.common.types import ArtifactStorageId +from ai.backend.logging import BraceStyleAdapter + +from .base import ( + GUID, + Base, +) + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + +__all__ = ("ArtifactStorageRow",) + + +class ArtifactStorageRow(Base): # type: ignore[misc] + """ + Common information of all artifact storage records. + Uses SQLAlchemy Joined Table Inheritance as the base class. + """ + + __tablename__ = "artifact_storages" + + id: Mapped[uuid.UUID] = mapped_column( + "id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()") + ) + name: Mapped[str] = mapped_column("name", sa.String, nullable=False, unique=True) + type: Mapped[str] = mapped_column("type", sa.String, nullable=False) + + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "base", + } + + def __str__(self) -> str: + return f"ArtifactStorageRow(id={self.id}, type={self.type}, name={self.name})" + + def __repr__(self) -> str: + return self.__str__() + + def to_dataclass(self) -> ArtifactStorageData: + return ArtifactStorageData( + id=ArtifactStorageId(self.id), + name=self.name, + type=ArtifactStorageType(self.type), + ) diff --git a/src/ai/backend/manager/models/association_artifacts_storages/row.py b/src/ai/backend/manager/models/association_artifacts_storages/row.py index 4627caf3724..49d79fc1426 100644 --- a/src/ai/backend/manager/models/association_artifacts_storages/row.py +++ b/src/ai/backend/manager/models/association_artifacts_storages/row.py @@ -77,6 +77,7 @@ class AssociationArtifactsStorageRow(Base): # type: ignore[misc] back_populates="association_artifacts_storages_rows", primaryjoin=_get_association_object_storage_join_cond, overlaps="vfs_storage_row", + viewonly=True, ) # only valid when storage_type is "vfs" @@ -85,4 +86,5 @@ class AssociationArtifactsStorageRow(Base): # type: ignore[misc] back_populates="association_artifacts_storages_rows", primaryjoin=_get_association_vfs_storage_join_cond, overlaps="object_storage_row", + viewonly=True, ) diff --git a/src/ai/backend/manager/models/object_storage/row.py b/src/ai/backend/manager/models/object_storage/row.py index 2a1d241b079..2d2f451e6ab 100644 --- a/src/ai/backend/manager/models/object_storage/row.py +++ b/src/ai/backend/manager/models/object_storage/row.py @@ -7,11 +7,13 @@ import sqlalchemy as sa from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship +from ai.backend.common.data.storage.types import ArtifactStorageType +from ai.backend.common.types import ArtifactStorageId from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.object_storage.types import ObjectStorageData +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.base import ( GUID, - Base, ) if TYPE_CHECKING: @@ -39,19 +41,19 @@ def _get_object_storage_namespace_join_cond() -> sa.ColumnElement[bool]: return foreign(StorageNamespaceRow.storage_id) == ObjectStorageRow.id -class ObjectStorageRow(Base): # type: ignore[misc] +class ObjectStorageRow(ArtifactStorageRow): """ Represents an object storage configuration. This model is used to store the details of object storage services such as access keys, endpoints. + Uses SQLAlchemy Joined Table Inheritance (child of ArtifactStorageRow). """ __tablename__ = "object_storages" id: Mapped[uuid.UUID] = mapped_column( - "id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()") + "id", GUID, sa.ForeignKey("artifact_storages.id", ondelete="CASCADE"), primary_key=True ) - name: Mapped[str] = mapped_column("name", sa.String, index=True, unique=True, nullable=False) host: Mapped[str] = mapped_column("host", sa.String, index=True, nullable=False) access_key: Mapped[str] = mapped_column( "access_key", @@ -80,19 +82,24 @@ class ObjectStorageRow(Base): # type: ignore[misc] back_populates="object_storage_row", primaryjoin=_get_object_storage_association_artifact_join_cond, overlaps="vfs_storage_row", + viewonly=True, ) ) namespace_rows: Mapped[list[StorageNamespaceRow]] = relationship( "StorageNamespaceRow", back_populates="object_storage_row", primaryjoin=_get_object_storage_namespace_join_cond, + viewonly=True, ) + __mapper_args__ = { + "polymorphic_identity": "object_storage", + } + def __str__(self) -> str: return ( f"ObjectStorageRow(" f"id={self.id}, " - f"name={self.name}, " f"host={self.host}, " f"access_key={self.access_key}, " f"secret_key={self.secret_key}, " @@ -105,8 +112,9 @@ def __repr__(self) -> str: def to_dataclass(self) -> ObjectStorageData: return ObjectStorageData( - id=self.id, + id=ArtifactStorageId(self.id), name=self.name, + type=ArtifactStorageType(self.type), host=self.host, access_key=self.access_key, secret_key=self.secret_key, diff --git a/src/ai/backend/manager/models/storage_namespace/row.py b/src/ai/backend/manager/models/storage_namespace/row.py index 4e7bea62c8b..2c4f71c79e8 100644 --- a/src/ai/backend/manager/models/storage_namespace/row.py +++ b/src/ai/backend/manager/models/storage_namespace/row.py @@ -49,6 +49,7 @@ class StorageNamespaceRow(Base): # type: ignore[misc] "ObjectStorageRow", back_populates="namespace_rows", primaryjoin=_get_storage_namespace_join_cond, + viewonly=True, ) def to_dataclass(self) -> StorageNamespaceData: diff --git a/src/ai/backend/manager/models/vfs_storage/row.py b/src/ai/backend/manager/models/vfs_storage/row.py index 11ca46f4041..b2485d85754 100644 --- a/src/ai/backend/manager/models/vfs_storage/row.py +++ b/src/ai/backend/manager/models/vfs_storage/row.py @@ -8,11 +8,13 @@ import sqlalchemy as sa from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship +from ai.backend.common.data.storage.types import ArtifactStorageType +from ai.backend.common.types import ArtifactStorageId from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.vfs_storage.types import VFSStorageData +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.base import ( GUID, - Base, ) if TYPE_CHECKING: @@ -33,19 +35,19 @@ def _get_vfs_storage_association_artifact_join_cond() -> sa.ColumnElement[bool]: return VFSStorageRow.id == foreign(AssociationArtifactsStorageRow.storage_namespace_id) -class VFSStorageRow(Base): # type: ignore[misc] +class VFSStorageRow(ArtifactStorageRow): """ Represents a VFS storage configuration. This model is used to store the details of VFS storage backends such as base paths, subpaths, and chunk sizes. + Uses SQLAlchemy Joined Table Inheritance (child of ArtifactStorageRow). """ __tablename__ = "vfs_storages" id: Mapped[uuid.UUID] = mapped_column( - "id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()") + "id", GUID, sa.ForeignKey("artifact_storages.id", ondelete="CASCADE"), primary_key=True ) - name: Mapped[str] = mapped_column("name", sa.String, index=True, unique=True, nullable=False) host: Mapped[str] = mapped_column("host", sa.String, nullable=False) base_path: Mapped[str] = mapped_column("base_path", sa.String, nullable=False) @@ -55,25 +57,25 @@ class VFSStorageRow(Base): # type: ignore[misc] back_populates="vfs_storage_row", primaryjoin=_get_vfs_storage_association_artifact_join_cond, overlaps="association_artifacts_storages_rows,object_storage_row", + viewonly=True, ) ) + __mapper_args__ = { + "polymorphic_identity": "vfs_storage", + } + def __str__(self) -> str: - return ( - f"VFSStorageRow(" - f"id={self.id}, " - f"name={self.name}, " - f"host={self.host}, " - f"base_path={self.base_path})" - ) + return f"VFSStorageRow(id={self.id}, host={self.host}, base_path={self.base_path})" def __repr__(self) -> str: return self.__str__() def to_dataclass(self) -> VFSStorageData: return VFSStorageData( - id=self.id, + id=ArtifactStorageId(self.id), name=self.name, + type=ArtifactStorageType(self.type), host=self.host, base_path=Path(self.base_path), ) diff --git a/src/ai/backend/manager/repositories/artifact_storage/__init__.py b/src/ai/backend/manager/repositories/artifact_storage/__init__.py new file mode 100644 index 00000000000..adc5b5889c3 --- /dev/null +++ b/src/ai/backend/manager/repositories/artifact_storage/__init__.py @@ -0,0 +1,7 @@ +from .repositories import ArtifactStorageRepositories +from .repository import ArtifactStorageRepository + +__all__ = ( + "ArtifactStorageRepositories", + "ArtifactStorageRepository", +) diff --git a/src/ai/backend/manager/repositories/artifact_storage/db_source/__init__.py b/src/ai/backend/manager/repositories/artifact_storage/db_source/__init__.py new file mode 100644 index 00000000000..9c06935ade8 --- /dev/null +++ b/src/ai/backend/manager/repositories/artifact_storage/db_source/__init__.py @@ -0,0 +1,3 @@ +from .db_source import ArtifactStorageDBSource + +__all__ = ("ArtifactStorageDBSource",) diff --git a/src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py b/src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py new file mode 100644 index 00000000000..fa3bf256173 --- /dev/null +++ b/src/ai/backend/manager/repositories/artifact_storage/db_source/db_source.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import uuid + +import sqlalchemy as sa + +from ai.backend.common.data.storage.exceptions import ArtifactStorageNotFoundError +from ai.backend.common.data.storage.types import ArtifactStorageData +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.repositories.base.updater import Updater, execute_updater + + +class ArtifactStorageDBSource: + """Database source for artifact storage operations.""" + + _db: ExtendedAsyncSAEngine + + def __init__(self, db: ExtendedAsyncSAEngine) -> None: + self._db = db + + async def get_by_id(self, storage_id: uuid.UUID) -> ArtifactStorageData: + """ + Get an existing artifact storage configuration from the database by ID. + """ + async with self._db.begin_session() as db_session: + query = sa.select(ArtifactStorageRow).where(ArtifactStorageRow.id == storage_id) + result = await db_session.execute(query) + row = result.scalar_one_or_none() + if row is None: + raise ArtifactStorageNotFoundError( + f"Artifact storage with ID {storage_id} not found." + ) + return row.to_dataclass() + + async def update( + self, + updater: Updater[ArtifactStorageRow], + ) -> ArtifactStorageData: + """ + Update an existing artifact storage configuration in the database. + """ + async with self._db.begin_session() as db_session: + await execute_updater(db_session, updater) + + artifact_storage_id = uuid.UUID(str(updater.pk_value)) + query = sa.select(ArtifactStorageRow).where( + ArtifactStorageRow.id == artifact_storage_id + ) + row_result = await db_session.execute(query) + row = row_result.scalar_one_or_none() + if row is None: + raise ArtifactStorageNotFoundError( + f"Artifact storage with ID {artifact_storage_id} not found." + ) + return row.to_dataclass() diff --git a/src/ai/backend/manager/repositories/artifact_storage/repositories.py b/src/ai/backend/manager/repositories/artifact_storage/repositories.py new file mode 100644 index 00000000000..316b4ef67de --- /dev/null +++ b/src/ai/backend/manager/repositories/artifact_storage/repositories.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from typing import Self + +from ai.backend.manager.repositories.types import RepositoryArgs + +from .repository import ArtifactStorageRepository + + +@dataclass +class ArtifactStorageRepositories: + repository: ArtifactStorageRepository + + @classmethod + def create(cls, args: RepositoryArgs) -> Self: + return cls( + repository=ArtifactStorageRepository( + db=args.db, + ), + ) diff --git a/src/ai/backend/manager/repositories/artifact_storage/repository.py b/src/ai/backend/manager/repositories/artifact_storage/repository.py new file mode 100644 index 00000000000..f2acae550f3 --- /dev/null +++ b/src/ai/backend/manager/repositories/artifact_storage/repository.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import uuid + +from ai.backend.common.data.storage.types import ArtifactStorageData +from ai.backend.common.exception import BackendAIError +from ai.backend.common.metrics.metric import DomainType, LayerType +from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy +from ai.backend.common.resilience.policies.retry import BackoffStrategy, RetryArgs, RetryPolicy +from ai.backend.common.resilience.resilience import Resilience +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.repositories.artifact_storage.db_source.db_source import ( + ArtifactStorageDBSource, +) +from ai.backend.manager.repositories.base.updater import Updater + +artifact_storage_repository_resilience = Resilience( + policies=[ + MetricPolicy( + MetricArgs( + domain=DomainType.REPOSITORY, + layer=LayerType.ARTIFACT_STORAGE_REPOSITORY, + ) + ), + RetryPolicy( + RetryArgs( + max_retries=10, + retry_delay=0.1, + backoff_strategy=BackoffStrategy.FIXED, + non_retryable_exceptions=(BackendAIError,), + ) + ), + ] +) + + +class ArtifactStorageRepository: + """Repository layer for artifact storage operations.""" + + _db_source: ArtifactStorageDBSource + + def __init__(self, db: ExtendedAsyncSAEngine) -> None: + self._db_source = ArtifactStorageDBSource(db) + + @artifact_storage_repository_resilience.apply() + async def get_by_id(self, storage_id: uuid.UUID) -> ArtifactStorageData: + return await self._db_source.get_by_id(storage_id) + + @artifact_storage_repository_resilience.apply() + async def update( + self, + updater: Updater[ArtifactStorageRow], + ) -> ArtifactStorageData: + return await self._db_source.update(updater) diff --git a/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py b/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py index 96695f9e806..0545c17241f 100644 --- a/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/object_storage/db_source/db_source.py @@ -14,7 +14,7 @@ from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.repositories.base import BatchQuerier, execute_batch_querier from ai.backend.manager.repositories.base.creator import Creator, execute_creator -from ai.backend.manager.repositories.base.updater import Updater, execute_updater +from ai.backend.manager.repositories.base.updater import Updater class ObjectStorageDBSource: @@ -53,7 +53,7 @@ async def get_by_id(self, storage_id: uuid.UUID) -> ObjectStorageData: async def get_by_namespace_id(self, storage_namespace_id: uuid.UUID) -> ObjectStorageData: """ - Get an existing object storage configuration from the database by ID. + Get an existing object storage configuration from the database by namespace ID. """ async with self._db.begin_readonly_session_read_committed() as db_session: query = ( @@ -76,26 +76,41 @@ async def get_by_namespace_id(self, storage_namespace_id: uuid.UUID) -> ObjectSt async def create(self, creator: Creator[ObjectStorageRow]) -> ObjectStorageData: """ Create a new object storage configuration in the database. + JTI handles inserting into both artifact_storages and object_storages. """ async with self._db.begin_session() as db_session: creator_result = await execute_creator(db_session, creator) return creator_result.row.to_dataclass() - async def update(self, updater: Updater[ObjectStorageRow]) -> ObjectStorageData: + async def update( + self, + updater: Updater[ObjectStorageRow], + ) -> ObjectStorageData: """ Update an existing object storage configuration in the database. + Uses plain UPDATE + SELECT instead of execute_updater's RETURNING+from_statement, + which is incompatible with JTI (RETURNING only includes child table columns). """ async with self._db.begin_session() as db_session: - result = await execute_updater(db_session, updater) - if result is None: - raise ObjectStorageNotFoundError( - f"Object storage with ID {updater.pk_value} not found." - ) - return result.row.to_dataclass() + storage_id = uuid.UUID(str(updater.pk_value)) + values = updater.spec.build_values() + if values: + table = ObjectStorageRow.__table__ + pk_col = list(table.primary_key.columns)[0] + update_stmt = sa.update(table).values(values).where(pk_col == storage_id) + await db_session.execute(update_stmt) + + query = sa.select(ObjectStorageRow).where(ObjectStorageRow.id == storage_id) + row_result = await db_session.execute(query) + row = row_result.scalar_one_or_none() + if row is None: + raise ObjectStorageNotFoundError(f"Object storage with ID {storage_id} not found.") + return row.to_dataclass() async def delete(self, storage_id: uuid.UUID) -> uuid.UUID: """ Delete an existing object storage configuration from the database. + FK cascade handles deleting the ArtifactStorageRow. """ async with self._db.begin_session() as db_session: delete_query = ( diff --git a/src/ai/backend/manager/repositories/object_storage/repository.py b/src/ai/backend/manager/repositories/object_storage/repository.py index 59eea2cd3bc..5b896093298 100644 --- a/src/ai/backend/manager/repositories/object_storage/repository.py +++ b/src/ai/backend/manager/repositories/object_storage/repository.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import uuid from ai.backend.common.exception import ( @@ -57,7 +59,10 @@ async def create(self, creator: Creator[ObjectStorageRow]) -> ObjectStorageData: return await self._db_source.create(creator) @object_storage_repository_resilience.apply() - async def update(self, updater: Updater[ObjectStorageRow]) -> ObjectStorageData: + async def update( + self, + updater: Updater[ObjectStorageRow], + ) -> ObjectStorageData: return await self._db_source.update(updater) @object_storage_repository_resilience.apply() diff --git a/src/ai/backend/manager/repositories/object_storage/updaters.py b/src/ai/backend/manager/repositories/object_storage/updaters.py index e9dd0d08ba9..27059dd7ee7 100644 --- a/src/ai/backend/manager/repositories/object_storage/updaters.py +++ b/src/ai/backend/manager/repositories/object_storage/updaters.py @@ -14,7 +14,6 @@ class ObjectStorageUpdaterSpec(UpdaterSpec[ObjectStorageRow]): """UpdaterSpec for object storage updates.""" - name: OptionalState[str] = field(default_factory=OptionalState.nop) host: OptionalState[str] = field(default_factory=OptionalState.nop) access_key: OptionalState[str] = field(default_factory=OptionalState.nop) secret_key: OptionalState[str] = field(default_factory=OptionalState.nop) @@ -29,7 +28,6 @@ def row_class(self) -> type[ObjectStorageRow]: @override def build_values(self) -> dict[str, Any]: to_update: dict[str, Any] = {} - self.name.update_dict(to_update, "name") self.host.update_dict(to_update, "host") self.access_key.update_dict(to_update, "access_key") self.secret_key.update_dict(to_update, "secret_key") diff --git a/src/ai/backend/manager/repositories/repositories.py b/src/ai/backend/manager/repositories/repositories.py index faf2b29c45b..c56d9f9a169 100644 --- a/src/ai/backend/manager/repositories/repositories.py +++ b/src/ai/backend/manager/repositories/repositories.py @@ -7,6 +7,9 @@ from ai.backend.manager.repositories.artifact_registry.repositories import ( ArtifactRegistryRepositories, ) +from ai.backend.manager.repositories.artifact_storage.repositories import ( + ArtifactStorageRepositories, +) from ai.backend.manager.repositories.audit_log.repositories import AuditLogRepositories from ai.backend.manager.repositories.auth.repositories import AuthRepositories from ai.backend.manager.repositories.container_registry.repositories import ( @@ -104,6 +107,7 @@ class Repositories: huggingface_registry: HuggingFaceRegistryRepositories artifact: ArtifactRepositories artifact_registry: ArtifactRegistryRepositories + artifact_storage: ArtifactStorageRepositories storage_namespace: StorageNamespaceRepositories audit_log: AuditLogRepositories @@ -146,6 +150,7 @@ def create(cls, args: RepositoryArgs) -> Self: artifact_repositories = ArtifactRepositories.create(args) huggingface_registry_repositories = HuggingFaceRegistryRepositories.create(args) artifact_registries = ArtifactRegistryRepositories.create(args) + artifact_storage_repositories = ArtifactStorageRepositories.create(args) storage_namespace_repositories = StorageNamespaceRepositories.create(args) audit_log_repositories = AuditLogRepositories.create(args) @@ -187,6 +192,7 @@ def create(cls, args: RepositoryArgs) -> Self: huggingface_registry=huggingface_registry_repositories, artifact=artifact_repositories, artifact_registry=artifact_registries, + artifact_storage=artifact_storage_repositories, storage_namespace=storage_namespace_repositories, audit_log=audit_log_repositories, ) diff --git a/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py b/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py index 50024e3a124..ebc9a7a142b 100644 --- a/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/vfs_storage/db_source/db_source.py @@ -12,7 +12,7 @@ from ai.backend.manager.models.vfs_storage import VFSStorageRow from ai.backend.manager.repositories.base import BatchQuerier, execute_batch_querier from ai.backend.manager.repositories.base.creator import Creator, execute_creator -from ai.backend.manager.repositories.base.updater import Updater, execute_updater +from ai.backend.manager.repositories.base.updater import Updater class VFSStorageDBSource: @@ -50,24 +50,41 @@ async def get_by_id(self, storage_id: uuid.UUID) -> VFSStorageData: async def create(self, creator: Creator[VFSStorageRow]) -> VFSStorageData: """ Create a new VFS storage configuration in the database. + JTI handles inserting into both artifact_storages and vfs_storages. """ async with self._db.begin_session() as db_session: creator_result = await execute_creator(db_session, creator) return creator_result.row.to_dataclass() - async def update(self, updater: Updater[VFSStorageRow]) -> VFSStorageData: + async def update( + self, + updater: Updater[VFSStorageRow], + ) -> VFSStorageData: """ Update an existing VFS storage configuration in the database. + Uses plain UPDATE + SELECT instead of execute_updater's RETURNING+from_statement, + which is incompatible with JTI (RETURNING only includes child table columns). """ async with self._db.begin_session() as db_session: - result = await execute_updater(db_session, updater) - if result is None: - raise VFSStorageNotFoundError(f"VFS storage with ID {updater.pk_value} not found.") - return result.row.to_dataclass() + storage_id = uuid.UUID(str(updater.pk_value)) + values = updater.spec.build_values() + if values: + table = VFSStorageRow.__table__ + pk_col = list(table.primary_key.columns)[0] + update_stmt = sa.update(table).values(values).where(pk_col == storage_id) + await db_session.execute(update_stmt) + + query = sa.select(VFSStorageRow).where(VFSStorageRow.id == storage_id) + row_result = await db_session.execute(query) + row = row_result.scalar_one_or_none() + if row is None: + raise VFSStorageNotFoundError(f"VFS storage with ID {storage_id} not found.") + return row.to_dataclass() async def delete(self, storage_id: uuid.UUID) -> uuid.UUID: """ Delete an existing VFS storage configuration from the database. + FK cascade handles deleting the ArtifactStorageRow. """ async with self._db.begin_session() as db_session: delete_query = ( diff --git a/src/ai/backend/manager/repositories/vfs_storage/repository.py b/src/ai/backend/manager/repositories/vfs_storage/repository.py index 5ab611a6233..57852b2da53 100644 --- a/src/ai/backend/manager/repositories/vfs_storage/repository.py +++ b/src/ai/backend/manager/repositories/vfs_storage/repository.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import uuid from ai.backend.common.exception import BackendAIError @@ -51,7 +53,10 @@ async def create(self, creator: Creator[VFSStorageRow]) -> VFSStorageData: return await self._db_source.create(creator) @vfs_storage_repository_resilience.apply() - async def update(self, updater: Updater[VFSStorageRow]) -> VFSStorageData: + async def update( + self, + updater: Updater[VFSStorageRow], + ) -> VFSStorageData: return await self._db_source.update(updater) @vfs_storage_repository_resilience.apply() diff --git a/src/ai/backend/manager/repositories/vfs_storage/updaters.py b/src/ai/backend/manager/repositories/vfs_storage/updaters.py index 96e8f682c2c..b85b6a71f84 100644 --- a/src/ai/backend/manager/repositories/vfs_storage/updaters.py +++ b/src/ai/backend/manager/repositories/vfs_storage/updaters.py @@ -14,7 +14,6 @@ class VFSStorageUpdaterSpec(UpdaterSpec[VFSStorageRow]): """UpdaterSpec for VFS storage updates.""" - name: OptionalState[str] = field(default_factory=OptionalState.nop) host: OptionalState[str] = field(default_factory=OptionalState.nop) base_path: OptionalState[str] = field(default_factory=OptionalState.nop) @@ -26,7 +25,6 @@ def row_class(self) -> type[VFSStorageRow]: @override def build_values(self) -> dict[str, Any]: to_update: dict[str, Any] = {} - self.name.update_dict(to_update, "name") self.host.update_dict(to_update, "host") self.base_path.update_dict(to_update, "base_path") return to_update diff --git a/src/ai/backend/manager/services/artifact_storage/__init__.py b/src/ai/backend/manager/services/artifact_storage/__init__.py new file mode 100644 index 00000000000..42c1c6fa761 --- /dev/null +++ b/src/ai/backend/manager/services/artifact_storage/__init__.py @@ -0,0 +1,7 @@ +from .processors import ArtifactStorageProcessors +from .service import ArtifactStorageService + +__all__ = ( + "ArtifactStorageProcessors", + "ArtifactStorageService", +) diff --git a/src/ai/backend/manager/services/artifact_storage/actions/__init__.py b/src/ai/backend/manager/services/artifact_storage/actions/__init__.py new file mode 100644 index 00000000000..670110e2da4 --- /dev/null +++ b/src/ai/backend/manager/services/artifact_storage/actions/__init__.py @@ -0,0 +1,8 @@ +from .base import ArtifactStorageAction +from .update import UpdateArtifactStorageAction, UpdateArtifactStorageActionResult + +__all__ = ( + "ArtifactStorageAction", + "UpdateArtifactStorageAction", + "UpdateArtifactStorageActionResult", +) diff --git a/src/ai/backend/manager/services/artifact_storage/actions/base.py b/src/ai/backend/manager/services/artifact_storage/actions/base.py new file mode 100644 index 00000000000..9625aeaff71 --- /dev/null +++ b/src/ai/backend/manager/services/artifact_storage/actions/base.py @@ -0,0 +1,11 @@ +from typing import override + +from ai.backend.common.data.permission.types import EntityType +from ai.backend.manager.actions.action import BaseAction + + +class ArtifactStorageAction(BaseAction): + @override + @classmethod + def entity_type(cls) -> EntityType: + return EntityType.ARTIFACT_STORAGE diff --git a/src/ai/backend/manager/services/artifact_storage/actions/update.py b/src/ai/backend/manager/services/artifact_storage/actions/update.py new file mode 100644 index 00000000000..c0acb5055ec --- /dev/null +++ b/src/ai/backend/manager/services/artifact_storage/actions/update.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import override + +from ai.backend.common.data.storage.types import ArtifactStorageData +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow +from ai.backend.manager.repositories.base.updater import Updater +from ai.backend.manager.services.artifact_storage.actions.base import ArtifactStorageAction + + +@dataclass +class UpdateArtifactStorageAction(ArtifactStorageAction): + updater: Updater[ArtifactStorageRow] + + @override + def entity_id(self) -> str | None: + return str(self.updater.pk_value) + + @override + @classmethod + def operation_type(cls) -> ActionOperationType: + return ActionOperationType.UPDATE + + +@dataclass +class UpdateArtifactStorageActionResult(BaseActionResult): + result: ArtifactStorageData + + @override + def entity_id(self) -> str | None: + return str(self.result.id) diff --git a/src/ai/backend/manager/services/artifact_storage/processors.py b/src/ai/backend/manager/services/artifact_storage/processors.py new file mode 100644 index 00000000000..1ede3b08bef --- /dev/null +++ b/src/ai/backend/manager/services/artifact_storage/processors.py @@ -0,0 +1,25 @@ +from typing import override + +from ai.backend.manager.actions.monitors.monitor import ActionMonitor +from ai.backend.manager.actions.processor import ActionProcessor +from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec +from ai.backend.manager.services.artifact_storage.actions.update import ( + UpdateArtifactStorageAction, + UpdateArtifactStorageActionResult, +) +from ai.backend.manager.services.artifact_storage.service import ArtifactStorageService + + +class ArtifactStorageProcessors(AbstractProcessorPackage): + update: ActionProcessor[UpdateArtifactStorageAction, UpdateArtifactStorageActionResult] + + def __init__( + self, service: ArtifactStorageService, action_monitors: list[ActionMonitor] + ) -> None: + self.update = ActionProcessor(service.update, action_monitors) + + @override + def supported_actions(self) -> list[ActionSpec]: + return [ + UpdateArtifactStorageAction.spec(), + ] diff --git a/src/ai/backend/manager/services/artifact_storage/service.py b/src/ai/backend/manager/services/artifact_storage/service.py new file mode 100644 index 00000000000..ee71c1724c6 --- /dev/null +++ b/src/ai/backend/manager/services/artifact_storage/service.py @@ -0,0 +1,32 @@ +import logging + +from ai.backend.logging.utils import BraceStyleAdapter +from ai.backend.manager.repositories.artifact_storage.repository import ArtifactStorageRepository +from ai.backend.manager.services.artifact_storage.actions.update import ( + UpdateArtifactStorageAction, + UpdateArtifactStorageActionResult, +) + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + + +class ArtifactStorageService: + """Service layer for artifact storage operations.""" + + _artifact_storage_repository: ArtifactStorageRepository + + def __init__( + self, + artifact_storage_repository: ArtifactStorageRepository, + ) -> None: + self._artifact_storage_repository = artifact_storage_repository + + async def update( + self, action: UpdateArtifactStorageAction + ) -> UpdateArtifactStorageActionResult: + """ + Update an existing artifact storage. + """ + log.info("Updating artifact storage with id: {}", action.updater.pk_value) + storage_data = await self._artifact_storage_repository.update(action.updater) + return UpdateArtifactStorageActionResult(result=storage_data) diff --git a/src/ai/backend/manager/services/object_storage/actions/create.py b/src/ai/backend/manager/services/object_storage/actions/create.py index 6ac9ab7e804..b135206bb9b 100644 --- a/src/ai/backend/manager/services/object_storage/actions/create.py +++ b/src/ai/backend/manager/services/object_storage/actions/create.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from typing import override diff --git a/src/ai/backend/manager/services/object_storage/actions/update.py b/src/ai/backend/manager/services/object_storage/actions/update.py index f51e8576bf3..3a84786d58e 100644 --- a/src/ai/backend/manager/services/object_storage/actions/update.py +++ b/src/ai/backend/manager/services/object_storage/actions/update.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from typing import override diff --git a/src/ai/backend/manager/services/processors.py b/src/ai/backend/manager/services/processors.py index e7985d56e15..e1077fd4fb2 100644 --- a/src/ai/backend/manager/services/processors.py +++ b/src/ai/backend/manager/services/processors.py @@ -38,6 +38,8 @@ from ai.backend.manager.services.artifact_registry.service import ArtifactRegistryService from ai.backend.manager.services.artifact_revision.processors import ArtifactRevisionProcessors from ai.backend.manager.services.artifact_revision.service import ArtifactRevisionService +from ai.backend.manager.services.artifact_storage.processors import ArtifactStorageProcessors +from ai.backend.manager.services.artifact_storage.service import ArtifactStorageService from ai.backend.manager.services.audit_log.processors import AuditLogProcessors from ai.backend.manager.services.audit_log.service import AuditLogService from ai.backend.manager.services.auth.processors import AuthProcessors @@ -202,6 +204,7 @@ class Services: artifact: ArtifactService artifact_revision: ArtifactRevisionService artifact_registry: ArtifactRegistryService + artifact_storage: ArtifactStorageService deployment: DeploymentService storage_namespace: StorageNamespaceService audit_log: AuditLogService @@ -406,6 +409,9 @@ def create(cls, args: ServiceArgs) -> Self: repositories.reservoir_registry.repository, repositories.artifact_registry.repository, ) + artifact_storage_service = ArtifactStorageService( + artifact_storage_repository=repositories.artifact_storage.repository, + ) deployment_service = DeploymentService( args.deployment_controller, args.deployment_controller._deployment_repository, @@ -459,6 +465,7 @@ def create(cls, args: ServiceArgs) -> Self: artifact=artifact_service, artifact_revision=artifact_revision_service, artifact_registry=artifact_registry_service, + artifact_storage=artifact_storage_service, deployment=deployment_service, storage_namespace=storage_namespace_service, audit_log=audit_log_service, @@ -511,6 +518,7 @@ class Processors(AbstractProcessorPackage): artifact: ArtifactProcessors artifact_registry: ArtifactRegistryProcessors artifact_revision: ArtifactRevisionProcessors + artifact_storage: ArtifactStorageProcessors deployment: DeploymentProcessors storage_namespace: StorageNamespaceProcessors audit_log: AuditLogProcessors @@ -585,6 +593,9 @@ def create(cls, args: ProcessorArgs, action_monitors: list[ActionMonitor]) -> Se artifact_revision_processors = ArtifactRevisionProcessors( services.artifact_revision, action_monitors ) + artifact_storage_processors = ArtifactStorageProcessors( + services.artifact_storage, action_monitors + ) deployment_processors = DeploymentProcessors(services.deployment, action_monitors) @@ -637,6 +648,7 @@ def create(cls, args: ProcessorArgs, action_monitors: list[ActionMonitor]) -> Se artifact=artifact_processors, artifact_registry=artifact_registry_processors, artifact_revision=artifact_revision_processors, + artifact_storage=artifact_storage_processors, deployment=deployment_processors, storage_namespace=storage_namespace_processors, audit_log=audit_log_processors, @@ -683,6 +695,7 @@ def supported_actions(self) -> list[ActionSpec]: *self.vfs_storage.supported_actions(), *self.artifact_registry.supported_actions(), *self.artifact_revision.supported_actions(), + *self.artifact_storage.supported_actions(), *self.artifact.supported_actions(), *(self.deployment.supported_actions() if self.deployment else []), *self.storage_namespace.supported_actions(), diff --git a/src/ai/backend/manager/services/vfs_storage/actions/create.py b/src/ai/backend/manager/services/vfs_storage/actions/create.py index b4a668eec77..d1de8b4ad2d 100644 --- a/src/ai/backend/manager/services/vfs_storage/actions/create.py +++ b/src/ai/backend/manager/services/vfs_storage/actions/create.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from typing import override diff --git a/src/ai/backend/manager/services/vfs_storage/actions/update.py b/src/ai/backend/manager/services/vfs_storage/actions/update.py index e33ae6f9d41..5821455f1fc 100644 --- a/src/ai/backend/manager/services/vfs_storage/actions/update.py +++ b/src/ai/backend/manager/services/vfs_storage/actions/update.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from typing import override diff --git a/tests/component/object_storage/conftest.py b/tests/component/object_storage/conftest.py index 30c8421db64..feafda73790 100644 --- a/tests/component/object_storage/conftest.py +++ b/tests/component/object_storage/conftest.py @@ -19,6 +19,7 @@ from ai.backend.manager.api.rest.object_storage.registry import register_object_storage_routes from ai.backend.manager.api.rest.types import ModuleRegistrar from ai.backend.manager.api.types import CleanupContext +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.object_storage import ObjectStorageRow from ai.backend.manager.models.storage_namespace.row import StorageNamespaceRow from ai.backend.manager.repositories.repositories import Repositories @@ -143,7 +144,21 @@ async def _create(**overrides: Any) -> ObjectStorageFixtureData: } defaults.update(overrides) async with db_engine.begin() as conn: - await conn.execute(sa.insert(ObjectStorageRow.__table__).values(**defaults)) + # Insert parent row first (Joined Table Inheritance) + await conn.execute( + sa.insert(ArtifactStorageRow.__table__).values( + id=defaults["id"], + name=defaults["name"], + type="object_storage", + ) + ) + # Insert child row with remaining columns + child_cols = { + k: v + for k, v in defaults.items() + if k in ("id", "host", "access_key", "secret_key", "endpoint", "region") + } + await conn.execute(sa.insert(ObjectStorageRow.__table__).values(**child_cols)) created_ids.append(defaults["id"]) return defaults @@ -159,6 +174,11 @@ async def _create(**overrides: Any) -> ObjectStorageFixtureData: await conn.execute( sa.delete(ObjectStorageRow.__table__).where(ObjectStorageRow.__table__.c.id == sid) ) + await conn.execute( + sa.delete(ArtifactStorageRow.__table__).where( + ArtifactStorageRow.__table__.c.id == sid + ) + ) @pytest.fixture() diff --git a/tests/unit/manager/api/object_storage/test_dataloader.py b/tests/unit/manager/api/object_storage/test_dataloader.py index 2ae5afc075c..71dabc4c7b4 100644 --- a/tests/unit/manager/api/object_storage/test_dataloader.py +++ b/tests/unit/manager/api/object_storage/test_dataloader.py @@ -5,6 +5,7 @@ import uuid from unittest.mock import AsyncMock, MagicMock +from ai.backend.common.types import ArtifactStorageId from ai.backend.manager.api.gql.data_loader.object_storage.loader import ( load_object_storages_by_ids, ) @@ -15,7 +16,7 @@ class TestLoadObjectStoragesByIds: """Tests for load_object_storages_by_ids function.""" @staticmethod - def create_mock_storage(storage_id: uuid.UUID) -> MagicMock: + def create_mock_storage(storage_id: ArtifactStorageId) -> MagicMock: mock = MagicMock(spec=ObjectStorageData) mock.id = storage_id return mock @@ -43,7 +44,11 @@ async def test_empty_ids_returns_empty_list(self) -> None: async def test_returns_storages_in_request_order(self) -> None: # Given - id1, id2, id3 = uuid.uuid4(), uuid.uuid4(), uuid.uuid4() + id1, id2, id3 = ( + ArtifactStorageId(uuid.uuid4()), + ArtifactStorageId(uuid.uuid4()), + ArtifactStorageId(uuid.uuid4()), + ) storage1 = self.create_mock_storage(id1) storage2 = self.create_mock_storage(id2) storage3 = self.create_mock_storage(id3) @@ -59,8 +64,8 @@ async def test_returns_storages_in_request_order(self) -> None: async def test_returns_none_for_missing_ids(self) -> None: # Given - existing_id = uuid.uuid4() - missing_id = uuid.uuid4() + existing_id = ArtifactStorageId(uuid.uuid4()) + missing_id = ArtifactStorageId(uuid.uuid4()) existing_storage = self.create_mock_storage(existing_id) mock_processor = self.create_mock_processor([existing_storage]) diff --git a/tests/unit/manager/api/vfs_storage/test_dataloader.py b/tests/unit/manager/api/vfs_storage/test_dataloader.py index 17ce34c96d8..6ad8c222e0f 100644 --- a/tests/unit/manager/api/vfs_storage/test_dataloader.py +++ b/tests/unit/manager/api/vfs_storage/test_dataloader.py @@ -5,6 +5,7 @@ import uuid from unittest.mock import AsyncMock, MagicMock +from ai.backend.common.types import ArtifactStorageId from ai.backend.manager.api.gql.data_loader.vfs_storage.loader import load_vfs_storages_by_ids from ai.backend.manager.data.vfs_storage.types import VFSStorageData @@ -13,7 +14,7 @@ class TestLoadVFSStoragesByIds: """Tests for load_vfs_storages_by_ids function.""" @staticmethod - def create_mock_storage(storage_id: uuid.UUID) -> MagicMock: + def create_mock_storage(storage_id: ArtifactStorageId) -> MagicMock: return MagicMock(spec=VFSStorageData, id=storage_id) @staticmethod @@ -39,7 +40,11 @@ async def test_empty_ids_returns_empty_list(self) -> None: async def test_returns_storages_in_request_order(self) -> None: # Given - id1, id2, id3 = uuid.uuid4(), uuid.uuid4(), uuid.uuid4() + id1, id2, id3 = ( + ArtifactStorageId(uuid.uuid4()), + ArtifactStorageId(uuid.uuid4()), + ArtifactStorageId(uuid.uuid4()), + ) storage1 = self.create_mock_storage(id1) storage2 = self.create_mock_storage(id2) storage3 = self.create_mock_storage(id3) @@ -55,8 +60,8 @@ async def test_returns_storages_in_request_order(self) -> None: async def test_returns_none_for_missing_ids(self) -> None: # Given - existing_id = uuid.uuid4() - missing_id = uuid.uuid4() + existing_id = ArtifactStorageId(uuid.uuid4()) + missing_id = ArtifactStorageId(uuid.uuid4()) existing_storage = self.create_mock_storage(existing_id) mock_processor = self.create_mock_processor([existing_storage]) diff --git a/tests/unit/manager/repositories/object_storage/test_object_storage_repository.py b/tests/unit/manager/repositories/object_storage/test_object_storage_repository.py index 3be3bb1844d..95afca4fbbf 100644 --- a/tests/unit/manager/repositories/object_storage/test_object_storage_repository.py +++ b/tests/unit/manager/repositories/object_storage/test_object_storage_repository.py @@ -11,10 +11,16 @@ import pytest from ai.backend.manager.errors.object_storage import ObjectStorageNotFoundError +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.object_storage import ObjectStorageRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.repositories.base import BatchQuerier, OffsetPagination +from ai.backend.manager.repositories.base.creator import Creator +from ai.backend.manager.repositories.base.updater import Updater +from ai.backend.manager.repositories.object_storage.creators import ObjectStorageCreatorSpec from ai.backend.manager.repositories.object_storage.repository import ObjectStorageRepository +from ai.backend.manager.repositories.object_storage.updaters import ObjectStorageUpdaterSpec +from ai.backend.manager.types import OptionalState from ai.backend.testutils.db import with_tables @@ -34,6 +40,7 @@ async def db_with_cleanup( async with with_tables( database_connection, [ + ArtifactStorageRow, ObjectStorageRow, ], ): @@ -126,6 +133,26 @@ async def object_storage_repository( repo = ObjectStorageRepository(db=db_with_cleanup) yield repo + @pytest.fixture + def object_storage_creator_spec(self) -> ObjectStorageCreatorSpec: + """Spec for creating a new object storage""" + return ObjectStorageCreatorSpec( + name="new-object-storage", + host="storage-proxy-2", + access_key="new-access-key", + secret_key="new-secret-key", + endpoint="https://s3.new.example.com", + region="ap-northeast-2", + ) + + @pytest.fixture + def object_storage_updater_spec(self) -> ObjectStorageUpdaterSpec: + """Spec for updating object storage fields""" + return ObjectStorageUpdaterSpec( + host=OptionalState.update("updated-host"), + endpoint=OptionalState.update("https://s3.updated.example.com"), + ) + # ========================================================================= # Tests - Get by ID # ========================================================================= @@ -339,3 +366,107 @@ async def test_search_with_pagination_filter_and_order( # Verify ordering is ascending result_names = [storage.name for storage in result.items] assert result_names == sorted(result_names) + + # ========================================================================= + # Tests - Create + # ========================================================================= + + async def test_create( + self, + object_storage_repository: ObjectStorageRepository, + object_storage_creator_spec: ObjectStorageCreatorSpec, + ) -> None: + """Test creating a new object storage via Creator""" + result = await object_storage_repository.create(Creator(spec=object_storage_creator_spec)) + + assert result.name == "new-object-storage" + assert result.host == "storage-proxy-2" + assert result.access_key == "new-access-key" + assert result.secret_key == "new-secret-key" + assert result.endpoint == "https://s3.new.example.com" + assert result.region == "ap-northeast-2" + assert result.id is not None + + # Verify persisted in DB + fetched = await object_storage_repository.get_by_id(result.id) + assert fetched.name == "new-object-storage" + + async def test_create_duplicate_name_raises_error( + self, + object_storage_repository: ObjectStorageRepository, + sample_storage_id: uuid.UUID, + object_storage_creator_spec: ObjectStorageCreatorSpec, + ) -> None: + """Test creating object storage with duplicate name raises error""" + duplicate_spec = ObjectStorageCreatorSpec( + name="test-object-storage", + host=object_storage_creator_spec.host, + access_key=object_storage_creator_spec.access_key, + secret_key=object_storage_creator_spec.secret_key, + endpoint=object_storage_creator_spec.endpoint, + region=object_storage_creator_spec.region, + ) + with pytest.raises(Exception): + await object_storage_repository.create(Creator(spec=duplicate_spec)) + + # ========================================================================= + # Tests - Update + # ========================================================================= + + async def test_update( + self, + object_storage_repository: ObjectStorageRepository, + sample_storage_id: uuid.UUID, + object_storage_updater_spec: ObjectStorageUpdaterSpec, + ) -> None: + """Test updating an existing object storage via Updater""" + updater = Updater(spec=object_storage_updater_spec, pk_value=sample_storage_id) + result = await object_storage_repository.update(updater) + + assert result.id == sample_storage_id + assert result.host == "updated-host" + assert result.endpoint == "https://s3.updated.example.com" + # Unchanged fields remain the same + assert result.name == "test-object-storage" + assert result.access_key == "test-access-key" + + async def test_update_not_found( + self, + object_storage_repository: ObjectStorageRepository, + ) -> None: + """Test updating non-existent object storage raises error""" + not_found_updater = Updater( + spec=ObjectStorageUpdaterSpec( + host=OptionalState.update("updated-host"), + ), + pk_value=uuid.uuid4(), + ) + + with pytest.raises(ObjectStorageNotFoundError): + await object_storage_repository.update(not_found_updater) + + # ========================================================================= + # Tests - Delete + # ========================================================================= + + async def test_delete( + self, + object_storage_repository: ObjectStorageRepository, + sample_storage_id: uuid.UUID, + ) -> None: + """Test deleting an existing object storage""" + deleted_id = await object_storage_repository.delete(sample_storage_id) + + assert deleted_id == sample_storage_id + + # Verify it no longer exists + with pytest.raises(ObjectStorageNotFoundError): + await object_storage_repository.get_by_id(sample_storage_id) + + async def test_delete_not_found( + self, + object_storage_repository: ObjectStorageRepository, + ) -> None: + """Test deleting non-existent object storage raises error""" + with pytest.raises(ObjectStorageNotFoundError): + await object_storage_repository.delete(uuid.uuid4()) diff --git a/tests/unit/manager/repositories/vfs_storage/test_vfs_storage_repository.py b/tests/unit/manager/repositories/vfs_storage/test_vfs_storage_repository.py index 7b139770028..2ed6b310b2f 100644 --- a/tests/unit/manager/repositories/vfs_storage/test_vfs_storage_repository.py +++ b/tests/unit/manager/repositories/vfs_storage/test_vfs_storage_repository.py @@ -10,10 +10,17 @@ import pytest +from ai.backend.manager.errors.vfs_storage import VFSStorageNotFoundError +from ai.backend.manager.models.artifact_storages import ArtifactStorageRow from ai.backend.manager.models.utils import ExtendedAsyncSAEngine from ai.backend.manager.models.vfs_storage import VFSStorageRow from ai.backend.manager.repositories.base import BatchQuerier, OffsetPagination +from ai.backend.manager.repositories.base.creator import Creator +from ai.backend.manager.repositories.base.updater import Updater +from ai.backend.manager.repositories.vfs_storage.creators import VFSStorageCreatorSpec from ai.backend.manager.repositories.vfs_storage.repository import VFSStorageRepository +from ai.backend.manager.repositories.vfs_storage.updaters import VFSStorageUpdaterSpec +from ai.backend.manager.types import OptionalState from ai.backend.testutils.db import with_tables @@ -32,7 +39,7 @@ async def db_with_cleanup( """Database connection with tables created. TRUNCATE CASCADE handles cleanup.""" async with with_tables( database_connection, - [VFSStorageRow], + [ArtifactStorageRow, VFSStorageRow], ): yield database_connection @@ -64,6 +71,23 @@ async def sample_vfs_storage_id( yield storage_id + @pytest.fixture + def vfs_storage_creator_spec(self) -> VFSStorageCreatorSpec: + """Spec for creating a new VFS storage""" + return VFSStorageCreatorSpec( + name="new-vfs-storage", + host="storage-host-1", + base_path="/mnt/nfs/new", + ) + + @pytest.fixture + def vfs_storage_updater_spec(self) -> VFSStorageUpdaterSpec: + """Spec for updating VFS storage fields""" + return VFSStorageUpdaterSpec( + host=OptionalState.update("updated-host"), + base_path=OptionalState.update("/mnt/vfs/updated"), + ) + @pytest.fixture async def sample_vfs_storages_for_filtering( self, @@ -309,3 +333,164 @@ async def test_search_vfs_storages_with_pagination_filter_and_order( # Verify ordering is ascending result_names = [storage.name for storage in result.items] assert result_names == sorted(result_names) + + # ========================================================================= + # Tests - Get by ID + # ========================================================================= + + async def test_get_by_id( + self, + vfs_storage_repository: VFSStorageRepository, + sample_vfs_storage_id: uuid.UUID, + ) -> None: + """Test retrieving VFS storage by ID""" + result = await vfs_storage_repository.get_by_id(sample_vfs_storage_id) + + assert result.id == sample_vfs_storage_id + assert result.name == "test-vfs-storage" + assert result.host == "localhost" + assert str(result.base_path) == "/mnt/vfs/test" + + async def test_get_by_id_not_found( + self, + vfs_storage_repository: VFSStorageRepository, + ) -> None: + """Test retrieving non-existent VFS storage raises error""" + with pytest.raises(VFSStorageNotFoundError): + await vfs_storage_repository.get_by_id(uuid.uuid4()) + + # ========================================================================= + # Tests - Get by Name + # ========================================================================= + + async def test_get_by_name( + self, + vfs_storage_repository: VFSStorageRepository, + sample_vfs_storage_id: uuid.UUID, + ) -> None: + """Test retrieving VFS storage by name""" + result = await vfs_storage_repository.get_by_name("test-vfs-storage") + + assert result.id == sample_vfs_storage_id + assert result.name == "test-vfs-storage" + + async def test_get_by_name_not_found( + self, + vfs_storage_repository: VFSStorageRepository, + ) -> None: + """Test retrieving non-existent VFS storage by name raises error""" + with pytest.raises(VFSStorageNotFoundError): + await vfs_storage_repository.get_by_name("non-existent-storage") + + # ========================================================================= + # Tests - Create + # ========================================================================= + + async def test_create( + self, + vfs_storage_repository: VFSStorageRepository, + vfs_storage_creator_spec: VFSStorageCreatorSpec, + ) -> None: + """Test creating a new VFS storage via Creator""" + result = await vfs_storage_repository.create(Creator(spec=vfs_storage_creator_spec)) + + assert result.name == "new-vfs-storage" + assert result.host == "storage-host-1" + assert str(result.base_path) == "/mnt/nfs/new" + assert result.id is not None + + # Verify persisted in DB + fetched = await vfs_storage_repository.get_by_id(result.id) + assert fetched.name == "new-vfs-storage" + + async def test_create_duplicate_name_raises_error( + self, + vfs_storage_repository: VFSStorageRepository, + sample_vfs_storage_id: uuid.UUID, + vfs_storage_creator_spec: VFSStorageCreatorSpec, + ) -> None: + """Test creating VFS storage with duplicate name raises error""" + duplicate_spec = VFSStorageCreatorSpec( + name="test-vfs-storage", + host=vfs_storage_creator_spec.host, + base_path=vfs_storage_creator_spec.base_path, + ) + with pytest.raises(Exception): + await vfs_storage_repository.create(Creator(spec=duplicate_spec)) + + # ========================================================================= + # Tests - Update + # ========================================================================= + + async def test_update( + self, + vfs_storage_repository: VFSStorageRepository, + sample_vfs_storage_id: uuid.UUID, + vfs_storage_updater_spec: VFSStorageUpdaterSpec, + ) -> None: + """Test updating an existing VFS storage via Updater""" + updater = Updater(spec=vfs_storage_updater_spec, pk_value=sample_vfs_storage_id) + result = await vfs_storage_repository.update(updater) + + assert result.id == sample_vfs_storage_id + assert result.host == "updated-host" + assert str(result.base_path) == "/mnt/vfs/updated" + # Unchanged fields remain the same + assert result.name == "test-vfs-storage" + + async def test_update_partial( + self, + vfs_storage_repository: VFSStorageRepository, + sample_vfs_storage_id: uuid.UUID, + ) -> None: + """Test partial update only changes specified fields""" + partial_spec = VFSStorageUpdaterSpec( + host=OptionalState.update("partial-updated-host"), + ) + updater = Updater(spec=partial_spec, pk_value=sample_vfs_storage_id) + result = await vfs_storage_repository.update(updater) + + assert result.host == "partial-updated-host" + # base_path should remain unchanged + assert str(result.base_path) == "/mnt/vfs/test" + + async def test_update_not_found( + self, + vfs_storage_repository: VFSStorageRepository, + ) -> None: + """Test updating non-existent VFS storage raises error""" + not_found_updater = Updater( + spec=VFSStorageUpdaterSpec( + host=OptionalState.update("updated-host"), + ), + pk_value=uuid.uuid4(), + ) + + with pytest.raises(VFSStorageNotFoundError): + await vfs_storage_repository.update(not_found_updater) + + # ========================================================================= + # Tests - Delete + # ========================================================================= + + async def test_delete( + self, + vfs_storage_repository: VFSStorageRepository, + sample_vfs_storage_id: uuid.UUID, + ) -> None: + """Test deleting an existing VFS storage""" + deleted_id = await vfs_storage_repository.delete(sample_vfs_storage_id) + + assert deleted_id == sample_vfs_storage_id + + # Verify it no longer exists + with pytest.raises(VFSStorageNotFoundError): + await vfs_storage_repository.get_by_id(sample_vfs_storage_id) + + async def test_delete_not_found( + self, + vfs_storage_repository: VFSStorageRepository, + ) -> None: + """Test deleting non-existent VFS storage raises error""" + with pytest.raises(VFSStorageNotFoundError): + await vfs_storage_repository.delete(uuid.uuid4()) diff --git a/tests/unit/manager/services/object_storage/test_object_storage_service.py b/tests/unit/manager/services/object_storage/test_object_storage_service.py index b4e53eaa5f8..cb894826e6c 100644 --- a/tests/unit/manager/services/object_storage/test_object_storage_service.py +++ b/tests/unit/manager/services/object_storage/test_object_storage_service.py @@ -10,6 +10,8 @@ import pytest +from ai.backend.common.data.storage.types import ArtifactStorageType +from ai.backend.common.types import ArtifactStorageId from ai.backend.manager.data.object_storage.types import ( ObjectStorageData, ObjectStorageListResult, @@ -74,8 +76,9 @@ def object_storage_service( def sample_object_storage_data(self) -> ObjectStorageData: """Create sample Object storage data""" return ObjectStorageData( - id=uuid4(), + id=ArtifactStorageId(uuid4()), name="test-object-storage", + type=ArtifactStorageType.OBJECT_STORAGE, host="storage-proxy-1", access_key="test-access-key", secret_key="test-secret-key", diff --git a/tests/unit/manager/services/vfs_storage/test_vfs_storage_service.py b/tests/unit/manager/services/vfs_storage/test_vfs_storage_service.py index 3ef710b5ab6..1fac5cdf065 100644 --- a/tests/unit/manager/services/vfs_storage/test_vfs_storage_service.py +++ b/tests/unit/manager/services/vfs_storage/test_vfs_storage_service.py @@ -11,6 +11,8 @@ import pytest +from ai.backend.common.data.storage.types import ArtifactStorageType +from ai.backend.common.types import ArtifactStorageId from ai.backend.manager.data.vfs_storage.types import VFSStorageData, VFSStorageListResult from ai.backend.manager.repositories.base import BatchQuerier, OffsetPagination from ai.backend.manager.repositories.vfs_storage.repository import VFSStorageRepository @@ -40,8 +42,9 @@ def vfs_storage_service( def sample_vfs_storage_data(self) -> VFSStorageData: """Create sample VFS storage data""" return VFSStorageData( - id=uuid4(), + id=ArtifactStorageId(uuid4()), name="test-vfs-storage", + type=ArtifactStorageType.VFS_STORAGE, host="localhost", base_path=Path("/mnt/vfs/test"), )