From be0f3ee3a0234fd1f48c18b8e3624d7b3f434d25 Mon Sep 17 00:00:00 2001 From: Daniel Ng Date: Fri, 8 May 2026 08:37:19 -0700 Subject: [PATCH] Internal Change PiperOrigin-RevId: 912542167 --- .../experimental/tiering_service/db_cli.py | 80 +++++++ .../experimental/tiering_service/db_lib.py | 220 ++++++++++++++++++ .../tiering_service/db_lib_test.py | 154 ++++++++++++ .../experimental/tiering_service/db_schema.py | 7 +- .../tiering_service/db_schema_test.py | 27 ++- .../proto/tiering_service.proto | 1 + .../experimental/tiering_service/server.py | 58 ++++- .../tiering_service/server_config.py | 153 ++++++++++++ .../tiering_service/server_test.py | 181 ++++++++++---- checkpoint/pyproject.toml | 1 + 10 files changed, 824 insertions(+), 58 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/experimental/tiering_service/db_cli.py create mode 100644 checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib.py create mode 100644 checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib_test.py create mode 100644 checkpoint/orbax/checkpoint/experimental/tiering_service/server_config.py diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_cli.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_cli.py new file mode 100644 index 000000000..92bdf2ac3 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_cli.py @@ -0,0 +1,80 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tiering Service Database CLI.""" + +from collections.abc import Sequence + +from absl import app +from absl import flags +from orbax.checkpoint.experimental.tiering_service import db_lib +from orbax.checkpoint.experimental.tiering_service import server_config + +from .pyglib import appcommands + + +class InitializeDbCmd(appcommands.Cmd): + """Initializes the Tiering Service database from a YAML configuration.""" + + def __init__(self, name: str, flag_values: flags.FlagValues, **kwargs): + """Initializes the InitializeDbCmd. + + Args: + name: The name of the command. + flag_values: The FlagValues instance with which flags will be registered. + **kwargs: Additional keyword arguments. + """ + super().__init__(name, flag_values, **kwargs) + self._yaml_path = flags.DEFINE_string( + "yaml_path", + None, + "Path to the YAML configuration file.", + required=True, + flag_values=flag_values, + ) + + def Run(self, argv: Sequence[str]) -> None: + """Executes the initialize_db command. + + Initializes the database based on the provided YAML configuration if + uninitialized, otherwise verifies that existing database entries match + the configuration. + + Args: + argv: Command line arguments. + + Raises: + app.UsageError: If too many command-line arguments are provided or if the + configuration file cannot be opened. + ValueError: If existing database entries do not match the configuration. + """ + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + try: + config = server_config.load_config(self._yaml_path.value) + except OSError as e: + raise app.UsageError(f"Failed to open configuration file: {e}") from e + if not db_lib.is_db_initialized(config): + db_lib.initialize_db(config) + else: + db_lib.verify_db(config) + + +def main(argv: Sequence[str]) -> None: + del argv + appcommands.AddCmd("initialize_db", InitializeDbCmd) + + +if __name__ == "__main__": + appcommands.Run() diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib.py new file mode 100644 index 000000000..45058027b --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib.py @@ -0,0 +1,220 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Database initialization utilities for Tiering Service.""" + +import asyncio +from orbax.checkpoint.experimental.tiering_service import db_schema +from orbax.checkpoint.experimental.tiering_service import server_config +from sqlalchemy.exc import OperationalError +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.future import select +from sqlalchemy.orm import sessionmaker + + +def _get_async_engine(config: server_config.ServerConfig) -> AsyncEngine: + """Returns an AsyncEngine configured from ServerConfig.""" + input_url = config.db_connection_str + url = ( + input_url.replace("psql://", "postgresql+asyncpg://", 1) + if input_url.startswith("psql://") + else input_url + ) + return create_async_engine(url) + + +def _get_backend_type(type_str: str) -> db_schema.BackendType: + """Maps a tier type string to BackendType enum.""" + type_str_upper = type_str.upper() + if type_str_upper == "LUSTRE": + return db_schema.BackendType.BACKEND_TYPE_LUSTRE + elif type_str_upper == "GCS": + return db_schema.BackendType.BACKEND_TYPE_GCS + else: + raise ValueError(f"Unknown storage backend type: {type_str!r}") + + +def _get_backend_key( + level: int, + zone: str | None, + region: str | None, + multi_regions: list[str] | None, +) -> tuple[int, str | None, str | None, tuple[str, ...] | None]: + """Generates a unique key for a StorageBackend based on level and location.""" + multi_regions_tuple = ( + tuple(sorted(multi_regions)) if multi_regions else None + ) + return (level, zone, region, multi_regions_tuple) + + +async def async_initialize_db(config: server_config.ServerConfig) -> None: + """Initializes the database with the schema and initial data. + + If the database is uninitialized, this function connects to the database, + creates all necessary tables based on the `db_schema.Base` metadata, and + populates the `StorageBackend` table with data from the provided server + configuration if it's empty. + + Args: + config: The server configuration containing tier information and DB URL. + """ + engine = _get_async_engine(config) + try: + async with engine.begin() as conn: + await conn.run_sync(db_schema.Base.metadata.create_all) + + session_maker = sessionmaker( + engine, expire_on_commit=False, class_=AsyncSession + ) + async with session_maker() as session: + result = await session.execute(select(db_schema.StorageBackend)) + existing = result.scalars().first() + + if existing is not None: + return + + for tier in config.tiers: + level = tier.level + + backend_type = _get_backend_type(tier.type) + + for instance in tier.instances: + location_fields = sum([ + 1 if instance.zone else 0, + 1 if instance.region else 0, + 1 if instance.multi_regions else 0, + ]) + if location_fields != 1: + raise ValueError( + "Exactly one of 'zone', 'region', or 'multi_regions' must be" + f" provided for instance: {instance!r}" + ) + + backend = db_schema.StorageBackend( + level=level, + backend_type=backend_type, + prefix=instance.prefix, + ) + if instance.zone: + backend.zone = instance.zone + elif instance.region: + backend.region = instance.region + elif instance.multi_regions: + backend.multi_regions = instance.multi_regions + session.add(backend) + if session.new: + await session.commit() + finally: + await engine.dispose() + + +def initialize_db(config: server_config.ServerConfig) -> None: + """Initializes the database with the schema and initial data. + + This is the synchronous version of `async_initialize_db`. + + Args: + config: The server configuration containing tier information and DB URL. + """ + asyncio.run(async_initialize_db(config)) + + +async def async_is_db_initialized(config: server_config.ServerConfig) -> bool: + """Returns whether the database is already initialized with StorageBackend entries.""" + engine = _get_async_engine(config) + session_maker = sessionmaker( + engine, expire_on_commit=False, class_=AsyncSession + ) + try: + async with session_maker() as session: + result = await session.execute(select(db_schema.StorageBackend)) + return result.scalars().first() is not None + except OperationalError: # If tables do not exist yet. + return False + finally: + await engine.dispose() + + +def is_db_initialized(config: server_config.ServerConfig) -> bool: + """Returns whether the database is already initialized with StorageBackend entries.""" + return asyncio.run(async_is_db_initialized(config)) + + +async def async_verify_db(config: server_config.ServerConfig) -> None: + """Verifies that the database StorageBackend table matches ServerConfig. + + Args: + config: The server configuration containing tier information and DB URL. + + Raises: + ValueError: If there is any mismatch between configuration and database. + """ + engine = _get_async_engine(config) + try: + session_maker = sessionmaker( + engine, expire_on_commit=False, class_=AsyncSession + ) + async with session_maker() as session: + result = await session.execute(select(db_schema.StorageBackend)) + db_backends = result.scalars().all() + finally: + await engine.dispose() + + expected_count = sum(len(t.instances) for t in config.tiers) + if len(db_backends) != expected_count: + raise ValueError( + f"Mismatch in total StorageBackend count: DB has {len(db_backends)}," + f" config expects {expected_count}" + ) + + backend_by_key = { + _get_backend_key(b.level, b.zone, b.region, b.multi_regions): b + for b in db_backends + } + + for tier in config.tiers: + level = tier.level + expected_type = _get_backend_type(tier.type) + + for instance in tier.instances: + instance_key = _get_backend_key( + level, instance.zone, instance.region, instance.multi_regions + ) + db_backend = backend_by_key.get(instance_key) + + if db_backend is None: + raise ValueError( + f"Configuration expects StorageBackend with key {instance_key!r}" + f" (prefix={instance.prefix!r}) but not found in Database." + ) + + if db_backend.backend_type != expected_type: + raise ValueError( + f"Backend with key {instance_key!r} mismatch" + f" backend_type: DB has {db_backend.backend_type.name}, config" + f" expects {expected_type.name}" + ) + if db_backend.prefix != instance.prefix: + raise ValueError( + f"Backend with key {instance_key!r} mismatch" + f" prefix: DB has {db_backend.prefix!r}, config expects" + f" {instance.prefix!r}" + ) + + +def verify_db(config: server_config.ServerConfig) -> None: + """Verifies that the database StorageBackend table exactly matches ServerConfig.""" + asyncio.run(async_verify_db(config)) diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib_test.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib_test.py new file mode 100644 index 000000000..0a8adf49c --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib_test.py @@ -0,0 +1,154 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from absl.testing import absltest +import aiosqlite # pylint: disable=unused-import +import greenlet # pylint: disable=unused-import +from orbax.checkpoint.experimental.tiering_service import db_lib +from orbax.checkpoint.experimental.tiering_service import server_config +import yaml + + +class DbLibTest(absltest.TestCase): + + def test_initialize_db_from_yaml(self): + tmp_file = self.create_tempfile() + db_url = f"sqlite+aiosqlite:///{tmp_file.full_path}" + yaml_content = textwrap.dedent(f""" + db_connection_str: {db_url} + tiers: + - level: 0 + type: Lustre + instances: + - prefix: /mnt/lustre + zone: us-central1-a + - prefix: /mnt/lustre2 + zone: us-central1-b + - level: 1 + type: GCS + instances: + - prefix: gs://my-bucket + region: us-central1 + - prefix: gs://my-bucket2 + region: us-west1 + - level: 2 + type: GCS + instances: + - prefix: gs://my-bucket3 + multi_regions: [us-central1, us-east1] + """) + config_dict = yaml.safe_load(yaml_content) + config = server_config.parse_config(config_dict) + + db_lib.initialize_db(config) + db_lib.verify_db(config) + + def test_verify_db_mismatch_raises(self): + tmp_file = self.create_tempfile() + db_url = f"sqlite+aiosqlite:///{tmp_file.full_path}" + yaml_content = textwrap.dedent(f""" + db_connection_str: {db_url} + tiers: + - level: 1 + type: GCS + instances: + - prefix: gs://my-bucket + region: us-central1 + """) + config_dict = yaml.safe_load(yaml_content) + config = server_config.parse_config(config_dict) + db_lib.initialize_db(config) + + # Modify config to expect a different region + yaml_content_mod = textwrap.dedent(f""" + db_connection_str: {db_url} + tiers: + - level: 1 + type: GCS + instances: + - prefix: gs://my-bucket + region: us-east1 + """) + config_mod = server_config.parse_config(yaml.safe_load(yaml_content_mod)) + with self.assertRaisesRegex( + ValueError, + "Configuration expects StorageBackend with key", + ): + db_lib.verify_db(config_mod) + + # Modify config to expect a different prefix in the same region + yaml_content_prefix = textwrap.dedent(f""" + db_connection_str: {db_url} + tiers: + - level: 1 + type: GCS + instances: + - prefix: gs://other-bucket + region: us-central1 + """) + config_prefix = server_config.parse_config( + yaml.safe_load(yaml_content_prefix) + ) + with self.assertRaisesRegex( + ValueError, + "Backend with key .* mismatch prefix", + ): + db_lib.verify_db(config_prefix) + + def test_initialize_db_missing_location_rejected(self): + tmp_file = self.create_tempfile() + db_url = f"sqlite+aiosqlite:///{tmp_file.full_path}" + yaml_content = textwrap.dedent(f""" + db_connection_str: {db_url} + tiers: + - level: 1 + type: GCS + instances: + - prefix: gs://my-bucket + """) + config_dict = yaml.safe_load(yaml_content) + config = server_config.parse_config(config_dict) + + with self.assertRaisesRegex( + ValueError, + "Exactly one of 'zone', 'region', or 'multi_regions' must be provided", + ): + db_lib.initialize_db(config) + + def test_is_db_initialized(self): + tmp_file = self.create_tempfile() + db_url = f"sqlite+aiosqlite:///{tmp_file.full_path}" + yaml_content = textwrap.dedent(f""" + db_connection_str: {db_url} + tiers: + - level: 1 + type: GCS + instances: + - prefix: gs://my-bucket + region: us-central1 + """) + config_dict = yaml.safe_load(yaml_content) + config = server_config.parse_config(config_dict) + + self.assertFalse(db_lib.is_db_initialized(config)) + + db_lib.initialize_db(config) + + self.assertTrue(db_lib.is_db_initialized(config)) + + +if __name__ == "__main__": + absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py index 52d09f431..3007fc7b8 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py @@ -159,6 +159,7 @@ class StorageBackend(Base): region: The region where the storage backend resides. multi_regions: A list of regions forming a multi-region deployment. backend_type: The type of storage (e.g., Lustre, GCS). + prefix: Storage backend prefix (e.g., gs://bucket-name, /mnt/lustre/). tier_paths: Relationship to the TierPath objects utilizing this backend. """ @@ -174,6 +175,7 @@ class StorageBackend(Base): backend_type = sqlalchemy.Column( sqlalchemy.Enum(BackendType), default=BackendType.BACKEND_TYPE_UNSPECIFIED ) + prefix = sqlalchemy.Column(sqlalchemy.String, nullable=False) tier_paths = sqlalchemy.orm.relationship( "TierPath", back_populates="storage_backend", cascade="all, delete-orphan" @@ -199,8 +201,9 @@ def __repr__(self): else: location = "None" return ( - f"StorageBackend(id={self.id}, level={self.level}, " - f"backend_type={self.backend_type.name!r}, {location})" + f"StorageBackend(id={self.id}, level={self.level}," + f" backend_type={self.backend_type.name!r}, prefix={self.prefix!r}," + f" {location})" ) def validate_pre_commit(self) -> None: diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py index 2942e31f5..8ee716a37 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py @@ -116,11 +116,13 @@ async def test_add_tier_path(self) -> None: level=0, zone="us-east5-a", backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + prefix="/lustre", ) backend1 = db_schema.StorageBackend( level=1, multi_regions=["us-central1", "us-east1"], backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://gcs-bucket", ) tier_path0 = db_schema.TierPath( asset_uuid="uuid-789", @@ -130,7 +132,7 @@ async def test_add_tier_path(self) -> None: tier_path1 = db_schema.TierPath( asset_uuid="uuid-789", storage_backend=backend1, - path="/gcs/path/2", + path="gs://gcs-bucket/path/2", ) session.add(asset) session.add(backend0) @@ -158,7 +160,7 @@ async def test_add_tier_path(self) -> None: ) self.assertEqual(tp0.path, "/lustre/path/1") self.assertEqual(tp0.storage_backend.zone, "us-east5-a") - self.assertEqual(tp1.path, "/gcs/path/2") + self.assertEqual(tp1.path, "gs://gcs-bucket/path/2") self.assertEqual( tp1.storage_backend.multi_regions, ["us-central1", "us-east1"], @@ -170,6 +172,7 @@ async def test_add_tier_path_fails_multiple_locations(self) -> None: level=0, zone="us-central1-a", backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://gcs-bucket", ) session.add(backend) await session.commit() @@ -212,11 +215,13 @@ async def test_storage_backend_fails_multiple_locations_zone(self) -> None: level=0, zone="us-central1-a", backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + prefix="/lustre", ) b2 = db_schema.StorageBackend( level=0, zone="us-central1-a", backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + prefix="/lustre", ) tp1 = db_schema.TierPath(storage_backend=b1, path="/path1") tp2 = db_schema.TierPath(storage_backend=b2, path="/path2") @@ -236,11 +241,13 @@ async def test_storage_backend_fails_multiple_locations_region(self) -> None: level=0, region="us-central1", backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://gcs-bucket", ) b2 = db_schema.StorageBackend( level=0, region="us-central1", backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://gcs-bucket", ) tp1 = db_schema.TierPath(storage_backend=b1, path="/path1") tp2 = db_schema.TierPath(storage_backend=b2, path="/path2") @@ -262,12 +269,14 @@ async def test_storage_backend_fails_multiple_locations_multi_regions( level=0, multi_regions=["us-central1", "us-east1"], backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://gcs-bucket", ) # Order of regions shouldn't matter b2 = db_schema.StorageBackend( level=0, multi_regions=["us-east1", "us-central1"], backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://gcs-bucket", ) tp1 = db_schema.TierPath(storage_backend=b1, path="/path1") tp2 = db_schema.TierPath(storage_backend=b2, path="/path2") @@ -283,6 +292,7 @@ async def test_add_tier_path_fails_no_locations(self) -> None: ): invalid_backend_empty = db_schema.StorageBackend( level=0, + prefix="test-empty", ) session.add(invalid_backend_empty) await session.commit() @@ -298,6 +308,7 @@ async def test_asset_job_queue(self) -> None: level=0, zone="us-central1-a", backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://gcs-bucket", ) tier_path = db_schema.TierPath( asset_uuid="uuid-queue", storage_backend=backend, path="/path1" @@ -395,11 +406,13 @@ async def test_create_asset_duplicates_blocked_for_active_stored( level=1, zone="us-central1-a", backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + prefix="/lustre", ), backend2=db_schema.StorageBackend( level=1, zone="us-central1-b", backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://gcs-bucket", ), expected_exception=ValueError, expected_regex="same backend_type", @@ -410,11 +423,13 @@ async def test_create_asset_duplicates_blocked_for_active_stored( level=1, zone="us-central1-a", backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + prefix="/lustre", ), backend2=db_schema.StorageBackend( level=1, zone="us-central1-a", backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + prefix="/lustre", ), expected_exception=ValueError, expected_regex="Duplicate zone", @@ -425,11 +440,13 @@ async def test_create_asset_duplicates_blocked_for_active_stored( level=1, region="us-central1", backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://gcs-bucket", ), backend2=db_schema.StorageBackend( level=1, region="us-central1", backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://gcs-bucket", ), expected_exception=ValueError, expected_regex="Duplicate region", @@ -440,11 +457,13 @@ async def test_create_asset_duplicates_blocked_for_active_stored( level=1, multi_regions=["us-central1", "us-east1"], backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://gcs-bucket", ), backend2=db_schema.StorageBackend( level=1, multi_regions=["us-east1", "us-central1"], backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://gcs-bucket", ), expected_exception=ValueError, expected_regex="Duplicate multi_regions", @@ -543,7 +562,9 @@ async def _setup(): path="/experiment/queue-multi", user="testuser", ) - sb = db_schema.StorageBackend(level=0, zone="us-central1-a") + sb = db_schema.StorageBackend( + level=0, zone="us-central1-a", prefix="gs://gcs-bucket" + ) tp = db_schema.TierPath( asset_uuid="uuid-queue-multi", storage_backend=sb, path="/path1" ) diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/proto/tiering_service.proto b/checkpoint/orbax/checkpoint/experimental/tiering_service/proto/tiering_service.proto index 2e0a20d08..0279cf493 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/proto/tiering_service.proto +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/proto/tiering_service.proto @@ -36,6 +36,7 @@ message StorageBackend { MultipleRegions multi_regions = 5; // e.g. GCS (dual/multi-regions) } BackendType backend_type = 6; + string prefix = 7; // GCS bucket name or Lustre mount point. } message TierPath { diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/server.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/server.py index 65872c6b7..4e4b026b4 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/server.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/server.py @@ -19,8 +19,12 @@ import os import uuid +from absl import app +from absl import flags from absl import logging import grpc +from orbax.checkpoint.experimental.tiering_service import db_lib +from orbax.checkpoint.experimental.tiering_service import server_config from orbax.checkpoint.experimental.tiering_service.proto import tiering_service_pb2 from orbax.checkpoint.experimental.tiering_service.proto import tiering_service_pb2_grpc @@ -114,9 +118,7 @@ def Reserve( # TODO: b/503445654 - Fake a TierPath for now asset_uuid = str(uuid.uuid4()) - gcs_path = ( - f"gs://checkpoint-tiering/{request.user}/{request.path}/{asset_uuid}" - ) + storage_path = f"/mnt/lustre/{request.path}" if not request.HasField("zone") and not request.HasField("region"): logging.error( @@ -143,17 +145,24 @@ def Reserve( updated_at=now, ) - tp = asset.tier_paths.add( - path=gcs_path, + # TODO: b/503445654 - Look up nearest backend to reserve path for requestor. + storage_backend_kwargs = { + "level": 0, + "backend_type": tiering_service_pb2.BACKEND_TYPE_LUSTRE, + # database/config. + "prefix": "/mnt/lustre", + } + if request.HasField("zone"): + storage_backend_kwargs["zone"] = request.zone + elif request.HasField("region"): + storage_backend_kwargs["region"] = request.region + + asset.tier_paths.add( + path=storage_path, storage_backend=tiering_service_pb2.StorageBackend( - level=1, - backend_type=tiering_service_pb2.BACKEND_TYPE_GCS, + **storage_backend_kwargs ), ) - if request.HasField("zone"): - tp.storage_backend.zone = request.zone - elif request.HasField("region"): - tp.storage_backend.region = request.region _assets_by_uuid[asset_uuid] = asset logging.info("Reserved asset with UUID: %s", asset_uuid) @@ -317,5 +326,30 @@ def serve() -> None: server.wait_for_termination() -if __name__ == "__main__": +def setup_storage_backends(config: server_config.ServerConfig) -> None: + """Initializes the database if uninitialized, otherwise verifies it matches configuration.""" + if not db_lib.is_db_initialized(config): + db_lib.initialize_db(config) + else: + db_lib.verify_db(config) + + +_YAML_PATH = flags.DEFINE_string( + "yaml_path", + None, + "Path to the YAML configuration file.", +) + + +def main(argv: Sequence[str]) -> None: + """Main entry point for CTS server.""" + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + config = server_config.load_config(_YAML_PATH.value) + setup_storage_backends(config) serve() + + +if __name__ == "__main__": + flags.mark_flag_as_required("yaml_path") + app.run(main) diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/server_config.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/server_config.py new file mode 100644 index 000000000..d9f61bd67 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/server_config.py @@ -0,0 +1,153 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Server configuration dataclasses.""" + +from collections.abc import Mapping +import dataclasses +import datetime +from typing import Any +import pytimeparse +import yaml + + +@dataclasses.dataclass +class StorageInstanceConfig: + """Configuration for a single storage instance. + + Attributes: + prefix: Storage path prefix (e.g. '/mnt/lustre' or 'gs://my-bucket'). + zone: Optional GCP zone name. + region: Optional GCP region name. + multi_regions: Optional Dual or multi-regions GCS deployment. + """ + + prefix: str + zone: str | None = None + region: str | None = None + multi_regions: list[str] | None = None + + +@dataclasses.dataclass +class TierConfig: + """Configuration for a storage tier. + + Attributes: + level: The integer tier level (e.g. 0 for hottest storage). + type: Storage backend type (e.g. 'Lustre', 'GCS'). + instances: List of storage instances belonging to this tier. + """ + + level: int + type: str + instances: list[StorageInstanceConfig] = dataclasses.field( + default_factory=list + ) + + +@dataclasses.dataclass +class ServerConfig: + """Server configuration. + + Attributes: + client_keep_alive_interval: Duration for client keep-alive checks. + db_connection_str: Database connection string / URL. + tiers: List of configured storage tiers. + """ + + client_keep_alive_interval: datetime.timedelta = dataclasses.field( + default_factory=lambda: datetime.timedelta(minutes=30) + ) + db_connection_str: str = "sqlite+aiosqlite:///:memory:" + tiers: list[TierConfig] = dataclasses.field(default_factory=list) + + +def _parse_storage_instance( + inst_data: Mapping[str, Any], +) -> StorageInstanceConfig: + """Parses a dictionary into a StorageInstanceConfig dataclass.""" + prefix = inst_data.get("prefix") + if prefix is None: + raise ValueError("StorageInstanceConfig missing required key: 'prefix'") + + return StorageInstanceConfig( + prefix=prefix, + zone=inst_data.get("zone"), + region=inst_data.get("region"), + multi_regions=inst_data.get("multi_regions"), + ) + + +def _parse_timedelta(val: str) -> datetime.timedelta: + """Parses a duration string (e.g. '1s', '30m', '1h') into a timedelta.""" + if not isinstance(val, str): + raise ValueError( + f"Invalid duration type for client_keep_alive_interval: {type(val)}," + " expected str." + ) + seconds = pytimeparse.parse(val) + if seconds is None: + raise ValueError( + f"Invalid duration format for client_keep_alive_interval: {val}" + ) + return datetime.timedelta(seconds=seconds) + + +def parse_config(data: Mapping[str, Any]) -> ServerConfig: + """Parses a dictionary into a ServerConfig dataclass. + + Args: + data: A dictionary (usually loaded from YAML) containing server + configuration parameters. + + Returns: + A ServerConfig instance populated with the parsed data. + """ + tiers_data = data.get("tiers", []) + tiers = [] + for t in tiers_data: + level = t.get("level") + if level is None: + raise ValueError("Tier configuration missing required key: 'level'") + + tier_type = t.get("type") + if not tier_type: + raise ValueError("Tier configuration missing required key: 'type'") + + instances_data = t.get("instances", []) + instances = [_parse_storage_instance(inst) for inst in instances_data] + tiers.append(TierConfig(level=level, type=tier_type, instances=instances)) + + kwargs = {"tiers": tiers} + if "client_keep_alive_interval" in data: + kwargs["client_keep_alive_interval"] = _parse_timedelta( + data["client_keep_alive_interval"] + ) + if "db_connection_str" in data: + kwargs["db_connection_str"] = data["db_connection_str"] + return ServerConfig(**kwargs) + + +def load_config(yaml_path: str) -> ServerConfig: + """Loads and parses a ServerConfig from a YAML file. + + Args: + yaml_path: Path to the YAML configuration file. + + Returns: + A ServerConfig instance populated with the parsed data. + """ + with open(yaml_path, "r") as f: + config_dict = yaml.safe_load(f) + return parse_config(config_dict) diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/server_test.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/server_test.py index fd6039bbd..53f90cce2 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/server_test.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/server_test.py @@ -12,22 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime from unittest import mock from absl.testing import absltest +from absl.testing import parameterized import grpc +from orbax.checkpoint.experimental.tiering_service import db_lib from orbax.checkpoint.experimental.tiering_service import server +from orbax.checkpoint.experimental.tiering_service import server_config from orbax.checkpoint.experimental.tiering_service.proto import tiering_service_pb2 from google.protobuf import timestamp_pb2 -class TieringServiceTest(absltest.TestCase): +class TieringServiceTest(parameterized.TestCase): def setUp(self): super().setUp() self.servicer = server.TieringServiceServicer() - self.context = mock.create_autospec(grpc.ServicerContext, instance=True) + self.context = mock.create_autospec( + grpc.ServicerContext, instance=True, spec_set=True + ) # Mock metadata for OAuth token self.context.invocation_metadata.return_value = ( ("authorization", "Bearer valid-mock-token"), @@ -35,6 +41,19 @@ def setUp(self): # Clear internal state between tests server._assets_by_uuid = {} + def _reserve_asset(self): + reserve_req = tiering_service_pb2.ReserveRequest( + path="test/path", user="test-user", zone="us-central1-a" + ) + reserve_res = self.servicer.Reserve(reserve_req, self.context) + return reserve_res.asset.uuid + + def _setup_config(self, config_dict): + config = server_config.parse_config(config_dict) + tmp_file = self.create_tempfile() + config.db_connection_str = f"sqlite+aiosqlite:///{tmp_file.full_path}" + return config + def test_reserve_success(self): request = tiering_service_pb2.ReserveRequest( path="test/path", @@ -50,7 +69,7 @@ def test_reserve_success(self): response.asset.state, tiering_service_pb2.ASSET_STATE_ACTIVE_WRITE ) self.assertLen(response.asset.tier_paths, 1) - self.assertTrue(response.asset.tier_paths[0].path.startswith("gs://")) + self.assertTrue(response.asset.tier_paths[0].path.startswith("/mnt/lustre")) def test_reserve_keep_alive_not_found(self): request = tiering_service_pb2.ReserveKeepAliveRequest(uuid="invalid-uuid") @@ -72,14 +91,7 @@ def test_reserve_invalid_argument(self): ) def test_finalize_success(self): - # First reserve - reserve_req = tiering_service_pb2.ReserveRequest( - path="test/path", user="test-user", zone="us-central1-a" - ) - reserve_res = self.servicer.Reserve(reserve_req, self.context) - asset_uuid = reserve_res.asset.uuid - - # Then finalize + asset_uuid = self._reserve_asset() finalize_req = tiering_service_pb2.FinalizeRequest(uuid=asset_uuid) finalize_res = self.servicer.Finalize(finalize_req, self.context) @@ -88,12 +100,7 @@ def test_finalize_success(self): ) def test_finalize_failed_precondition(self): - # Reserve and then finalize once - reserve_req = tiering_service_pb2.ReserveRequest( - path="test/path", user="test-user", zone="us-central1-a" - ) - reserve_res = self.servicer.Reserve(reserve_req, self.context) - asset_uuid = reserve_res.asset.uuid + asset_uuid = self._reserve_asset() self.servicer.Finalize( tiering_service_pb2.FinalizeRequest(uuid=asset_uuid), self.context ) @@ -107,36 +114,21 @@ def test_finalize_failed_precondition(self): ) def test_delete_success(self): - reserve_req = tiering_service_pb2.ReserveRequest( - path="test/path", user="test-user", zone="us-central1-a" - ) - reserve_res = self.servicer.Reserve(reserve_req, self.context) - asset_uuid = reserve_res.asset.uuid - + asset_uuid = self._reserve_asset() self.servicer.Delete( tiering_service_pb2.DeleteRequest(uuid=asset_uuid), self.context ) self.assertNotIn(asset_uuid, server._assets_by_uuid) def test_info_success(self): - reserve_req = tiering_service_pb2.ReserveRequest( - path="test/path", user="test-user", zone="us-central1-a" - ) - reserve_res = self.servicer.Reserve(reserve_req, self.context) - asset_uuid = reserve_res.asset.uuid - + asset_uuid = self._reserve_asset() response = self.servicer.Info( tiering_service_pb2.InfoRequest(uuid=asset_uuid), self.context ) self.assertEqual(response.asset.uuid, asset_uuid) def test_prefetch_success(self): - reserve_req = tiering_service_pb2.ReserveRequest( - path="test/path", user="test-user", zone="us-central1-a" - ) - reserve_res = self.servicer.Reserve(reserve_req, self.context) - asset_uuid = reserve_res.asset.uuid - + asset_uuid = self._reserve_asset() prefetch_req = tiering_service_pb2.PrefetchRequest( uuid=asset_uuid, zone="us-central1-a" ) @@ -145,12 +137,7 @@ def test_prefetch_success(self): self.assertEqual(response.asset.uuid, asset_uuid) def test_prefetch_invalid_argument(self): - reserve_req = tiering_service_pb2.ReserveRequest( - path="test/path", user="test-user", zone="us-central1-a" - ) - reserve_res = self.servicer.Reserve(reserve_req, self.context) - asset_uuid = reserve_res.asset.uuid - + asset_uuid = self._reserve_asset() prefetch_req = tiering_service_pb2.PrefetchRequest(uuid=asset_uuid) self.servicer.Prefetch(prefetch_req, self.context) @@ -186,6 +173,118 @@ def test_reserve_permission_denied(self): grpc.StatusCode.PERMISSION_DENIED, "Insufficient GCS permissions" ) + @parameterized.named_parameters( + ( + "uninitialized", + [ + { + "level": 0, + "type": "Lustre", + "instances": [{ + "prefix": "/mnt/lustre", + "zone": "us-central1-a", + }], + }, + { + "level": 1, + "type": "GCS", + "instances": [{ + "prefix": "gs://my-bucket", + "region": "us-central1", + }], + }, + ], + True, + ), + ( + "matching", + [ + { + "level": 1, + "type": "GCS", + "instances": [{ + "prefix": "gs://my-bucket", + "region": "us-central1", + }], + }, + ], + False, + ), + ) + def test_setup_storage_backends_success( + self, tiers_config, check_uninitialized + ): + config = self._setup_config({"tiers": tiers_config}) + + if check_uninitialized: + server.setup_storage_backends(config) + self.assertTrue(db_lib.is_db_initialized(config)) + db_lib.verify_db(config) + else: + server.setup_storage_backends(config) + server.setup_storage_backends(config) + + def test_setup_storage_backends_mismatch(self): + config_dict = { + "tiers": [ + { + "level": 1, + "type": "GCS", + "instances": [{ + "prefix": "gs://my-bucket", + "region": "us-central1", + }], + }, + ] + } + config = self._setup_config(config_dict) + + # Initialize DB + server.setup_storage_backends(config) + + # Mismatching config + config_mod_dict = { + "tiers": [ + { + "level": 1, + "type": "GCS", + "instances": [{ + "prefix": "gs://my-bucket", + "region": "us-east1", + }], + }, + ] + } + config_mod = server_config.parse_config(config_mod_dict) + config_mod.db_connection_str = config.db_connection_str + + with self.assertRaisesRegex( + ValueError, "Configuration expects StorageBackend with key" + ): + server.setup_storage_backends(config_mod) + + def test_parse_timedelta(self): + self.assertEqual( + server_config._parse_timedelta("1s"), + datetime.timedelta(seconds=1), + ) + self.assertEqual( + server_config._parse_timedelta("30m"), + datetime.timedelta(minutes=30), + ) + self.assertEqual( + server_config._parse_timedelta("1h"), + datetime.timedelta(hours=1), + ) + with self.assertRaisesRegex( + ValueError, "Invalid duration format for client_keep_alive_interval:" + ): + server_config._parse_timedelta("invalid") + with self.assertRaisesRegex( + ValueError, "Invalid duration type for client_keep_alive_interval:" + ): + server_config._parse_timedelta(123) # type: ignore + if __name__ == "__main__": absltest.main() diff --git a/checkpoint/pyproject.toml b/checkpoint/pyproject.toml index ab1d9fd67..b650c0879 100644 --- a/checkpoint/pyproject.toml +++ b/checkpoint/pyproject.toml @@ -81,6 +81,7 @@ tiering_service = [ 'greenlet', 'grpcio-tools>=1.80.0', 'pysqlite3', + 'pytimeparse', 'sqlalchemy>=1.4.0', ]