Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/tiering_service/db_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tiering Service Database CLI."""

from collections.abc import Sequence

from absl import app
from absl import flags
from orbax.checkpoint.experimental.tiering_service import db_lib
from orbax.checkpoint.experimental.tiering_service import server_config

from .pyglib import appcommands


class InitializeDbCmd(appcommands.Cmd):
"""Initializes the Tiering Service database from a YAML configuration."""

def __init__(self, name: str, flag_values: flags.FlagValues, **kwargs):
"""Initializes the InitializeDbCmd.

Args:
name: The name of the command.
flag_values: The FlagValues instance with which flags will be registered.
**kwargs: Additional keyword arguments.
"""
super().__init__(name, flag_values, **kwargs)
self._yaml_path = flags.DEFINE_string(
"yaml_path",
None,
"Path to the YAML configuration file.",
required=True,
flag_values=flag_values,
)

def Run(self, argv: Sequence[str]) -> None:
"""Executes the initialize_db command.

Initializes the database based on the provided YAML configuration if
uninitialized, otherwise verifies that existing database entries match
the configuration.

Args:
argv: Command line arguments.

Raises:
app.UsageError: If too many command-line arguments are provided or if the
configuration file cannot be opened.
ValueError: If existing database entries do not match the configuration.
"""
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
try:
config = server_config.load_config(self._yaml_path.value)
except OSError as e:
raise app.UsageError(f"Failed to open configuration file: {e}") from e
if not db_lib.is_db_initialized(config):
db_lib.initialize_db(config)
else:
db_lib.verify_db(config)


def main(argv: Sequence[str]) -> None:
del argv
appcommands.AddCmd("initialize_db", InitializeDbCmd)


if __name__ == "__main__":
appcommands.Run()
220 changes: 220 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Database initialization utilities for Tiering Service."""

import asyncio
from orbax.checkpoint.experimental.tiering_service import db_schema
from orbax.checkpoint.experimental.tiering_service import server_config
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.future import select
from sqlalchemy.orm import sessionmaker


def _get_async_engine(config: server_config.ServerConfig) -> AsyncEngine:
"""Returns an AsyncEngine configured from ServerConfig."""
input_url = config.db_connection_str
url = (
input_url.replace("psql://", "postgresql+asyncpg://", 1)
if input_url.startswith("psql://")
else input_url
)
return create_async_engine(url)


def _get_backend_type(type_str: str) -> db_schema.BackendType:
"""Maps a tier type string to BackendType enum."""
type_str_upper = type_str.upper()
if type_str_upper == "LUSTRE":
return db_schema.BackendType.BACKEND_TYPE_LUSTRE
elif type_str_upper == "GCS":
return db_schema.BackendType.BACKEND_TYPE_GCS
else:
raise ValueError(f"Unknown storage backend type: {type_str!r}")


def _get_backend_key(
level: int,
zone: str | None,
region: str | None,
multi_regions: list[str] | None,
) -> tuple[int, str | None, str | None, tuple[str, ...] | None]:
"""Generates a unique key for a StorageBackend based on level and location."""
multi_regions_tuple = (
tuple(sorted(multi_regions)) if multi_regions else None
)
return (level, zone, region, multi_regions_tuple)


async def async_initialize_db(config: server_config.ServerConfig) -> None:
"""Initializes the database with the schema and initial data.

If the database is uninitialized, this function connects to the database,
creates all necessary tables based on the `db_schema.Base` metadata, and
populates the `StorageBackend` table with data from the provided server
configuration if it's empty.

Args:
config: The server configuration containing tier information and DB URL.
"""
engine = _get_async_engine(config)
try:
async with engine.begin() as conn:
await conn.run_sync(db_schema.Base.metadata.create_all)

session_maker = sessionmaker(
engine, expire_on_commit=False, class_=AsyncSession
)
async with session_maker() as session:
result = await session.execute(select(db_schema.StorageBackend))
existing = result.scalars().first()

if existing is not None:
return

for tier in config.tiers:
level = tier.level

backend_type = _get_backend_type(tier.type)

for instance in tier.instances:
location_fields = sum([
1 if instance.zone else 0,
1 if instance.region else 0,
1 if instance.multi_regions else 0,
])
if location_fields != 1:
raise ValueError(
"Exactly one of 'zone', 'region', or 'multi_regions' must be"
f" provided for instance: {instance!r}"
)

backend = db_schema.StorageBackend(
level=level,
backend_type=backend_type,
prefix=instance.prefix,
)
if instance.zone:
backend.zone = instance.zone
elif instance.region:
backend.region = instance.region
elif instance.multi_regions:
backend.multi_regions = instance.multi_regions
session.add(backend)
if session.new:
await session.commit()
finally:
await engine.dispose()


def initialize_db(config: server_config.ServerConfig) -> None:
"""Initializes the database with the schema and initial data.

This is the synchronous version of `async_initialize_db`.

Args:
config: The server configuration containing tier information and DB URL.
"""
asyncio.run(async_initialize_db(config))


async def async_is_db_initialized(config: server_config.ServerConfig) -> bool:
"""Returns whether the database is already initialized with StorageBackend entries."""
engine = _get_async_engine(config)
session_maker = sessionmaker(
engine, expire_on_commit=False, class_=AsyncSession
)
try:
async with session_maker() as session:
result = await session.execute(select(db_schema.StorageBackend))
return result.scalars().first() is not None
except OperationalError: # If tables do not exist yet.
return False
finally:
await engine.dispose()


def is_db_initialized(config: server_config.ServerConfig) -> bool:
"""Returns whether the database is already initialized with StorageBackend entries."""
return asyncio.run(async_is_db_initialized(config))


async def async_verify_db(config: server_config.ServerConfig) -> None:
"""Verifies that the database StorageBackend table matches ServerConfig.

Args:
config: The server configuration containing tier information and DB URL.

Raises:
ValueError: If there is any mismatch between configuration and database.
"""
engine = _get_async_engine(config)
try:
session_maker = sessionmaker(
engine, expire_on_commit=False, class_=AsyncSession
)
async with session_maker() as session:
result = await session.execute(select(db_schema.StorageBackend))
db_backends = result.scalars().all()
finally:
await engine.dispose()

expected_count = sum(len(t.instances) for t in config.tiers)
if len(db_backends) != expected_count:
raise ValueError(
f"Mismatch in total StorageBackend count: DB has {len(db_backends)},"
f" config expects {expected_count}"
)

backend_by_key = {
_get_backend_key(b.level, b.zone, b.region, b.multi_regions): b
for b in db_backends
}

for tier in config.tiers:
level = tier.level
expected_type = _get_backend_type(tier.type)

for instance in tier.instances:
instance_key = _get_backend_key(
level, instance.zone, instance.region, instance.multi_regions
)
db_backend = backend_by_key.get(instance_key)

if db_backend is None:
raise ValueError(
f"Configuration expects StorageBackend with key {instance_key!r}"
f" (prefix={instance.prefix!r}) but not found in Database."
)

if db_backend.backend_type != expected_type:
raise ValueError(
f"Backend with key {instance_key!r} mismatch"
f" backend_type: DB has {db_backend.backend_type.name}, config"
f" expects {expected_type.name}"
)
if db_backend.prefix != instance.prefix:
raise ValueError(
f"Backend with key {instance_key!r} mismatch"
f" prefix: DB has {db_backend.prefix!r}, config expects"
f" {instance.prefix!r}"
)


def verify_db(config: server_config.ServerConfig) -> None:
"""Verifies that the database StorageBackend table exactly matches ServerConfig."""
asyncio.run(async_verify_db(config))
Loading
Loading