From 4e0c4092a0ce1439c1f9c37a84ab77662a215f51 Mon Sep 17 00:00:00 2001 From: Abhishek Agrawal Date: Tue, 28 Apr 2026 02:20:13 -0700 Subject: [PATCH] Add optimized GCS loading for Safetensors using gcsfuse. PiperOrigin-RevId: 906830773 --- .../orbax/checkpoint/_src/path/gcs_utils.py | 46 +++++ .../v1/_src/layout/safetensors_layout.py | 191 ++++++++++++++++-- 2 files changed, 220 insertions(+), 17 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py b/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py index f023dffa1..ef312dd80 100644 --- a/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py +++ b/checkpoint/orbax/checkpoint/_src/path/gcs_utils.py @@ -28,6 +28,52 @@ def is_gcs_path(path: epath.Path) -> bool: return path.as_posix().startswith(_GCS_PATH_PREFIX) +def to_gs_path(path: epath.PathLike) -> str: + """Converts a GCS path to a gs:// path string. + + GCS paths can start with any of the prefixes in _GCS_PATH_PREFIX. This + function converts them to gs:// format. + + Args: + path: A GCS path which can be a string or epath.Path. + + Returns: + A GCS path string starting with gs://. + + Raises: + ValueError: If path is not a GCS path. + """ + path_str = str(path) + if path_str.startswith('gs://'): + return path_str + else: + raise ValueError(f'Path is not a GCS path: {path}') + + +def to_gcsfuse_path(path: epath.PathLike) -> str: + """Converts a GCS path to a gcsfuse path string. + + GCSfuse paths start with /gcs/ and are accessible via File API when gcsfuse + is enabled. + + Args: + path: A GCS path which can be a string or epath.Path. + + Returns: + A gcsfuse path string starting with /gcs/. + + Raises: + ValueError: If path is not a GCS path. + """ + path_str = str(path) + if path_str.startswith('gs://'): + return path_str.replace('gs://', '/gcs/', 1) + elif path_str.startswith('/gcs/'): + return path_str + else: + raise ValueError(f'Path is not a GCS path: {path}') + + def parse_gcs_path(path: epath.PathLike) -> tuple[str, str]: parsed = parse.urlparse(str(path)) assert parsed.scheme == 'gs', f'Unsupported scheme for GCS: {parsed.scheme}' diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py index cbaa5ce02..0c779b9b0 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py @@ -15,8 +15,12 @@ """Defines `SafetensorsLayout`, a class to handle Safetensors checkpoint formats.""" import asyncio +from concurrent import futures import dataclasses import json +import mmap +import os +import subprocess import time from typing import Any, Awaitable, cast @@ -26,6 +30,7 @@ import numpy as np from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import async_path +from orbax.checkpoint._src.path import gcs_utils from orbax.checkpoint._src.tree import utils as tree_utils from orbax.checkpoint.experimental.v1._src.context import context as context_lib from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout @@ -306,6 +311,55 @@ def _reshard_transient_array( )(global_transient_array) +def _get_tensor_bounds( + header: dict[str, Any], + tensor_names: list[str], +) -> tuple[list[str], int, int]: + """Calculates tensor bounds and returns tensors to load, min_start and max_end offsets.""" + min_start = float("inf") + max_end = 0 + + tensors_to_load = [] + for t_name in tensor_names: + if t_name == "__metadata__": + continue + if t_name not in header: + continue + tensors_to_load.append(t_name) + start, end = header[t_name]["data_offsets"] + if start < min_start: + min_start = start + if end > max_end: + max_end = end + if not tensors_to_load: + return [], 0, 0 + return tensors_to_load, int(min_start), int(max_end) + + +def _process_data_bytes( + data_bytes: bytes, + header: dict[str, Any], + tensor_names: list[str], + min_start_offset: int, +) -> dict[str, np.ndarray]: + """Extracts tensors from data bytes.""" + tensors = {} + data_mv = memoryview(data_bytes) + for name in tensor_names: + if name == "__metadata__": + continue + shape, dtype = _get_array_properties(header[name]) + start_offset, end_offset = header[name]["data_offsets"] + tensor_bytes = data_mv[ + start_offset - min_start_offset : end_offset - min_start_offset + ] + np_array = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape) + if not np.isfinite(np_array).all(): + raise ValueError(f"Non-finite values found in tensor {name}.") + tensors[name] = np_array + return tensors + + @dataclasses.dataclass class _LoadContext: host_id: int @@ -438,24 +492,127 @@ async def _read_bundle( bundle_start_offset = 0 return bundle_bytes, bundle_start_offset - async def load_single_host(self) -> dict[str, np.ndarray]: + def _read_single_chunk( + self, + gcs_path_str: str, + chunk_data: tuple[int, int] + ) -> bytes: + """Reads a single chunk of data from a GCS file.""" + chunk_size, offset = chunk_data + with open(gcs_path_str, "rb") as f: + bytes_read = 0 + chunk_pieces = [] + while bytes_read < chunk_size: + piece = os.pread( + f.fileno(), chunk_size - bytes_read, offset + bytes_read + ) + if not piece: + raise EOFError( + f"Unexpected end of file at offset {offset + bytes_read} " + f"in file {gcs_path_str}. Expected {chunk_size} bytes, " + f"got {bytes_read}." + ) + chunk_pieces.append(piece) + bytes_read += len(piece) + return b"".join(chunk_pieces) + + async def load_single_host_gcs( + self, + *, + data_start_offset: int, + min_start: int, + max_end: int, + ) -> bytes: + """Downloads tensors from Google Cloud Storage using high-bandwidth parallel reads. + + This method uses `os.pread` with a thread pool to achieve high-bandwidth + parallel downloads from GCS via gcsfuse. It first calculates the bounding + box of the required tensor data and then reads chunks within that range. + + Args: + data_start_offset: The offset where the tensor data begins. + min_start: The minimum start offset of the tensors to load. + max_end: The maximum end offset of the tensors to load. + + Returns: + A bytes object containing the loaded tensor data. + + Raises: + EOFError: If the file is truncated or reading fails unexpectedly. + ValueError: If non-finite values are found in a loaded tensor. + """ + + gcs_path_str = gcs_utils.to_gcsfuse_path(self.path) + if not os.path.exists(gcs_path_str): + _, blob_name = gcs_utils.parse_gcs_path(gcs_utils.to_gs_path(self.path)) + blob_name = blob_name.rstrip("/") + safe_temp_name = blob_name.replace("/", "_") + ram_disk_path = f"/dev/shm/{safe_temp_name}_temp.bin" + subprocess.run( + [ + "gcloud", + "storage", + "cp", + gcs_utils.to_gs_path(self.path), + ram_disk_path, + ], + check=True, + ) + with open(ram_disk_path, "rb") as f: + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + data_bytes = mm[ + data_start_offset + min_start : data_start_offset + max_end + ] + os.remove(ram_disk_path) + return data_bytes + chunk_size = 1024 * 1024 * 1024 + max_workers = 16 + offset = data_start_offset + min_start + length = max_end - min_start + chunks = [] + bytes_read = 0 + while bytes_read < length: + current_chunk_size = min(chunk_size, length - bytes_read) + current_offset = offset + bytes_read + chunks.append((current_chunk_size, current_offset)) + bytes_read += current_chunk_size + + # 2. Execute the parallel reads + with futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + read_chunks = executor.map( + lambda chunk: self._read_single_chunk(gcs_path_str, chunk), chunks + ) + + data_bytes = b"".join(read_chunks) + return data_bytes + + async def load_single_host( + self, abstract_pytree: dict[str, Any] | None + ) -> dict[str, np.ndarray]: """Loads tensors from a safetensors file into host NumPy arrays.""" header, data_start_offset = await self.read_header() - tensors = {} - async with async_path.open_file(self.path, mode="rb") as f: - await f.seek(data_start_offset) - data_bytes = await f.read() - for name, info in header.items(): - if name == "__metadata__": - continue - shape, dtype = _get_array_properties(info) - start_offset, end_offset = info["data_offsets"] - tensor_bytes = data_bytes[start_offset:end_offset] - np_array = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape) - if not np.isfinite(np_array).all(): - raise ValueError(f"Non-finite values found in tensor {name}.") - tensors[name] = np_array - return tensors + if abstract_pytree is None: + tensor_names = list(header.keys()) + else: + tensor_names = list(abstract_pytree.keys()) + tensors_to_load, min_start, max_end = _get_tensor_bounds( + header, tensor_names + ) + if not tensors_to_load: + return {} + if gcs_utils.is_gcs_path(self.path): + data_bytes = await self.load_single_host_gcs( + data_start_offset=data_start_offset, + min_start=min_start, + max_end=max_end, + ) + else: + async with async_path.open_file(self.path, mode="rb") as f: + await f.seek(data_start_offset + min_start) + data_bytes = await f.read(max_end - min_start) + return _process_data_bytes( + data_bytes, header, tensors_to_load, min_start + ) async def load_multi_host( self, abstract_pytree: dict[str, Any] @@ -585,7 +742,7 @@ async def _load_single_host(self, abstract_pytree: dict[str, Any]) -> Any: start = time.time() load_ops = [] for loader in await self._get_loaders(): - load_ops.append(loader.load_single_host()) + load_ops.append(loader.load_single_host(abstract_pytree)) restored_pytree = {} for file_tensors in await asyncio.gather(*load_ops):