Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,52 @@ def is_gcs_path(path: epath.Path) -> bool:
return path.as_posix().startswith(_GCS_PATH_PREFIX)


def to_gs_path(path: epath.PathLike) -> str:
"""Converts a GCS path to a gs:// path string.

GCS paths can start with any of the prefixes in _GCS_PATH_PREFIX. This
function converts them to gs:// format.

Args:
path: A GCS path which can be a string or epath.Path.

Returns:
A GCS path string starting with gs://.

Raises:
ValueError: If path is not a GCS path.
"""
path_str = str(path)
if path_str.startswith('gs://'):
return path_str
else:
raise ValueError(f'Path is not a GCS path: {path}')


def to_gcsfuse_path(path: epath.PathLike) -> str:
"""Converts a GCS path to a gcsfuse path string.

GCSfuse paths start with /gcs/ and are accessible via File API when gcsfuse
is enabled.

Args:
path: A GCS path which can be a string or epath.Path.

Returns:
A gcsfuse path string starting with /gcs/.

Raises:
ValueError: If path is not a GCS path.
"""
path_str = str(path)
if path_str.startswith('gs://'):
return path_str.replace('gs://', '/gcs/', 1)
elif path_str.startswith('/gcs/'):
return path_str
else:
raise ValueError(f'Path is not a GCS path: {path}')


def parse_gcs_path(path: epath.PathLike) -> tuple[str, str]:
parsed = parse.urlparse(str(path))
assert parsed.scheme == 'gs', f'Unsupported scheme for GCS: {parsed.scheme}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
"""Defines `SafetensorsLayout`, a class to handle Safetensors checkpoint formats."""

import asyncio
from concurrent import futures
import dataclasses
import json
import mmap
import os
import subprocess
import time
from typing import Any, Awaitable, cast

Expand All @@ -26,6 +30,7 @@
import numpy as np
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.path import gcs_utils
from orbax.checkpoint._src.tree import utils as tree_utils
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
Expand Down Expand Up @@ -306,6 +311,55 @@ def _reshard_transient_array(
)(global_transient_array)


def _get_tensor_bounds(
header: dict[str, Any],
tensor_names: list[str],
) -> tuple[list[str], int, int]:
"""Calculates tensor bounds and returns tensors to load, min_start and max_end offsets."""
min_start = float("inf")
max_end = 0

tensors_to_load = []
for t_name in tensor_names:
if t_name == "__metadata__":
continue
if t_name not in header:
continue
tensors_to_load.append(t_name)
start, end = header[t_name]["data_offsets"]
if start < min_start:
min_start = start
if end > max_end:
max_end = end
if not tensors_to_load:
return [], 0, 0
return tensors_to_load, int(min_start), int(max_end)


def _process_data_bytes(
data_bytes: bytes,
header: dict[str, Any],
tensor_names: list[str],
min_start_offset: int,
) -> dict[str, np.ndarray]:
"""Extracts tensors from data bytes."""
tensors = {}
data_mv = memoryview(data_bytes)
for name in tensor_names:
if name == "__metadata__":
continue
shape, dtype = _get_array_properties(header[name])
start_offset, end_offset = header[name]["data_offsets"]
tensor_bytes = data_mv[
start_offset - min_start_offset : end_offset - min_start_offset
]
np_array = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape)
if not np.isfinite(np_array).all():
raise ValueError(f"Non-finite values found in tensor {name}.")
tensors[name] = np_array
return tensors


@dataclasses.dataclass
class _LoadContext:
host_id: int
Expand Down Expand Up @@ -438,24 +492,127 @@ async def _read_bundle(
bundle_start_offset = 0
return bundle_bytes, bundle_start_offset

async def load_single_host(self) -> dict[str, np.ndarray]:
def _read_single_chunk(
self,
gcs_path_str: str,
chunk_data: tuple[int, int]
) -> bytes:
"""Reads a single chunk of data from a GCS file."""
chunk_size, offset = chunk_data
with open(gcs_path_str, "rb") as f:
bytes_read = 0
chunk_pieces = []
while bytes_read < chunk_size:
piece = os.pread(
f.fileno(), chunk_size - bytes_read, offset + bytes_read
)
if not piece:
raise EOFError(
f"Unexpected end of file at offset {offset + bytes_read} "
f"in file {gcs_path_str}. Expected {chunk_size} bytes, "
f"got {bytes_read}."
)
chunk_pieces.append(piece)
bytes_read += len(piece)
return b"".join(chunk_pieces)

async def load_single_host_gcs(
self,
*,
data_start_offset: int,
min_start: int,
max_end: int,
) -> bytes:
"""Downloads tensors from Google Cloud Storage using high-bandwidth parallel reads.

This method uses `os.pread` with a thread pool to achieve high-bandwidth
parallel downloads from GCS via gcsfuse. It first calculates the bounding
box of the required tensor data and then reads chunks within that range.

Args:
data_start_offset: The offset where the tensor data begins.
min_start: The minimum start offset of the tensors to load.
max_end: The maximum end offset of the tensors to load.

Returns:
A bytes object containing the loaded tensor data.

Raises:
EOFError: If the file is truncated or reading fails unexpectedly.
ValueError: If non-finite values are found in a loaded tensor.
"""

gcs_path_str = gcs_utils.to_gcsfuse_path(self.path)
if not os.path.exists(gcs_path_str):
_, blob_name = gcs_utils.parse_gcs_path(gcs_utils.to_gs_path(self.path))
blob_name = blob_name.rstrip("/")
safe_temp_name = blob_name.replace("/", "_")
ram_disk_path = f"/dev/shm/{safe_temp_name}_temp.bin"
subprocess.run(
[
"gcloud",
"storage",
"cp",
gcs_utils.to_gs_path(self.path),
ram_disk_path,
],
check=True,
)
with open(ram_disk_path, "rb") as f:
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
data_bytes = mm[
data_start_offset + min_start : data_start_offset + max_end
]
os.remove(ram_disk_path)
return data_bytes
chunk_size = 1024 * 1024 * 1024
max_workers = 16
offset = data_start_offset + min_start
length = max_end - min_start
chunks = []
bytes_read = 0
while bytes_read < length:
current_chunk_size = min(chunk_size, length - bytes_read)
current_offset = offset + bytes_read
chunks.append((current_chunk_size, current_offset))
bytes_read += current_chunk_size

# 2. Execute the parallel reads
with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
read_chunks = executor.map(
lambda chunk: self._read_single_chunk(gcs_path_str, chunk), chunks
)

data_bytes = b"".join(read_chunks)
return data_bytes

async def load_single_host(
self, abstract_pytree: dict[str, Any] | None
) -> dict[str, np.ndarray]:
"""Loads tensors from a safetensors file into host NumPy arrays."""
header, data_start_offset = await self.read_header()
tensors = {}
async with async_path.open_file(self.path, mode="rb") as f:
await f.seek(data_start_offset)
data_bytes = await f.read()
for name, info in header.items():
if name == "__metadata__":
continue
shape, dtype = _get_array_properties(info)
start_offset, end_offset = info["data_offsets"]
tensor_bytes = data_bytes[start_offset:end_offset]
np_array = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape)
if not np.isfinite(np_array).all():
raise ValueError(f"Non-finite values found in tensor {name}.")
tensors[name] = np_array
return tensors
if abstract_pytree is None:
tensor_names = list(header.keys())
else:
tensor_names = list(abstract_pytree.keys())
tensors_to_load, min_start, max_end = _get_tensor_bounds(
header, tensor_names
)
if not tensors_to_load:
return {}
if gcs_utils.is_gcs_path(self.path):
data_bytes = await self.load_single_host_gcs(
data_start_offset=data_start_offset,
min_start=min_start,
max_end=max_end,
)
else:
async with async_path.open_file(self.path, mode="rb") as f:
await f.seek(data_start_offset + min_start)
data_bytes = await f.read(max_end - min_start)
return _process_data_bytes(
data_bytes, header, tensors_to_load, min_start
)

async def load_multi_host(
self, abstract_pytree: dict[str, Any]
Expand Down Expand Up @@ -585,7 +742,7 @@ async def _load_single_host(self, abstract_pytree: dict[str, Any]) -> Any:
start = time.time()
load_ops = []
for loader in await self._get_loaders():
load_ops.append(loader.load_single_host())
load_ops.append(loader.load_single_host(abstract_pytree))

restored_pytree = {}
for file_tensors in await asyncio.gather(*load_ops):
Expand Down
Loading