From e35f2dc57e67e420fd6e73952837642f4921eebb Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Wed, 14 Jan 2026 09:21:43 +0100 Subject: [PATCH 1/3] Improving logging for data-prep Signed-off-by: Marc Romeyn --- src/nemotron/data_prep/__init__.py | 9 ++ src/nemotron/data_prep/config.py | 4 + src/nemotron/data_prep/console.py | 123 +++++++++++++----- src/nemotron/data_prep/pipeline.py | 4 + src/nemotron/kit/cli/recipe.py | 5 + .../config/data_prep/default.yaml | 10 +- .../config/data_prep/tiny.yaml | 7 +- .../nano3/stage0_pretrain/data_prep.py | 3 +- 8 files changed, 130 insertions(+), 35 deletions(-) diff --git a/src/nemotron/data_prep/__init__.py b/src/nemotron/data_prep/__init__.py index ae22c5fb7..adf0459cb 100644 --- a/src/nemotron/data_prep/__init__.py +++ b/src/nemotron/data_prep/__init__.py @@ -63,6 +63,7 @@ import os from dataclasses import dataclass, field from pathlib import Path +from typing import Literal from nemotron.data_prep.blend import DataBlend, Dataset from nemotron.data_prep.config import ( @@ -182,6 +183,12 @@ class DataPrepConfig: ray_data_max_tasks_in_flight: int = 2 """Max tasks in flight per actor (pipelining depth)""" + console_mode: Literal["rich", "simple"] = "simple" + """Console output mode: 'rich' for animated progress bars, 'simple' for periodic text updates""" + + simple_log_interval_sec: int = 30 + """Interval in seconds between status updates in simple console mode (default: 30)""" + def run_data_prep( config: DataPrepConfig, *, artifact_class: type = PretrainBlendsArtifact @@ -321,6 +328,8 @@ def run_data_prep( split=config.split, per_split=config.per_split, ray_data=ray_data_config, + console_mode=config.console_mode, + simple_log_interval_sec=config.simple_log_interval_sec, ) # Run processing pipeline diff --git a/src/nemotron/data_prep/config.py b/src/nemotron/data_prep/config.py index 60e839b9f..f1aaa08af 100644 --- a/src/nemotron/data_prep/config.py +++ b/src/nemotron/data_prep/config.py @@ -288,6 +288,8 @@ class PipelineConfig: per_split: Per-split output configuration for Megatron-Bridge per_split_data_args_path ray_data: Ray Data execution configuration. When enabled and ray_data.enabled=True, uses Ray Data's ActorPoolStrategy for shard processing instead of manual actors. + console_mode: Console output mode ('rich' or 'simple') + simple_log_interval_sec: Interval in seconds for simple mode status updates """ output: OutputConfig @@ -298,6 +300,8 @@ class PipelineConfig: split: str | None = None # Deprecated - use per_split instead per_split: PerSplitConfig | None = None ray_data: RayDataConfig | None = None + console_mode: str = "simple" + simple_log_interval_sec: int = 30 # ============================================================================ diff --git a/src/nemotron/data_prep/console.py b/src/nemotron/data_prep/console.py index 6303186c0..833f8a7eb 100644 --- a/src/nemotron/data_prep/console.py +++ b/src/nemotron/data_prep/console.py @@ -347,6 +347,8 @@ class LiveExecutionStatus: datasets: list[DatasetStatus] = field(default_factory=list) run_hash: str = "" + console_mode: str = "simple" # "rich" or "simple" (default: simple) + simple_log_interval_sec: int = 30 # Configurable interval for simple mode _live: Live | None = field(default=None, repr=False) _progress: Progress | None = field(default=None, repr=False) _overall_task_id: int | None = field(default=None, repr=False) @@ -355,6 +357,7 @@ class LiveExecutionStatus: _wandb_step: int = field(default=0, repr=False) _last_wandb_log_time: float = field(default=0.0, repr=False) _wandb_log_interval: float = field(default=10.0, repr=False) # Log every 10 seconds + _last_simple_log_time: float = field(default=0.0, repr=False) # For simple mode throttling _start_time: float = field(default=0.0, repr=False) # Pipeline start time _total_tokens: int = field(default=0, repr=False) # Cumulative tokens processed _max_display: int = field(default=3, repr=False) # Max datasets to show per page @@ -475,6 +478,34 @@ def _log_progress_to_wandb(self, force: bool = False) -> None: except Exception as e: logger.warning(f"[W&B] Failed to log metrics: {e}") + def _print_simple_status(self) -> None: + """Print simple text status update (for simple console mode).""" + import time as time_module + + done, cached, pending, processing = self._get_summary_counts() + total_completed, total_shards = self._get_total_shards_progress() + + elapsed = time_module.time() - self._start_time if self._start_time > 0 else 0 + elapsed_str = f"{int(elapsed // 60)}m {int(elapsed % 60)}s" + + pct = (total_completed / total_shards * 100) if total_shards > 0 else 0 + + # Single line status update + console.print( + f"[{elapsed_str}] Progress: {total_completed}/{total_shards} shards ({pct:.1f}%) | " + f"Datasets: {done + cached}/{len(self.datasets)} complete " + f"({processing} active, {pending} pending) | " + f"Tokens: {self._total_tokens:,}" + ) + + # Show active datasets + active = [ds for ds in self.datasets if ds.status == "processing"] + if active: + active_names = ", ".join(ds.name[:30] for ds in active[:5]) # Show first 5 + if len(active) > 5: + active_names += f", +{len(active) - 5} more" + console.print(f" Active: {active_names}") + def _build_summary_line(self) -> Text: """Build a compact summary line.""" done, cached, pending, processing = self._get_summary_counts() @@ -656,41 +687,53 @@ def start(self) -> None: self._start_time = time_module.time() - # Calculate total shards across all datasets - total_shards = sum(ds.total_shards for ds in self.datasets) - - # Create overall progress bar - self._progress = Progress( - SpinnerColumn(), - TextColumn("[bold blue]Overall[/bold blue]"), - BarColumn(bar_width=40), - MofNCompleteColumn(), - TaskProgressColumn(), - TimeElapsedColumn(), - console=console, - transient=True, - ) - self._overall_task_id = self._progress.add_task("Processing", total=total_shards) - - self._live = Live( - self._build_display(), - console=console, - refresh_per_second=4, - transient=False, - ) - self._live.start() + if self.console_mode == "rich": + # Rich mode: Create animated progress bars + # Calculate total shards across all datasets + total_shards = sum(ds.total_shards for ds in self.datasets) + + # Create overall progress bar + self._progress = Progress( + SpinnerColumn(), + TextColumn("[bold blue]Overall[/bold blue]"), + BarColumn(bar_width=40), + MofNCompleteColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + console=console, + transient=True, + ) + self._overall_task_id = self._progress.add_task("Processing", total=total_shards) + + self._live = Live( + self._build_display(), + console=console, + refresh_per_second=4, + transient=False, + ) + self._live.start() + else: + # Simple mode: Print initial status + console.print("\n[bold]Starting data preparation...[/bold]") + self._print_simple_status() def stop(self) -> None: """Stop the live display.""" - if self._live: - self._live.stop() - self._live = None - self._progress = None - self._overall_task_id = None + if self.console_mode == "rich": + if self._live: + self._live.stop() + self._live = None + self._progress = None + self._overall_task_id = None + else: + # Simple mode: Print final status + self._print_simple_status() + console.print("[bold green]✓ Data preparation complete[/bold green]\n") def refresh(self) -> None: """Refresh the live display and cycle pages.""" - if self._live: + if self.console_mode == "rich" and self._live: + # Rich mode: Update animated display with page cycling # Auto-cycle pages every ~2 seconds (8 refresh calls at 4 fps) self._page_cycle_counter += 1 if self._page_cycle_counter >= 8: @@ -701,6 +744,17 @@ def refresh(self) -> None: self._current_page = (self._current_page + 1) % total_pages self._live.update(self._build_display()) + self._log_progress_to_wandb() + elif self.console_mode == "simple": + # Simple mode: Periodic text updates with configurable interval + import time as time_module + + current_time = time_module.time() + if (current_time - self._last_simple_log_time) >= self.simple_log_interval_sec: + self._last_simple_log_time = current_time + self._print_simple_status() + # Still log to W&B for dashboards (has built-in throttling) + self._log_progress_to_wandb() def start_dataset(self, name: str) -> None: """Mark a dataset as processing (for parallel execution).""" @@ -823,14 +877,23 @@ def report_phase(self, name: str, phase: str, detail: str = "") -> None: self.refresh() -def create_live_status(datasets: list[tuple[str, int]], run_hash: str) -> LiveExecutionStatus: +def create_live_status( + datasets: list[tuple[str, int]], + run_hash: str, + console_mode: str = "simple", + simple_log_interval_sec: int = 30, +) -> LiveExecutionStatus: """Create a live execution status tracker. Args: datasets: List of (name, total_shards) tuples run_hash: The run hash to display + console_mode: Console output mode ('rich' or 'simple') + simple_log_interval_sec: Interval in seconds for simple mode updates """ return LiveExecutionStatus( datasets=[DatasetStatus(name=name, total_shards=total) for name, total in datasets], run_hash=run_hash, + console_mode=console_mode, + simple_log_interval_sec=simple_log_interval_sec, ) diff --git a/src/nemotron/data_prep/pipeline.py b/src/nemotron/data_prep/pipeline.py index 61ea18399..0b29e69d9 100644 --- a/src/nemotron/data_prep/pipeline.py +++ b/src/nemotron/data_prep/pipeline.py @@ -679,6 +679,8 @@ def _process_split( for ep in execution_plans ], run_hash=run_hash, + console_mode=config.console_mode, + simple_log_interval_sec=config.simple_log_interval_sec, ) live_status.start() @@ -2161,6 +2163,8 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin (dataset.name, num_shards) for dataset, _, _, files in dataset_plans if files ], run_hash=run_hash, + console_mode=config.console_mode, + simple_log_interval_sec=config.simple_log_interval_sec, ) live_status.start() diff --git a/src/nemotron/kit/cli/recipe.py b/src/nemotron/kit/cli/recipe.py index 48d2a7f66..866261677 100644 --- a/src/nemotron/kit/cli/recipe.py +++ b/src/nemotron/kit/cli/recipe.py @@ -485,6 +485,11 @@ def _execute_nemo_run( runtime_env_yaml=runtime_env_yaml, ) + # Copy config.yaml to remote code directory since it's excluded by .gitignore + # This ensures the config file with rewritten paths is available remotely + remote_code_dir = f"{executor.tunnel.job_dir}/{job_name}/code" + executor.tunnel.put(str(repo_config), f"{remote_code_dir}/config.yaml") + # Workaround for nemo-run bug: when reusing an existing cluster, # SlurmRayCluster.create() returns None instead of the job_id. if ray_job.backend.job_id is None: diff --git a/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/default.yaml b/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/default.yaml index c6e3d3b7b..fa43ea45f 100644 --- a/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/default.yaml +++ b/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/default.yaml @@ -47,8 +47,14 @@ max_doc_tokens: null # Limit rows per dataset for quick tests (null = no limit) sample: null -# Ray actors for parallel processing (null = auto) -num_actors: null +# Ray Data executor settings (limit actors to avoid OOM on high-CPU nodes) +ray_data_max_actors: 32 + +# Console output mode: 'simple' for periodic text updates, 'rich' for animated progress +console_mode: simple + +# Interval in seconds for simple mode status updates (default: 30) +simple_log_interval_sec: 30 # Force new run, ignoring cache force: false diff --git a/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/tiny.yaml b/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/tiny.yaml index 9a7a2b25e..763186bae 100644 --- a/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/tiny.yaml +++ b/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/tiny.yaml @@ -46,8 +46,11 @@ max_doc_tokens: null # Limit rows per dataset for quick tests - small sample for tiny sample: 1000 -# Ray actors for parallel processing (null = auto) -num_actors: null +# Console output mode: 'simple' for periodic text updates, 'rich' for animated progress +console_mode: simple + +# Interval in seconds for simple mode status updates (default: 30) +simple_log_interval_sec: 30 # Force new run, ignoring cache force: false diff --git a/src/nemotron/recipes/nano3/stage0_pretrain/data_prep.py b/src/nemotron/recipes/nano3/stage0_pretrain/data_prep.py index 463a93c28..f0a5f666f 100644 --- a/src/nemotron/recipes/nano3/stage0_pretrain/data_prep.py +++ b/src/nemotron/recipes/nano3/stage0_pretrain/data_prep.py @@ -178,9 +178,10 @@ def run_data_prep_main(cfg: PreTrainDataPrepConfig) -> PretrainBlendsArtifact: min_doc_chars=cfg.min_doc_chars, max_doc_tokens=cfg.max_doc_tokens, sample=cfg.sample, - num_actors=cfg.num_actors, force=cfg.force, artifact_name=artifact_name, + console_mode=getattr(cfg, "console_mode", "simple"), + simple_log_interval_sec=getattr(cfg, "simple_log_interval_sec", 30), ) artifact = run_data_prep(data_prep_config) print_step_complete(data_prep=artifact) From 7c9f78d32b4c99019196c3587359133343903481 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 19 Jan 2026 15:59:44 +0100 Subject: [PATCH 2/3] Migrating data-prep pipelines to cosmos-xenna Signed-off-by: Marc Romeyn --- docs/design/nemotron-prep-port.md | 313 ++++++++ pyproject.toml | 2 + scripts/test_pack_memory.py | 321 ++++++++ src/nemotron/cli/nano3/data/prep/app.py | 36 + .../cli/nano3/data/prep/pretrain_xenna.py | 37 + src/nemotron/cli/nano3/data/prep/rl_xenna.py | 36 + src/nemotron/cli/nano3/data/prep/sft_xenna.py | 36 + src/nemotron/data_prep/__init__.py | 136 +++- src/nemotron/data_prep/chat_sft_processor.py | 425 +--------- src/nemotron/data_prep/chat_sft_shard_core.py | 738 ++++++++++++++++++ src/nemotron/data_prep/config.py | 18 + src/nemotron/data_prep/downloader.py | 621 +++++++++++++++ src/nemotron/data_prep/jsonl_processor.py | 204 +---- src/nemotron/data_prep/jsonl_shard_core.py | 224 ++++++ src/nemotron/data_prep/packed_processor.py | 4 +- .../data_prep/packing/bin_assignment.py | 104 +++ src/nemotron/data_prep/packing/materialize.py | 98 +++ src/nemotron/data_prep/packing/spool.py | 385 +++++++++ src/nemotron/data_prep/pipeline.py | 636 ++++++++++----- src/nemotron/data_prep/ray_data/executor.py | 34 +- src/nemotron/data_prep/shard_processor.py | 21 +- src/nemotron/data_prep/xenna/__init__.py | 26 + src/nemotron/data_prep/xenna/runner.py | 662 ++++++++++++++++ src/nemotron/data_prep/xenna/stages.py | 619 +++++++++++++++ src/nemotron/data_prep/xenna/work_items.py | 77 ++ src/nemotron/kit/run.py | 7 + src/nemotron/kit/templates/ray_cpu.sub.j2 | 19 + .../config/data_prep/default.yaml | 5 +- .../nano3/stage0_pretrain/data_prep.py | 4 + .../nano3/stage0_pretrain/prep_xenna.py | 147 ++++ .../config/data_prep/data_blend_raw.json | 11 +- .../stage1_sft/config/data_prep/default.yaml | 2 +- .../nano3/stage1_sft/data_prep_xenna.py | 221 ++++++ .../stage2_rl/config/data_prep/default.yaml | 2 +- .../nano3/stage2_rl/data_prep_xenna.py | 280 +++++++ tests/data_prep/test_ray_data.py | 8 +- uv.lock | 422 +++++++++- 37 files changed, 6095 insertions(+), 846 deletions(-) create mode 100644 docs/design/nemotron-prep-port.md create mode 100644 scripts/test_pack_memory.py create mode 100644 src/nemotron/cli/nano3/data/prep/pretrain_xenna.py create mode 100644 src/nemotron/cli/nano3/data/prep/rl_xenna.py create mode 100644 src/nemotron/cli/nano3/data/prep/sft_xenna.py create mode 100644 src/nemotron/data_prep/chat_sft_shard_core.py create mode 100644 src/nemotron/data_prep/downloader.py create mode 100644 src/nemotron/data_prep/jsonl_shard_core.py create mode 100644 src/nemotron/data_prep/packing/bin_assignment.py create mode 100644 src/nemotron/data_prep/packing/materialize.py create mode 100644 src/nemotron/data_prep/packing/spool.py create mode 100644 src/nemotron/data_prep/xenna/__init__.py create mode 100644 src/nemotron/data_prep/xenna/runner.py create mode 100644 src/nemotron/data_prep/xenna/stages.py create mode 100644 src/nemotron/data_prep/xenna/work_items.py create mode 100644 src/nemotron/recipes/nano3/stage0_pretrain/prep_xenna.py create mode 100644 src/nemotron/recipes/nano3/stage1_sft/data_prep_xenna.py create mode 100644 src/nemotron/recipes/nano3/stage2_rl/data_prep_xenna.py diff --git a/docs/design/nemotron-prep-port.md b/docs/design/nemotron-prep-port.md new file mode 100644 index 000000000..8329e8127 --- /dev/null +++ b/docs/design/nemotron-prep-port.md @@ -0,0 +1,313 @@ +# Nemotron Data Prep Xenna Port + +Status: Draft +Owner: Nemotron team +Last updated: TBD + +## Summary + +Port the Stage 0 (pretrain) data preparation execution engine from the current +Ray actor pool to Cosmos-Xenna pipelines. Keep the existing data prep planning, +shard formats, and output artifacts, while introducing a thin Xenna-based +execution layer that accepts a dict dataset spec (local paths or hf:// sources), +performs optimized HuggingFace pre-downloads, and emits Megatron-compatible +.bin/.idx shards plus blend.json. + +## Goals + +- Use Cosmos-Xenna pipeline execution for shard processing with minimal new + abstraction. +- Introduce a dict-based dataset input for the pretrain entrypoint. +- Keep HuggingFace Hub as a first-class source with parallel pre-downloads. +- Preserve current outputs (bin/idx + blend.json) and receipts for resuming. +- Keep the existing planning model (ShardPlan, determinism, caching). + +## Non-Goals + +- Rework SFT/RL data prep formats (JSONL, packed, chat SFT) in this effort. +- Change the on-disk artifact layout or receipt schema. +- Replace fsspec or the current discovery/planning logic. +- Add new distributed download backends beyond HuggingFace and existing fsspec. + +## Current State (Stage 0) + +- `run_data_prep(DataPrepConfig)` loads a blend.json, builds a `PipelineConfig`, + then calls `last_mile_process()` in `data_prep/pipeline.py`. +- `_process_split()` creates a ShardPlan per dataset, applies caching, and + runs pending shards through a Ray actor pool. +- Shard execution uses `process_binidx_shard_core()` in + `data_prep/shard_processor.py`, writing .bin/.idx plus receipt JSON. +- HuggingFace datasets use discovery + optional parallel predownload + (`downloader.parallel_predownload()`). + +## Proposed Design + +### 1) Dict Dataset Input + +Add a new entrypoint that accepts a dict structure describing datasets and +converts it to `DataBlend` in-memory (no required JSON file). + +Proposed schema (aligned with `DataBlend`/`Dataset`): + +```json +{ + "datasets": [ + { + "name": "pile", + "path": "hf://EleutherAI/pile", + "weight": 1.0, + "split": "train", + "subset": null, + "text_field": "text" + } + ] +} +``` + +Per-split mode remains supported: + +```json +{ + "train": [ ... ], + "valid": [ ... ], + "test": [ ... ] +} +``` + +Notes: +- `path` supports local paths/globs, `s3://`, `gs://`, and `hf://`. +- `split` is required for `hf://` paths and optional otherwise. +- Missing fields use current defaults (weight=1.0, text_field="text"). + +This entrypoint will live alongside `run_data_prep()` to preserve backward +compatibility. Stage 0 pretrain will switch to the dict-based entrypoint. + +### 2) Xenna Execution Layer (Thin Wrapper) + +Introduce a new Xenna-based shard runner that reuses the existing planning and +processing logic: + +- Planning: keep `create_shard_plan()` and `get_pending_shards()`. +- Execution: replace the manual Ray actor pool with a Xenna `Stage` that calls + `process_binidx_shard_core()`. + +New flow (per split): + +1. Convert dict input -> `DataBlend` +2. Create/validate shard plans (existing logic) +3. Build a list of `ShardTask` items (one per shard index) +4. Optionally predownload HF files (see below) +5. Run `cosmos_xenna.pipelines.run_pipeline()` with one Stage +6. Aggregate receipts -> stats -> blend.json (existing logic) + +### 3) Xenna Stage Definition (nemotron.data_prep.xenna) + +All Xenna-facing wrappers live under `nemotron.data_prep.xenna` to keep the +integration boundary clean. This includes stage classes, work item types, and +the pipeline runner. + +Define a single stage class (CPU only): + +- `required_resources`: `Resources(cpus=1.0, gpus=0.0)` (configurable). +- `setup()`: initialize tokenizer using `create_tokenizer()` and store + resolved config; set output fs once (fsspec `url_to_fs`). +- `process_data()`: accept one `ShardTask` dict, invoke + `process_binidx_shard_core()` with the cached tokenizer and output fs. + +`ShardTask` fields (minimal): + +``` +{ + "dataset_name": str, + "shard_index": int, + "assignment_json": str, + "plan_hash": str, + "output_dir": str, + "receipts_dir": str, + "text_field": str, + "dtype": "int32|int64|uint16", + "min_doc_chars": int|null, + "max_doc_tokens": int|null, + "max_rows": int|null +} +``` + +The stage is intentionally thin: no new shard logic or I/O semantics, just a +bridge from Xenna tasks to existing core processing. + +### 3.1) Proposed Xenna Stages (Nemotron Data Prep) + +This effort defines a minimal set of Xenna stages. Stage 0 (pretrain) is the +first target; the other formats remain future work. + +#### Stage: PretrainShardStage (bin/idx) + +Responsibilities: +- Process a `ShardTask` (one shard) using `process_binidx_shard_core()` +- Write `.bin/.idx` and receipt JSON + +Sketch: + +```py +import cosmos_xenna.pipelines.v1 as pipelines_v1 +from nemotron.data_prep.shard_processor import process_binidx_shard_core +from nemotron.data_prep.providers import create_tokenizer +from nemotron.data_prep.filesystem import get_filesystem + +class PretrainShardStage(pipelines_v1.Stage): + @property + def stage_batch_size(self) -> int: + return 1 # one shard per task + + @property + def required_resources(self) -> pipelines_v1.Resources: + return pipelines_v1.Resources(gpus=0, cpus=1.0) + + def setup(self, worker_metadata: pipelines_v1.WorkerMetadata) -> None: + self._tokenize = create_tokenizer(self._resolved_tokenizer) + self._output_fs, _ = get_filesystem(self._output_dir) + + def process_data(self, tasks: list[dict]) -> list[dict]: + results = [] + for task in tasks: + stats = process_binidx_shard_core( + tokenize=self._tokenize, + text_field=task["text_field"], + min_doc_chars=task["min_doc_chars"], + max_doc_tokens=task["max_doc_tokens"], + dtype=task["dtype"], + max_rows=task["max_rows"], + shard_index=task["shard_index"], + assignment=task["assignment"], + plan_hash=task["plan_hash"], + output_dir=task["output_dir"], + receipts_dir=task["receipts_dir"], + output_fs=self._output_fs, + ) + results.append(stats) + return results +``` + +Notes: +- All Xenna scaffolding (stage class, work item types, runner) is located in + `nemotron.data_prep.xenna` (e.g., `xenna/stages.py`, `xenna/work_items.py`, + `xenna/runner.py`). +- `assignment` is parsed from `assignment_json` before pipeline submission. +- `_resolved_tokenizer` and `_output_dir` are injected at stage construction time. +- The stage is intentionally thin and delegates all shard semantics to existing code. + +#### Stage (future): JsonlStage + +Not part of this effort. Placeholder if JSONL output migrates later: +- Wrap `jsonl_processor` logic into a Xenna stage. +- Preserve existing transform semantics. + +#### Stage (future): PackedStage / ChatSftStage + +Not part of this effort. Placeholders for packed SFT and chat SFT formats. + +### 4) HuggingFace Optimized Downloading + +Continue using the existing HF downloader, but move it earlier in the flow: + +- Collect all HF files from shard assignments. +- Run `parallel_predownload()` with `max_concurrent_downloads`. +- Store in HF cache (HF_HOME/hub). +- Shard processing uses `hf_hub_download(..., local_files_only=True)` to + avoid network I/O and ensure determinism. + +This keeps HF as a first-class source while minimizing new download machinery. +Cosmos-Xenna distributed download is not used for HF (it targets object stores). + +### 5) Output and Artifacts + +Outputs remain unchanged: + +- `.bin`/`.idx` shards under `runs//datasets///` +- `blend.json` with `train`/`valid`/`test` or `data_paths` + split +- Receipts per shard for resuming and cache hits + +`PretrainBlendsArtifact` creation stays in `run_data_prep()` (or new entrypoint). + +## API and Configuration Changes + +### New entrypoint + +``` +run_data_prep_from_dict( + datasets: dict, + *, + output_dir: Path, + tokenizer_model: str, + ... # existing DataPrepConfig fields +) +``` + +### Execution engine flag + +Add a config flag (default to Xenna for stage0): + +``` +execution_engine: Literal["ray", "xenna"] = "xenna" +``` + +This allows phased rollout and easy fallback. + +## Integration Points + +- `nemotron/data_prep/pipeline.py`: + - add a Xenna execution path (replacing `_process_all_shards_parallel`) + - keep planning, receipts, and manifest generation. +- `nemotron/data_prep/xenna/`: + - new Xenna integration module (stages, work items, runner). +- `nemotron/data_prep/shard_processor.py`: + - reuse `process_binidx_shard_core()` as-is. +- `nemotron/data_prep/downloader.py`: + - reuse `parallel_predownload()` before Xenna execution. +- `recipes/nano3/stage0_pretrain/data_prep.py`: + - accept dict dataset spec and call new entrypoint. +- `recipes/nano3/stage0_pretrain/prep_xenna.py`: + - new Xenna-specific entrypoint for staged validation. +- `cli/nano3/data/prep/pretrain_xenna.py`: + - new recipe command module (mirrors `pretrain.py` style). +- `cli/nano3/data/prep/app.py`: + - register a new `prep xenna` or `prep pretrain-xenna` command via `make_recipe_command`. + +## Rollout Plan + +1. Add `prep_xenna.py` for pretrain and wire it into the CLI as an opt-in path + using the existing recipe command pattern. +2. Land the Xenna execution path behind `execution_engine` flag (default stays Ray). +3. Validate Xenna path in recipes and tests. +4. Switch Stage 0 pretrain to Xenna by default once stable. +5. Keep legacy Ray path for regression fallback until removed. +6. Add docs update describing dict input and Xenna engine selection. + +## Testing Strategy + +- Unit tests: + - dict schema -> DataBlend conversion + - ShardTask mapping (assignment_json, plan_hash, output dirs) +- Integration tests (existing): + - `tests/recipes/nano3/stage0_pretrain/test_data_prep_train_integration.py` + should pass for both engines. +- HF download tests: + - Mock HF cache predownload; verify local_files_only path works. + +## Risks and Mitigations + +- **Xenna pipeline expectations**: Stage `process_data()` must be pure and + pickle-friendly. Mitigate by keeping ShardTask data primitive types only. +- **HF cache not warmed**: Ensure predownload step runs before Xenna pipeline, + or fallback to allowing direct download if cache misses occur. +- **Resource sizing**: Xenna defaults may oversubscribe CPUs. Provide explicit + config mapping to `PipelineConfig` and `Resources`. +- **Behavioral drift**: Keep shard planning + core processor unchanged and + preserve receipts to avoid output changes. + +## Open Questions + +- Should dict input be accepted by `run_data_prep()` directly (optional), or + only by a new explicit entrypoint? +- Should Xenna execution be enabled for non-pretrain formats later, or keep + scope limited to bin/idx until validated? diff --git a/pyproject.toml b/pyproject.toml index 2d16415c6..6cda1e63f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ wandb = ["wandb>=0.15.0"] s3 = ["s3fs>=2024.0.0"] gcs = ["gcsfs>=2024.0.0"] sentencepiece = ["sentencepiece>=0.2.0"] +xenna = ["cosmos-xenna"] dev = [ "pytest>=7.0.0", "pytest-cov>=4.0.0", @@ -60,6 +61,7 @@ all = [ "s3fs>=2024.0.0", "gcsfs>=2024.0.0", "sentencepiece>=0.2.0", + "cosmos-xenna", ] # Note: megatron-bridge is required for training but not listed as a dependency diff --git a/scripts/test_pack_memory.py b/scripts/test_pack_memory.py new file mode 100644 index 000000000..10fcbd20c --- /dev/null +++ b/scripts/test_pack_memory.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Test script to investigate memory usage in the ChatSFT packing stage. +# Creates synthetic spool data and exercises the pack pipeline to identify OOM sources. +# +# Usage: +# uv run python scripts/test_pack_memory.py --num-sequences 150000 --avg-seq-length 2000 +# +# This mimics a single shard from the full SFT pipeline (~440M tokens per shard). + +from __future__ import annotations + +import argparse +import gc +import os +import shutil +import tempfile +import time +from pathlib import Path + +import numpy as np + + +def log_memory(label: str) -> float: + """Log current memory usage and return RSS in GB.""" + try: + import psutil + process = psutil.Process() + rss_gb = process.memory_info().rss / (1024**3) + print(f"[Memory] {label}: RSS={rss_gb:.2f} GB") + return rss_gb + except ImportError: + print(f"[Memory] {label}: psutil not available") + return 0.0 + + +def create_synthetic_spool( + spool_dir: str, + num_sequences: int, + avg_seq_length: int, + length_std: int = 500, + seed: int = 42, +) -> dict: + """Create synthetic spool data mimicking real tokenized sequences.""" + from fsspec.implementations.local import LocalFileSystem + from nemotron.data_prep.packing.spool import SequenceSpoolPaths, SequenceSpoolWriter + + print(f"\n=== Creating synthetic spool ===") + print(f" Sequences: {num_sequences:,}") + print(f" Avg length: {avg_seq_length}") + print(f" Spool dir: {spool_dir}") + + log_memory("Before spool creation") + + fs = LocalFileSystem() + paths = SequenceSpoolPaths.for_root(spool_dir) + writer = SequenceSpoolWriter(fs=fs, paths=paths) + + rng = np.random.default_rng(seed) + total_tokens = 0 + + # Generate sequences in batches to avoid memory spikes + batch_size = 10000 + for batch_start in range(0, num_sequences, batch_size): + batch_end = min(batch_start + batch_size, num_sequences) + batch_count = batch_end - batch_start + + # Generate random lengths for this batch + lengths = rng.normal(avg_seq_length, length_std, batch_count).astype(int) + lengths = np.clip(lengths, 100, avg_seq_length * 3) # Clamp to reasonable range + + for length in lengths: + # Generate random tokens and mask + input_ids = rng.integers(0, 50000, size=length, dtype=np.int32) + loss_mask = rng.integers(0, 2, size=length, dtype=np.uint8) + writer.append(input_ids, loss_mask) + total_tokens += length + + if (batch_end % 50000) == 0 or batch_end == num_sequences: + print(f" Created {batch_end:,}/{num_sequences:,} sequences ({total_tokens:,} tokens)") + + log_memory("After writing sequences") + + manifest = writer.finalize(extra_manifest={ + "shard_id": "test_shard", + "shard_index": 0, + "pack_size": 4096, + "algorithm": "first_fit_decreasing", + "dtype": "int32", + "tokenization_stats": { + "num_input_rows": num_sequences, + "num_output_sequences": num_sequences, + } + }) + + log_memory("After spool finalize") + + return { + "num_sequences": num_sequences, + "total_tokens": total_tokens, + "spool_dir": spool_dir, + } + + +def test_pack_from_spool( + spool_dir: str, + output_dir: str, + pack_size: int = 4096, + algorithm: str = "first_fit_decreasing", +) -> dict: + """Test the packing stage using the actual core function.""" + from fsspec.implementations.local import LocalFileSystem + from nemotron.data_prep.chat_sft_shard_core import process_chat_sft_pack_from_spool_core + + print(f"\n=== Testing pack from spool (using actual core function) ===") + log_memory("Start of pack test") + + fs = LocalFileSystem() + os.makedirs(output_dir, exist_ok=True) + receipts_dir = os.path.join(output_dir, ".receipts") + os.makedirs(receipts_dir, exist_ok=True) + + print("\n--- Calling process_chat_sft_pack_from_spool_core ---") + log_memory("Before core function") + + stats = process_chat_sft_pack_from_spool_core( + shard_index=0, + output_dir=output_dir, + receipts_dir=receipts_dir, + spool_dir=spool_dir, + output_fs=fs, + pack_size=pack_size, + algorithm=algorithm, + dtype=np.dtype("int32"), + seed=42, + ) + + log_memory("After core function returned") + + # Additional cleanup + gc.collect() + log_memory("After gc.collect()") + + npy_path = os.path.join(output_dir, "shard_000000.npy") + npy_size = os.path.getsize(npy_path) if os.path.exists(npy_path) else 0 + + return { + "num_sequences": stats.get("num_sequences", 0), + "num_bins": stats.get("num_packed_sequences", 0), + "npy_size": npy_size, + "total_tokens": stats.get("total_tokens", 0), + } + + +def test_ray_simulation( + spool_dir: str, + output_dir: str, + pack_size: int = 4096, +) -> None: + """Simulate the Ray pipeline behavior with memory logging.""" + import ray + + print(f"\n=== Simulating Ray pipeline behavior ===") + log_memory("Before ray.init()") + + # Initialize Ray locally + if not ray.is_initialized(): + ray.init(num_cpus=4, object_store_memory=2 * 1024**3) # 2GB object store + log_memory("After ray.init()") + + # Run the pack as a Ray task (simulates actor behavior) + @ray.remote + def pack_task(spool_dir: str, output_dir: str, pack_size: int) -> dict: + return test_pack_from_spool(spool_dir, output_dir, pack_size) + + print("\n--- Running pack as Ray task ---") + result_ref = pack_task.remote(spool_dir, output_dir, pack_size) + result = ray.get(result_ref) + log_memory("After ray.get() of pack task") + + # Simulate what happens after pipeline + print("\n--- Simulating post-pipeline ---") + del result_ref + log_memory("After del result_ref") + + gc.collect() + log_memory("After gc.collect()") + + # Wait to see if delayed cleanup causes issues + print("\n--- Waiting 30s to observe delayed memory behavior ---") + for i in range(6): + time.sleep(5) + log_memory(f"After {(i+1)*5}s wait") + + ray.shutdown() + log_memory("After ray.shutdown()") + + +def test_multiple_shards( + num_shards: int, + num_sequences: int, + avg_seq_length: int, + pack_size: int = 4096, +) -> None: + """Test processing multiple shards sequentially to check memory accumulation.""" + print(f"\n{'='*60}") + print(f"Testing {num_shards} sequential shards") + print(f"{'='*60}") + + log_memory("Initial state before multi-shard test") + + temp_dir = tempfile.mkdtemp(prefix="multi_shard_test_") + + try: + for shard_idx in range(num_shards): + print(f"\n--- Shard {shard_idx + 1}/{num_shards} ---") + spool_dir = os.path.join(temp_dir, f"spool_{shard_idx}") + output_dir = os.path.join(temp_dir, f"output_{shard_idx}") + + # Create spool + create_synthetic_spool( + spool_dir=spool_dir, + num_sequences=num_sequences, + avg_seq_length=avg_seq_length, + ) + log_memory(f"After creating spool {shard_idx}") + + # Process shard + test_pack_from_spool(spool_dir, output_dir, pack_size) + log_memory(f"After processing shard {shard_idx}") + + # Force cleanup + gc.collect() + log_memory(f"After gc.collect() for shard {shard_idx}") + + finally: + print(f"\nCleaning up temp dir: {temp_dir}") + shutil.rmtree(temp_dir, ignore_errors=True) + + log_memory("Final state after multi-shard test") + + +def main(): + parser = argparse.ArgumentParser(description="Test memory usage in ChatSFT packing") + parser.add_argument("--num-sequences", type=int, default=150000, + help="Number of sequences to generate (default: 150000, ~1 shard)") + parser.add_argument("--avg-seq-length", type=int, default=2000, + help="Average sequence length in tokens (default: 2000)") + parser.add_argument("--pack-size", type=int, default=4096, + help="Pack size for packing (default: 4096)") + parser.add_argument("--with-ray", action="store_true", + help="Also test with Ray to simulate actor behavior") + parser.add_argument("--keep-temp", action="store_true", + help="Keep temporary files for inspection") + parser.add_argument("--multi-shard", type=int, default=0, + help="Test multiple shards sequentially (0=disabled)") + args = parser.parse_args() + + print("=" * 60) + print("ChatSFT Packing Memory Test") + print("=" * 60) + print(f"Configuration:") + print(f" Sequences: {args.num_sequences:,}") + print(f" Avg length: {args.avg_seq_length}") + print(f" Expected tokens: ~{args.num_sequences * args.avg_seq_length:,}") + print(f" Pack size: {args.pack_size}") + print(f" With Ray: {args.with_ray}") + print(f" Multi-shard: {args.multi_shard}") + + log_memory("Initial state") + + # Multi-shard test mode + if args.multi_shard > 0: + test_multiple_shards( + num_shards=args.multi_shard, + num_sequences=args.num_sequences, + avg_seq_length=args.avg_seq_length, + pack_size=args.pack_size, + ) + return + + # Create temp directories + temp_dir = tempfile.mkdtemp(prefix="pack_memory_test_") + spool_dir = os.path.join(temp_dir, "spool") + output_dir = os.path.join(temp_dir, "output") + + try: + # Create synthetic spool + spool_info = create_synthetic_spool( + spool_dir=spool_dir, + num_sequences=args.num_sequences, + avg_seq_length=args.avg_seq_length, + ) + print(f"\nSpool created: {spool_info['total_tokens']:,} tokens") + + gc.collect() + log_memory("After spool creation + gc") + + if args.with_ray: + # Test with Ray + test_ray_simulation(spool_dir, output_dir, args.pack_size) + else: + # Test without Ray + pack_info = test_pack_from_spool(spool_dir, output_dir, args.pack_size) + print(f"\nPacking complete: {pack_info['num_bins']:,} bins") + + print("\n" + "=" * 60) + print("Test complete!") + log_memory("Final state") + + finally: + if not args.keep_temp: + print(f"\nCleaning up temp dir: {temp_dir}") + shutil.rmtree(temp_dir, ignore_errors=True) + else: + print(f"\nKeeping temp dir: {temp_dir}") + + +if __name__ == "__main__": + main() diff --git a/src/nemotron/cli/nano3/data/prep/app.py b/src/nemotron/cli/nano3/data/prep/app.py index 44f8221b1..da488bc2a 100644 --- a/src/nemotron/cli/nano3/data/prep/app.py +++ b/src/nemotron/cli/nano3/data/prep/app.py @@ -19,8 +19,11 @@ import typer from nemotron.cli.nano3.data.prep.pretrain import pretrain +from nemotron.cli.nano3.data.prep.pretrain_xenna import pretrain_xenna from nemotron.cli.nano3.data.prep.rl import rl +from nemotron.cli.nano3.data.prep.rl_xenna import rl_xenna from nemotron.cli.nano3.data.prep.sft import sft +from nemotron.cli.nano3.data.prep.sft_xenna import sft_xenna from nemotron.cli.nano3.help import make_recipe_command # Create prep app @@ -43,6 +46,17 @@ ), )(pretrain) +prep_app.command( + name="pretrain-xenna", + context_settings={ + "allow_extra_args": True, + "ignore_unknown_options": True, + }, + cls=make_recipe_command( + config_dir="src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep", + ), +)(pretrain_xenna) + prep_app.command( name="sft", context_settings={ @@ -54,6 +68,17 @@ ), )(sft) +prep_app.command( + name="sft-xenna", + context_settings={ + "allow_extra_args": True, + "ignore_unknown_options": True, + }, + cls=make_recipe_command( + config_dir="src/nemotron/recipes/nano3/stage1_sft/config/data_prep", + ), +)(sft_xenna) + prep_app.command( name="rl", context_settings={ @@ -64,3 +89,14 @@ config_dir="src/nemotron/recipes/nano3/stage2_rl/config/data_prep", ), )(rl) + +prep_app.command( + name="rl-xenna", + context_settings={ + "allow_extra_args": True, + "ignore_unknown_options": True, + }, + cls=make_recipe_command( + config_dir="src/nemotron/recipes/nano3/stage2_rl/config/data_prep", + ), +)(rl_xenna) diff --git a/src/nemotron/cli/nano3/data/prep/pretrain_xenna.py b/src/nemotron/cli/nano3/data/prep/pretrain_xenna.py new file mode 100644 index 000000000..e099a84f9 --- /dev/null +++ b/src/nemotron/cli/nano3/data/prep/pretrain_xenna.py @@ -0,0 +1,37 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pretrain data preparation command (Xenna execution).""" + +from __future__ import annotations + +import typer + +from nemotron.kit.cli.recipe import recipe + + +@recipe( + name="nano3/data/prep/pretrain-xenna", + script_path="src/nemotron/recipes/nano3/stage0_pretrain/prep_xenna.py", + config_dir="src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep", + default_config="default", + torchrun=False, + ray=True, + packager="code", + # Use --extra xenna so Ray's uv hook propagates cosmos-xenna to workers + run_command="uv run --extra xenna python {script} --config {config}", +) +def pretrain_xenna(ctx: typer.Context) -> None: + """Tokenize data for pretraining (bin/idx) using Xenna.""" + ... diff --git a/src/nemotron/cli/nano3/data/prep/rl_xenna.py b/src/nemotron/cli/nano3/data/prep/rl_xenna.py new file mode 100644 index 000000000..5d19bde8d --- /dev/null +++ b/src/nemotron/cli/nano3/data/prep/rl_xenna.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RL data preparation command (Xenna execution).""" + +from __future__ import annotations + +import typer + +from nemotron.kit.cli.recipe import recipe + + +@recipe( + name="nano3/data/prep/rl-xenna", + script_path="src/nemotron/recipes/nano3/stage2_rl/data_prep_xenna.py", + config_dir="src/nemotron/recipes/nano3/stage2_rl/config/data_prep", + default_config="default", + torchrun=False, + ray=True, + packager="code", + run_command="uv run --extra xenna python {script} --config {config}", +) +def rl_xenna(ctx: typer.Context) -> None: + """Prepare data for RL (JSONL chat format) using Xenna.""" + ... diff --git a/src/nemotron/cli/nano3/data/prep/sft_xenna.py b/src/nemotron/cli/nano3/data/prep/sft_xenna.py new file mode 100644 index 000000000..0eadd77fc --- /dev/null +++ b/src/nemotron/cli/nano3/data/prep/sft_xenna.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SFT data preparation command (Xenna execution).""" + +from __future__ import annotations + +import typer + +from nemotron.kit.cli.recipe import recipe + + +@recipe( + name="nano3/data/prep/sft-xenna", + script_path="src/nemotron/recipes/nano3/stage1_sft/data_prep_xenna.py", + config_dir="src/nemotron/recipes/nano3/stage1_sft/config/data_prep", + default_config="default", + torchrun=False, + ray=True, + packager="code", + run_command="uv run --extra xenna python {script} --config {config}", +) +def sft_xenna(ctx: typer.Context) -> None: + """Prepare data for SFT (packed .npy format) using Xenna.""" + ... diff --git a/src/nemotron/data_prep/__init__.py b/src/nemotron/data_prep/__init__.py index adf0459cb..ecb1accb0 100644 --- a/src/nemotron/data_prep/__init__.py +++ b/src/nemotron/data_prep/__init__.py @@ -180,8 +180,15 @@ class DataPrepConfig: ray_data_cpus_per_actor: float = 1.0 """CPUs per actor for Ray Data executor""" - ray_data_max_tasks_in_flight: int = 2 - """Max tasks in flight per actor (pipelining depth)""" + ray_data_max_tasks_in_flight: int = 4 + """Max tasks in flight per actor (pipelining depth for better I/O overlap)""" + + max_concurrent_downloads: int = 64 + """Maximum parallel HuggingFace file downloads during pre-download phase. + Higher values increase throughput but may overwhelm HF servers or network.""" + + cleanup_hf_cache: bool = False + """Delete HuggingFace cache after processing. Useful for one-off jobs.""" console_mode: Literal["rich", "simple"] = "simple" """Console output mode: 'rich' for animated progress bars, 'simple' for periodic text updates""" @@ -189,6 +196,21 @@ class DataPrepConfig: simple_log_interval_sec: int = 30 """Interval in seconds between status updates in simple console mode (default: 30)""" + execution_engine: Literal["ray", "xenna"] = "ray" + """Execution backend for shard processing.""" + + wandb_log_downloads: bool = False + """Log download progress metrics to W&B (Xenna path only).""" + + wandb_download_log_interval_sec: int = 30 + """Interval (seconds) for W&B download progress logging.""" + + hf_download_timeout_sec: int = 300 + """Per-file HF download timeout in seconds (Xenna path only).""" + + hf_download_max_retries: int = 3 + """Max retries for HF downloads before giving up (Xenna path only).""" + def run_data_prep( config: DataPrepConfig, *, artifact_class: type = PretrainBlendsArtifact @@ -240,31 +262,40 @@ def run_data_prep( # Resolve output_dir to absolute path for W&B artifact storage output_dir = config.output_dir.resolve() if hasattr(config.output_dir, 'resolve') else Path(config.output_dir).resolve() - # Initialize Ray early so we can query cluster resources - import ray - - if not ray.is_initialized(): - runtime_env = { - "excludes": [ - "output/", - "outputs/", - "wandb/", - "data/", - "checkpoints/", - "*.bin", - "*.idx", - "*.npy", - "__pycache__/", - ".git/", - ".venv/", - "*.egg-info/", - ] - } - ray.init(address="auto", ignore_reinit_error=True, runtime_env=runtime_env) + # Initialize Ray for download tasks (Xenna and Ray executors both use Ray) + if config.execution_engine in ("ray", "xenna"): + import ray + + if not ray.is_initialized(): + runtime_env = { + "excludes": [ + "output/", + "outputs/", + "wandb/", + "data/", + "checkpoints/", + "*.bin", + "*.idx", + "*.npy", + "__pycache__/", + ".git/", + ".venv/", + "*.egg-info/", + ], + "env_vars": {}, + } + # Pass HF_HOME to Ray actors for persistent dataset caching on Lustre + if os.environ.get("HF_HOME"): + runtime_env["env_vars"]["HF_HOME"] = os.environ["HF_HOME"] + # Pass HF_TOKEN for private dataset access + if os.environ.get("HF_TOKEN"): + runtime_env["env_vars"]["HF_TOKEN"] = os.environ["HF_TOKEN"] + + ray.init(address="auto", ignore_reinit_error=True, runtime_env=runtime_env) # Build Ray Data config if enabled, auto-detecting cluster resources ray_data_config = None - if config.ray_data_enabled: + if config.execution_engine == "ray" and config.ray_data_enabled: from nemotron.data_prep.config import RayDataConfig # Auto-detect available CPUs from Ray cluster @@ -277,11 +308,36 @@ def run_data_prep( # Use the highest available CPU count (Ray may report fewer due to config issues) available_cpus = max(int(ray_cpus), slurm_cpus, os_cpus) - # Use most of available CPUs for actors (leave some headroom) - # min_actors = start with good parallelism - # max_actors = allow scaling up to use all CPUs + # CPU-based limit (use 90% of CPUs) cpus_per_actor = config.ray_data_cpus_per_actor - auto_max_actors = int(available_cpus * 0.9 / cpus_per_actor) # Use 90% of CPUs + cpu_based_limit = int(available_cpus * 0.9 / cpus_per_actor) + + # Memory-based limit to prevent OOM + # Each actor loads a tokenizer (~1GB) + needs working memory (~1GB) = ~2GB total + # Ray's object_store_memory is pre-allocated from system RAM + # Worker memory = total system RAM - object store - overhead + ray_memory = cluster_resources.get("memory", 0) + object_store = cluster_resources.get("object_store_memory", 0) + + if ray_memory > 0 and object_store > 0: + # Worker memory is roughly: total - object_store - 10% overhead + total_memory_gb = ray_memory / (1024**3) + object_store_gb = object_store / (1024**3) + worker_memory_gb = (total_memory_gb - object_store_gb) * 0.9 + # Estimate: tokenizer (~1-2GB) + working memory (~1GB) = ~3GB per actor + memory_per_actor_gb = 3.0 + memory_based_limit = max(4, int(worker_memory_gb / memory_per_actor_gb)) + else: + # Fallback: use conservative 50% of CPUs when memory info unavailable + memory_based_limit = int(available_cpus * 0.5 / cpus_per_actor) + total_memory_gb = 0 + object_store_gb = 0 + worker_memory_gb = 0 + + # Use the more restrictive limit to prevent OOM + auto_max_actors = min(cpu_based_limit, memory_based_limit) + + # Apply user override if specified if config.ray_data_max_actors is not None: max_actors = min(config.ray_data_max_actors, auto_max_actors) else: @@ -291,6 +347,9 @@ def run_data_prep( # Log resource detection for debugging print(f"Ray cluster resources: {cluster_resources}") print(f"CPU detection: Ray={ray_cpus}, SLURM={slurm_cpus}, os={os_cpus} -> using {available_cpus}") + if ray_memory > 0: + print(f"Memory detection: total={total_memory_gb:.1f}GB, object_store={object_store_gb:.1f}GB, worker={worker_memory_gb:.1f}GB") + print(f"Actor limits: CPU-based={cpu_based_limit}, Memory-based={memory_based_limit}") print(f"Ray Data config: min_actors={min_actors}, max_actors={max_actors}") # Log W&B status for debugging @@ -309,6 +368,8 @@ def run_data_prep( max_actors=max_actors, cpus_per_actor=cpus_per_actor, max_tasks_in_flight_per_actor=config.ray_data_max_tasks_in_flight, + max_concurrent_downloads=config.max_concurrent_downloads, + cleanup_hf_cache=config.cleanup_hf_cache, ) pipeline_config = PipelineConfig( @@ -330,6 +391,12 @@ def run_data_prep( ray_data=ray_data_config, console_mode=config.console_mode, simple_log_interval_sec=config.simple_log_interval_sec, + execution_engine=config.execution_engine, + max_concurrent_downloads=config.max_concurrent_downloads, + wandb_log_downloads=config.wandb_log_downloads, + wandb_download_log_interval_sec=config.wandb_download_log_interval_sec, + hf_download_timeout_sec=config.hf_download_timeout_sec, + hf_download_max_retries=config.hf_download_max_retries, ) # Run processing pipeline @@ -386,6 +453,19 @@ def run_data_prep( ) artifact.save() + # Cleanup HuggingFace cache if requested + if config.cleanup_hf_cache: + hf_home = os.environ.get("HF_HOME") + if hf_home and os.path.isdir(hf_home): + import shutil + + print(f"Cleaning up HF cache: {hf_home}") + try: + shutil.rmtree(hf_home) + print(f"HF cache deleted: {hf_home}") + except Exception as e: + print(f"Failed to delete HF cache: {e}") + # Mark wandb run as successful (before Ray shutdown to avoid socket noise) finish_wandb(exit_code=0) diff --git a/src/nemotron/data_prep/chat_sft_processor.py b/src/nemotron/data_prep/chat_sft_processor.py index 40f234add..f2b282ed6 100644 --- a/src/nemotron/data_prep/chat_sft_processor.py +++ b/src/nemotron/data_prep/chat_sft_processor.py @@ -12,52 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""ChatSftShardProcessor Ray actor for parallel chat-templated SFT output processing. - -Applies chat templates to OpenAI-format messages, tokenizes with role-based -loss masking, and outputs packed .npy files compatible with GPTSFTPackedDataset. - -Pipeline: -1. Apply materialize.py chat template logic -> role-labeled chunks -2. Tokenize chunks -> input_ids -3. Build loss_mask based on role (0=system/user, 1=assistant) -4. Pack sequences -> .npy output -""" +"""ChatSftShardProcessor Ray actor for parallel chat-templated SFT output processing.""" from __future__ import annotations -import json import logging -import time -from collections.abc import Iterator from pathlib import Path import numpy as np -import pyarrow.parquet as pq import ray from fsspec import filesystem -from nemotron.data_prep.chat_template import ( - create_masked_messages, - replace_json_args, - split_system_user_chunks, - validate_conversation, -) +from nemotron.data_prep.chat_sft_shard_core import process_chat_sft_shard_core from nemotron.data_prep.config import FileInfo -from nemotron.data_prep.filesystem import ensure_dir, write_json -from nemotron.data_prep.packing.builder import PackedSequenceBuilder logger = logging.getLogger(__name__) @ray.remote class ChatSftShardProcessor: - """Ray actor for chat-templated SFT output with loss masking. - - Reads input files with OpenAI-format messages, applies chat templates - to generate loss-masked sequences, packs them, and writes to .npy files - compatible with Megatron-Bridge's GPTSFTPackedDataset. - """ + """Ray actor for chat-templated SFT output with loss masking.""" def __init__( self, @@ -74,22 +48,6 @@ def __init__( used_in_filter: str | None = None, used_in_field: str = "used_in", ): - """Initialize chat SFT processor. - - Args: - resolved_tokenizer: Tokenizer configuration dict with resolved SHA. - messages_field: Field name for messages in input records. - tools_field: Field name for tools in input records. - pack_size: Maximum tokens per packed sequence. - algorithm: Packing algorithm. - dtype: Token dtype for output. - chat_template: "nano3", path to .jinja file, or inline template string. - max_doc_tokens: Truncate sequences longer than this. - max_rows: Maximum rows to process per shard. - seed: Random seed for shuffle-based algorithms. - used_in_filter: Filter to only include records where used_in contains this value. - used_in_field: Field name for used_in filtering (default: "used_in"). - """ from transformers import AutoTokenizer self.messages_field = messages_field @@ -103,26 +61,21 @@ def __init__( self.used_in_filter = used_in_filter self.used_in_field = used_in_field - # Load HuggingFace tokenizer with full chat template support self._tokenizer = AutoTokenizer.from_pretrained( resolved_tokenizer["model"], revision=resolved_tokenizer.get("resolved_revision"), trust_remote_code=resolved_tokenizer.get("trust_remote_code", False), ) - # Load chat template if chat_template: if chat_template == "nano3": - # Load bundled template template_path = Path(__file__).parent / "templates" / "nano3.jinja" with open(template_path) as f: self._tokenizer.chat_template = f.read() elif Path(chat_template).exists(): - # Load from file path with open(chat_template) as f: self._tokenizer.chat_template = f.read() else: - # Assume inline template string self._tokenizer.chat_template = chat_template def process_shard( @@ -133,365 +86,25 @@ def process_shard( receipts_dir: str, fs_protocol: str, ) -> dict: - """Process files to a single packed shard with loss masks. - - Args: - shard_index: Index of this shard. - files: List of FileInfo dicts to process. - output_dir: Output directory for .npy files. - receipts_dir: Directory for receipt files. - fs_protocol: Filesystem protocol (e.g., "file", "s3"). - - Returns: - Shard statistics dict. - """ + """Process files to a single packed shard with loss masks.""" fs = filesystem(fs_protocol) - shard_id = f"shard_{shard_index:06d}" - npy_path = f"{output_dir}/{shard_id}.npy" - receipt_path = f"{receipts_dir}/{shard_id}.json" - - # Ensure directories - ensure_dir(fs, output_dir) - ensure_dir(fs, receipts_dir) - - # Stats tracking - stats = { - "num_input_rows": 0, - "num_output_sequences": 0, - "num_filtered": 0, - "num_validation_errors": 0, - "num_truncated": 0, - "num_errors": 0, - } - - # Convert file dicts back to FileInfo - file_infos = [FileInfo(**f) for f in files] - input_file_paths = [f.path for f in file_infos] - - # Handle empty assignment - if not file_infos: - return self._write_empty_receipt( - shard_id, - shard_index, - input_file_paths, - stats, - receipt_path, - fs, - ) - - # Create packing builder - builder = PackedSequenceBuilder( + return process_chat_sft_shard_core( + shard_index=shard_index, + files=[FileInfo(**f) for f in files], + output_dir=output_dir, + receipts_dir=receipts_dir, + output_fs=fs, + tokenizer=self._tokenizer, + messages_field=self.messages_field, + tools_field=self.tools_field, pack_size=self.pack_size, algorithm=self.algorithm, + dtype=self.dtype, + chat_template=None, + max_doc_tokens=self.max_doc_tokens, + max_rows=self.max_rows, seed=self.seed, - dtype=str(self.dtype), - ) - - # Track rows processed across files for max_rows limit - rows_processed = 0 - - # Process files SEQUENTIALLY for determinism - for file_info in file_infos: - rows_processed = self._process_file(file_info, builder, stats, fs, rows_processed) - # Stop if we've hit max_rows - if self.max_rows and rows_processed >= self.max_rows: - break - - # Finalize packing - packed_data, packing_metadata = builder.finalize() - - # Handle empty result (all rows filtered) - if not packed_data: - return self._write_empty_receipt( - shard_id, - shard_index, - input_file_paths, - stats, - receipt_path, - fs, - ) - - # Save packed data as .npy - with fs.open(npy_path, "wb") as f: - np.save(f, packed_data, allow_pickle=True) - - # Get file size - npy_bytes = fs.size(npy_path) - - # Write receipt (commits the shard) - receipt = { - "shard_id": shard_id, - "shard_index": shard_index, - "status": "completed", - "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), - "input_files": input_file_paths, - "output_file": f"{shard_id}.npy", - "npy_bytes": npy_bytes, - "packing": packing_metadata, - "stats": { - "num_sequences": packing_metadata["num_sequences"], - "num_packed_sequences": packing_metadata["num_packed_sequences"], - "total_tokens": packing_metadata["total_tokens"], - **stats, - }, - } - - write_json(fs, receipt_path, receipt) - return receipt["stats"] - - def _process_file( - self, - file_info: FileInfo, - builder: PackedSequenceBuilder, - stats: dict, - fs, - rows_processed: int = 0, - ) -> int: - """Process a single file, adding sequences to builder. - - Returns the total number of rows processed (for max_rows tracking). - """ - # Resolve file path - handle HF deferred download - local_path = self._resolve_file_path(file_info) - - # Determine file type and iterate records - is_parquet = local_path.endswith(".parquet") or not ( - local_path.endswith(".jsonl") or local_path.endswith(".json") + used_in_filter=self.used_in_filter, + used_in_field=self.used_in_field, ) - - if is_parquet: - record_iter = self._iter_parquet_records(local_path, fs) - else: - record_iter = self._iter_jsonl_records(local_path, fs) - - for record in record_iter: - # Check max_rows limit - if self.max_rows and rows_processed >= self.max_rows: - break - - stats["num_input_rows"] += 1 - rows_processed += 1 - - self._process_record(record, builder, stats) - - return rows_processed - - def _process_record( - self, - record: dict, - builder: PackedSequenceBuilder, - stats: dict, - ) -> None: - """Process a single record using materialize.py logic.""" - # Apply used_in filter if configured - if self.used_in_filter: - used_in = record.get(self.used_in_field) - if not self._matches_used_in_filter(used_in): - stats["num_filtered"] += 1 - return - - messages = record.get(self.messages_field) - tools = record.get(self.tools_field) - - # Skip if no messages - if not messages: - stats["num_filtered"] += 1 - return - - # Step 1: Validate (materialize_fast.py checks) - is_valid, error = validate_conversation(messages, tools) - if not is_valid: - stats["num_filtered"] += 1 - stats["num_validation_errors"] += 1 - return - - # Step 2: Pre-process (materialize.py) - try: - messages = replace_json_args(messages) - except (json.JSONDecodeError, KeyError, TypeError) as e: - stats["num_filtered"] += 1 - stats["num_errors"] += 1 - logger.debug(f"Error in replace_json_args: {e}") - return - - # Step 3: Apply chat template, get role-labeled chunks (materialize.py) - try: - masked_results = create_masked_messages(messages, self._tokenizer, tools) - except Exception as e: - stats["num_filtered"] += 1 - stats["num_errors"] += 1 - logger.debug(f"Error in create_masked_messages: {e}") - return - - # Step 4: For each output sequence (may be multiple due to multi-turn splitting) - for chunks, _ in masked_results: - # Post-process: split system/user (materialize_fast.py) - processed_chunks = split_system_user_chunks(chunks) - - # Step 5: Tokenize and build loss_mask - try: - input_ids, loss_mask = self._tokenize_chunks_with_mask(processed_chunks) - except Exception as e: - stats["num_errors"] += 1 - logger.debug(f"Error tokenizing chunks: {e}") - continue - - # Skip empty sequences - if not input_ids: - continue - - # Truncate if needed - if self.max_doc_tokens and len(input_ids) > self.max_doc_tokens: - input_ids = input_ids[: self.max_doc_tokens] - loss_mask = loss_mask[: self.max_doc_tokens] - stats["num_truncated"] += 1 - - # Step 6: Add to packer - builder.add_sequence(input_ids, loss_mask=loss_mask) - stats["num_output_sequences"] += 1 - - def _tokenize_chunks_with_mask(self, chunks: list[dict]) -> tuple[list[int], list[int]]: - """Tokenize chunks and generate loss_mask based on role. - - Loss mask: 0 for system/user chunks, 1 for assistant chunks. - - Args: - chunks: List of chunks with 'role' and 'content' fields. - - Returns: - Tuple of (input_ids, loss_mask). - """ - all_input_ids: list[int] = [] - all_loss_mask: list[int] = [] - - for chunk in chunks: - # Tokenize the pre-rendered content (no special tokens - already in template) - tokens = self._tokenizer.encode(chunk["content"], add_special_tokens=False) - - # Build mask based on role: assistant = 1, others = 0 - mask_value = 1 if chunk["role"] == "assistant" else 0 - mask = [mask_value] * len(tokens) - - all_input_ids.extend(tokens) - all_loss_mask.extend(mask) - - return all_input_ids, all_loss_mask - - def _resolve_file_path(self, file_info: FileInfo) -> str: - """Resolve file to a local path, downloading from HF if needed.""" - if file_info.hf_repo_id is not None: - from huggingface_hub import hf_hub_download - - local_path = hf_hub_download( - repo_id=file_info.hf_repo_id, - filename=file_info.hf_filename, - revision=file_info.hf_revision, - repo_type="dataset", - local_files_only=False, - ) - return local_path - - return file_info.local_path or file_info.path - - def _iter_parquet_records(self, path: str, fs) -> Iterator[dict]: - """Iterate records from parquet file.""" - if self._is_remote_path(path): - with fs.open(path, "rb") as f: - parquet_file = pq.ParquetFile(f) - yield from self._iter_parquet_batches_as_dicts(parquet_file) - else: - parquet_file = pq.ParquetFile(path) - yield from self._iter_parquet_batches_as_dicts(parquet_file) - - def _iter_parquet_batches_as_dicts(self, parquet_file: pq.ParquetFile) -> Iterator[dict]: - """Iterate batches from parquet file as dicts.""" - for batch in parquet_file.iter_batches(batch_size=1000): - table = batch.to_pydict() - # Transpose from column-oriented to row-oriented - keys = list(table.keys()) - num_rows = len(table[keys[0]]) if keys else 0 - for i in range(num_rows): - yield {k: table[k][i] for k in keys} - - def _iter_jsonl_records(self, path: str, fs) -> Iterator[dict]: - """Iterate records from JSONL file.""" - if self._is_remote_path(path): - with fs.open(path, "r") as f: - for line in f: - if line.strip(): - yield json.loads(line) - else: - with open(path) as f: - for line in f: - if line.strip(): - yield json.loads(line) - - def _is_remote_path(self, path: str) -> bool: - """Check if path is a remote path (S3/GCS/etc).""" - return path.startswith(("s3://", "gs://", "gcs://", "az://", "abfs://")) - - def _matches_used_in_filter(self, used_in: str | list | None) -> bool: - """Check if record's used_in field matches the filter. - - Args: - used_in: Value of the used_in field (can be string, list, or None). - - Returns: - True if the filter matches, False otherwise. - """ - if used_in is None: - return False - - # Handle list format (e.g., ["nano_v3", "prod_v1"]) - if isinstance(used_in, list): - return self.used_in_filter in used_in - - # Handle string format (e.g., "nano_v3" or "nano_v3,prod_v1") - if isinstance(used_in, str): - # Check for exact match first - if used_in == self.used_in_filter: - return True - # Check comma-separated values - values = [v.strip() for v in used_in.split(",")] - return self.used_in_filter in values - - return False - - def _write_empty_receipt( - self, - shard_id: str, - shard_index: int, - input_files: list[str], - stats: dict, - receipt_path: str, - fs, - ) -> dict: - """Write receipt for empty shard.""" - receipt = { - "shard_id": shard_id, - "shard_index": shard_index, - "status": "completed", - "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), - "input_files": input_files, - "output_file": None, - "npy_bytes": 0, - "packing": { - "pack_size": self.pack_size, - "algorithm": self.algorithm, - "num_sequences": 0, - "num_packed_sequences": 0, - "packing_factor": 0, - "packing_efficiency": 0, - "total_tokens": 0, - }, - "stats": { - "num_sequences": 0, - "num_packed_sequences": 0, - "total_tokens": 0, - **stats, - }, - } - - write_json(fs, receipt_path, receipt) - return receipt["stats"] diff --git a/src/nemotron/data_prep/chat_sft_shard_core.py b/src/nemotron/data_prep/chat_sft_shard_core.py new file mode 100644 index 000000000..723910698 --- /dev/null +++ b/src/nemotron/data_prep/chat_sft_shard_core.py @@ -0,0 +1,738 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core ChatSFT shard processing (retry-safe, engine-agnostic).""" + +from __future__ import annotations + +import json +import time +from collections.abc import Iterator +from pathlib import Path +from typing import Any + +import numpy as np +import pyarrow.parquet as pq +from fsspec import AbstractFileSystem +from transformers import PreTrainedTokenizerBase + +from nemotron.data_prep.chat_template import ( + create_masked_messages, + replace_json_args, + split_system_user_chunks, + validate_conversation, +) +from nemotron.data_prep.config import FileInfo +from nemotron.data_prep.filesystem import ensure_dir, get_filesystem, read_json, write_json +from nemotron.data_prep.packing.algorithms import get_packer +from nemotron.data_prep.packing.bin_assignment import BinAssignment +from nemotron.data_prep.packing.builder import PackedSequenceBuilder +from nemotron.data_prep.packing.materialize import materialize_packed_samples +from nemotron.data_prep.packing.spool import ( + SequenceSpoolPaths, + SequenceSpoolReader, + SequenceSpoolWriter, +) + + +def process_chat_sft_shard_core( + *, + shard_index: int, + files: list[dict] | list[FileInfo], + output_dir: str, + receipts_dir: str, + output_fs: AbstractFileSystem, + tokenizer: PreTrainedTokenizerBase, + messages_field: str, + tools_field: str, + pack_size: int, + algorithm: str, + dtype: np.dtype, + chat_template: str | None, + max_doc_tokens: int | None, + max_rows: int | None, + seed: int | None, + used_in_filter: str | None, + used_in_field: str, +) -> dict[str, Any]: + """Process a ChatSFT shard with retry-safe atomic commits.""" + shard_id = f"shard_{shard_index:06d}" + npy_path = f"{output_dir}/{shard_id}.npy" + npy_tmp = f"{npy_path}.tmp" + receipt_path = f"{receipts_dir}/{shard_id}.json" + + if output_fs.exists(receipt_path): + try: + receipt = read_json(output_fs, receipt_path) + if receipt.get("status") == "completed": + return receipt.get("stats", {}) + except Exception: + pass + + ensure_dir(output_fs, output_dir) + ensure_dir(output_fs, receipts_dir) + + file_infos = [FileInfo(**f) if isinstance(f, dict) else f for f in files] + input_file_paths = [f.path for f in file_infos] + + stats: dict[str, Any] = { + "num_input_rows": 0, + "num_output_sequences": 0, + "num_filtered": 0, + "num_validation_errors": 0, + "num_truncated": 0, + "num_errors": 0, + } + + if not file_infos: + return _write_empty_receipt( + shard_id=shard_id, + shard_index=shard_index, + input_files=input_file_paths, + stats=stats, + receipt_path=receipt_path, + output_fs=output_fs, + pack_size=pack_size, + algorithm=algorithm, + ) + + if chat_template: + _apply_chat_template(tokenizer, chat_template) + + builder = PackedSequenceBuilder( + pack_size=pack_size, + algorithm=algorithm, + seed=seed, + dtype=str(dtype), + ) + + rows_processed = 0 + for file_info in file_infos: + rows_processed = _process_file( + file_info=file_info, + builder=builder, + stats=stats, + tokenizer=tokenizer, + messages_field=messages_field, + tools_field=tools_field, + max_doc_tokens=max_doc_tokens, + max_rows=max_rows, + rows_processed=rows_processed, + used_in_filter=used_in_filter, + used_in_field=used_in_field, + ) + if max_rows and rows_processed >= max_rows: + break + + packed_data, packing_metadata = builder.finalize() + + if not packed_data: + return _write_empty_receipt( + shard_id=shard_id, + shard_index=shard_index, + input_files=input_file_paths, + stats=stats, + receipt_path=receipt_path, + output_fs=output_fs, + pack_size=pack_size, + algorithm=algorithm, + ) + + with output_fs.open(npy_tmp, "wb") as f: + np.save(f, packed_data, allow_pickle=True) + + output_fs.rename(npy_tmp, npy_path) + npy_bytes = output_fs.size(npy_path) + + receipt = { + "shard_id": shard_id, + "shard_index": shard_index, + "status": "completed", + "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "input_files": input_file_paths, + "output_file": f"{shard_id}.npy", + "npy_bytes": npy_bytes, + "packing": packing_metadata, + "stats": { + "num_sequences": packing_metadata.get("num_sequences", 0), + "num_packed_sequences": packing_metadata.get("num_packed_sequences", 0), + "total_tokens": packing_metadata.get("total_tokens", 0), + **stats, + }, + } + + write_json(output_fs, receipt_path, receipt) + return receipt["stats"] + + +def _apply_chat_template(tokenizer: PreTrainedTokenizerBase, chat_template: str) -> None: + if chat_template == "nano3": + template_path = Path(__file__).parent / "templates" / "nano3.jinja" + with open(template_path) as f: + tokenizer.chat_template = f.read() + elif Path(chat_template).exists(): + with open(chat_template) as f: + tokenizer.chat_template = f.read() + else: + tokenizer.chat_template = chat_template + + +def _process_file( + *, + file_info: FileInfo, + builder: PackedSequenceBuilder, + stats: dict[str, Any], + tokenizer: PreTrainedTokenizerBase, + messages_field: str, + tools_field: str, + max_doc_tokens: int | None, + max_rows: int | None, + rows_processed: int, + used_in_filter: str | None, + used_in_field: str, +) -> int: + local_path = _resolve_file_path(file_info) + input_path = local_path if file_info.hf_repo_id is not None else (file_info.local_path or file_info.path) + input_fs, normalized = get_filesystem(input_path) + + # Use original filename for format detection (hf_hub_download returns blob path without extension) + format_check_path = (file_info.hf_filename or normalized) if file_info.hf_repo_id else normalized + is_parquet = format_check_path.endswith(".parquet") or not ( + format_check_path.endswith(".jsonl") or format_check_path.endswith(".json") + ) + + if is_parquet: + record_iter = _iter_parquet_records(normalized, input_fs) + else: + record_iter = _iter_jsonl_records(normalized, input_fs) + + for record in record_iter: + if max_rows and rows_processed >= max_rows: + break + + stats["num_input_rows"] += 1 + rows_processed += 1 + + _process_record( + record=record, + builder=builder, + stats=stats, + tokenizer=tokenizer, + messages_field=messages_field, + tools_field=tools_field, + max_doc_tokens=max_doc_tokens, + used_in_filter=used_in_filter, + used_in_field=used_in_field, + ) + + return rows_processed + + +def _process_record( + *, + record: dict, + builder: PackedSequenceBuilder, + stats: dict[str, Any], + tokenizer: PreTrainedTokenizerBase, + messages_field: str, + tools_field: str, + max_doc_tokens: int | None, + used_in_filter: str | None, + used_in_field: str, +) -> None: + if used_in_filter: + used_in = record.get(used_in_field) + if not _matches_used_in_filter(used_in, used_in_filter): + stats["num_filtered"] += 1 + return + + messages = record.get(messages_field) + tools = record.get(tools_field) + + if not messages: + stats["num_filtered"] += 1 + return + + is_valid, _ = validate_conversation(messages, tools) + if not is_valid: + stats["num_filtered"] += 1 + stats["num_validation_errors"] += 1 + return + + try: + messages = replace_json_args(messages) + except (json.JSONDecodeError, KeyError, TypeError): + stats["num_filtered"] += 1 + stats["num_errors"] += 1 + return + + try: + masked_results = create_masked_messages(messages, tokenizer, tools) + except Exception: + stats["num_filtered"] += 1 + stats["num_errors"] += 1 + return + + for chunks, _ in masked_results: + processed_chunks = split_system_user_chunks(chunks) + try: + input_ids, loss_mask = _tokenize_chunks_with_mask(tokenizer, processed_chunks) + except Exception: + stats["num_errors"] += 1 + continue + + if not input_ids: + continue + + if max_doc_tokens and len(input_ids) > max_doc_tokens: + input_ids = input_ids[:max_doc_tokens] + loss_mask = loss_mask[:max_doc_tokens] + stats["num_truncated"] += 1 + + builder.add_sequence(input_ids, loss_mask=loss_mask) + stats["num_output_sequences"] += 1 + + +def _tokenize_chunks_with_mask( + tokenizer: PreTrainedTokenizerBase, + chunks: list[dict], +) -> tuple[list[int], list[int]]: + all_input_ids: list[int] = [] + all_loss_mask: list[int] = [] + + for chunk in chunks: + tokens = tokenizer.encode(chunk["content"], add_special_tokens=False) + mask_value = 1 if chunk["role"] == "assistant" else 0 + mask = [mask_value] * len(tokens) + all_input_ids.extend(tokens) + all_loss_mask.extend(mask) + + return all_input_ids, all_loss_mask + + +def _resolve_file_path(file_info: FileInfo) -> str: + if file_info.hf_repo_id is not None: + from huggingface_hub import hf_hub_download + + return hf_hub_download( + repo_id=file_info.hf_repo_id, + filename=file_info.hf_filename, + revision=file_info.hf_revision, + repo_type="dataset", + local_files_only=True, # Files should be pre-downloaded by HfPredownloadStage + ) + + return file_info.local_path or file_info.path + + +def _iter_parquet_records(path: str, fs: AbstractFileSystem) -> Iterator[dict]: + try: + with fs.open(path, "rb") as f: + parquet_file = pq.ParquetFile(f) + for batch in parquet_file.iter_batches(batch_size=1000): + table = batch.to_pydict() + keys = list(table.keys()) + num_rows = len(table[keys[0]]) if keys else 0 + for i in range(num_rows): + yield {k: table[k][i] for k in keys} + except Exception as e: + raise RuntimeError(f"Failed to read parquet file: {path}") from e + + +def _iter_jsonl_records(path: str, fs: AbstractFileSystem) -> Iterator[dict]: + with fs.open(path, "r") as f: + for line in f: + if line.strip(): + yield json.loads(line) + + +def _matches_used_in_filter(used_in: str | list | None, used_in_filter: str) -> bool: + if used_in is None: + return False + + if isinstance(used_in, list): + return used_in_filter in used_in + + if isinstance(used_in, str): + if used_in == used_in_filter: + return True + values = [v.strip() for v in used_in.split(",")] + return used_in_filter in values + + return False + + +def _write_empty_receipt( + *, + shard_id: str, + shard_index: int, + input_files: list[str], + stats: dict[str, Any], + receipt_path: str, + output_fs: AbstractFileSystem, + pack_size: int, + algorithm: str, +) -> dict[str, Any]: + receipt = { + "shard_id": shard_id, + "shard_index": shard_index, + "status": "completed", + "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "input_files": input_files, + "output_file": None, + "npy_bytes": 0, + "packing": { + "pack_size": pack_size, + "algorithm": algorithm, + "num_sequences": 0, + "num_packed_sequences": 0, + "packing_factor": 0, + "packing_efficiency": 0, + "total_tokens": 0, + }, + "stats": { + "num_sequences": 0, + "num_packed_sequences": 0, + "total_tokens": 0, + **stats, + }, + } + + write_json(output_fs, receipt_path, receipt) + return receipt["stats"] + + +def process_chat_sft_spool_core( + *, + shard_index: int, + files: list[dict] | list[FileInfo], + output_dir: str, + receipts_dir: str, + spool_dir: str | None, + output_fs: AbstractFileSystem, + tokenizer: PreTrainedTokenizerBase, + messages_field: str, + tools_field: str, + pack_size: int, + algorithm: str, + dtype: np.dtype, + chat_template: str | None, + max_doc_tokens: int | None, + max_rows: int | None, + seed: int | None, + used_in_filter: str | None, + used_in_field: str, +) -> dict[str, Any]: + """Tokenize+mask a ChatSFT shard into a SequenceSpool intermediate. + + Retry safety: + - The spool is considered committed when manifest.json exists. + - SequenceSpoolWriter writes data to *.tmp then renames + writes manifest last. + """ + shard_id = f"shard_{shard_index:06d}" + spool_root = spool_dir or f"{output_dir.rstrip('/')}/spool/{shard_id}" + paths = SequenceSpoolPaths.for_root(spool_root) + + # If the spool manifest exists, treat it as completed. + if output_fs.exists(paths.manifest_path): + try: + manifest = read_json(output_fs, paths.manifest_path) + tokenization_stats = manifest.get("tokenization_stats", {}) + return tokenization_stats if isinstance(tokenization_stats, dict) else {} + except Exception: + # Fall through to regenerate spool if manifest is unreadable. + pass + + ensure_dir(output_fs, output_dir) + ensure_dir(output_fs, receipts_dir) + ensure_dir(output_fs, spool_root) + + file_infos = [FileInfo(**f) if isinstance(f, dict) else f for f in files] + input_file_paths = [f.path for f in file_infos] + + if chat_template: + _apply_chat_template(tokenizer, chat_template) + + stats: dict[str, Any] = { + "num_input_rows": 0, + "num_output_sequences": 0, + "num_filtered": 0, + "num_validation_errors": 0, + "num_truncated": 0, # truncation due to max_doc_tokens + "num_errors": 0, + } + + writer = SequenceSpoolWriter(fs=output_fs, paths=paths) + + rows_processed = 0 + + def _process_record_to_spool(record: dict) -> None: + if used_in_filter: + used_in = record.get(used_in_field) + if not _matches_used_in_filter(used_in, used_in_filter): + stats["num_filtered"] += 1 + return + + messages = record.get(messages_field) + tools = record.get(tools_field) + + if not messages: + stats["num_filtered"] += 1 + return + + is_valid, _ = validate_conversation(messages, tools) + if not is_valid: + stats["num_filtered"] += 1 + stats["num_validation_errors"] += 1 + return + + try: + messages_local = replace_json_args(messages) + except (json.JSONDecodeError, KeyError, TypeError): + stats["num_filtered"] += 1 + stats["num_errors"] += 1 + return + + try: + masked_results = create_masked_messages(messages_local, tokenizer, tools) + except Exception: + stats["num_filtered"] += 1 + stats["num_errors"] += 1 + return + + for chunks, _ in masked_results: + processed_chunks = split_system_user_chunks(chunks) + try: + input_ids, loss_mask = _tokenize_chunks_with_mask(tokenizer, processed_chunks) + except Exception: + stats["num_errors"] += 1 + continue + + if not input_ids: + continue + + if max_doc_tokens and len(input_ids) > max_doc_tokens: + input_ids = input_ids[:max_doc_tokens] + loss_mask = loss_mask[:max_doc_tokens] + stats["num_truncated"] += 1 + + writer.append(input_ids, loss_mask) + stats["num_output_sequences"] += 1 + + for file_info in file_infos: + if max_rows and rows_processed >= max_rows: + break + + local_path = _resolve_file_path(file_info) + input_path = ( + local_path + if file_info.hf_repo_id is not None + else (file_info.local_path or file_info.path) + ) + input_fs, normalized = get_filesystem(input_path) + + # Use original filename for format detection (hf_hub_download returns blob path without extension) + format_check_path = (file_info.hf_filename or normalized) if file_info.hf_repo_id else normalized + is_parquet = format_check_path.endswith(".parquet") or not ( + format_check_path.endswith(".jsonl") or format_check_path.endswith(".json") + ) + + record_iter = _iter_parquet_records(normalized, input_fs) if is_parquet else _iter_jsonl_records(normalized, input_fs) + + for record in record_iter: + if max_rows and rows_processed >= max_rows: + break + stats["num_input_rows"] += 1 + rows_processed += 1 + _process_record_to_spool(record) + + writer.finalize( + extra_manifest={ + "shard_id": shard_id, + "shard_index": shard_index, + "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "input_files": input_file_paths, + "messages_field": messages_field, + "tools_field": tools_field, + "chat_template": chat_template, + "max_doc_tokens": max_doc_tokens, + "max_rows": max_rows, + "seed": seed, + "used_in_filter": used_in_filter, + "used_in_field": used_in_field, + "pack_size": pack_size, + "algorithm": algorithm, + "dtype": str(dtype), + "tokenization_stats": stats, + } + ) + + return stats + + +def process_chat_sft_pack_from_spool_core( + *, + shard_index: int, + output_dir: str, + receipts_dir: str, + spool_dir: str | None, + output_fs: AbstractFileSystem, + pack_size: int, + algorithm: str, + dtype: np.dtype, + seed: int | None, +) -> dict[str, Any]: + """Two-pass pack from a SequenceSpool and write packed .npy + standard receipt.""" + shard_id = f"shard_{shard_index:06d}" + npy_path = f"{output_dir}/{shard_id}.npy" + npy_tmp = f"{npy_path}.tmp" + receipt_path = f"{receipts_dir}/{shard_id}.json" + + if output_fs.exists(receipt_path): + try: + receipt = read_json(output_fs, receipt_path) + if receipt.get("status") == "completed": + return receipt.get("stats", {}) + except Exception: + pass + + ensure_dir(output_fs, output_dir) + ensure_dir(output_fs, receipts_dir) + + spool_root = spool_dir or f"{output_dir.rstrip('/')}/spool/{shard_id}" + paths = SequenceSpoolPaths.for_root(spool_root) + + if not output_fs.exists(paths.manifest_path): + raise RuntimeError(f"Missing spool manifest for shard {shard_id}: {paths.manifest_path}") + + try: + manifest = read_json(output_fs, paths.manifest_path) + except Exception as e: + raise RuntimeError(f"Failed to read spool manifest for shard {shard_id}: {paths.manifest_path}") from e + + tokenization_stats = manifest.get("tokenization_stats", {}) + if not isinstance(tokenization_stats, dict): + tokenization_stats = {} + + input_files = manifest.get("input_files", []) + if not isinstance(input_files, list): + input_files = [] + + reader = SequenceSpoolReader(fs=output_fs, paths=paths) + + try: + _, lengths = reader.load_offsets_and_lengths() + num_sequences = int(lengths.shape[0]) + + if num_sequences == 0: + return _write_empty_receipt( + shard_id=shard_id, + shard_index=shard_index, + input_files=[str(x) for x in input_files], + stats=tokenization_stats, + receipt_path=receipt_path, + output_fs=output_fs, + pack_size=pack_size, + algorithm=algorithm, + ) + + lengths_clamped = np.minimum(lengths.astype(np.int64), int(pack_size)) + num_truncated_to_pack_size = int((lengths.astype(np.int64) > int(pack_size)).sum()) + + packer = get_packer(algorithm, pack_size, seed=seed) + bins, _ = packer.pack([int(x) for x in lengths_clamped.tolist()]) + + assignment = BinAssignment.from_bins(bins=bins, num_sequences=num_sequences) + + packed_data: list[dict] = [] + for item in materialize_packed_samples( + spool_reader=reader, + assignment=assignment, + pack_size=pack_size, + ): + packed_data.append(item) + + finally: + try: + reader.close() + except Exception: + pass + + if not packed_data: + return _write_empty_receipt( + shard_id=shard_id, + shard_index=shard_index, + input_files=[str(x) for x in input_files], + stats=tokenization_stats, + receipt_path=receipt_path, + output_fs=output_fs, + pack_size=pack_size, + algorithm=algorithm, + ) + + with output_fs.open(npy_tmp, "wb") as f: + np.save(f, packed_data, allow_pickle=True) + + # Explicitly free memory from large data structures before computing stats + # This prevents memory accumulation when processing multiple shards sequentially + num_bins = int(assignment.num_bins) + total_tokens = int(lengths_clamped.sum()) + + del packed_data + del bins + del assignment + del lengths_clamped + + import gc + gc.collect() + + output_fs.rename(npy_tmp, npy_path) + npy_bytes = output_fs.size(npy_path) + + packing_factor = round(num_sequences / num_bins, 2) if num_bins else 0.0 + packing_efficiency = ( + round((total_tokens / (num_bins * pack_size)) * 100, 1) if num_bins else 0.0 + ) + + packing_metadata = { + "pack_size": pack_size, + "algorithm": str(algorithm), + "num_sequences": num_sequences, + "num_packed_sequences": num_bins, + "packing_factor": packing_factor, + "packing_efficiency": packing_efficiency, + "num_truncated": num_truncated_to_pack_size, + "total_tokens": total_tokens, + } + + receipt = { + "shard_id": shard_id, + "shard_index": shard_index, + "status": "completed", + "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "input_files": [str(x) for x in input_files], + "output_file": f"{shard_id}.npy", + "npy_bytes": npy_bytes, + "packing": packing_metadata, + "stats": { + "num_sequences": packing_metadata.get("num_sequences", 0), + "num_packed_sequences": packing_metadata.get("num_packed_sequences", 0), + "total_tokens": packing_metadata.get("total_tokens", 0), + "num_truncated_to_pack_size": num_truncated_to_pack_size, + **tokenization_stats, + }, + } + + write_json(output_fs, receipt_path, receipt) + return receipt["stats"] diff --git a/src/nemotron/data_prep/config.py b/src/nemotron/data_prep/config.py index f1aaa08af..2e70eed0f 100644 --- a/src/nemotron/data_prep/config.py +++ b/src/nemotron/data_prep/config.py @@ -102,6 +102,7 @@ class JsonlOutputConfig: num_shards: int | None = None transform: Transform | None = None compression: Literal["none", "zstd"] = "none" + resolve_hf_placeholders: bool = False def __post_init__(self) -> None: if self.shard_size is not None and self.num_shards is not None: @@ -215,6 +216,12 @@ class RayDataConfig: bubbles and keep actors fed. Note: does not by itself parallelize a single actor; true I/O latency hiding requires either more actors (with fractional num_cpus) or async internal concurrency. + max_concurrent_downloads: Maximum parallel HuggingFace file downloads + during the pre-download phase. Higher values increase throughput + but may overwhelm HF servers or local network. Default: 64. + cleanup_hf_cache: If True, delete the HuggingFace cache directory + after processing completes. Useful for one-off jobs where cache + isn't needed. Default: False. """ enabled: bool = False @@ -222,6 +229,8 @@ class RayDataConfig: max_actors: int | None = None # None = use all available CPUs cpus_per_actor: float = 1.0 max_tasks_in_flight_per_actor: int = 2 + max_concurrent_downloads: int = 64 + cleanup_hf_cache: bool = False @dataclass(frozen=True) @@ -290,6 +299,8 @@ class PipelineConfig: uses Ray Data's ActorPoolStrategy for shard processing instead of manual actors. console_mode: Console output mode ('rich' or 'simple') simple_log_interval_sec: Interval in seconds for simple mode status updates + execution_engine: Execution backend ("ray" or "xenna") + max_concurrent_downloads: Max parallel HF downloads (used by Xenna path) """ output: OutputConfig @@ -302,6 +313,13 @@ class PipelineConfig: ray_data: RayDataConfig | None = None console_mode: str = "simple" simple_log_interval_sec: int = 30 + execution_engine: Literal["ray", "xenna"] = "ray" + max_concurrent_downloads: int = 64 + wandb_log_downloads: bool = False + wandb_download_log_interval_sec: int = 30 + hf_download_timeout_sec: int = 300 + hf_download_max_retries: int = 3 + num_actors: int | None = None # ============================================================================ diff --git a/src/nemotron/data_prep/downloader.py b/src/nemotron/data_prep/downloader.py new file mode 100644 index 000000000..5e15d8059 --- /dev/null +++ b/src/nemotron/data_prep/downloader.py @@ -0,0 +1,621 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parallel HuggingFace file downloader with overlapped shard processing. + +This module provides parallel downloading of HuggingFace files that overlaps +with shard processing. Instead of download-all-then-process, shards are +yielded for processing as soon as all their required files are downloaded. + +Architecture: + - Downloads are I/O-bound (network) → use lightweight Ray tasks + - Processing is CPU/memory-bound → use Ray Data actors + - Overlap them: start processing shards while downloads continue + +Usage: + from nemotron.data_prep.downloader import OverlappedDownloader + + downloader = OverlappedDownloader(shard_tasks, max_concurrent_downloads=64) + + # Yields shards ready for processing while downloads continue + for ready_tasks in downloader.iter_ready_batches(): + process_shards(ready_tasks) +""" + +from __future__ import annotations + +import json +import logging +import sys +import time +from collections import defaultdict +from collections.abc import Callable, Iterator +from dataclasses import dataclass, field +from typing import Any + +import ray + +logger = logging.getLogger(__name__) + + +@dataclass +class DownloadStats: + """Statistics from parallel download phase.""" + + total_files: int + downloaded_files: int + cached_files: int + failed_files: int + elapsed_sec: float + + @property + def success_rate(self) -> float: + """Percentage of successful downloads.""" + if self.total_files == 0: + return 100.0 + return (self.downloaded_files + self.cached_files) / self.total_files * 100 + + +@ray.remote(num_cpus=0.1) # Light on CPU, mostly I/O bound +def _download_hf_file( + repo_id: str, + filename: str, + revision: str | None, + cache_dir: str | None = None, +) -> dict[str, Any]: + """Download a single HuggingFace file. + + Returns dict with status and path information. + + Args: + repo_id: HuggingFace repository ID (e.g., "nvidia/Nemotron-CC-v2.1") + filename: Path within the repository (e.g., "data/part_00001.parquet") + revision: Git revision/commit SHA for determinism + cache_dir: HuggingFace cache directory (defaults to HF_HOME/hub) + + Returns: + Dict with keys: status ("downloaded", "cached", "failed"), path, error + """ + try: + from huggingface_hub import hf_hub_download + + local_path = hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + repo_type="dataset", + local_files_only=False, + cache_dir=cache_dir, + ) + + # Return file key for tracking + return { + "status": "downloaded", + "repo_id": repo_id, + "filename": filename, + "revision": revision, + "local_path": local_path, + "error": None, + } + except Exception as e: + logger.warning(f"Failed to download {repo_id}/{filename}: {e}") + return { + "status": "failed", + "repo_id": repo_id, + "filename": filename, + "revision": revision, + "local_path": None, + "error": str(e), + } + + +def _make_file_key(repo_id: str, filename: str, revision: str | None) -> str: + """Create a unique key for a HuggingFace file.""" + return f"{repo_id}:{filename}:{revision or ''}" + + +@dataclass +class OverlappedDownloader: + """Downloads files and yields shards ready for processing. + + This class manages parallel file downloads while tracking which shards + are ready for processing. Shards are yielded as soon as all their + required files have been downloaded. + + Attributes: + shard_tasks: List of ShardTask objects to process + max_concurrent_downloads: Maximum parallel downloads (rate limiting) + on_progress: Optional callback for progress updates + cache_dir: HuggingFace cache directory (defaults to HF_HOME/hub) + """ + + shard_tasks: list + max_concurrent_downloads: int = 64 + on_progress: Callable[[dict[str, Any]], None] | None = None + cache_dir: str | None = None + + # Internal state (initialized in __post_init__) + _file_to_shards: dict[str, set[int]] = field(default_factory=dict, init=False) + _shard_to_files: dict[int, set[str]] = field(default_factory=dict, init=False) + _shard_pending_files: dict[int, set[str]] = field(default_factory=dict, init=False) + _downloaded_files: set[str] = field(default_factory=set, init=False) + _failed_files: set[str] = field(default_factory=set, init=False) + _ready_shards: list[int] = field(default_factory=list, init=False) + _all_unique_files: list[dict[str, str]] = field(default_factory=list, init=False) + _shard_index_to_task: dict[int, Any] = field(default_factory=dict, init=False) + + def __post_init__(self) -> None: + """Build file→shard dependency mappings.""" + import os + + # Compute cache_dir from HF_HOME if not specified + if self.cache_dir is None: + hf_home = os.environ.get("HF_HOME") + if hf_home: + # Use object.__setattr__ because dataclass is frozen-like + object.__setattr__(self, "cache_dir", os.path.join(hf_home, "hub")) + + self._file_to_shards = defaultdict(set) + self._shard_to_files = defaultdict(set) + self._shard_pending_files = {} + self._downloaded_files = set() + self._failed_files = set() + self._ready_shards = [] + self._shard_index_to_task = {} + + seen_files = set() + unique_files = [] + + for idx, task in enumerate(self.shard_tasks): + self._shard_index_to_task[idx] = task + assignment = json.loads(task.assignment_json) + files = assignment.get("files", []) + + shard_file_keys = set() + for file_info in files: + repo_id = file_info.get("hf_repo_id") + if repo_id is None: + continue # Skip non-HF files + + filename = file_info.get("hf_filename") + revision = file_info.get("hf_revision") + file_key = _make_file_key(repo_id, filename, revision) + + self._file_to_shards[file_key].add(idx) + shard_file_keys.add(file_key) + + if file_key not in seen_files: + seen_files.add(file_key) + unique_files.append({ + "repo_id": repo_id, + "filename": filename, + "revision": revision, + }) + + self._shard_to_files[idx] = shard_file_keys + self._shard_pending_files[idx] = shard_file_keys.copy() + + # If shard has no HF files, it's ready immediately + if not shard_file_keys: + self._ready_shards.append(idx) + + self._all_unique_files = unique_files + + @property + def total_files(self) -> int: + """Total unique files to download.""" + return len(self._all_unique_files) + + @property + def total_shards(self) -> int: + """Total shards to process.""" + return len(self.shard_tasks) + + def _mark_file_downloaded(self, file_key: str) -> list[int]: + """Mark a file as downloaded and return newly-ready shard indices.""" + if file_key in self._downloaded_files: + return [] + + self._downloaded_files.add(file_key) + newly_ready = [] + + # Update all shards that needed this file + for shard_idx in self._file_to_shards.get(file_key, []): + pending = self._shard_pending_files.get(shard_idx) + if pending is not None: + pending.discard(file_key) + if not pending: + # All files for this shard are downloaded! + newly_ready.append(shard_idx) + del self._shard_pending_files[shard_idx] + + return newly_ready + + def _mark_file_failed(self, file_key: str) -> None: + """Mark a file as failed.""" + self._failed_files.add(file_key) + # Note: shards with failed files will never become ready + # This is intentional - they'll be handled during processing + + def iter_ready_batches( + self, + min_batch_size: int = 1, + max_wait_sec: float = 1.0, + ) -> Iterator[list]: + """Yield batches of shard tasks ready for processing. + + Downloads files in parallel and yields shard tasks as soon as + all their required files are downloaded. Processing can start + while downloads continue. + + Args: + min_batch_size: Minimum shards to accumulate before yielding + max_wait_sec: Maximum time to wait for batch to fill + + Yields: + Lists of ShardTask objects ready for processing + """ + start_time = time.perf_counter() + total_files = self.total_files + + if total_files == 0: + print("[Download+Process] No HuggingFace files - all shards ready immediately") + sys.stdout.flush() + # All shards are ready immediately + yield self.shard_tasks + return + + print(f"[Download+Process] Starting overlapped download of {total_files} files " + f"for {self.total_shards} shards (max_concurrent={self.max_concurrent_downloads})") + sys.stdout.flush() + + # Track statistics + downloaded = 0 + failed = 0 + shards_yielded = 0 + + # Submit initial batch of downloads + futures: list = [] + future_to_file: dict = {} + pending_files = list(self._all_unique_files) + + while pending_files and len(futures) < self.max_concurrent_downloads: + file_info = pending_files.pop(0) + future = _download_hf_file.remote( + file_info["repo_id"], + file_info["filename"], + file_info["revision"], + self.cache_dir, + ) + futures.append(future) + file_key = _make_file_key( + file_info["repo_id"], + file_info["filename"], + file_info["revision"], + ) + future_to_file[future] = file_key + + # Yield any shards that were ready immediately (no HF files) + if self._ready_shards: + ready_tasks = [self._shard_index_to_task[i] for i in self._ready_shards] + shards_yielded += len(ready_tasks) + self._ready_shards.clear() + yield ready_tasks + + # Process downloads and yield ready shards + last_progress_time = start_time + batch_start_time = time.perf_counter() + ready_batch: list[int] = [] + + while futures or ready_batch: + # Wait for downloads (non-blocking if we have ready shards) + if futures: + timeout = 0.1 if ready_batch else max_wait_sec + done, futures = ray.wait(futures, num_returns=1, timeout=timeout) + + for future in done: + file_key = future_to_file.pop(future, None) + try: + result = ray.get(future) + if result["status"] in ("downloaded", "cached"): + downloaded += 1 + if file_key: + newly_ready = self._mark_file_downloaded(file_key) + ready_batch.extend(newly_ready) + else: + failed += 1 + if file_key: + self._mark_file_failed(file_key) + except Exception as e: + failed += 1 + logger.warning(f"Download task failed: {e}") + if file_key: + self._mark_file_failed(file_key) + + # Submit more downloads + while pending_files and len(futures) < self.max_concurrent_downloads: + file_info = pending_files.pop(0) + future = _download_hf_file.remote( + file_info["repo_id"], + file_info["filename"], + file_info["revision"], + self.cache_dir, + ) + futures.append(future) + file_key = _make_file_key( + file_info["repo_id"], + file_info["filename"], + file_info["revision"], + ) + future_to_file[future] = file_key + + # Yield ready batch if we have enough or waited long enough + batch_wait = time.perf_counter() - batch_start_time + if ready_batch and (len(ready_batch) >= min_batch_size or batch_wait >= max_wait_sec or not futures): + ready_tasks = [self._shard_index_to_task[i] for i in ready_batch] + shards_yielded += len(ready_tasks) + ready_batch.clear() + batch_start_time = time.perf_counter() + + # Progress update before yielding + current_time = time.perf_counter() + elapsed = current_time - start_time + rate = downloaded / max(elapsed, 0.001) + print(f"[Download+Process] {downloaded}/{total_files} files, " + f"{shards_yielded}/{self.total_shards} shards ready ({rate:.1f} files/s)") + sys.stdout.flush() + + yield ready_tasks + + # Progress reporting (every 5 seconds) + current_time = time.perf_counter() + if current_time - last_progress_time >= 5.0: + last_progress_time = current_time + elapsed = current_time - start_time + rate = downloaded / max(elapsed, 0.001) + print(f"[Download+Process] {downloaded}/{total_files} files, " + f"{shards_yielded}/{self.total_shards} shards ready ({rate:.1f} files/s)") + sys.stdout.flush() + + if self.on_progress: + self.on_progress({ + "phase": "downloading", + "detail": f"{downloaded}/{total_files} files, {shards_yielded} shards ready", + "elapsed_sec": elapsed, + "files_completed": downloaded, + "files_total": total_files, + "shards_ready": shards_yielded, + }) + + elapsed = time.perf_counter() - start_time + print(f"[Download+Process] Complete: {downloaded}/{total_files} files, " + f"{shards_yielded}/{self.total_shards} shards in {elapsed:.1f}s " + f"({failed} failed)") + sys.stdout.flush() + + def get_stats(self) -> DownloadStats: + """Get current download statistics.""" + return DownloadStats( + total_files=self.total_files, + downloaded_files=len(self._downloaded_files), + cached_files=0, # Can't distinguish from downloaded + failed_files=len(self._failed_files), + elapsed_sec=0.0, # Updated during iteration + ) + + +def _collect_unique_hf_files(shard_tasks: list) -> list[dict[str, str]]: + """Extract unique HuggingFace files from shard tasks. + + Deduplicates files across all shard tasks since the same file + may appear in multiple shards. + + Args: + shard_tasks: List of ShardTask objects with assignment_json + + Returns: + List of unique file dicts with repo_id, filename, revision + """ + seen = set() + unique_files = [] + + for task in shard_tasks: + # Parse the assignment JSON + assignment = json.loads(task.assignment_json) + files = assignment.get("files", []) + + for file_info in files: + # Only include HuggingFace files (have hf_repo_id) + repo_id = file_info.get("hf_repo_id") + if repo_id is None: + continue + + filename = file_info.get("hf_filename") + revision = file_info.get("hf_revision") + + # Create unique key for deduplication + key = (repo_id, filename, revision or "") + if key in seen: + continue + seen.add(key) + + unique_files.append({ + "repo_id": repo_id, + "filename": filename, + "revision": revision, + }) + + return unique_files + + +def parallel_predownload( + shard_tasks: list, + *, + max_concurrent: int = 64, + on_progress: Callable[[dict[str, Any]], None] | None = None, + cache_dir: str | None = None, +) -> DownloadStats: + """Pre-download all HuggingFace files in parallel before shard processing. + + This function collects all unique HuggingFace files from the shard tasks + and downloads them in parallel using Ray tasks. The files are stored in + the HuggingFace cache (controlled by cache_dir or HF_HOME), so subsequent + access during shard processing will be instant cache hits. + + Rate limiting is implemented to avoid overwhelming the HuggingFace servers + or the local network. The max_concurrent parameter controls how many + downloads can be in flight simultaneously. + + Args: + shard_tasks: List of ShardTask objects with assignment_json + max_concurrent: Maximum concurrent downloads (rate limiting) + on_progress: Optional callback for progress updates + cache_dir: HuggingFace cache directory (defaults to HF_HOME/hub or ~/.cache/huggingface/hub) + + Returns: + DownloadStats with counts and timing information + """ + import os + start_time = time.perf_counter() + + # Use HF_HOME if cache_dir not specified + if cache_dir is None: + hf_home = os.environ.get("HF_HOME") + if hf_home: + cache_dir = os.path.join(hf_home, "hub") + + # Collect unique files + unique_files = _collect_unique_hf_files(shard_tasks) + total_files = len(unique_files) + + if total_files == 0: + print("[Pre-download] No HuggingFace files found - skipping download phase") + sys.stdout.flush() + return DownloadStats( + total_files=0, + downloaded_files=0, + cached_files=0, + failed_files=0, + elapsed_sec=0.0, + ) + + print(f"[Pre-download] Starting download of {total_files} unique files (max_concurrent={max_concurrent})") + sys.stdout.flush() + + if on_progress: + on_progress({ + "phase": "downloading", + "detail": f"0/{total_files} files", + "elapsed_sec": 0.0, + "files_completed": 0, + "files_total": total_files, + }) + + # Track statistics + downloaded = 0 + cached = 0 + failed = 0 + completed = 0 + + # Submit downloads with rate limiting + futures = [] + pending_files = list(unique_files) # Copy to avoid modifying original + + # Initial batch + while pending_files and len(futures) < max_concurrent: + file_info = pending_files.pop(0) + future = _download_hf_file.remote( + file_info["repo_id"], + file_info["filename"], + file_info["revision"], + cache_dir, + ) + futures.append(future) + + # Process completions and submit new downloads + last_progress_time = start_time + while futures: + # Wait for at least one to complete + done, futures = ray.wait(futures, num_returns=1, timeout=1.0) + + for future in done: + try: + result = ray.get(future) + completed += 1 + + if result["status"] == "downloaded": + downloaded += 1 + elif result["status"] == "cached": + cached += 1 + else: + failed += 1 + if result["error"]: + logger.warning(f"Download failed: {result['repo_id']}/{result['filename']}: {result['error']}") + except Exception as e: + completed += 1 + failed += 1 + logger.warning(f"Download task failed: {e}") + + # Submit more downloads if we have capacity + while pending_files and len(futures) < max_concurrent: + file_info = pending_files.pop(0) + future = _download_hf_file.remote( + file_info["repo_id"], + file_info["filename"], + file_info["revision"], + cache_dir, + ) + futures.append(future) + + # Progress reporting (every 5 seconds or on completion) + current_time = time.perf_counter() + if current_time - last_progress_time >= 5.0 or completed == total_files: + last_progress_time = current_time + elapsed = current_time - start_time + rate = completed / max(elapsed, 0.001) + print(f"[Pre-download] {completed}/{total_files} files ({rate:.1f}/s, {elapsed:.0f}s elapsed)") + sys.stdout.flush() + if on_progress: + on_progress({ + "phase": "downloading", + "detail": f"{completed}/{total_files} files ({rate:.1f}/s)", + "elapsed_sec": elapsed, + "files_completed": completed, + "files_total": total_files, + }) + + elapsed = time.perf_counter() - start_time + + stats = DownloadStats( + total_files=total_files, + downloaded_files=downloaded, + cached_files=cached, + failed_files=failed, + elapsed_sec=elapsed, + ) + + print( + f"[Pre-download] Complete: {downloaded + cached}/{total_files} files in {elapsed:.1f}s " + f"({(downloaded + cached) / max(elapsed, 0.001):.1f} files/s, {failed} failed)" + ) + sys.stdout.flush() + + return stats + + +__all__ = [ + "parallel_predownload", + "OverlappedDownloader", + "DownloadStats", +] diff --git a/src/nemotron/data_prep/jsonl_processor.py b/src/nemotron/data_prep/jsonl_processor.py index 54c654d4a..fc265a31a 100644 --- a/src/nemotron/data_prep/jsonl_processor.py +++ b/src/nemotron/data_prep/jsonl_processor.py @@ -14,29 +14,21 @@ """JsonlShardProcessor Ray actor for parallel JSONL output processing.""" -import json import logging -import time -from collections.abc import Callable, Iterator +from collections.abc import Callable -import pyarrow.parquet as pq import ray from fsspec import filesystem from nemotron.data_prep.config import FileInfo -from nemotron.data_prep.filesystem import ensure_dir, write_json -from nemotron.data_prep.formats.jsonl_dataset import JsonlDatasetBuilder +from nemotron.data_prep.jsonl_shard_core import process_jsonl_shard_core logger = logging.getLogger(__name__) @ray.remote class JsonlShardProcessor: - """Ray actor for processing data files to JSONL output. - - Reads input files (parquet or jsonl), applies optional transform, - and writes to JSONL output (optionally compressed). - """ + """Ray actor for processing data files to JSONL output.""" def __init__( self, @@ -45,14 +37,6 @@ def __init__( compression: str = "none", max_rows: int | None = None, ): - """Initialize JSONL processor. - - Args: - text_field: Field name for text in input records. - transform: Optional callable to transform records. - compression: Output compression ("none" or "zstd"). - max_rows: Maximum rows to process per shard. - """ self.text_field = text_field self.transform = transform self.compression = compression @@ -64,174 +48,20 @@ def process_shard( files: list[dict], # FileInfo as dicts for Ray serialization output_dir: str, fs_protocol: str, + receipts_dir: str | None = None, ) -> dict: - """Process files to a single JSONL shard. - - Args: - shard_index: Index of this shard. - files: List of FileInfo dicts to process. - output_dir: Output directory for JSONL files. - fs_protocol: Filesystem protocol (e.g., "file", "s3"). - - Returns: - Shard statistics dict. - """ + """Process files to a single JSONL shard.""" fs = filesystem(fs_protocol) - - shard_id = f"shard_{shard_index:06d}" - ext = ".jsonl.zst" if self.compression == "zstd" else ".jsonl" - jsonl_path = f"{output_dir}/{shard_id}{ext}" - receipt_path = f"{output_dir}/{shard_id}.receipt.json" - - # Ensure directory exists - ensure_dir(fs, output_dir) - - # Convert file dicts back to FileInfo - file_infos = [FileInfo(**f) for f in files] - - # Handle empty assignment - if not file_infos: - return self._write_empty_receipt(shard_id, shard_index, receipt_path, fs) - - # Process files and write JSONL - with fs.open(jsonl_path, "wb") as f: - builder = JsonlDatasetBuilder( - file=f, - transform=self.transform, - compression=self.compression, - ) - - rows_processed = 0 - for file_info in file_infos: - if self.max_rows and rows_processed >= self.max_rows: - break - rows_processed = self._process_file(file_info, builder, fs, rows_processed) - - builder.finalize() - total_bytes, checksum = builder.get_info() - stats = builder.get_stats() - - # Write receipt - receipt = { - "shard_id": shard_id, - "shard_index": shard_index, - "status": "completed", - "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), - "input_files": [f.path for f in file_infos], - "output_file": f"{shard_id}{ext}", - "total_bytes": total_bytes, - "checksum": checksum, - **stats, - } - - write_json(fs, receipt_path, receipt) - return stats - - def _process_file( - self, - file_info: FileInfo, - builder: JsonlDatasetBuilder, - fs, - rows_processed: int, - ) -> int: - """Process a single file, writing records to builder. - - Returns the total number of rows processed. - """ - local_path = self._resolve_file_path(file_info) - - # Determine file type - is_parquet = local_path.endswith(".parquet") or not ( - local_path.endswith(".jsonl") or local_path.endswith(".json") + resolved_receipts_dir = receipts_dir or f"{output_dir}/receipts" + + return process_jsonl_shard_core( + shard_index=shard_index, + files=[FileInfo(**f) for f in files], + output_dir=output_dir, + receipts_dir=resolved_receipts_dir, + output_fs=fs, + text_field=self.text_field, + transform=self.transform, + compression=self.compression, + max_rows=self.max_rows, ) - - if is_parquet: - for record in self._iter_parquet_records(local_path, fs): - if self.max_rows and rows_processed >= self.max_rows: - break - builder.add_record(record) - rows_processed += 1 - else: - for record in self._iter_jsonl_records(local_path, fs): - if self.max_rows and rows_processed >= self.max_rows: - break - builder.add_record(record) - rows_processed += 1 - - return rows_processed - - def _resolve_file_path(self, file_info: FileInfo) -> str: - """Resolve file to a local path, downloading from HF if needed.""" - if file_info.hf_repo_id is not None: - from huggingface_hub import hf_hub_download - - local_path = hf_hub_download( - repo_id=file_info.hf_repo_id, - filename=file_info.hf_filename, - revision=file_info.hf_revision, - repo_type="dataset", - local_files_only=False, - ) - return local_path - - return file_info.local_path or file_info.path - - def _iter_parquet_records(self, path: str, fs) -> Iterator[dict]: - """Iterate records from parquet file.""" - if self._is_remote_path(path): - with fs.open(path, "rb") as f: - parquet_file = pq.ParquetFile(f) - yield from self._iter_parquet_batches_as_dicts(parquet_file) - else: - parquet_file = pq.ParquetFile(path) - yield from self._iter_parquet_batches_as_dicts(parquet_file) - - def _iter_parquet_batches_as_dicts(self, parquet_file: pq.ParquetFile) -> Iterator[dict]: - """Iterate records from parquet file as dicts.""" - for batch in parquet_file.iter_batches(batch_size=10000): - # Convert batch to list of dicts - table = batch.to_pydict() - num_rows = len(next(iter(table.values()))) - for i in range(num_rows): - yield {k: v[i] for k, v in table.items()} - - def _iter_jsonl_records(self, path: str, fs) -> Iterator[dict]: - """Iterate records from JSONL file.""" - if self._is_remote_path(path): - with fs.open(path, "r") as f: - for line in f: - if line.strip(): - yield json.loads(line) - else: - with open(path) as f: - for line in f: - if line.strip(): - yield json.loads(line) - - def _is_remote_path(self, path: str) -> bool: - """Check if path is a remote path (S3/GCS/etc).""" - return path.startswith(("s3://", "gs://", "gcs://", "az://", "abfs://")) - - def _write_empty_receipt( - self, - shard_id: str, - shard_index: int, - receipt_path: str, - fs, - ) -> dict: - """Write receipt for empty shard.""" - receipt = { - "shard_id": shard_id, - "shard_index": shard_index, - "status": "completed", - "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), - "input_files": [], - "output_file": None, - "total_bytes": 0, - "checksum": "xxh64:empty", - "num_records": 0, - "num_skipped": 0, - } - - write_json(fs, receipt_path, receipt) - return {"num_records": 0, "num_skipped": 0, "total_bytes": 0} diff --git a/src/nemotron/data_prep/jsonl_shard_core.py b/src/nemotron/data_prep/jsonl_shard_core.py new file mode 100644 index 000000000..990d23965 --- /dev/null +++ b/src/nemotron/data_prep/jsonl_shard_core.py @@ -0,0 +1,224 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core JSONL shard processing (retry-safe, engine-agnostic).""" + +from __future__ import annotations + +import json +import time +from collections.abc import Callable, Iterator +from typing import Any, Literal + +import pyarrow.parquet as pq +from fsspec import AbstractFileSystem + +from nemotron.data_prep.config import FileInfo +from nemotron.data_prep.filesystem import ensure_dir, get_filesystem, read_json, write_json +from nemotron.data_prep.formats.jsonl_dataset import JsonlDatasetBuilder + +Transform = Callable[[dict], dict | None] + + +def process_jsonl_shard_core( + *, + shard_index: int, + files: list[dict] | list[FileInfo], + output_dir: str, + receipts_dir: str, + output_fs: AbstractFileSystem, + text_field: str, + transform: Transform | None, + compression: Literal["none", "zstd"], + max_rows: int | None, +) -> dict[str, Any]: + """Process a JSONL shard with retry-safe atomic commits.""" + shard_id = f"shard_{shard_index:06d}" + ext = ".jsonl.zst" if compression == "zstd" else ".jsonl" + jsonl_path = f"{output_dir}/{shard_id}{ext}" + jsonl_tmp = f"{jsonl_path}.tmp" + receipt_path = f"{receipts_dir}/{shard_id}.json" + + if output_fs.exists(receipt_path): + try: + receipt = read_json(output_fs, receipt_path) + if receipt.get("status") == "completed": + return receipt.get("stats", {}) + except Exception: + pass + + ensure_dir(output_fs, output_dir) + ensure_dir(output_fs, receipts_dir) + + file_infos = [FileInfo(**f) if isinstance(f, dict) else f for f in files] + input_file_paths = [f.path for f in file_infos] + + if not file_infos: + return _write_empty_receipt( + shard_id=shard_id, + shard_index=shard_index, + input_files=input_file_paths, + receipt_path=receipt_path, + output_fs=output_fs, + ) + + rows_processed = 0 + with output_fs.open(jsonl_tmp, "wb") as f: + builder = JsonlDatasetBuilder( + file=f, + transform=transform, + compression=compression, + ) + + for file_info in file_infos: + if max_rows and rows_processed >= max_rows: + break + rows_processed = _process_file( + file_info=file_info, + builder=builder, + rows_processed=rows_processed, + max_rows=max_rows, + ) + + builder.finalize() + total_bytes, checksum = builder.get_info() + stats = builder.get_stats() + + if stats.get("num_records", 0) == 0: + try: + output_fs.rm(jsonl_tmp) + except Exception: + pass + return _write_empty_receipt( + shard_id=shard_id, + shard_index=shard_index, + input_files=input_file_paths, + receipt_path=receipt_path, + output_fs=output_fs, + ) + + output_fs.rename(jsonl_tmp, jsonl_path) + + receipt = { + "shard_id": shard_id, + "shard_index": shard_index, + "status": "completed", + "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "input_files": input_file_paths, + "output_file": f"{shard_id}{ext}", + "total_bytes": total_bytes, + "checksum": checksum, + "stats": { + "total_tokens": 0, + "num_records": stats.get("num_records", 0), + "num_skipped": stats.get("num_skipped", 0), + "total_bytes": stats.get("total_bytes", 0), + }, + } + + write_json(output_fs, receipt_path, receipt) + return receipt["stats"] + + +def _process_file( + *, + file_info: FileInfo, + builder: JsonlDatasetBuilder, + rows_processed: int, + max_rows: int | None, +) -> int: + local_path = _resolve_file_path(file_info) + input_path = local_path if file_info.hf_repo_id is not None else (file_info.local_path or file_info.path) + input_fs, normalized = get_filesystem(input_path) + + # Use original filename for format detection (hf_hub_download returns blob path without extension) + format_check_path = (file_info.hf_filename or normalized) if file_info.hf_repo_id else normalized + is_parquet = format_check_path.endswith(".parquet") or not ( + format_check_path.endswith(".jsonl") or format_check_path.endswith(".json") + ) + + if is_parquet: + record_iter = _iter_parquet_records(normalized, input_fs) + else: + record_iter = _iter_jsonl_records(normalized, input_fs) + + for record in record_iter: + if max_rows and rows_processed >= max_rows: + break + builder.add_record(record) + rows_processed += 1 + + return rows_processed + + +def _resolve_file_path(file_info: FileInfo) -> str: + if file_info.hf_repo_id is not None: + from huggingface_hub import hf_hub_download + + return hf_hub_download( + repo_id=file_info.hf_repo_id, + filename=file_info.hf_filename, + revision=file_info.hf_revision, + repo_type="dataset", + local_files_only=True, # Files should be pre-downloaded by HfPredownloadStage + ) + + return file_info.local_path or file_info.path + + +def _iter_parquet_records(path: str, fs: AbstractFileSystem) -> Iterator[dict]: + with fs.open(path, "rb") as f: + parquet_file = pq.ParquetFile(f) + for batch in parquet_file.iter_batches(batch_size=10000): + table = batch.to_pydict() + num_rows = len(next(iter(table.values()))) if table else 0 + for i in range(num_rows): + yield {k: v[i] for k, v in table.items()} + + +def _iter_jsonl_records(path: str, fs: AbstractFileSystem) -> Iterator[dict]: + with fs.open(path, "r") as f: + for line in f: + if line.strip(): + yield json.loads(line) + + +def _write_empty_receipt( + *, + shard_id: str, + shard_index: int, + input_files: list[str], + receipt_path: str, + output_fs: AbstractFileSystem, +) -> dict[str, Any]: + receipt = { + "shard_id": shard_id, + "shard_index": shard_index, + "status": "completed", + "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "input_files": input_files, + "output_file": None, + "total_bytes": 0, + "checksum": "xxh64:empty", + "stats": { + "total_tokens": 0, + "num_records": 0, + "num_skipped": 0, + "total_bytes": 0, + }, + } + + write_json(output_fs, receipt_path, receipt) + return receipt["stats"] + diff --git a/src/nemotron/data_prep/packed_processor.py b/src/nemotron/data_prep/packed_processor.py index 87a578cbf..4f87f2c16 100644 --- a/src/nemotron/data_prep/packed_processor.py +++ b/src/nemotron/data_prep/packed_processor.py @@ -319,7 +319,7 @@ def _process_jsonl_file( return rows_processed def _resolve_file_path(self, file_info: FileInfo) -> str: - """Resolve file to a local path, downloading from HF if needed.""" + """Resolve file to a local path, using HF cache (no download).""" if file_info.hf_repo_id is not None: from huggingface_hub import hf_hub_download @@ -328,7 +328,7 @@ def _resolve_file_path(self, file_info: FileInfo) -> str: filename=file_info.hf_filename, revision=file_info.hf_revision, repo_type="dataset", - local_files_only=False, + local_files_only=True, # Only use cached files ) return local_path diff --git a/src/nemotron/data_prep/packing/bin_assignment.py b/src/nemotron/data_prep/packing/bin_assignment.py new file mode 100644 index 000000000..c7e32b37e --- /dev/null +++ b/src/nemotron/data_prep/packing/bin_assignment.py @@ -0,0 +1,104 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compact bin assignment representation for packed sequence materialization. + +BinAssignment stores a list-of-lists bin structure (bin_id -> [seq_index...]) +in a CSR-like representation: + +- bin_offsets: int64 array of length (num_bins + 1) +- bin_seq_indices: int32 array of length (total_assigned_sequences) + +For bin i, the sequence indices live in: + bin_seq_indices[bin_offsets[i] : bin_offsets[i+1]] + +This structure is designed to: +- avoid large nested Python lists during the materialization phase +- allow cheap slicing per bin +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence + +import numpy as np + + +@dataclass(frozen=True) +class BinAssignment: + """CSR-like representation of bin assignments.""" + + bin_offsets: np.ndarray + bin_seq_indices: np.ndarray + num_bins: int + num_sequences: int + + @classmethod + def from_bins(cls, *, bins: Sequence[Sequence[int]], num_sequences: int) -> "BinAssignment": + """Build a BinAssignment from a Python bins structure. + + Args: + bins: Sequence of bins, each bin is a sequence of seq indices. + num_sequences: Total number of sequences in the shard/spool (for metadata). + + Returns: + BinAssignment with int64 offsets and int32 indices. + + Raises: + ValueError: If indices are out of range. + """ + num_bins = int(len(bins)) + if num_bins == 0: + return cls( + bin_offsets=np.zeros((1,), dtype=np.int64), + bin_seq_indices=np.zeros((0,), dtype=np.int32), + num_bins=0, + num_sequences=int(num_sequences), + ) + + total_entries = 0 + for b in bins: + total_entries += len(b) + + offsets = np.zeros((num_bins + 1,), dtype=np.int64) + indices = np.zeros((total_entries,), dtype=np.int32) + + cursor = 0 + for i, b in enumerate(bins): + offsets[i] = cursor + for idx in b: + if idx < 0 or idx >= num_sequences: + raise ValueError(f"Sequence index out of range in bins: {idx} (num_sequences={num_sequences})") + indices[cursor] = np.int32(idx) + cursor += 1 + offsets[num_bins] = cursor + + return cls( + bin_offsets=offsets, + bin_seq_indices=indices, + num_bins=num_bins, + num_sequences=int(num_sequences), + ) + + def bin_indices(self, bin_id: int) -> np.ndarray: + """Return the seq indices for a given bin as a view.""" + if bin_id < 0 or bin_id >= self.num_bins: + raise IndexError(f"bin_id out of range: {bin_id}") + start = int(self.bin_offsets[bin_id]) + end = int(self.bin_offsets[bin_id + 1]) + return self.bin_seq_indices[start:end] + + +__all__ = ["BinAssignment"] \ No newline at end of file diff --git a/src/nemotron/data_prep/packing/materialize.py b/src/nemotron/data_prep/packing/materialize.py new file mode 100644 index 000000000..35dc36fdb --- /dev/null +++ b/src/nemotron/data_prep/packing/materialize.py @@ -0,0 +1,98 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Materialize packed samples from a SequenceSpool + BinAssignment. + +This module performs the "reduce" step: +- given a bin assignment (bin_id -> sequence indices) +- and a spool providing random-access sequences (tokens + masks) +it produces packed dict items compatible with the existing .npy pickle-of-dicts +format used by GPTSFTPackedDataset. + +Truncation semantics match PackedSequenceBuilder._build_packed_sequence: +- If a sequence is longer than pack_size, it is truncated to pack_size. +- If adding a sequence would exceed pack_size, it is truncated to the remaining space. +- The loss_mask is rolled by 1: [0] + mask[:-1] +""" + +from __future__ import annotations + +from collections.abc import Iterator + +import numpy as np + +from nemotron.data_prep.packing.bin_assignment import BinAssignment +from nemotron.data_prep.packing.spool import SequenceSpoolReader + + +def materialize_packed_samples( + *, + spool_reader: SequenceSpoolReader, + assignment: BinAssignment, + pack_size: int, +) -> Iterator[dict]: + """Yield packed items one bin at a time. + + Args: + spool_reader: Reader for the SequenceSpool (random-access sequences). + assignment: CSR-like bin assignment. + pack_size: Maximum tokens per packed sample. + + Yields: + Dicts with keys: input_ids, loss_mask, seq_start_id + """ + if pack_size <= 0: + raise ValueError(f"pack_size must be positive, got {pack_size}") + + for bin_id in range(assignment.num_bins): + seq_indices = assignment.bin_indices(bin_id) + + all_input_ids: list[int] = [] + all_loss_mask: list[int] = [] + seq_start_ids: list[int] = [0] + + for seq_index in seq_indices: + input_ids_arr, loss_mask_arr = spool_reader.read_sequence(int(seq_index)) + + # Truncate if needed (builder truncates per-seq to pack_size). + if input_ids_arr.shape[0] > pack_size: + input_ids_arr = input_ids_arr[:pack_size] + loss_mask_arr = loss_mask_arr[:pack_size] + + current_len = len(all_input_ids) + if current_len >= pack_size: + break + + if current_len + int(input_ids_arr.shape[0]) > pack_size: + remaining = pack_size - current_len + input_ids_arr = input_ids_arr[:remaining] + loss_mask_arr = loss_mask_arr[:remaining] + + if input_ids_arr.shape[0] == 0: + continue + + all_input_ids.extend([int(x) for x in input_ids_arr.tolist()]) + all_loss_mask.extend([int(x) for x in loss_mask_arr.tolist()]) + seq_start_ids.append(len(all_input_ids)) + + rolled_loss_mask = [0] + all_loss_mask[:-1] if all_loss_mask else [] + + yield { + "input_ids": all_input_ids, + "loss_mask": rolled_loss_mask, + "seq_start_id": seq_start_ids[:-1], + } + + +__all__ = ["materialize_packed_samples"] \ No newline at end of file diff --git a/src/nemotron/data_prep/packing/spool.py b/src/nemotron/data_prep/packing/spool.py new file mode 100644 index 000000000..3dd088116 --- /dev/null +++ b/src/nemotron/data_prep/packing/spool.py @@ -0,0 +1,385 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SequenceSpool intermediate format for low-memory ChatSFT processing. + +This module defines an append-only intermediate representation for tokenized +sequences (input_ids + loss_mask) that is efficient to write in streaming +fashion and efficient to read for a later "central pack" finalizer step. + +Spool layout (per shard): +- tokens.bin : flat int32 token ids (concatenated) +- masks.bin : flat uint8 loss masks (concatenated) +- offsets.bin : uint64 token offsets (start index per sequence) +- lengths.bin : uint32 sequence lengths (tokens per sequence) +- manifest.json: metadata and validation info written on finalize + +Notes: +- This format is designed to minimize Python object overhead (no list-of-lists). +- Random-access reads require a seekable file object (most local/Lustre paths). +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any + +import numpy as np +from fsspec import AbstractFileSystem + +from nemotron.data_prep.filesystem import ensure_dir, read_json + + +def _join_path(root_dir: str, filename: str) -> str: + root = root_dir.rstrip("/") + return f"{root}/{filename}" + + +def _rename(fs: AbstractFileSystem, src: str, dst: str) -> None: + try: + fs.rename(src, dst) + return + except Exception: + pass + try: + fs.mv(src, dst) + return + except Exception as e: + raise RuntimeError(f"Failed to rename/move '{src}' -> '{dst}'") from e + + +def _rm_if_exists(fs: AbstractFileSystem, path: str) -> None: + try: + if fs.exists(path): + try: + fs.rm(path) + except Exception: + fs.delete(path) + except Exception: + # Best-effort cleanup only + pass + + +def _write_json_atomic(fs: AbstractFileSystem, path: str, payload: dict[str, Any]) -> None: + tmp_path = f"{path}.tmp" + parent = str(path).rsplit("/", 1)[0] if "/" in str(path) else "" + if parent: + ensure_dir(fs, parent) + + with fs.open(tmp_path, "w") as f: + json.dump(payload, f, indent=2, sort_keys=True) + _rename(fs, tmp_path, path) + + +@dataclass(frozen=True) +class SequenceSpoolPaths: + """Concrete paths for a SequenceSpool instance.""" + + root_dir: str + tokens_path: str + masks_path: str + offsets_path: str + lengths_path: str + manifest_path: str + + @classmethod + def for_root(cls, root_dir: str) -> "SequenceSpoolPaths": + root = root_dir.rstrip("/") + return cls( + root_dir=root, + tokens_path=_join_path(root, "tokens.bin"), + masks_path=_join_path(root, "masks.bin"), + offsets_path=_join_path(root, "offsets.bin"), + lengths_path=_join_path(root, "lengths.bin"), + manifest_path=_join_path(root, "manifest.json"), + ) + + def with_suffix(self, suffix: str) -> "SequenceSpoolPaths": + return SequenceSpoolPaths( + root_dir=self.root_dir, + tokens_path=f"{self.tokens_path}{suffix}", + masks_path=f"{self.masks_path}{suffix}", + offsets_path=f"{self.offsets_path}{suffix}", + lengths_path=f"{self.lengths_path}{suffix}", + manifest_path=f"{self.manifest_path}{suffix}", + ) + + def tmp(self) -> "SequenceSpoolPaths": + return self.with_suffix(".tmp") + + +class SequenceSpoolWriter: + """Append-only writer for SequenceSpool. + + Usage: + paths = SequenceSpoolPaths.for_root("/path/to/spool/shard_000000") + w = SequenceSpoolWriter(fs=output_fs, paths=paths) + w.append(input_ids, loss_mask) + ... + manifest = w.finalize(extra_manifest={"pack_size": 4096}) + """ + + def __init__( + self, + *, + fs: AbstractFileSystem, + paths: SequenceSpoolPaths, + tokens_dtype: np.dtype = np.dtype("int32"), + masks_dtype: np.dtype = np.dtype("uint8"), + overwrite_tmp: bool = True, + ) -> None: + self._fs = fs + self._final_paths = paths + self._tmp_paths = paths.tmp() + self._tokens_dtype = np.dtype(tokens_dtype) + self._masks_dtype = np.dtype(masks_dtype) + + if self._tokens_dtype != np.dtype("int32"): + raise ValueError(f"tokens_dtype must be int32, got {self._tokens_dtype}") + if self._masks_dtype != np.dtype("uint8"): + raise ValueError(f"masks_dtype must be uint8, got {self._masks_dtype}") + + ensure_dir(self._fs, self._final_paths.root_dir) + + if overwrite_tmp: + _rm_if_exists(self._fs, self._tmp_paths.tokens_path) + _rm_if_exists(self._fs, self._tmp_paths.masks_path) + _rm_if_exists(self._fs, self._tmp_paths.offsets_path) + _rm_if_exists(self._fs, self._tmp_paths.lengths_path) + _rm_if_exists(self._fs, self._tmp_paths.manifest_path) + + # Open append handles once (faster than per-append open/close). + self._tokens_f = self._fs.open(self._tmp_paths.tokens_path, "ab") + self._masks_f = self._fs.open(self._tmp_paths.masks_path, "ab") + self._offsets_f = self._fs.open(self._tmp_paths.offsets_path, "ab") + self._lengths_f = self._fs.open(self._tmp_paths.lengths_path, "ab") + + self._num_sequences = 0 + self._total_tokens = 0 + self._closed = False + + @property + def num_sequences(self) -> int: + return self._num_sequences + + @property + def total_tokens(self) -> int: + return self._total_tokens + + def append(self, input_ids: np.ndarray | list[int], loss_mask: np.ndarray | list[int] | None) -> None: + if self._closed: + raise RuntimeError("SequenceSpoolWriter is closed") + + input_arr = np.asarray(input_ids, dtype=self._tokens_dtype) + if input_arr.ndim != 1: + raise ValueError(f"input_ids must be 1D, got shape={input_arr.shape}") + + if input_arr.size == 0: + return + + if loss_mask is None: + mask_arr = np.ones((input_arr.size,), dtype=self._masks_dtype) + else: + mask_arr = np.asarray(loss_mask, dtype=self._masks_dtype) + + if mask_arr.ndim != 1: + raise ValueError(f"loss_mask must be 1D, got shape={mask_arr.shape}") + if mask_arr.size != input_arr.size: + raise ValueError( + f"loss_mask length mismatch: loss_mask={mask_arr.size}, input_ids={input_arr.size}" + ) + + offset = np.array([self._total_tokens], dtype=np.uint64) + length = np.array([input_arr.size], dtype=np.uint32) + + # Write offsets/lengths first (small), then bulk token/mask bytes. + self._offsets_f.write(offset.tobytes(order="C")) + self._lengths_f.write(length.tobytes(order="C")) + self._tokens_f.write(input_arr.tobytes(order="C")) + self._masks_f.write(mask_arr.tobytes(order="C")) + + self._num_sequences += 1 + self._total_tokens += int(input_arr.size) + + def finalize(self, *, extra_manifest: dict[str, Any] | None = None) -> dict[str, Any]: + if self._closed: + raise RuntimeError("SequenceSpoolWriter is already finalized/closed") + + self._closed = True + + # Close all handles before rename. + try: + self._tokens_f.close() + finally: + try: + self._masks_f.close() + finally: + try: + self._offsets_f.close() + finally: + self._lengths_f.close() + + # Promote tmp files to final names. + _rename(self._fs, self._tmp_paths.tokens_path, self._final_paths.tokens_path) + _rename(self._fs, self._tmp_paths.masks_path, self._final_paths.masks_path) + _rename(self._fs, self._tmp_paths.offsets_path, self._final_paths.offsets_path) + _rename(self._fs, self._tmp_paths.lengths_path, self._final_paths.lengths_path) + + manifest: dict[str, Any] = { + "version": "spool_v1", + "num_sequences": int(self._num_sequences), + "total_tokens": int(self._total_tokens), + "tokens_dtype": str(np.dtype(self._tokens_dtype)), + "mask_dtype": str(np.dtype(self._masks_dtype)), + "offsets_dtype": "uint64", + "lengths_dtype": "uint32", + } + if extra_manifest: + manifest.update(extra_manifest) + + _write_json_atomic(self._fs, self._final_paths.manifest_path, manifest) + return manifest + + +class SequenceSpoolReader: + """Reader for SequenceSpool. + + This reader supports loading offsets/lengths and reading arbitrary sequences + by index using byte-range seeks into tokens.bin and masks.bin. + """ + + def __init__( + self, + *, + fs: AbstractFileSystem, + paths: SequenceSpoolPaths, + tokens_dtype: np.dtype = np.dtype("int32"), + masks_dtype: np.dtype = np.dtype("uint8"), + ) -> None: + self._fs = fs + self._paths = paths + self._tokens_dtype = np.dtype(tokens_dtype) + self._masks_dtype = np.dtype(masks_dtype) + + self._offsets: np.ndarray | None = None + self._lengths: np.ndarray | None = None + self._tokens_f = None + self._masks_f = None + + def read_manifest(self) -> dict[str, Any] | None: + try: + if not self._fs.exists(self._paths.manifest_path): + return None + return read_json(self._fs, self._paths.manifest_path) + except Exception: + return None + + def load_offsets_and_lengths(self) -> tuple[np.ndarray, np.ndarray]: + if self._offsets is not None and self._lengths is not None: + return self._offsets, self._lengths + + with self._fs.open(self._paths.offsets_path, "rb") as f: + offsets_bytes = f.read() + with self._fs.open(self._paths.lengths_path, "rb") as f: + lengths_bytes = f.read() + + offsets = np.frombuffer(offsets_bytes, dtype=np.uint64) + lengths = np.frombuffer(lengths_bytes, dtype=np.uint32) + + if offsets.shape[0] != lengths.shape[0]: + raise ValueError( + f"Spool offsets/lengths mismatch: offsets={offsets.shape[0]}, lengths={lengths.shape[0]}" + ) + + self._offsets = offsets + self._lengths = lengths + return offsets, lengths + + @property + def num_sequences(self) -> int: + _, lengths = self.load_offsets_and_lengths() + return int(lengths.shape[0]) + + @property + def total_tokens(self) -> int: + _, lengths = self.load_offsets_and_lengths() + return int(lengths.sum()) + + def _ensure_open(self) -> None: + if self._tokens_f is None: + self._tokens_f = self._fs.open(self._paths.tokens_path, "rb") + if self._masks_f is None: + self._masks_f = self._fs.open(self._paths.masks_path, "rb") + + # Validate seek support (required for random access reads). + if not hasattr(self._tokens_f, "seek") or not hasattr(self._masks_f, "seek"): + raise RuntimeError("SequenceSpoolReader requires seekable file objects") + + def close(self) -> None: + if self._tokens_f is not None: + try: + self._tokens_f.close() + finally: + self._tokens_f = None + if self._masks_f is not None: + try: + self._masks_f.close() + finally: + self._masks_f = None + + def read_sequence(self, seq_index: int) -> tuple[np.ndarray, np.ndarray]: + offsets, lengths = self.load_offsets_and_lengths() + + if seq_index < 0 or seq_index >= int(lengths.shape[0]): + raise IndexError(f"seq_index out of range: {seq_index}") + + self._ensure_open() + + offset_tokens = int(offsets[seq_index]) + length_tokens = int(lengths[seq_index]) + + tok_byte_offset = offset_tokens * self._tokens_dtype.itemsize + tok_byte_len = length_tokens * self._tokens_dtype.itemsize + + mask_byte_offset = offset_tokens * self._masks_dtype.itemsize + mask_byte_len = length_tokens * self._masks_dtype.itemsize + + self._tokens_f.seek(tok_byte_offset) + tok_bytes = self._tokens_f.read(tok_byte_len) + + self._masks_f.seek(mask_byte_offset) + mask_bytes = self._masks_f.read(mask_byte_len) + + input_ids = np.frombuffer(tok_bytes, dtype=self._tokens_dtype) + loss_mask = np.frombuffer(mask_bytes, dtype=self._masks_dtype) + + # Defensive validation (helps catch corrupt/incomplete spools early). + if input_ids.shape[0] != length_tokens: + raise RuntimeError( + f"Failed to read tokens for seq_index={seq_index}: got={input_ids.shape[0]}, expected={length_tokens}" + ) + if loss_mask.shape[0] != length_tokens: + raise RuntimeError( + f"Failed to read masks for seq_index={seq_index}: got={loss_mask.shape[0]}, expected={length_tokens}" + ) + + return input_ids, loss_mask + + +__all__ = [ + "SequenceSpoolPaths", + "SequenceSpoolWriter", + "SequenceSpoolReader", +] \ No newline at end of file diff --git a/src/nemotron/data_prep/pipeline.py b/src/nemotron/data_prep/pipeline.py index 0b29e69d9..793f6c0cf 100644 --- a/src/nemotron/data_prep/pipeline.py +++ b/src/nemotron/data_prep/pipeline.py @@ -696,10 +696,17 @@ def _process_split( _process_all_shards_parallel( execution_plans=[ep for ep in execution_plans if ep.pending_indices], output_config=output_config, + output_root=str(config.output.dir), fs=fs, live_status=live_status, results=results, ray_data_config=config.ray_data, + execution_engine=config.execution_engine, + max_concurrent_downloads=config.max_concurrent_downloads, + wandb_log_downloads=config.wandb_log_downloads, + wandb_download_log_interval_sec=config.wandb_download_log_interval_sec, + hf_download_timeout_sec=config.hf_download_timeout_sec, + hf_download_max_retries=config.hf_download_max_retries, ) finally: live_status.stop() @@ -829,10 +836,17 @@ def _load_or_create_plan( def _process_all_shards_parallel( execution_plans: list[_DatasetExecutionPlan], output_config: InternalOutputConfig, + output_root: str, fs, live_status, results: dict, ray_data_config: RayDataConfig | None = None, + execution_engine: str = "ray", + max_concurrent_downloads: int = 64, + wandb_log_downloads: bool = False, + wandb_download_log_interval_sec: int = 30, + hf_download_timeout_sec: int = 300, + hf_download_max_retries: int = 3, ) -> None: """Process ALL pending shards from ALL datasets in parallel. @@ -845,6 +859,25 @@ def _process_all_shards_parallel( if not execution_plans: return + # Dispatch to Xenna executor if requested + if execution_engine == "xenna": + from nemotron.data_prep.xenna.runner import run_xenna_pipeline + + run_xenna_pipeline( + execution_plans=execution_plans, + output_config=output_config, + output_root=output_root, + fs=fs, + live_status=live_status, + results=results, + max_concurrent_downloads=max_concurrent_downloads, + wandb_log_downloads=wandb_log_downloads, + wandb_download_log_interval_sec=wandb_download_log_interval_sec, + hf_download_timeout_sec=hf_download_timeout_sec, + hf_download_max_retries=hf_download_max_retries, + ) + return + # Dispatch to Ray Data executor if enabled if ray_data_config is not None and ray_data_config.enabled: _process_shards_ray_data( @@ -1150,7 +1183,28 @@ def on_progress(p: dict) -> None: if dataset_completed_counts.get(ds_name, 0) < dataset_pending_counts.get(ds_name, 0): live_status.report_phase(ds_name, phase, detail) - # Execute tasks via Ray Data + # Pre-download all HF files before processing starts + # This ensures processing actors only use cached files (no on-demand downloads) + from nemotron.data_prep.downloader import parallel_predownload + + def download_progress(p: dict) -> None: + """Progress callback for downloads.""" + phase = p.get("phase", "downloading") + detail = p.get("detail", "") + print(f"[Pre-download] {detail}") + + print("[Pre-download] Starting parallel download of HuggingFace files...") + download_stats = parallel_predownload( + tasks, + max_concurrent=ray_data_config.max_concurrent_downloads, + on_progress=download_progress, + ) + print( + f"[Pre-download] Complete: {download_stats.downloaded_files} downloaded, " + f"{download_stats.cached_files} cached, {download_stats.failed_files} failed" + ) + + # Now process shards - all files should be cached execute_shard_tasks( tasks, udf_cls=BinIdxShardTaskUDF, @@ -1461,7 +1515,7 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe # Planning phase: discover files and check cache for all datasets from nemotron.data_prep.discovery import discover_input_files, get_dataset_metadata - dataset_plans: list[tuple] = [] # (dataset, dataset_dir, files, cached_stats) + dataset_plans: list[tuple] = [] # (dataset, dataset_dir, receipts_dir, shard_assignments, pending_indices, cached_stats) plan_infos = [] for dataset in blend.datasets: @@ -1470,6 +1524,8 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe # Create dataset directory dataset_dir = f"{run_dir}/datasets/{name}" ensure_dir(fs, dataset_dir) + receipts_dir = f"{dataset_dir}/receipts" + ensure_dir(fs, receipts_dir) # Get files for this dataset dataset_config = DatasetConfig( @@ -1480,11 +1536,20 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe text_field=dataset.text_field, ) files = discover_input_files(dataset_config, fs) + files = sorted(files, key=lambda f: f.path) + + if files: + shard_assignments = _assign_files_round_robin(files, num_shards) + completed_indices = _get_completed_jsonl_shards(dataset_dir, receipts_dir, fs) + pending_indices = [i for i in range(num_shards) if i not in completed_indices] + else: + shard_assignments = {i: [] for i in range(num_shards)} + pending_indices = [] # Check cached stats cached_stats = _aggregate_jsonl_stats(dataset_dir, num_shards, fs) cached_shards = cached_stats.get("num_shards_completed", 0) - pending_shards = num_shards - cached_shards + pending_shards = len(pending_indices) # Fetch HuggingFace metadata (non-blocking, best-effort) hf_metadata = get_dataset_metadata(dataset_config) @@ -1506,46 +1571,134 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe ) ) - dataset_plans.append((dataset, dataset_dir, files, cached_stats)) + dataset_plans.append( + (dataset, dataset_dir, receipts_dir, shard_assignments, pending_indices, cached_stats) + ) # Show plan summary (auto-detect workers from cluster) con.plan_summary(plan_infos, run_hash) # Execution phase - has_work = any( - num_shards - cached_stats.get("num_shards_completed", 0) > 0 - for _, _, _, cached_stats in dataset_plans - ) + has_work = any(pending_indices for _, _, _, _, pending_indices, _ in dataset_plans) if has_work: con.execution_header() - for dataset, dataset_dir, files, cached_stats in dataset_plans: - name = dataset.name + if has_work and config.execution_engine == "xenna": + from dataclasses import asdict - # Process with actors - if files: - _process_jsonl_shards_with_actors( - files=files, - num_shards=num_shards, - dataset_dir=dataset_dir, - text_field=dataset.text_field, - transform=format_config.transform, - compression=format_config.compression, - max_rows=config.output.max_rows, - fs=fs, - ) + from nemotron.data_prep.xenna.runner import run_xenna_jsonl_pipeline + from nemotron.data_prep.xenna.work_items import JsonlShardWorkItem - # Aggregate stats - stats = _aggregate_jsonl_stats(dataset_dir, num_shards, fs) - results[name] = stats + tasks: list[JsonlShardWorkItem] = [] + dataset_infos: list[dict] = [] - # Build data_paths - weight = dataset.weight - if weight > 0: - prefix = f"{dataset_dir}/shard" - data_paths.append(str(weight)) - data_paths.append(prefix) + live_status = con.create_live_status( + datasets=[(dataset.name, num_shards) for dataset, *_ in dataset_plans], + run_hash=run_hash, + console_mode=config.console_mode, + simple_log_interval_sec=config.simple_log_interval_sec, + ) + live_status.start() + + try: + for dataset, dataset_dir, receipts_dir, shard_assignments, pending_indices, _ in dataset_plans: + assignment_dicts = { + shard_idx: { + "shard_index": shard_idx, + "files": [asdict(f) for f in shard_assignments[shard_idx]], + "total_bytes": sum(f.size for f in shard_assignments[shard_idx]), + } + for shard_idx in range(num_shards) + } + + if pending_indices: + for shard_idx in pending_indices: + tasks.append( + JsonlShardWorkItem( + dataset_name=dataset.name, + shard_index=shard_idx, + assignment=assignment_dicts[shard_idx], + output_dir=dataset_dir, + receipts_dir=receipts_dir, + text_field=dataset.text_field, + compression=format_config.compression, + max_rows=config.output.max_rows, + resolve_hf_placeholders=format_config.resolve_hf_placeholders, + ) + ) + + dataset_infos.append( + { + "name": dataset.name, + "dataset_dir": dataset_dir, + "receipts_dir": receipts_dir, + "num_shards": num_shards, + } + ) + + for info in dataset_infos: + live_status.start_dataset(info["name"]) + + if tasks: + run_xenna_jsonl_pipeline( + tasks=tasks, + dataset_infos=dataset_infos, + output_root=str(config.output.dir), + fs=fs, + live_status=live_status, + results=results, + text_field=dataset_plans[0][0].text_field if dataset_plans else "text", + transform=format_config.transform, + compression=format_config.compression, + max_rows=config.output.max_rows, + resolve_hf_placeholders=format_config.resolve_hf_placeholders, + max_concurrent_downloads=config.max_concurrent_downloads, + wandb_log_downloads=config.wandb_log_downloads, + wandb_download_log_interval_sec=config.wandb_download_log_interval_sec, + hf_download_timeout_sec=config.hf_download_timeout_sec, + hf_download_max_retries=config.hf_download_max_retries, + ) + else: + for info in dataset_infos: + stats = _aggregate_jsonl_stats(info["dataset_dir"], num_shards, fs) + results[info["name"]] = stats + live_status.complete_dataset(info["name"]) + + for dataset, dataset_dir, _, _, _, _ in dataset_plans: + weight = dataset.weight + if weight > 0: + prefix = f"{dataset_dir}/shard" + data_paths.append(str(weight)) + data_paths.append(prefix) + finally: + live_status.stop() + else: + for dataset, dataset_dir, receipts_dir, shard_assignments, pending_indices, _ in dataset_plans: + name = dataset.name + + if pending_indices: + _process_jsonl_shards_with_actors( + shard_assignments=shard_assignments, + pending_indices=pending_indices, + dataset_dir=dataset_dir, + receipts_dir=receipts_dir, + text_field=dataset.text_field, + transform=format_config.transform, + compression=format_config.compression, + max_rows=config.output.max_rows, + fs=fs, + num_actors=config.num_actors, + ) + + stats = _aggregate_jsonl_stats(dataset_dir, num_shards, fs) + results[name] = stats + + weight = dataset.weight + if weight > 0: + prefix = f"{dataset_dir}/shard" + data_paths.append(str(weight)) + data_paths.append(prefix) # Generate blend.json blend_data: dict = {"data_paths": data_paths} @@ -1613,15 +1766,70 @@ def _estimate_blend_bytes(blend: DataBlend, fs) -> int: return total or 1 # Avoid division by zero +def _assign_files_round_robin(files: list, num_shards: int) -> dict[int, list]: + shard_assignments: dict[int, list] = {i: [] for i in range(num_shards)} + for i, file_info in enumerate(files): + shard_idx = i % num_shards + shard_assignments[shard_idx].append(file_info) + return shard_assignments + + +def _get_completed_jsonl_shards(dataset_dir: str, receipts_dir: str, fs) -> set[int]: + completed: set[int] = set() + patterns = [ + f"{receipts_dir}/shard_*.json", + f"{dataset_dir}/shard_*.receipt.json", + ] + for pattern in patterns: + try: + receipt_files = fs.glob(pattern) + except Exception: + continue + for receipt_file in receipt_files: + filename = str(receipt_file).split("/")[-1] + if filename.startswith("shard_"): + if filename.endswith(".receipt.json"): + suffix = ".receipt.json" + else: + suffix = ".json" + try: + shard_str = filename[len("shard_") : -len(suffix)] + completed.add(int(shard_str)) + except ValueError: + continue + return completed + + +def _get_completed_packed_shards(receipts_dir: str, fs) -> set[int]: + completed: set[int] = set() + try: + receipt_files = fs.glob(f"{receipts_dir}/shard_*.json") + except Exception: + return completed + + for receipt_file in receipt_files: + filename = str(receipt_file).split("/")[-1] + if not filename.startswith("shard_") or not filename.endswith(".json"): + continue + try: + shard_str = filename[len("shard_") : -len(".json")] + completed.add(int(shard_str)) + except ValueError: + continue + return completed + + def _process_jsonl_shards_with_actors( - files: list, - num_shards: int, + shard_assignments: dict[int, list], + pending_indices: list[int], dataset_dir: str, + receipts_dir: str, text_field: str, transform, compression: str, max_rows: int | None, fs, + num_actors: int | None, ) -> None: """Process files to JSONL shards using Ray actors.""" from nemotron.data_prep.jsonl_processor import JsonlShardProcessor @@ -1633,7 +1841,7 @@ def _process_jsonl_shards_with_actors( fs_protocol = protocol if protocol != "file" else "file" # Auto-detect num_actors from cluster - num_actors = get_num_actors_from_cluster() + num_actors = num_actors or get_num_actors_from_cluster() # Create actor pool actors = [ @@ -1646,16 +1854,15 @@ def _process_jsonl_shards_with_actors( for _ in range(num_actors) ] - # Distribute files across shards (round-robin for now) - # TODO: Could use smarter distribution based on file sizes - shard_assignments: dict[int, list] = {i: [] for i in range(num_shards)} - for i, file_info in enumerate(files): - shard_idx = i % num_shards - shard_assignments[shard_idx].append(file_info) + serialized_assignments: dict[int, list] = {} + for shard_idx, files in shard_assignments.items(): + serialized_assignments[shard_idx] = [ + asdict(f) if hasattr(f, "__dict__") else f for f in files + ] # Submit tasks with backpressure max_in_flight = num_actors * 2 - shard_queue = list(range(num_shards)) + shard_queue = list(pending_indices) actor_idx = 0 pending_list: list = [] future_to_shard: dict = {} @@ -1671,6 +1878,7 @@ def submit_task(shard_index: int) -> None: ], output_dir=dataset_dir, fs_protocol=fs_protocol, + receipts_dir=receipts_dir, ) pending_list.append(future) future_to_shard[future] = shard_index @@ -1700,21 +1908,40 @@ def _aggregate_jsonl_stats(dataset_dir: str, num_shards: int, fs) -> dict: "num_records": 0, "num_skipped": 0, "total_bytes": 0, + "total_tokens": 0, } + receipt_files: list[str] = [] try: - receipt_files = fs.glob(f"{dataset_dir}/shard_*.receipt.json") + receipt_files.extend(fs.glob(f"{dataset_dir}/receipts/shard_*.json")) except Exception: - return stats + pass + try: + receipt_files.extend(fs.glob(f"{dataset_dir}/shard_*.receipt.json")) + except Exception: + pass + seen_indices: set[int] = set() for receipt_file in receipt_files: try: + filename = str(receipt_file).split("/")[-1] + if filename.startswith("shard_"): + suffix = ".receipt.json" if filename.endswith(".receipt.json") else ".json" + shard_str = filename[len("shard_") : -len(suffix)] + shard_index = int(shard_str) + if shard_index in seen_indices: + continue + seen_indices.add(shard_index) + receipt = read_json(fs, receipt_file) - if receipt.get("status") == "completed": - stats["num_shards_completed"] += 1 - stats["num_records"] += receipt.get("num_records", 0) - stats["num_skipped"] += receipt.get("num_skipped", 0) - stats["total_bytes"] += receipt.get("total_bytes", 0) + if receipt.get("status") != "completed": + continue + stats["num_shards_completed"] += 1 + receipt_stats = receipt.get("stats", receipt) + stats["num_records"] += receipt_stats.get("num_records", 0) + stats["num_skipped"] += receipt_stats.get("num_skipped", 0) + stats["total_bytes"] += receipt_stats.get("total_bytes", 0) + stats["total_tokens"] += receipt_stats.get("total_tokens", 0) except Exception: pass @@ -1940,7 +2167,7 @@ def submit_task(shard_index: int) -> None: actor_idx += 1 future = actor.process_shard.remote( shard_index=shard_index, - files=shard_assignments[shard_index], + files=serialized_assignments[shard_index], output_dir=dataset_dir, receipts_dir=receipts_dir, fs_protocol=fs_protocol, @@ -2074,7 +2301,7 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin con.planning_header() # Discover files for all datasets - dataset_plans: list[tuple] = [] # (dataset, dataset_dir, receipts_dir, files) + dataset_plans: list[tuple] = [] # (dataset, dataset_dir, receipts_dir, shard_assignments, pending_indices) for dataset in blend.datasets: name = dataset.name @@ -2093,15 +2320,23 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin text_field=dataset.text_field, ) files = discover_input_files(dataset_config, fs) + files = sorted(files, key=lambda f: f.path) + + if files: + shard_assignments = _assign_files_round_robin(files, num_shards) + completed_indices = _get_completed_packed_shards(receipts_dir, fs) + pending_indices = [i for i in range(num_shards) if i not in completed_indices] + else: + shard_assignments = {i: [] for i in range(num_shards)} + pending_indices = [] - # Display discovered info logger.info(f"Discovered dataset '{name}' with {len(files)} files") - dataset_plans.append((dataset, dataset_dir, receipts_dir, files)) + dataset_plans.append((dataset, dataset_dir, receipts_dir, shard_assignments, pending_indices)) # Build plan info for display plan_infos = [] - for dataset, dataset_dir, receipts_dir, files in dataset_plans: + for dataset, dataset_dir, receipts_dir, shard_assignments, pending_indices in dataset_plans: # Check cached stats cached_stats = _aggregate_packed_stats(dataset_dir, receipts_dir, fs) cached_shards = cached_stats.get("num_shards_completed", 0) @@ -2112,7 +2347,7 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin plan_hash=run_hash[:8], num_shards=num_shards, num_files=len(files), - pending=num_shards - cached_shards if files else 0, + pending=len(pending_indices) if shard_assignments else 0, cached=cached_shards, cached_tokens=cached_stats.get("total_tokens", 0), cached_sequences=cached_stats.get("num_sequences", 0), @@ -2126,142 +2361,173 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin con.plan_summary(plan_infos, run_hash) # Execution phase - has_work = any(len(files) > 0 for _, _, _, files in dataset_plans) + has_work = any(pending_indices for _, _, _, _, pending_indices in dataset_plans) if has_work: con.execution_header() - # Create actor pool ONCE and reuse across all datasets - from dataclasses import asdict + if config.execution_engine == "xenna": + from dataclasses import asdict - from nemotron.data_prep.chat_sft_processor import ChatSftShardProcessor + from nemotron.data_prep.xenna.runner import run_xenna_chat_sft_pipeline + from nemotron.data_prep.xenna.work_items import ChatSftShardWorkItem - # Auto-detect num_actors from cluster - num_actors = get_num_actors_from_cluster() + tasks: list[ChatSftShardWorkItem] = [] + dataset_infos: list[dict] = [] - actors = [ - ChatSftShardProcessor.remote( - resolved_tokenizer=resolved_tokenizer, - messages_field=format_config.messages_field, - tools_field=format_config.tools_field, - pack_size=format_config.pack_size, - algorithm=format_config.algorithm, - dtype=format_config.dtype, - chat_template=format_config.chat_template, - max_doc_tokens=config.output.max_doc_tokens, - max_rows=config.output.max_rows, - seed=42, - used_in_filter=format_config.used_in_filter, - used_in_field=format_config.used_in_field, + live_status = con.create_live_status( + datasets=[(dataset.name, num_shards) for dataset, *_ in dataset_plans], + run_hash=run_hash, + console_mode=config.console_mode, + simple_log_interval_sec=config.simple_log_interval_sec, ) - for _ in range(num_actors) - ] + live_status.start() - # Create live status panel with all datasets - live_status = con.create_live_status( - datasets=[ - (dataset.name, num_shards) for dataset, _, _, files in dataset_plans if files - ], - run_hash=run_hash, - console_mode=config.console_mode, - simple_log_interval_sec=config.simple_log_interval_sec, - ) - live_status.start() + try: + for dataset, dataset_dir, receipts_dir, shard_assignments, pending_indices in dataset_plans: + assignment_dicts = { + shard_idx: { + "shard_index": shard_idx, + "files": [asdict(f) for f in shard_assignments[shard_idx]], + "total_bytes": sum(f.size for f in shard_assignments[shard_idx]), + } + for shard_idx in range(num_shards) + } + + if pending_indices: + for shard_idx in pending_indices: + tasks.append( + ChatSftShardWorkItem( + dataset_name=dataset.name, + shard_index=shard_idx, + assignment=assignment_dicts[shard_idx], + output_dir=dataset_dir, + receipts_dir=receipts_dir, + max_rows=config.output.max_rows, + ) + ) + + dataset_infos.append( + { + "name": dataset.name, + "dataset_dir": dataset_dir, + "receipts_dir": receipts_dir, + } + ) - try: - # Determine filesystem protocol - protocol = fs.protocol - if isinstance(protocol, tuple): - protocol = protocol[0] - fs_protocol = protocol if protocol != "file" else "file" - - # Build all tasks upfront - process ALL datasets in parallel - all_tasks: list[ - tuple[str, str, str, int, list] - ] = [] # (name, dataset_dir, receipts_dir, shard_idx, files) - for dataset, dataset_dir, receipts_dir, files in dataset_plans: - if not files: - continue - # Each dataset gets 1 shard (since num_shards is computed per-dataset with 1 file) - # Convert files to dicts for Ray serialization - files_as_dicts = [asdict(f) if hasattr(f, "__dict__") else f for f in files] - all_tasks.append((dataset.name, dataset_dir, receipts_dir, 0, files_as_dicts)) - live_status.start_dataset(dataset.name) - - # Submit all tasks with backpressure - num_actors = len(actors) - max_in_flight = num_actors * 2 - task_queue = list(all_tasks) - actor_idx = 0 - pending_list: list = [] - future_to_task: dict = {} - - def submit_task(task: tuple) -> None: - nonlocal actor_idx - name, dataset_dir, receipts_dir, shard_idx, files_dicts = task - actor = actors[actor_idx % num_actors] - actor_idx += 1 - future = actor.process_shard.remote( - shard_index=shard_idx, - files=files_dicts, - output_dir=dataset_dir, - receipts_dir=receipts_dir, - fs_protocol=fs_protocol, + for info in dataset_infos: + live_status.start_dataset(info["name"]) + + if tasks: + run_xenna_chat_sft_pipeline( + tasks=tasks, + dataset_infos=dataset_infos, + output_root=str(config.output.dir), + fs=fs, + live_status=live_status, + results=results, + resolved_tokenizer=resolved_tokenizer, + messages_field=format_config.messages_field, + tools_field=format_config.tools_field, + pack_size=format_config.pack_size, + algorithm=format_config.algorithm, + dtype=format_config.dtype, + chat_template=format_config.chat_template, + max_doc_tokens=config.output.max_doc_tokens, + max_rows=config.output.max_rows, + seed=42, + used_in_filter=format_config.used_in_filter, + used_in_field=format_config.used_in_field, + max_concurrent_downloads=config.max_concurrent_downloads, + wandb_log_downloads=config.wandb_log_downloads, + wandb_download_log_interval_sec=config.wandb_download_log_interval_sec, + hf_download_timeout_sec=config.hf_download_timeout_sec, + hf_download_max_retries=config.hf_download_max_retries, + ) + else: + for info in dataset_infos: + stats = _aggregate_packed_stats( + info["dataset_dir"], info["receipts_dir"], fs + ) + results[info["name"]] = stats + live_status.complete_dataset(info["name"]) + + for dataset, dataset_dir, _, _, _ in dataset_plans: + weight = dataset.weight + if weight > 0: + prefix = f"{dataset_dir}/shard" + data_paths.append(str(weight)) + data_paths.append(prefix) + finally: + live_status.stop() + else: + from nemotron.data_prep.chat_sft_processor import ChatSftShardProcessor + + num_actors = config.num_actors or get_num_actors_from_cluster() + + actors = [ + ChatSftShardProcessor.remote( + resolved_tokenizer=resolved_tokenizer, + messages_field=format_config.messages_field, + tools_field=format_config.tools_field, + pack_size=format_config.pack_size, + algorithm=format_config.algorithm, + dtype=format_config.dtype, + chat_template=format_config.chat_template, + max_doc_tokens=config.output.max_doc_tokens, + max_rows=config.output.max_rows, + seed=42, + used_in_filter=format_config.used_in_filter, + used_in_field=format_config.used_in_field, ) - pending_list.append(future) - future_to_task[future] = task - - # Initial submission - while task_queue and len(pending_list) < max_in_flight: - submit_task(task_queue.pop(0)) - - # Process with backpressure - while pending_list: - done, pending_list = ray.wait(pending_list, num_returns=1, timeout=60) - for future in done: - task = future_to_task.pop(future) - name = task[0] - dataset_dir = task[1] - receipts_dir = task[2] - try: - ray.get(future) - except Exception as e: - logger.error(f"Chat SFT shard for {name} failed: {e}") + for _ in range(num_actors) + ] - # Update progress - live_status.advance_dataset(name) + live_status = con.create_live_status( + datasets=[(dataset.name, num_shards) for dataset, *_ in dataset_plans], + run_hash=run_hash, + console_mode=config.console_mode, + simple_log_interval_sec=config.simple_log_interval_sec, + ) + live_status.start() + + try: + for dataset, dataset_dir, receipts_dir, shard_assignments, pending_indices in dataset_plans: + live_status.start_dataset(dataset.name) + if pending_indices: + _process_chat_sft_shards_with_actors_pool( + actors=actors, + shard_assignments=shard_assignments, + pending_indices=pending_indices, + dataset_dir=dataset_dir, + receipts_dir=receipts_dir, + max_rows=config.output.max_rows, + fs=fs, + on_progress=lambda name=dataset.name: live_status.advance_dataset(name), + ) - # Aggregate stats for this dataset stats = _aggregate_packed_stats(dataset_dir, receipts_dir, fs) - results[name] = stats + results[dataset.name] = stats live_status.report_metrics( - name, + dataset.name, rows=stats.get("num_sequences", 0), tokens=stats.get("total_tokens", 0), ) - live_status.complete_dataset(name) - - # Submit next task if available - if task_queue: - submit_task(task_queue.pop(0)) - - # Build data_paths for all completed datasets - for dataset, dataset_dir, receipts_dir, files in dataset_plans: - if not files: - continue - weight = dataset.weight - if weight > 0: - prefix = f"{dataset_dir}/shard" - data_paths.append(str(weight)) - data_paths.append(prefix) - finally: - live_status.stop() - # Clean up actors - for actor in actors: - ray.kill(actor) + live_status.complete_dataset(dataset.name) + + for dataset, dataset_dir, receipts_dir, _, _ in dataset_plans: + weight = dataset.weight + if weight > 0: + prefix = f"{dataset_dir}/shard" + data_paths.append(str(weight)) + data_paths.append(prefix) + finally: + live_status.stop() + for actor in actors: + ray.kill(actor) else: # No work to do - all datasets empty or cached - for dataset, dataset_dir, receipts_dir, files in dataset_plans: + for dataset, dataset_dir, receipts_dir, _, _ in dataset_plans: stats = _aggregate_packed_stats(dataset_dir, receipts_dir, fs) results[dataset.name] = stats weight = dataset.weight @@ -2313,8 +2579,8 @@ def submit_task(task: tuple) -> None: def _process_chat_sft_shards_with_actors_pool( actors: list, - files: list, - num_shards: int, + shard_assignments: dict[int, list], + pending_indices: list[int], dataset_dir: str, receipts_dir: str, max_rows: int | None, @@ -2325,8 +2591,6 @@ def _process_chat_sft_shards_with_actors_pool( This version takes a pre-created actor pool to allow reuse across datasets. """ - from dataclasses import asdict - # Determine filesystem protocol protocol = fs.protocol if isinstance(protocol, tuple): @@ -2335,19 +2599,15 @@ def _process_chat_sft_shards_with_actors_pool( num_actors = len(actors) - # Distribute files across shards (round-robin) - shard_assignments: dict[int, list] = {i: [] for i in range(num_shards)} - for i, file_info in enumerate(files): - shard_idx = i % num_shards - # Convert FileInfo to dict for Ray serialization - if hasattr(file_info, "__dict__"): - shard_assignments[shard_idx].append(asdict(file_info)) - else: - shard_assignments[shard_idx].append(file_info) + serialized_assignments: dict[int, list] = {} + for shard_idx, files in shard_assignments.items(): + serialized_assignments[shard_idx] = [ + f.__dict__ if hasattr(f, "__dict__") else f for f in files + ] # Submit tasks with backpressure max_in_flight = num_actors * 2 - shard_queue = list(range(num_shards)) + shard_queue = list(pending_indices) actor_idx = 0 pending_list: list = [] future_to_shard: dict = {} @@ -2358,7 +2618,7 @@ def submit_task(shard_index: int) -> None: actor_idx += 1 future = actor.process_shard.remote( shard_index=shard_index, - files=shard_assignments[shard_index], + files=serialized_assignments[shard_index], output_dir=dataset_dir, receipts_dir=receipts_dir, fs_protocol=fs_protocol, @@ -2431,10 +2691,12 @@ def _process_chat_sft_shards_with_actors( ] try: + shard_assignments = _assign_files_round_robin(files, num_shards) + pending_indices = [i for i in range(num_shards)] _process_chat_sft_shards_with_actors_pool( actors=actors, - files=files, - num_shards=num_shards, + shard_assignments=shard_assignments, + pending_indices=pending_indices, dataset_dir=dataset_dir, receipts_dir=receipts_dir, max_rows=max_rows, diff --git a/src/nemotron/data_prep/ray_data/executor.py b/src/nemotron/data_prep/ray_data/executor.py index 9af24c8b9..634307610 100644 --- a/src/nemotron/data_prep/ray_data/executor.py +++ b/src/nemotron/data_prep/ray_data/executor.py @@ -123,7 +123,7 @@ class RayDataExecConfig: min_actors: int = 2 max_actors: int = 32 cpus_per_actor: float = 1.0 - max_tasks_in_flight_per_actor: int = 2 + max_tasks_in_flight_per_actor: int = 4 # Increased from 2 for better CPU utilization def execute_shard_tasks( @@ -190,6 +190,20 @@ def execute_shard_tasks( max_tasks_in_flight_per_actor=exec_cfg.max_tasks_in_flight_per_actor, ) + # Build runtime_env with HF cache settings for actors + # This ensures HuggingFace downloads go to persistent Lustre storage + import os + actor_env_vars = {} + hf_home = os.environ.get("HF_HOME") + hf_token = os.environ.get("HF_TOKEN") + + if hf_home: + actor_env_vars["HF_HOME"] = hf_home + if hf_token: + actor_env_vars["HF_TOKEN"] = hf_token + + actor_runtime_env = {"env_vars": actor_env_vars} if actor_env_vars else None + # Execute with explicit CPU allocation # batch_size=1 means one shard task per UDF call # Default batch_format is dict-of-numpy-arrays (NumPy is DEFAULT_BATCH_FORMAT) @@ -198,13 +212,17 @@ def execute_shard_tasks( # FAULT TOLERANCE: We rely on idempotent atomic commit in the UDF rather than # disabling retries. Ray Data defaults to max_restarts=-1, max_task_retries=-1. # The atomic write protocol (temp -> rename -> receipt) makes retries safe. - stats_ds = ds.map_batches( - udf_cls, - fn_constructor_kwargs=udf_constructor_kwargs, - batch_size=1, - compute=compute, - num_cpus=exec_cfg.cpus_per_actor, - ) + map_batches_kwargs = { + "fn_constructor_kwargs": udf_constructor_kwargs, + "batch_size": 1, + "compute": compute, + "num_cpus": exec_cfg.cpus_per_actor, + } + if actor_runtime_env: + map_batches_kwargs["runtime_env"] = actor_runtime_env + logger.info(f"Ray actors will use HF_HOME: {actor_env_vars.get('HF_HOME', 'not set')}") + + stats_ds = ds.map_batches(udf_cls, **map_batches_kwargs) # Stream results to handle callback and collect stats all_stats: list[dict[str, Any]] = [] diff --git a/src/nemotron/data_prep/shard_processor.py b/src/nemotron/data_prep/shard_processor.py index cbc54a63c..56ffa4bea 100644 --- a/src/nemotron/data_prep/shard_processor.py +++ b/src/nemotron/data_prep/shard_processor.py @@ -370,8 +370,12 @@ def _process_file_core( def _resolve_file_path_core(file_info: FileInfo) -> str: - """Resolve file to a local path, downloading from HF if needed.""" - # HF files need deferred download + """Resolve file to a local path, using HF cache (no download). + + Files should be pre-downloaded by parallel_predownload() before processing. + This function only looks up cached files to avoid network I/O during processing. + """ + # HF files - use cache only (should be pre-downloaded) if file_info.hf_repo_id is not None: from huggingface_hub import hf_hub_download @@ -380,7 +384,7 @@ def _resolve_file_path_core(file_info: FileInfo) -> str: filename=file_info.hf_filename, revision=file_info.hf_revision, repo_type="dataset", - local_files_only=False, + local_files_only=True, # Only use cached files ) return local_path @@ -934,13 +938,12 @@ def _process_jsonl_file( return rows_processed def _resolve_file_path(self, file_info: FileInfo) -> str: - """ - Resolve file to a local path, downloading from HF if needed. + """Resolve file to a local path, using HF cache (no download). - For HF files, downloads to local cache (node-local). - For other files, returns local_path or path. + Files should be pre-downloaded by parallel_predownload() before processing. + This method only looks up cached files to avoid network I/O during processing. """ - # HF files need deferred download + # HF files - use cache only (should be pre-downloaded) if file_info.hf_repo_id is not None: from huggingface_hub import hf_hub_download @@ -949,7 +952,7 @@ def _resolve_file_path(self, file_info: FileInfo) -> str: filename=file_info.hf_filename, revision=file_info.hf_revision, repo_type="dataset", - local_files_only=False, + local_files_only=True, # Only use cached files ) return local_path diff --git a/src/nemotron/data_prep/xenna/__init__.py b/src/nemotron/data_prep/xenna/__init__.py new file mode 100644 index 000000000..73ba56ce8 --- /dev/null +++ b/src/nemotron/data_prep/xenna/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Xenna integration for Nemotron data prep.""" + +from nemotron.data_prep.xenna.runner import run_xenna_pipeline +from nemotron.data_prep.xenna.stages import HfPredownloadStage, PretrainShardStage +from nemotron.data_prep.xenna.work_items import ShardWorkItem + +__all__ = [ + "HfPredownloadStage", + "PretrainShardStage", + "ShardWorkItem", + "run_xenna_pipeline", +] diff --git a/src/nemotron/data_prep/xenna/runner.py b/src/nemotron/data_prep/xenna/runner.py new file mode 100644 index 000000000..aafdb045f --- /dev/null +++ b/src/nemotron/data_prep/xenna/runner.py @@ -0,0 +1,662 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Xenna pipeline runner for Nemotron data prep.""" + +from __future__ import annotations + +from dataclasses import asdict +import gc +import json +import threading +import time +from typing import TYPE_CHECKING + +import cosmos_xenna.pipelines.v1 as pipelines_v1 + + +def _log_memory_status(label: str) -> None: + """Log memory usage for debugging OOM issues.""" + try: + import psutil + process = psutil.Process() + rss_gb = process.memory_info().rss / (1024**3) + print(f"[Memory] {label}: RSS={rss_gb:.2f} GB") + except ImportError: + print(f"[Memory] {label}: psutil not available") + + try: + import ray + if ray.is_initialized(): + resources = ray.available_resources() + obj_store = resources.get("object_store_memory", 0) / (1024**3) + print(f"[Ray] {label}: object_store_available={obj_store:.2f} GB") + except Exception as e: + print(f"[Ray] {label}: error getting status - {e}") + +if TYPE_CHECKING: + from cosmos_xenna.pipelines.private.monitoring_types import PipelineStats + +from nemotron.data_prep.xenna.stages import ( + ChatSftCentralPackStage, + ChatSftSpoolStage, + HfPredownloadStage, + JsonlShardStage, + PretrainShardStage, +) +from nemotron.data_prep.xenna.work_items import ( + ChatSftShardWorkItem, + JsonlShardWorkItem, + ShardWorkItem, +) + + +def run_xenna_pipeline( + *, + execution_plans: list, + output_config, + output_root: str, + fs, + live_status, + results: dict, + max_concurrent_downloads: int = 64, + wandb_log_downloads: bool = False, + wandb_log_pipeline_stats: bool = False, + wandb_download_log_interval_sec: int = 30, + hf_download_timeout_sec: int = 300, + hf_download_max_retries: int = 3, +) -> None: + """Run shard processing via Xenna pipeline.""" + if not execution_plans: + return + + resolved_tokenizer = execution_plans[0].plan.resolved_tokenizer + for ep in execution_plans[1:]: + if ep.plan.resolved_tokenizer != resolved_tokenizer: + raise ValueError( + f"Tokenizer mismatch: dataset '{ep.name}' uses different tokenizer. " + "Xenna executor requires uniform tokenizer across datasets in v1." + ) + + tasks: list[ShardWorkItem] = [] + for ep in execution_plans: + live_status.start_dataset(ep.name) + live_status.report_phase(ep.name, "processing", "xenna") + + assignment_dicts = {} + for a in ep.plan.file_assignments: + assignment_dicts[a.shard_index] = { + "shard_index": a.shard_index, + "files": [asdict(f) for f in a.files], + "total_bytes": a.total_bytes, + } + + for shard_idx in ep.pending_indices: + tasks.append( + ShardWorkItem( + dataset_name=ep.name, + plan_hash=ep.plan.plan_hash, + shard_index=shard_idx, + assignment=assignment_dicts[shard_idx], + output_dir=ep.dataset_dir, + receipts_dir=ep.receipts_dir, + text_field=ep.config.text_field, + dtype=output_config.dtype, + min_doc_chars=output_config.min_doc_chars, + max_doc_tokens=output_config.max_doc_tokens, + max_rows=output_config.max_rows, + ) + ) + + if not tasks: + return + + print(f"[Xenna] Launching pipeline for {len(tasks)} shard(s)") + + pipeline_spec = pipelines_v1.PipelineSpec( + input_data=tasks, + stages=[ + pipelines_v1.StageSpec( + HfPredownloadStage( + max_concurrent_downloads=max_concurrent_downloads, + output_root=output_root, + download_timeout_sec=hf_download_timeout_sec, + max_retries=hf_download_max_retries, + ), + num_workers_per_node=1, + ), + pipelines_v1.StageSpec( + PretrainShardStage( + resolved_tokenizer=resolved_tokenizer, + output_root=output_root, + ), + ) + ], + config=pipelines_v1.PipelineConfig( + execution_mode=pipelines_v1.ExecutionMode.STREAMING, + return_last_stage_outputs=False, + ), + ) + + dataset_pending_counts = {ep.name: len(ep.pending_indices) for ep in execution_plans} + + stop_event = threading.Event() + + def _poll_receipts() -> None: + last_counts: dict[str, int] = {ep.name: 0 for ep in execution_plans} + seen_receipts: dict[str, set[str]] = {ep.name: set() for ep in execution_plans} + tokens_by_dataset: dict[str, int] = {ep.name: 0 for ep in execution_plans} + while not stop_event.is_set(): + for ep in execution_plans: + try: + if not fs.exists(ep.receipts_dir): + continue + entries = [p for p in fs.ls(ep.receipts_dir, detail=False) if str(p).endswith(".json")] + count = len(entries) + except Exception: + continue + + last = last_counts.get(ep.name, 0) + if count > last: + new_entries = [] + for p in entries: + if p not in seen_receipts[ep.name]: + seen_receipts[ep.name].add(p) + new_entries.append(p) + + for receipt_path in new_entries: + try: + with fs.open(receipt_path, "r") as f: + receipt = json.load(f) + tokens_by_dataset[ep.name] += _extract_tokens(receipt) + except Exception: + pass + + for _ in range(count - last): + live_status.advance_dataset(ep.name) + last_counts[ep.name] = count + live_status.report_tokens(ep.name, tokens_by_dataset[ep.name]) + stop_event.wait(10.0) + + poll_thread = threading.Thread(target=_poll_receipts, daemon=True) + poll_thread.start() + + wandb_thread = None + if wandb_log_downloads: + wandb_thread = threading.Thread( + target=_poll_download_stats, + args=(fs, output_root, stop_event, wandb_download_log_interval_sec), + daemon=True, + ) + wandb_thread.start() + + try: + pipelines_v1.run_pipeline(pipeline_spec) + finally: + stop_event.set() + poll_thread.join(timeout=2.0) + if wandb_thread is not None: + wandb_thread.join(timeout=2.0) + + for ep in execution_plans: + results[ep.name] = _aggregate_stats_from_receipts(ep.receipts_dir, ep.plan, fs) + live_status.report_metrics( + ep.name, + rows=results[ep.name].get("total_sequences", 0), + tokens=results[ep.name].get("total_tokens", 0), + ) + live_status.complete_dataset(ep.name) + + +def _aggregate_stats_from_receipts(receipts_dir: str, plan, fs) -> dict: + """Import-free wrapper; actual implementation lives in pipeline.py.""" + from nemotron.data_prep.pipeline import _aggregate_stats_from_receipts as _agg + + return _agg(receipts_dir, plan, fs) + + +def run_xenna_jsonl_pipeline( + *, + tasks: list[JsonlShardWorkItem], + dataset_infos: list[dict], + output_root: str, + fs, + live_status, + results: dict, + text_field: str, + transform, + compression: str, + max_rows: int | None, + resolve_hf_placeholders: bool, + max_concurrent_downloads: int = 64, + wandb_log_downloads: bool = False, + wandb_log_pipeline_stats: bool = False, + wandb_download_log_interval_sec: int = 30, + hf_download_timeout_sec: int = 300, + hf_download_max_retries: int = 3, +) -> None: + if not tasks: + return + + pipeline_spec = pipelines_v1.PipelineSpec( + input_data=tasks, + stages=[ + pipelines_v1.StageSpec( + HfPredownloadStage( + max_concurrent_downloads=max_concurrent_downloads, + output_root=output_root, + download_timeout_sec=hf_download_timeout_sec, + max_retries=hf_download_max_retries, + ), + num_workers_per_node=1, + ), + pipelines_v1.StageSpec( + JsonlShardStage( + output_root=output_root, + text_field=text_field, + transform=transform, + compression=compression, + max_rows=max_rows, + resolve_hf_placeholders=resolve_hf_placeholders, + ), + # Limit workers to prevent OOM on large datasets + num_workers=4, + ), + ], + config=pipelines_v1.PipelineConfig( + execution_mode=pipelines_v1.ExecutionMode.STREAMING, + return_last_stage_outputs=False, + ), + ) + + stop_event = threading.Event() + + def _poll_receipts() -> None: + last_counts: dict[str, int] = {info["name"]: 0 for info in dataset_infos} + seen_receipts: dict[str, set[str]] = {info["name"]: set() for info in dataset_infos} + tokens_by_dataset: dict[str, int] = {info["name"]: 0 for info in dataset_infos} + while not stop_event.is_set(): + for info in dataset_infos: + name = info["name"] + receipts_dir = info["receipts_dir"] + try: + if not fs.exists(receipts_dir): + continue + entries = [p for p in fs.ls(receipts_dir, detail=False) if str(p).endswith(".json")] + count = len(entries) + except Exception: + continue + + last = last_counts.get(name, 0) + if count > last: + new_entries = [] + for p in entries: + if p not in seen_receipts[name]: + seen_receipts[name].add(p) + new_entries.append(p) + + for receipt_path in new_entries: + try: + with fs.open(receipt_path, "r") as f: + receipt = json.load(f) + tokens_by_dataset[name] += _extract_tokens(receipt) + except Exception: + pass + + for _ in range(count - last): + live_status.advance_dataset(name) + last_counts[name] = count + live_status.report_tokens(name, tokens_by_dataset[name]) + stop_event.wait(10.0) + + poll_thread = threading.Thread(target=_poll_receipts, daemon=True) + poll_thread.start() + + wandb_thread = None + if wandb_log_downloads: + wandb_thread = threading.Thread( + target=_poll_download_stats, + args=(fs, output_root, stop_event, wandb_download_log_interval_sec), + daemon=True, + ) + wandb_thread.start() + + try: + pipelines_v1.run_pipeline(pipeline_spec) + finally: + stop_event.set() + poll_thread.join(timeout=2.0) + if wandb_thread is not None: + wandb_thread.join(timeout=2.0) + + for info in dataset_infos: + name = info["name"] + stats = _aggregate_jsonl_stats_from_receipts( + dataset_dir=info["dataset_dir"], + num_shards=info["num_shards"], + fs=fs, + ) + results[name] = stats + live_status.report_metrics( + name, + rows=stats.get("num_records", 0), + tokens=stats.get("total_tokens", 0), + ) + live_status.complete_dataset(name) + + +def run_xenna_chat_sft_pipeline( + *, + tasks: list[ChatSftShardWorkItem], + dataset_infos: list[dict], + output_root: str, + fs, + live_status, + results: dict, + resolved_tokenizer: dict, + messages_field: str, + tools_field: str, + pack_size: int, + algorithm: str, + dtype: str, + chat_template: str | None, + max_doc_tokens: int | None, + max_rows: int | None, + seed: int | None, + used_in_filter: str | None, + used_in_field: str, + max_concurrent_downloads: int = 64, + wandb_log_downloads: bool = False, + wandb_log_pipeline_stats: bool = False, + wandb_download_log_interval_sec: int = 30, + hf_download_timeout_sec: int = 300, + hf_download_max_retries: int = 3, +) -> None: + if not tasks: + return + + stages: list[pipelines_v1.StageSpec] = [ + pipelines_v1.StageSpec( + HfPredownloadStage( + max_concurrent_downloads=max_concurrent_downloads, + output_root=output_root, + download_timeout_sec=hf_download_timeout_sec, + max_retries=hf_download_max_retries, + ), + num_workers_per_node=1, + ), + pipelines_v1.StageSpec( + ChatSftSpoolStage( + resolved_tokenizer=resolved_tokenizer, + output_root=output_root, + messages_field=messages_field, + tools_field=tools_field, + pack_size=pack_size, + algorithm=algorithm, + dtype=dtype, + chat_template=chat_template, + max_doc_tokens=max_doc_tokens, + max_rows=max_rows, + seed=seed, + used_in_filter=used_in_filter, + used_in_field=used_in_field, + ), + # Let Xenna auto-scale - spool stage is memory-efficient + ), + pipelines_v1.StageSpec( + ChatSftCentralPackStage( + output_root=output_root, + pack_size=pack_size, + algorithm=algorithm, + dtype=dtype, + seed=seed, + ), + num_workers=1, # Must be single-worker for centralized packing + ), + ] + + pipeline_spec = pipelines_v1.PipelineSpec( + input_data=tasks, + stages=stages, + config=pipelines_v1.PipelineConfig( + execution_mode=pipelines_v1.ExecutionMode.STREAMING, + return_last_stage_outputs=False, + ), + ) + + stop_event = threading.Event() + + def _poll_receipts() -> None: + last_counts: dict[str, int] = {info["name"]: 0 for info in dataset_infos} + seen_receipts: dict[str, set[str]] = {info["name"]: set() for info in dataset_infos} + tokens_by_dataset: dict[str, int] = {info["name"]: 0 for info in dataset_infos} + while not stop_event.is_set(): + for info in dataset_infos: + name = info["name"] + receipts_dir = info["receipts_dir"] + try: + if not fs.exists(receipts_dir): + continue + entries = [p for p in fs.ls(receipts_dir, detail=False) if str(p).endswith(".json")] + count = len(entries) + except Exception: + continue + + last = last_counts.get(name, 0) + if count > last: + new_entries = [] + for p in entries: + if p not in seen_receipts[name]: + seen_receipts[name].add(p) + new_entries.append(p) + + for receipt_path in new_entries: + try: + with fs.open(receipt_path, "r") as f: + receipt = json.load(f) + tokens_by_dataset[name] += _extract_tokens(receipt) + except Exception: + pass + + for _ in range(count - last): + live_status.advance_dataset(name) + last_counts[name] = count + live_status.report_tokens(name, tokens_by_dataset[name]) + stop_event.wait(10.0) + + poll_thread = threading.Thread(target=_poll_receipts, daemon=True) + poll_thread.start() + + wandb_thread = None + if wandb_log_downloads: + wandb_thread = threading.Thread( + target=_poll_download_stats, + args=(fs, output_root, stop_event, wandb_download_log_interval_sec), + daemon=True, + ) + wandb_thread.start() + + _log_memory_status("Before run_pipeline") + try: + pipelines_v1.run_pipeline(pipeline_spec) + finally: + _log_memory_status("After run_pipeline (in finally)") + stop_event.set() + poll_thread.join(timeout=2.0) + if wandb_thread is not None: + wandb_thread.join(timeout=2.0) + + _log_memory_status("After thread cleanup") + + # Force garbage collection to release memory from pipeline + gc.collect() + _log_memory_status("After gc.collect()") + + for info in dataset_infos: + name = info["name"] + _log_memory_status(f"Before aggregating {name}") + stats = _aggregate_packed_stats_from_receipts( + dataset_dir=info["dataset_dir"], + receipts_dir=info["receipts_dir"], + fs=fs, + ) + results[name] = stats + live_status.report_metrics( + name, + rows=stats.get("num_sequences", 0), + tokens=stats.get("total_tokens", 0), + ) + live_status.complete_dataset(name) + + _log_memory_status("After all aggregation - pipeline complete") + + +def _aggregate_jsonl_stats_from_receipts(*, dataset_dir: str, num_shards: int, fs) -> dict: + from nemotron.data_prep.pipeline import _aggregate_jsonl_stats as _agg + + return _agg(dataset_dir, num_shards, fs) + + +def _aggregate_packed_stats_from_receipts(*, dataset_dir: str, receipts_dir: str, fs) -> dict: + from nemotron.data_prep.pipeline import _aggregate_packed_stats as _agg + + return _agg(dataset_dir, receipts_dir, fs) + + +def _extract_tokens(receipt: dict) -> int: + return int(receipt.get("stats", {}).get("total_tokens", 0)) + + +def _make_wandb_stats_callback(): + """Create a callback function for logging pipeline stats to wandb. + + Returns a callback if wandb is active, None otherwise. + """ + try: + import wandb + except ImportError: + return None + + if wandb.run is None: + return None + + def _log_stats(stats: "PipelineStats") -> None: + """Log PipelineStats to wandb.""" + metrics = { + # Overall pipeline progress + "data_prep/pipeline_duration_min": stats.pipeline_duration_s / 60, + "data_prep/inputs_initial": stats.num_initial_input_tasks, + "data_prep/inputs_remaining": stats.num_input_tasks_remaining, + "data_prep/outputs_total": stats.num_outputs, + "data_prep/main_loop_rate_hz": stats.main_loop_rate_hz, + # Cluster resources + "data_prep/cluster_cpus_total": stats.cluster.total.num_cpus, + "data_prep/cluster_cpus_available": stats.cluster.available.num_cpus, + "data_prep/cluster_gpus_total": stats.cluster.total.num_gpus, + "data_prep/cluster_gpus_available": stats.cluster.available.num_gpus, + "data_prep/cluster_memory_total_gb": stats.cluster.total.memory / 1e9, + "data_prep/cluster_memory_available_gb": stats.cluster.available.memory / 1e9, + } + + # Progress percentage + if stats.num_initial_input_tasks > 0: + progress = 1.0 - (stats.num_input_tasks_remaining / stats.num_initial_input_tasks) + metrics["data_prep/pipeline_progress"] = progress + + # Per-stage resource usage + for stage_name, usage in stats.resource_usage_per_stage.items(): + safe_name = stage_name.replace(" ", "_").replace("-", "_") + metrics[f"data_prep/stage_{safe_name}_cpu_pct"] = usage.cpu_utilization + metrics[f"data_prep/stage_{safe_name}_memory_gb"] = usage.memory_usage / 1e9 + metrics[f"data_prep/stage_{safe_name}_actor_count"] = usage.actor_count + + # Per-stage state from actor pools + for pool_stats in stats.actor_pools: + safe_name = pool_stats.name.replace(" ", "_").replace("-", "_") + # Actor counts + metrics[f"data_prep/stage_{safe_name}_actors_target"] = pool_stats.actor_stats.target + metrics[f"data_prep/stage_{safe_name}_actors_ready"] = pool_stats.actor_stats.ready + metrics[f"data_prep/stage_{safe_name}_actors_running"] = pool_stats.actor_stats.running + metrics[f"data_prep/stage_{safe_name}_actors_idle"] = pool_stats.actor_stats.idle + # Task stats + metrics[f"data_prep/stage_{safe_name}_tasks_completed"] = pool_stats.task_stats.total_completed + metrics[f"data_prep/stage_{safe_name}_input_queue_size"] = pool_stats.task_stats.input_queue_size + metrics[f"data_prep/stage_{safe_name}_output_queue_size"] = pool_stats.task_stats.output_queue_size + # Slot stats + metrics[f"data_prep/stage_{safe_name}_slots_used"] = pool_stats.slot_stats.num_used + metrics[f"data_prep/stage_{safe_name}_slots_empty"] = pool_stats.slot_stats.num_empty + # Speed + if pool_stats.processing_speed_tasks_per_second is not None: + metrics[f"data_prep/stage_{safe_name}_speed_tasks_per_sec"] = pool_stats.processing_speed_tasks_per_second + + wandb.log(metrics) + + return _log_stats + + +def _poll_download_stats(fs, output_root: str, stop_event: threading.Event, interval_sec: int) -> None: + try: + import wandb + except ImportError: + return + + if wandb.run is None: + return + + progress_dir = f"{output_root.rstrip('/')}/.xenna/downloads" + last_logged = 0.0 + + while not stop_event.is_set(): + now = time.time() + if now - last_logged < interval_sec: + stop_event.wait(1.0) + continue + last_logged = now + + try: + if not fs.exists(progress_dir): + continue + entries = fs.ls(progress_dir, detail=False) + except Exception: + continue + + total_completed = 0 + total_files = 0 + max_elapsed = 0.0 + max_rate = 0.0 + + for path in entries: + if not str(path).endswith(".json"): + continue + try: + with fs.open(path, "r") as f: + data = json.load(f) + except Exception: + continue + total_completed += int(data.get("completed", 0)) + total_files += int(data.get("total", 0)) + max_elapsed = max(max_elapsed, float(data.get("elapsed_sec", 0.0))) + max_rate = max(max_rate, float(data.get("rate", 0.0))) + + if total_files == 0: + continue + + wandb.log( + { + "data_prep/hf_download_completed": total_completed, + "data_prep/hf_download_total": total_files, + "data_prep/hf_download_rate": max_rate, + "data_prep/hf_download_elapsed_sec": max_elapsed, + } + ) diff --git a/src/nemotron/data_prep/xenna/stages.py b/src/nemotron/data_prep/xenna/stages.py new file mode 100644 index 000000000..c328c868c --- /dev/null +++ b/src/nemotron/data_prep/xenna/stages.py @@ -0,0 +1,619 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Xenna stages for Nemotron data prep.""" + +from __future__ import annotations + +import json +import os +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Any + +import cosmos_xenna.pipelines.v1 as pipelines_v1 +from cosmos_xenna.ray_utils.runtime_envs import RuntimeEnv +import numpy as np + + +def _get_hf_runtime_env() -> RuntimeEnv: + """Create RuntimeEnv with HF_HOME and HF_TOKEN for worker processes.""" + env_vars = {} + if os.environ.get("HF_HOME"): + env_vars["HF_HOME"] = os.environ["HF_HOME"] + if os.environ.get("HF_TOKEN"): + env_vars["HF_TOKEN"] = os.environ["HF_TOKEN"] + return RuntimeEnv(extra_env_vars=env_vars) if env_vars else RuntimeEnv() + +from nemotron.data_prep.chat_sft_shard_core import ( + process_chat_sft_pack_from_spool_core, + process_chat_sft_shard_core, + process_chat_sft_spool_core, +) +from nemotron.data_prep.filesystem import ensure_dir, get_filesystem +from nemotron.data_prep.formats.transforms import resolve_hf_placeholders +from nemotron.data_prep.hf_placeholder import HFPlaceholderResolver +from nemotron.data_prep.jsonl_shard_core import process_jsonl_shard_core +from nemotron.data_prep.providers import create_tokenizer +from nemotron.data_prep.shard_processor import process_binidx_shard_core +from nemotron.data_prep.xenna.work_items import ( + ChatSftShardWorkItem, + ChatSftSpoolWorkItem, + JsonlShardWorkItem, + ShardWorkItem, +) + + +class PretrainShardStage(pipelines_v1.Stage[ShardWorkItem, dict]): + """Process bin/idx shards using Xenna.""" + + def __init__(self, *, resolved_tokenizer: dict, output_root: str) -> None: + self._resolved_tokenizer = resolved_tokenizer + self._output_root = output_root + self._tokenize = None + self._output_fs = None + + @property + def stage_batch_size(self) -> int: + return 1 + + @property + def required_resources(self) -> pipelines_v1.Resources: + return pipelines_v1.Resources(gpus=0, cpus=1.0) + + @property + def env_info(self) -> RuntimeEnv: + return _get_hf_runtime_env() + + def setup(self, worker_metadata: pipelines_v1.WorkerMetadata) -> None: + self._tokenize = create_tokenizer(self._resolved_tokenizer) + self._output_fs, _ = get_filesystem(self._output_root) + + def process_data(self, tasks: list[ShardWorkItem]) -> list[dict[str, Any]]: + results: list[dict[str, Any]] = [] + for task in tasks: + stats = process_binidx_shard_core( + tokenize=self._tokenize, + text_field=task.text_field, + min_doc_chars=task.min_doc_chars, + max_doc_tokens=task.max_doc_tokens, + dtype=task.dtype, + max_rows=task.max_rows, + shard_index=task.shard_index, + assignment=task.assignment, + plan_hash=task.plan_hash, + output_dir=task.output_dir, + receipts_dir=task.receipts_dir, + output_fs=self._output_fs, + ) + results.append( + { + "dataset_name": task.dataset_name, + "shard_index": task.shard_index, + "stats": stats, + } + ) + return results + + +class JsonlShardStage(pipelines_v1.Stage[JsonlShardWorkItem, dict]): + """Process JSONL shards using Xenna.""" + + def __init__( + self, + *, + output_root: str, + text_field: str, + transform, + compression: str, + max_rows: int | None, + resolve_hf_placeholders: bool = False, + ) -> None: + self._output_root = output_root + self._text_field = text_field + self._transform = transform + self._compression = compression + self._max_rows = max_rows + self._resolve_hf_placeholders = resolve_hf_placeholders + self._output_fs = None + + @property + def stage_batch_size(self) -> int: + return 1 + + @property + def required_resources(self) -> pipelines_v1.Resources: + return pipelines_v1.Resources(gpus=0, cpus=0.5) + + @property + def env_info(self) -> RuntimeEnv: + return _get_hf_runtime_env() + + def setup(self, worker_metadata: pipelines_v1.WorkerMetadata) -> None: + self._output_fs, _ = get_filesystem(self._output_root) + if self._resolve_hf_placeholders: + resolver = HFPlaceholderResolver.create() + self._transform = resolve_hf_placeholders(resolver=resolver) + + def process_data(self, tasks: list[JsonlShardWorkItem]) -> list[dict[str, Any]]: + results: list[dict[str, Any]] = [] + for task in tasks: + stats = process_jsonl_shard_core( + shard_index=task.shard_index, + files=task.assignment.get("files", []), + output_dir=task.output_dir, + receipts_dir=task.receipts_dir, + output_fs=self._output_fs, + text_field=self._text_field, + transform=self._transform, + compression=self._compression, + max_rows=self._max_rows, + ) + results.append( + { + "dataset_name": task.dataset_name, + "shard_index": task.shard_index, + "stats": stats, + } + ) + return results + + +class ChatSftShardStage(pipelines_v1.Stage[ChatSftShardWorkItem, dict]): + """Process ChatSFT packed shards using Xenna.""" + + def __init__( + self, + *, + resolved_tokenizer: dict, + output_root: str, + messages_field: str, + tools_field: str, + pack_size: int, + algorithm: str, + dtype: str, + chat_template: str | None, + max_doc_tokens: int | None, + max_rows: int | None, + seed: int | None, + used_in_filter: str | None, + used_in_field: str, + ) -> None: + self._resolved_tokenizer = resolved_tokenizer + self._output_root = output_root + self._messages_field = messages_field + self._tools_field = tools_field + self._pack_size = pack_size + self._algorithm = algorithm + self._dtype = dtype + self._chat_template = chat_template + self._max_doc_tokens = max_doc_tokens + self._max_rows = max_rows + self._seed = seed + self._used_in_filter = used_in_filter + self._used_in_field = used_in_field + self._tokenizer = None + self._output_fs = None + + @property + def stage_batch_size(self) -> int: + return 1 + + @property + def required_resources(self) -> pipelines_v1.Resources: + return pipelines_v1.Resources(gpus=0, cpus=1.0) + + @property + def env_info(self) -> RuntimeEnv: + return _get_hf_runtime_env() + + def setup(self, worker_metadata: pipelines_v1.WorkerMetadata) -> None: + from transformers import AutoTokenizer + + self._tokenizer = AutoTokenizer.from_pretrained( + self._resolved_tokenizer["model"], + revision=self._resolved_tokenizer.get("resolved_revision"), + trust_remote_code=self._resolved_tokenizer.get("trust_remote_code", False), + local_files_only=True, # Use cached files to avoid HF rate limits + ) + if self._chat_template: + if self._chat_template == "nano3": + template_path = Path(__file__).parent.parent / "templates" / "nano3.jinja" + with open(template_path) as f: + self._tokenizer.chat_template = f.read() + elif Path(self._chat_template).exists(): + with open(self._chat_template) as f: + self._tokenizer.chat_template = f.read() + else: + self._tokenizer.chat_template = self._chat_template + + self._output_fs, _ = get_filesystem(self._output_root) + + def process_data(self, tasks: list[ChatSftShardWorkItem]) -> list[dict[str, Any]]: + results: list[dict[str, Any]] = [] + for task in tasks: + stats = process_chat_sft_shard_core( + shard_index=task.shard_index, + files=task.assignment.get("files", []), + output_dir=task.output_dir, + receipts_dir=task.receipts_dir, + output_fs=self._output_fs, + tokenizer=self._tokenizer, + messages_field=self._messages_field, + tools_field=self._tools_field, + pack_size=self._pack_size, + algorithm=self._algorithm, + dtype=np.dtype(self._dtype), + chat_template=None, + max_doc_tokens=self._max_doc_tokens, + max_rows=self._max_rows, + seed=self._seed, + used_in_filter=self._used_in_filter, + used_in_field=self._used_in_field, + ) + results.append( + { + "dataset_name": task.dataset_name, + "shard_index": task.shard_index, + "stats": stats, + } + ) + return results + + +class ChatSftSpoolStage(pipelines_v1.Stage[ChatSftShardWorkItem, ChatSftSpoolWorkItem]): + """Tokenize ChatSFT shards into SequenceSpool intermediates (no packing).""" + + def __init__( + self, + *, + resolved_tokenizer: dict, + output_root: str, + messages_field: str, + tools_field: str, + pack_size: int, + algorithm: str, + dtype: str, + chat_template: str | None, + max_doc_tokens: int | None, + max_rows: int | None, + seed: int | None, + used_in_filter: str | None, + used_in_field: str, + ) -> None: + self._resolved_tokenizer = resolved_tokenizer + self._output_root = output_root + self._messages_field = messages_field + self._tools_field = tools_field + self._pack_size = pack_size + self._algorithm = algorithm + self._dtype = dtype + self._chat_template = chat_template + self._max_doc_tokens = max_doc_tokens + self._max_rows = max_rows + self._seed = seed + self._used_in_filter = used_in_filter + self._used_in_field = used_in_field + self._tokenizer = None + self._output_fs = None + + @property + def stage_batch_size(self) -> int: + return 1 + + @property + def required_resources(self) -> pipelines_v1.Resources: + return pipelines_v1.Resources(gpus=0, cpus=1.0) + + @property + def env_info(self) -> RuntimeEnv: + return _get_hf_runtime_env() + + def setup(self, worker_metadata: pipelines_v1.WorkerMetadata) -> None: + from transformers import AutoTokenizer + + self._tokenizer = AutoTokenizer.from_pretrained( + self._resolved_tokenizer["model"], + revision=self._resolved_tokenizer.get("resolved_revision"), + trust_remote_code=self._resolved_tokenizer.get("trust_remote_code", False), + local_files_only=True, # Use cached files to avoid HF rate limits + ) + if self._chat_template: + if self._chat_template == "nano3": + template_path = Path(__file__).parent.parent / "templates" / "nano3.jinja" + with open(template_path) as f: + self._tokenizer.chat_template = f.read() + elif Path(self._chat_template).exists(): + with open(self._chat_template) as f: + self._tokenizer.chat_template = f.read() + else: + self._tokenizer.chat_template = self._chat_template + + self._output_fs, _ = get_filesystem(self._output_root) + + def process_data(self, tasks: list[ChatSftShardWorkItem]) -> list[ChatSftSpoolWorkItem]: + out: list[ChatSftSpoolWorkItem] = [] + for task in tasks: + shard_id = f"shard_{task.shard_index:06d}" + spool_dir = f"{task.output_dir.rstrip('/')}/spool/{shard_id}" + + process_chat_sft_spool_core( + shard_index=task.shard_index, + files=task.assignment.get("files", []), + output_dir=task.output_dir, + receipts_dir=task.receipts_dir, + spool_dir=spool_dir, + output_fs=self._output_fs, + tokenizer=self._tokenizer, + messages_field=self._messages_field, + tools_field=self._tools_field, + pack_size=self._pack_size, + algorithm=self._algorithm, + dtype=np.dtype(self._dtype), + chat_template=None, + max_doc_tokens=self._max_doc_tokens, + max_rows=task.max_rows if task.max_rows is not None else self._max_rows, + seed=self._seed, + used_in_filter=self._used_in_filter, + used_in_field=self._used_in_field, + ) + + out.append( + ChatSftSpoolWorkItem( + dataset_name=task.dataset_name, + shard_index=task.shard_index, + assignment=task.assignment, + output_dir=task.output_dir, + receipts_dir=task.receipts_dir, + spool_dir=spool_dir, + max_rows=task.max_rows, + ) + ) + return out + + +class ChatSftCentralPackStage(pipelines_v1.Stage[ChatSftSpoolWorkItem, dict]): + """Pack ChatSFT SequenceSpool intermediates into packed .npy shards (single-worker stage).""" + + def __init__( + self, + *, + output_root: str, + pack_size: int, + algorithm: str, + dtype: str, + seed: int | None, + ) -> None: + self._output_root = output_root + self._pack_size = pack_size + self._algorithm = algorithm + self._dtype = dtype + self._seed = seed + self._output_fs = None + + @property + def stage_batch_size(self) -> int: + return 1 + + @property + def required_resources(self) -> pipelines_v1.Resources: + return pipelines_v1.Resources(gpus=0, cpus=1.0) + + @property + def env_info(self) -> RuntimeEnv: + return _get_hf_runtime_env() + + def setup(self, worker_metadata: pipelines_v1.WorkerMetadata) -> None: + self._output_fs, _ = get_filesystem(self._output_root) + + def process_data(self, tasks: list[ChatSftSpoolWorkItem]) -> list[dict[str, Any]]: + import gc + + results: list[dict[str, Any]] = [] + for task in tasks: + stats = process_chat_sft_pack_from_spool_core( + shard_index=task.shard_index, + output_dir=task.output_dir, + receipts_dir=task.receipts_dir, + spool_dir=task.spool_dir, + output_fs=self._output_fs, + pack_size=self._pack_size, + algorithm=self._algorithm, + dtype=np.dtype(self._dtype), + seed=self._seed, + ) + results.append( + { + "dataset_name": task.dataset_name, + "shard_index": task.shard_index, + "stats": stats, + } + ) + # Force garbage collection after each shard to prevent memory accumulation + # across sequential tasks in this single-worker stage + gc.collect() + return results + + +class HfPredownloadStage(pipelines_v1.Stage[ShardWorkItem, ShardWorkItem]): + """Pre-download HuggingFace files for a batch of shards.""" + + def __init__( + self, + *, + max_concurrent_downloads: int, + output_root: str, + download_timeout_sec: int = 300, + max_retries: int = 3, + ) -> None: + self._max_concurrent_downloads = max_concurrent_downloads + self._output_root = output_root + self._download_timeout_sec = download_timeout_sec + self._max_retries = max_retries + self._progress_path = None + self._output_fs = None + + def setup(self, worker_metadata: pipelines_v1.WorkerMetadata) -> None: + self._output_fs, base_path = get_filesystem(self._output_root) + progress_dir = f"{base_path.rstrip('/')}/.xenna/downloads" + ensure_dir(self._output_fs, progress_dir) + self._progress_path = f"{progress_dir}/{worker_metadata.worker_id}.json" + + @property + def stage_batch_size(self) -> int: + return 1 + + @property + def required_resources(self) -> pipelines_v1.Resources: + return pipelines_v1.Resources(gpus=0, cpus=0.5) + + @property + def env_info(self) -> RuntimeEnv: + return _get_hf_runtime_env() + + def process_data(self, tasks: list[ShardWorkItem]) -> list[ShardWorkItem]: + if not tasks: + return [] + + unique_files = _collect_unique_hf_files(tasks) + if not unique_files: + return tasks + + cache_dir = None + hf_home = os.environ.get("HF_HOME") + if hf_home: + cache_dir = os.path.join(hf_home, "hub") + + total_files = len(unique_files) + max_workers = min(self._max_concurrent_downloads, total_files) + print(f"[Pre-download] Starting download of {total_files} unique files (max_concurrent={max_workers})") + start_time = time.perf_counter() + completed = 0 + last_report = start_time + self._write_progress(completed, total_files, start_time) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit( + _download_hf_file, + file_info["repo_id"], + file_info["filename"], + file_info["revision"], + cache_dir, + self._download_timeout_sec, + self._max_retries, + ) + for file_info in unique_files + ] + + for future in as_completed(futures): + try: + future.result() + except Exception: + pass + completed += 1 + now = time.perf_counter() + if now - last_report >= 5.0 or completed == total_files: + rate = completed / max(now - start_time, 0.001) + print(f"[Pre-download] {completed}/{total_files} files ({rate:.1f}/s)") + last_report = now + self._write_progress(completed, total_files, start_time) + + self._write_progress(completed, total_files, start_time) + + return tasks + + def _write_progress(self, completed: int, total: int, start_time: float) -> None: + if self._output_fs is None or self._progress_path is None: + return + elapsed = time.perf_counter() - start_time + rate = completed / max(elapsed, 0.001) + payload = { + "completed": completed, + "total": total, + "elapsed_sec": elapsed, + "rate": rate, + "updated_at": time.time(), + } + try: + with self._output_fs.open(self._progress_path, "w") as f: + json.dump(payload, f) + except Exception: + pass + + +def _collect_unique_hf_files(tasks: list[ShardWorkItem]) -> list[dict[str, str]]: + seen: set[tuple[str, str, str]] = set() + unique_files: list[dict[str, str]] = [] + for task in tasks: + files = task.assignment.get("files", []) + for file_info in files: + repo_id = file_info.get("hf_repo_id") + if repo_id is None: + continue + filename = file_info.get("hf_filename") + revision = file_info.get("hf_revision") or "" + key = (repo_id, filename, revision) + if key in seen: + continue + seen.add(key) + unique_files.append( + { + "repo_id": repo_id, + "filename": filename, + "revision": revision or None, + } + ) + return unique_files + +def _download_hf_file( + repo_id: str, + filename: str, + revision: str | None, + cache_dir: str | None, + timeout_sec: int, + max_retries: int, +) -> None: + from huggingface_hub import hf_hub_download + + last_error: Exception | None = None + for attempt in range(1, max_retries + 1): + try: + hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + repo_type="dataset", + local_files_only=False, + cache_dir=cache_dir, + etag_timeout=timeout_sec, + download_timeout=timeout_sec, + ) + return + except TypeError: + hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + repo_type="dataset", + local_files_only=False, + cache_dir=cache_dir, + ) + return + except Exception as exc: + last_error = exc + time.sleep(min(5 * attempt, 20)) + if last_error: + raise last_error diff --git a/src/nemotron/data_prep/xenna/work_items.py b/src/nemotron/data_prep/xenna/work_items.py new file mode 100644 index 000000000..ceb6a134e --- /dev/null +++ b/src/nemotron/data_prep/xenna/work_items.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Work item types passed through Xenna pipelines.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass +class ShardWorkItem: + """Payload for Xenna shard processing.""" + + dataset_name: str + plan_hash: str + shard_index: int + assignment: dict[str, Any] + output_dir: str + receipts_dir: str + text_field: str + dtype: str + min_doc_chars: int | None + max_doc_tokens: int | None + max_rows: int | None + + +@dataclass +class JsonlShardWorkItem: + """Payload for Xenna JSONL shard processing.""" + + dataset_name: str + shard_index: int + assignment: dict[str, Any] + output_dir: str + receipts_dir: str + text_field: str + compression: str + max_rows: int | None + resolve_hf_placeholders: bool = False + + +@dataclass +class ChatSftShardWorkItem: + """Payload for Xenna ChatSFT shard processing.""" + + dataset_name: str + shard_index: int + assignment: dict[str, Any] + output_dir: str + receipts_dir: str + max_rows: int | None + + +@dataclass +class ChatSftSpoolWorkItem: + """Payload for Xenna ChatSFT SequenceSpool generation (tokenize-only).""" + + dataset_name: str + shard_index: int + assignment: dict[str, Any] + output_dir: str + receipts_dir: str + spool_dir: str + max_rows: int | None diff --git a/src/nemotron/kit/run.py b/src/nemotron/kit/run.py index eb6ff076f..7edcc075e 100644 --- a/src/nemotron/kit/run.py +++ b/src/nemotron/kit/run.py @@ -982,6 +982,13 @@ def run_with_nemo_run( except Exception: pass + # Set HF_HOME for persistent dataset caching on Lustre + # Priority: env var > remote_job_dir/hf + if os.environ.get("HF_HOME"): + runtime_env["env_vars"]["HF_HOME"] = os.environ["HF_HOME"] + elif run_config.remote_job_dir: + runtime_env["env_vars"]["HF_HOME"] = f"{run_config.remote_job_dir}/hf" + # Auto-detect Weights & Biases API key for Ray workers try: import wandb diff --git a/src/nemotron/kit/templates/ray_cpu.sub.j2 b/src/nemotron/kit/templates/ray_cpu.sub.j2 index 8eba9f00f..6a6864a2d 100644 --- a/src/nemotron/kit/templates/ray_cpu.sub.j2 +++ b/src/nemotron/kit/templates/ray_cpu.sub.j2 @@ -16,6 +16,10 @@ set -eoux pipefail export PYTHONUNBUFFERED=1 export SLURM_UNBUFFEREDIO=1 +# cosmos_xenna recommended env vars for Ray state API (monitoring) +export RAY_MAX_LIMIT_FROM_API_SERVER=40000 +export RAY_MAX_LIMIT_FROM_DATA_SOURCE=40000 + {%- for env_var in env_vars %} {{env_var}} {%- endfor %} @@ -195,6 +199,7 @@ ray start --head \ --port=${PORT} \ --ray-client-server-port=${RAY_CLIENT_SERVER_PORT} \ --dashboard-port=${DASHBOARD_PORT} \ + --dashboard-host=0.0.0.0 \ \ --node-manager-port=${NODE_MANAGER_PORT} \ --object-manager-port=${OBJECT_MANAGER_PORT} \ @@ -383,6 +388,20 @@ COMMAND="${COMMAND:-{{ command | default('', true) }}}" COMMAND_WORKDIR={{ command_workdir | default('$CONTAINER_CWD') }} if [[ -n "$COMMAND" ]]; then + export RAY_ADDRESS="${head_node_ip}:${PORT}" + # Wait for Ray dashboard to be ready (needed by cosmos_xenna state API) + echo "[INFO] Waiting for Ray dashboard to be ready at http://${head_node_ip}:${DASHBOARD_PORT}..." + dashboard_timeout=60 + dashboard_elapsed=0 + while ! srun --overlap --container-name=ray-head --nodes=1 --ntasks=1 -w "$head_node" curl -sf "http://127.0.0.1:${DASHBOARD_PORT}/api/version" > /dev/null 2>&1; do + if [[ $dashboard_elapsed -ge $dashboard_timeout ]]; then + echo "[WARN] Dashboard not ready after ${dashboard_timeout}s, proceeding anyway..." + break + fi + sleep 2 + dashboard_elapsed=$((dashboard_elapsed + 2)) + done + echo "[INFO] Ray dashboard ready, starting job..." srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$COMMAND_WORKDIR --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/{{ ray_log_prefix }}job.log bash -c "$COMMAND" else echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:" diff --git a/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/default.yaml b/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/default.yaml index fa43ea45f..5fd34caf8 100644 --- a/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/default.yaml +++ b/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/default.yaml @@ -47,8 +47,9 @@ max_doc_tokens: null # Limit rows per dataset for quick tests (null = no limit) sample: null -# Ray Data executor settings (limit actors to avoid OOM on high-CPU nodes) -ray_data_max_actors: 32 +# Ray Data executor settings +# Set to 48 as balance between parallelism and memory on 172GB nodes +ray_data_max_actors: 48 # Console output mode: 'simple' for periodic text updates, 'rich' for animated progress console_mode: simple diff --git a/src/nemotron/recipes/nano3/stage0_pretrain/data_prep.py b/src/nemotron/recipes/nano3/stage0_pretrain/data_prep.py index f0a5f666f..8c6141dbc 100644 --- a/src/nemotron/recipes/nano3/stage0_pretrain/data_prep.py +++ b/src/nemotron/recipes/nano3/stage0_pretrain/data_prep.py @@ -112,6 +112,9 @@ class PreTrainDataPrepConfig: num_actors: int | None = None """Ray actors for parallel processing (None = auto)""" + ray_data_max_actors: int | None = None + """Maximum Ray Data actors (None = auto-detect based on memory)""" + force: bool = False """Force new run, ignoring cache""" @@ -182,6 +185,7 @@ def run_data_prep_main(cfg: PreTrainDataPrepConfig) -> PretrainBlendsArtifact: artifact_name=artifact_name, console_mode=getattr(cfg, "console_mode", "simple"), simple_log_interval_sec=getattr(cfg, "simple_log_interval_sec", 30), + ray_data_max_actors=cfg.ray_data_max_actors, ) artifact = run_data_prep(data_prep_config) print_step_complete(data_prep=artifact) diff --git a/src/nemotron/recipes/nano3/stage0_pretrain/prep_xenna.py b/src/nemotron/recipes/nano3/stage0_pretrain/prep_xenna.py new file mode 100644 index 000000000..5e42cd77c --- /dev/null +++ b/src/nemotron/recipes/nano3/stage0_pretrain/prep_xenna.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data preparation for Nano3 pretraining using Xenna execution.""" + +from __future__ import annotations + +import os +import sys +from dataclasses import dataclass, field +from pathlib import Path + +from nemotron.data_prep import DataPrepConfig, PerSplitConfig, run_data_prep +from nemotron.kit import PretrainBlendsArtifact, print_step_complete +from nemotron.kit.train_script import ( + apply_hydra_overrides, + init_wandb_from_env, + load_omegaconf_yaml, + omegaconf_to_dataclass, + parse_config_and_overrides, +) +from nemotron.kit.wandb import add_wandb_tags + +STAGE_PATH = Path(__file__).parent + +# Default config path relative to this file +DEFAULT_CONFIG_PATH = STAGE_PATH / "config" / "data_prep.yaml" + +# Use NEMO_RUN_DIR for output when running via nemo-run (avoids writing to code dir) +_OUTPUT_BASE = Path(os.environ.get("NEMO_RUN_DIR", ".")) + +# Module-level flag for Ray execution (used by nemotron CLI) +RAY = True + + +@dataclass +class PreTrainDataPrepConfig: + """Pretrain data preparation config.""" + + blend_path: Path = field(default_factory=lambda: STAGE_PATH / "config/data_blend_raw.json") + output_dir: Path = field(default_factory=lambda: _OUTPUT_BASE / "output/nano3/stage0_pretrain") + num_shards: int = 128 + valid_shards: int = 1 + test_shards: int = 1 + tokenizer_model: str = "nvidia/NVIDIA-Nemotron-Nano-9B-v2" + add_bos: bool = False + add_eos: bool = True + text_field: str = "text" + min_doc_chars: int | None = None + max_doc_tokens: int | None = None + sample: int | None = None + num_actors: int | None = None + ray_data_max_actors: int | None = None + force: bool = False + config_name: str = "default" + + def __post_init__(self) -> None: + if isinstance(self.blend_path, str): + self.blend_path = Path(self.blend_path) + if isinstance(self.output_dir, str): + self.output_dir = Path(self.output_dir) + if self.sample is not None: + self.output_dir = self.output_dir / f"sample-{self.sample}" + + +def run_data_prep_main(cfg: PreTrainDataPrepConfig) -> PretrainBlendsArtifact: + """Run pretrain data preparation (Xenna execution).""" + add_wandb_tags(["data-prep", "pretrain", cfg.config_name, "xenna"]) + + try: + import wandb + from dataclasses import asdict + + if wandb.run is not None: + config_dict = asdict(cfg) + for key, value in config_dict.items(): + if isinstance(value, Path): + config_dict[key] = str(value) + wandb.config.update(config_dict) + except ImportError: + pass + + sample_suffix = f"?sample={cfg.sample}" if cfg.sample else "" + artifact_name = f"nano3/{cfg.config_name}/data{sample_suffix}" + + data_prep_config = DataPrepConfig( + blend_path=cfg.blend_path, + output_dir=cfg.output_dir, + num_shards=cfg.num_shards, + per_split=PerSplitConfig( + enabled=True, + valid_shards=cfg.valid_shards, + test_shards=cfg.test_shards, + ), + tokenizer_model=cfg.tokenizer_model, + add_bos=cfg.add_bos, + add_eos=cfg.add_eos, + text_field=cfg.text_field, + min_doc_chars=cfg.min_doc_chars, + max_doc_tokens=cfg.max_doc_tokens, + sample=cfg.sample, + force=cfg.force, + artifact_name=artifact_name, + console_mode=getattr(cfg, "console_mode", "simple"), + simple_log_interval_sec=getattr(cfg, "simple_log_interval_sec", 30), + ray_data_max_actors=cfg.ray_data_max_actors, + execution_engine="xenna", + ) + artifact = run_data_prep(data_prep_config) + print_step_complete(data_prep=artifact) + return artifact + + +def main(cfg: PreTrainDataPrepConfig | None = None) -> PretrainBlendsArtifact: + """Entry point for Xenna pretrain data preparation.""" + if cfg is None: + config_path, cli_overrides = parse_config_and_overrides(default_config=DEFAULT_CONFIG_PATH) + try: + config = load_omegaconf_yaml(config_path) + except FileNotFoundError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + if cli_overrides: + config = apply_hydra_overrides(config, cli_overrides) + + cfg = omegaconf_to_dataclass(config, PreTrainDataPrepConfig) + + init_wandb_from_env() + return run_data_prep_main(cfg) + + +if __name__ == "__main__": + main() diff --git a/src/nemotron/recipes/nano3/stage1_sft/config/data_prep/data_blend_raw.json b/src/nemotron/recipes/nano3/stage1_sft/config/data_prep/data_blend_raw.json index 9d5f26e9e..715b5a4f0 100644 --- a/src/nemotron/recipes/nano3/stage1_sft/config/data_prep/data_blend_raw.json +++ b/src/nemotron/recipes/nano3/stage1_sft/config/data_prep/data_blend_raw.json @@ -1,15 +1,6 @@ { + "_disabled_datasets_comment": "science-mcq and science-rqa temporarily disabled due to corrupted parquet file", "datasets": [ - { - "name": "science-mcq", - "path": "hf://nvidia/Nemotron-Science-v1", - "subset": "MCQ" - }, - { - "name": "science-rqa", - "path": "hf://nvidia/Nemotron-Science-v1", - "subset": "RQA" - }, { "name": "instruction-following-chat", "path": "hf://nvidia/Nemotron-Instruction-Following-Chat-v1", diff --git a/src/nemotron/recipes/nano3/stage1_sft/config/data_prep/default.yaml b/src/nemotron/recipes/nano3/stage1_sft/config/data_prep/default.yaml index 1d773818e..f78518339 100644 --- a/src/nemotron/recipes/nano3/stage1_sft/config/data_prep/default.yaml +++ b/src/nemotron/recipes/nano3/stage1_sft/config/data_prep/default.yaml @@ -1,6 +1,6 @@ run: env: - container: anyscale/ray:2.49.2-py311 + container: anyscale/ray:2.49.2-py312 # Default config for SFT data preparation # diff --git a/src/nemotron/recipes/nano3/stage1_sft/data_prep_xenna.py b/src/nemotron/recipes/nano3/stage1_sft/data_prep_xenna.py new file mode 100644 index 000000000..65eac1f32 --- /dev/null +++ b/src/nemotron/recipes/nano3/stage1_sft/data_prep_xenna.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data preparation for Nano3 SFT stage using Xenna execution.""" + +from __future__ import annotations + +import logging +import os +import sys +import time +from pathlib import Path + +import ray + +from nemotron.data_prep import ( + ChatSftOutputConfig, + DataBlend, + OutputConfig, + PipelineConfig, + TokenizerConfig, + last_mile_process, +) +from nemotron.data_prep.config import DatasetConfig +from nemotron.data_prep.discovery import get_dataset_metadata +from nemotron.kit import SFTDataArtifact, print_step_complete +from nemotron.kit.trackers import InputDatasetInfo, tokenizer_to_uri +from nemotron.kit.train_script import ( + apply_hydra_overrides, + init_wandb_from_env, + load_omegaconf_yaml, + omegaconf_to_dataclass, + parse_config_and_overrides, +) +from nemotron.kit.wandb import add_wandb_tags, finish_wandb +from nemotron.recipes.nano3.stage1_sft.data_prep import ( + SFTDataPrepConfig, + _concatenate_and_split_npy, +) + +logger = logging.getLogger(__name__) + +STAGE_PATH = Path(__file__).parent +DEFAULT_CONFIG_PATH = STAGE_PATH / "config" / "data_prep.yaml" +_OUTPUT_BASE = Path(os.environ.get("NEMO_RUN_DIR", ".")) + +# Module-level flag for Ray execution (used by nemotron CLI) +RAY = True + + +def run_data_prep_main(cfg: SFTDataPrepConfig) -> SFTDataArtifact: + """Run SFT data preparation with Xenna execution.""" + start_time = time.time() + add_wandb_tags(["data-prep", "sft", "xenna"]) + + blend = DataBlend.load(cfg.blend_path) + + num_actors = cfg.num_actors + if num_actors is None: + cpu_count = os.cpu_count() or 4 + num_actors = max(2, min(32, cpu_count * 3 // 4)) + + shards_dir = cfg.output_dir / "_shards" + + format_config = ChatSftOutputConfig( + shard_size=cfg.shard_size, + pack_size=cfg.pack_size, + chat_template=cfg.chat_template, + messages_field=cfg.messages_field, + tools_field=cfg.tools_field, + used_in_filter=cfg.used_in_filter, + used_in_field=cfg.used_in_field, + ) + + pipeline_config = PipelineConfig( + output=OutputConfig( + dir=shards_dir, + format=format_config, + max_doc_tokens=cfg.max_doc_tokens, + max_rows=cfg.sample, + ), + tokenizer=TokenizerConfig(model=cfg.tokenizer_model), + num_actors=num_actors, + force=cfg.force, + execution_engine="xenna", + ) + + # Initialize Ray with runtime_env for HF_HOME and HF_TOKEN propagation + if not ray.is_initialized(): + runtime_env = { + "excludes": [ + "output/", + "outputs/", + "wandb/", + "data/", + "checkpoints/", + "*.bin", + "*.idx", + "*.npy", + "__pycache__/", + ".git/", + ".venv/", + "*.egg-info/", + ], + "env_vars": {}, + } + if os.environ.get("HF_HOME"): + runtime_env["env_vars"]["HF_HOME"] = os.environ["HF_HOME"] + if os.environ.get("HF_TOKEN"): + runtime_env["env_vars"]["HF_TOKEN"] = os.environ["HF_TOKEN"] + ray.init(address="auto", ignore_reinit_error=True, runtime_env=runtime_env) + + logger.info("Running pipeline to generate shards (Xenna)...") + result = last_mile_process(blend, pipeline_config) + + data_paths = result.splits["all"].data_paths if "all" in result.splits else None + logger.info("Concatenating shards and splitting by ratio...") + split_stats = _concatenate_and_split_npy( + shards_dir=shards_dir, + output_dir=cfg.output_dir, + train_ratio=cfg.train_ratio, + valid_ratio=cfg.valid_ratio, + test_ratio=cfg.test_ratio, + pack_size=cfg.pack_size, + data_paths=data_paths, + ) + + if shards_dir.exists(): + import shutil + + logger.info(f"Cleaning up intermediate shards directory: {shards_dir}") + shutil.rmtree(shards_dir) + + elapsed_sec = time.time() - start_time + + source_datasets: list[InputDatasetInfo] = [] + seen_keys: set[str] = set() + for split_datasets in blend.splits.values(): + for dataset in split_datasets: + key = f"{dataset.path}|{dataset.subset or ''}" + if key not in seen_keys: + seen_keys.add(key) + ds_config = DatasetConfig( + name=dataset.name, + path=dataset.path, + split=dataset.split, + subset=dataset.subset, + text_field=dataset.text_field, + ) + hf_metadata = get_dataset_metadata(ds_config) + source_datasets.append( + InputDatasetInfo( + uri=dataset.path, + name=dataset.name, + weight=dataset.weight, + split=dataset.split, + subset=dataset.subset, + text_field=dataset.text_field, + num_rows=hf_metadata.num_rows, + size_bytes=hf_metadata.size_bytes, + ) + ) + + tok_uri = tokenizer_to_uri(cfg.tokenizer_model) + + artifact = SFTDataArtifact( + path=cfg.output_dir.resolve(), + total_tokens=result.total_tokens, + total_sequences=split_stats["total_sequences"], + elapsed_sec=elapsed_sec, + pack_size=cfg.pack_size, + source_datasets=source_datasets, + tokenizer_uri=tok_uri, + training_path=split_stats["training_path"], + validation_path=split_stats["validation_path"], + test_path=split_stats["test_path"], + metadata_path=split_stats["metadata_path"], + ) + artifact.name = f"nano3/sft/data{'?sample=' + str(cfg.sample) if cfg.sample else ''}" + artifact.save() + + finish_wandb(exit_code=0) + print_step_complete(data_prep=artifact) + return artifact + + +def main(cfg: SFTDataPrepConfig | None = None) -> SFTDataArtifact: + """Entry point for Xenna SFT data preparation.""" + if cfg is None: + config_path, cli_overrides = parse_config_and_overrides(default_config=DEFAULT_CONFIG_PATH) + + try: + config = load_omegaconf_yaml(config_path) + except FileNotFoundError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + if cli_overrides: + config = apply_hydra_overrides(config, cli_overrides) + + cfg = omegaconf_to_dataclass(config, SFTDataPrepConfig) + + init_wandb_from_env() + return run_data_prep_main(cfg) + + +if __name__ == "__main__": + main() diff --git a/src/nemotron/recipes/nano3/stage2_rl/config/data_prep/default.yaml b/src/nemotron/recipes/nano3/stage2_rl/config/data_prep/default.yaml index d695f6698..92f84beb4 100644 --- a/src/nemotron/recipes/nano3/stage2_rl/config/data_prep/default.yaml +++ b/src/nemotron/recipes/nano3/stage2_rl/config/data_prep/default.yaml @@ -1,6 +1,6 @@ run: env: - container: anyscale/ray:2.49.2-py311 + container: anyscale/ray:2.49.2-py312 # Config for RL data preparation with HuggingFace placeholder resolution # diff --git a/src/nemotron/recipes/nano3/stage2_rl/data_prep_xenna.py b/src/nemotron/recipes/nano3/stage2_rl/data_prep_xenna.py new file mode 100644 index 000000000..c1461e06e --- /dev/null +++ b/src/nemotron/recipes/nano3/stage2_rl/data_prep_xenna.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data preparation for Nano3 RL stage using Xenna execution.""" + +from __future__ import annotations + +import json +import os +import sys +import time +from pathlib import Path + +import ray + +from nemotron.data_prep import ( + DataBlend, + Dataset, + OutputConfig, + PipelineConfig, + last_mile_process, +) +from nemotron.data_prep.config import DatasetConfig, JsonlOutputConfig +from nemotron.data_prep.discovery import get_dataset_metadata +from nemotron.data_prep.hf_placeholder import HFPlaceholderResolver +from nemotron.kit import SplitJsonlDataArtifact, print_step_complete +from nemotron.kit.trackers import InputDatasetInfo +from nemotron.kit.train_script import ( + apply_hydra_overrides, + init_wandb_from_env, + load_omegaconf_yaml, + omegaconf_to_dataclass, + parse_config_and_overrides, +) +from nemotron.kit.wandb import add_wandb_tags, finish_wandb +from nemotron.recipes.nano3.stage2_rl.data_prep import RLDataPrepConfig + +STAGE_PATH = Path(__file__).parent +DEFAULT_CONFIG_PATH = STAGE_PATH / "config" / "data_prep" / "default.yaml" + +# Module-level flag for Ray execution (used by nemotron CLI) +RAY = True + + +def _run_resolve( + blend: DataBlend, + cfg: RLDataPrepConfig, + num_actors: int, + source_datasets: list[InputDatasetInfo], +) -> SplitJsonlDataArtifact: + from datasets import get_dataset_split_names + + start_time = time.time() + total_sequences = 0 + split_paths: dict[str, Path] = {} + + if len(blend.datasets) != 1: + raise ValueError( + f"Resolve mode expects exactly one dataset in blend, got {len(blend.datasets)}." + ) + + dataset = blend.datasets[0] + + dataset_path = dataset.path + if dataset_path.startswith("hf://"): + dataset_path = dataset_path[5:] + + available_splits = get_dataset_split_names(dataset_path) + + split_name_mapping = { + "train": "train", + "validation": "val", + "test": "test", + } + + for hf_split in available_splits: + output_split_name = split_name_mapping.get(hf_split, hf_split) + split_output_dir = cfg.output_dir / output_split_name + + split_blend = DataBlend( + datasets=[ + Dataset( + name=dataset.name, + path=dataset.path, + split=hf_split, + subset=dataset.subset, + weight=1.0, + text_field=dataset.text_field, + ) + ] + ) + + format_config = JsonlOutputConfig( + shard_size=cfg.shard_size, + transform=None, + resolve_hf_placeholders=True, + ) + + pipeline_config = PipelineConfig( + output=OutputConfig( + dir=split_output_dir, + format=format_config, + max_rows=cfg.sample, + ), + tokenizer=None, + num_actors=num_actors, + force=cfg.force, + execution_engine="xenna", + ) + + result = last_mile_process(split_blend, pipeline_config) + total_sequences += result.total_sequences + + shard_prefix = result.splits["all"].data_paths[1] + shard_files = sorted(Path(shard_prefix).parent.glob("shard_*.jsonl")) + if shard_files: + jsonl_path = shard_files[0] + else: + raise FileNotFoundError(f"No JSONL shard files found at {shard_prefix}") + + split_paths[output_split_name] = str(jsonl_path.resolve()) + + output_dir = cfg.output_dir.resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + resolved_split_paths = {k: str(Path(v).resolve()) for k, v in split_paths.items() if v} + + manifest = { + "train": resolved_split_paths.get("train", ""), + "val": resolved_split_paths.get("val", ""), + "test": resolved_split_paths.get("test", ""), + "mode": "resolve", + "source_splits": available_splits, + } + + manifest_path = output_dir / "manifest.json" + with open(manifest_path, "w") as f: + json.dump(manifest, f, indent=2) + + elapsed = time.time() - start_time + + artifact = SplitJsonlDataArtifact( + path=manifest_path, + total_sequences=total_sequences, + elapsed_sec=elapsed, + source_datasets=source_datasets, + train=resolved_split_paths.get("train"), + val=resolved_split_paths.get("val"), + test=resolved_split_paths.get("test"), + ) + + return artifact + + +def _init_ray_with_hf_env() -> None: + """Initialize Ray with runtime_env for HF_HOME and HF_TOKEN propagation.""" + if not ray.is_initialized(): + runtime_env = { + "excludes": [ + "output/", + "outputs/", + "wandb/", + "data/", + "checkpoints/", + "*.bin", + "*.idx", + "*.npy", + "__pycache__/", + ".git/", + ".venv/", + "*.egg-info/", + ], + "env_vars": {}, + } + if os.environ.get("HF_HOME"): + runtime_env["env_vars"]["HF_HOME"] = os.environ["HF_HOME"] + if os.environ.get("HF_TOKEN"): + runtime_env["env_vars"]["HF_TOKEN"] = os.environ["HF_TOKEN"] + ray.init(address="auto", ignore_reinit_error=True, runtime_env=runtime_env) + + +def run_data_prep_main(cfg: RLDataPrepConfig) -> SplitJsonlDataArtifact: + """Run RL data preparation with placeholder resolution (Xenna).""" + add_wandb_tags(["data-prep", "rl", "xenna"]) + + # Initialize Ray with HF environment propagation + _init_ray_with_hf_env() + + blend = DataBlend.load(cfg.blend_path) + + num_actors = cfg.num_actors + if num_actors is None: + cpu_count = os.cpu_count() or 4 + num_actors = max(2, min(32, cpu_count * 3 // 4)) + + source_datasets: list[InputDatasetInfo] = [] + seen_keys: set[str] = set() + for dataset in blend.datasets: + key = f"{dataset.path}|{dataset.subset or ''}" + if key not in seen_keys: + seen_keys.add(key) + ds_config = DatasetConfig( + name=dataset.name, + path=dataset.path, + split=dataset.split, + subset=dataset.subset, + text_field=dataset.text_field, + ) + hf_metadata = get_dataset_metadata(ds_config) + source_datasets.append( + InputDatasetInfo( + uri=dataset.path, + name=dataset.name, + weight=dataset.weight, + split=dataset.split, + subset=dataset.subset, + text_field=dataset.text_field, + num_rows=hf_metadata.num_rows, + size_bytes=hf_metadata.size_bytes, + ) + ) + + try: + resolver = HFPlaceholderResolver.create() + for ext_ds_info in resolver.get_loaded_datasets_info(): + source_datasets.append( + InputDatasetInfo( + uri=ext_ds_info["uri"], + name=ext_ds_info["name"], + split=ext_ds_info["split"], + num_rows=ext_ds_info["num_rows"], + ) + ) + except Exception: + pass + + artifact = _run_resolve(blend, cfg, num_actors, source_datasets) + + artifact.name = f"nano3/rl/data-resolved{'?sample=' + str(cfg.sample) if cfg.sample else ''}" + artifact.save() + + finish_wandb(exit_code=0) + print_step_complete(data_prep=artifact) + return artifact + + +def main(cfg: RLDataPrepConfig | None = None) -> SplitJsonlDataArtifact: + """Entry point for Xenna RL data preparation.""" + if cfg is None: + config_path, cli_overrides = parse_config_and_overrides(default_config=DEFAULT_CONFIG_PATH) + try: + config = load_omegaconf_yaml(config_path) + except FileNotFoundError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + if cli_overrides: + config = apply_hydra_overrides(config, cli_overrides) + + cfg = omegaconf_to_dataclass(config, RLDataPrepConfig) + + init_wandb_from_env() + return run_data_prep_main(cfg) + + +if __name__ == "__main__": + main() diff --git a/tests/data_prep/test_ray_data.py b/tests/data_prep/test_ray_data.py index c3b3719bc..dab5990f4 100644 --- a/tests/data_prep/test_ray_data.py +++ b/tests/data_prep/test_ray_data.py @@ -292,7 +292,7 @@ def test_default_config(self): assert cfg.min_actors == 2 assert cfg.max_actors == 32 assert cfg.cpus_per_actor == 1.0 - assert cfg.max_tasks_in_flight_per_actor == 2 + assert cfg.max_tasks_in_flight_per_actor == 4 # Increased for better CPU utilization def test_custom_config(self): """Test custom config values.""" @@ -447,8 +447,8 @@ def test_data_prep_config_ray_data_defaults(self): cfg = DataPrepConfig() assert cfg.ray_data_enabled is True # Enabled by default - assert cfg.ray_data_min_actors == 16 # Start with good parallelism - assert cfg.ray_data_max_actors == 64 # Allow scaling on large nodes + assert cfg.ray_data_min_actors == 2 # Start with minimal warm pool + assert cfg.ray_data_max_actors is None # Auto-detect based on CPU and memory assert cfg.ray_data_cpus_per_actor == 1.0 assert cfg.ray_data_max_tasks_in_flight == 2 @@ -1065,7 +1065,7 @@ def test_ray_data_config_with_defaults(self): assert pipeline_cfg.ray_data is not None assert pipeline_cfg.ray_data.enabled is False assert pipeline_cfg.ray_data.min_actors == 2 - assert pipeline_cfg.ray_data.max_actors == 32 + assert pipeline_cfg.ray_data.max_actors is None # Auto-detect based on CPU and memory def test_ray_data_config_enabled(self): """Test enabling RayDataConfig.""" diff --git a/uv.lock b/uv.lock index edb4618d1..5cd06c598 100644 --- a/uv.lock +++ b/uv.lock @@ -182,6 +182,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/4d/d22668674122c08f4d56972297c51a624e64b3ed1efaa40187607a7cb66e/aiohttp-3.13.2-cp314-cp314t-win_amd64.whl", hash = "sha256:ff0a7b0a82a7ab905cbda74006318d1b12e37c797eb1b0d4eb3e316cf47f658f", size = 498093, upload-time = "2025-10-28T20:58:52.782Z" }, ] +[[package]] +name = "aiohttp-cors" +version = "0.8.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6f/6d/d89e846a5444b3d5eb8985a6ddb0daef3774928e1bfbce8e84ec97b0ffa7/aiohttp_cors-0.8.1.tar.gz", hash = "sha256:ccacf9cb84b64939ea15f859a146af1f662a6b1d68175754a07315e305fb1403", size = 38626, upload-time = "2025-03-31T14:16:20.048Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/3b/40a68de458904bcc143622015fff2352b6461cd92fd66d3527bf1c6f5716/aiohttp_cors-0.8.1-py3-none-any.whl", hash = "sha256:3180cf304c5c712d626b9162b195b1db7ddf976a2a25172b35bb2448b890a80d", size = 25231, upload-time = "2025-03-31T14:16:18.478Z" }, +] + [[package]] name = "aioitertools" version = "0.13.0" @@ -406,6 +418,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/96/d32b941a501ab566a16358d68b6eb4e4acc373fab3c3c4d7d9e649f7b4bb/catalogue-2.0.10-py3-none-any.whl", hash = "sha256:58c2de0020aa90f4a2da7dfad161bf7b3b054c86a5f09fcedc0b2b740c109a9f", size = 17325, upload-time = "2023-09-25T06:29:23.337Z" }, ] +[[package]] +name = "cattrs" +version = "25.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6e/00/2432bb2d445b39b5407f0a90e01b9a271475eea7caf913d7a86bcb956385/cattrs-25.3.0.tar.gz", hash = "sha256:1ac88d9e5eda10436c4517e390a4142d88638fe682c436c93db7ce4a277b884a", size = 509321, upload-time = "2025-10-07T12:26:08.737Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d8/2b/a40e1488fdfa02d3f9a653a61a5935ea08b3c2225ee818db6a76c7ba9695/cattrs-25.3.0-py3-none-any.whl", hash = "sha256:9896e84e0a5bf723bc7b4b68f4481785367ce07a8a02e7e9ee6eb2819bc306ff", size = 70738, upload-time = "2025-10-07T12:26:06.603Z" }, +] + [[package]] name = "certifi" version = "2025.11.12" @@ -616,6 +642,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "colorful" +version = "0.5.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/82/31/109ef4bedeb32b4202e02ddb133162457adc4eb890a9ed9c05c9dd126ed0/colorful-0.5.8.tar.gz", hash = "sha256:bb16502b198be2f1c42ba3c52c703d5f651d826076817185f0294c1a549a7445", size = 209361, upload-time = "2025-10-29T11:53:21.663Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/11/25cdf9d5fc21efd30134fc74c43702c6f7ef09ebae8ed927f1283403ad8d/colorful-0.5.8-py2.py3-none-any.whl", hash = "sha256:a9381fdda3337fbaba5771991020abc69676afa102646650b759927892875992", size = 201334, upload-time = "2025-10-29T11:53:20.251Z" }, +] + [[package]] name = "contextlib2" version = "21.6.0" @@ -625,6 +663,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/56/6d6872f79d14c0cb02f1646cbb4592eef935857c0951a105874b7b62a0c3/contextlib2-21.6.0-py2.py3-none-any.whl", hash = "sha256:3fbdb64466afd23abaf6c977627b75b6139a5a3e8ce38405c5b413aed7a0471f", size = 13277, upload-time = "2021-06-27T06:54:20.972Z" }, ] +[[package]] +name = "cosmos-xenna" +version = "0.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "cattrs" }, + { name = "jinja2" }, + { name = "loguru" }, + { name = "obstore" }, + { name = "portpicker" }, + { name = "pulp" }, + { name = "ray", extra = ["default"] }, + { name = "tabulate" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6c/a4/7e11856ffe97d3275114ef4f3007607a6eff592a5372ac3022c3311b5e35/cosmos_xenna-0.1.8.tar.gz", hash = "sha256:af00d5026835409e91c4d7cf8614dc5354e4edd6a54d9dbc24e308b8292640e1", size = 469952, upload-time = "2025-12-03T00:10:50.34Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/08/b020e86b04a48a149e584e080a931ff222bb2917542afa90ca5bba499c9b/cosmos_xenna-0.1.8-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64af8d5d440ab42cc32b98ab733925c79ae5eccac4570d1b2b4fa29dedc838a", size = 4618483, upload-time = "2025-12-03T00:10:47.441Z" }, + { url = "https://files.pythonhosted.org/packages/db/fc/2479bb58e63e202ee7dea7afe3d304e880574a76f93ffa037217a089bc70/cosmos_xenna-0.1.8-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ccf4537f233543f05eb5af17c8559809c13ae1adb5067ded71f26f1c12f27587", size = 4414730, upload-time = "2025-12-03T00:10:48.941Z" }, +] + [[package]] name = "coverage" version = "7.13.0" @@ -823,6 +882,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/50/3d/9373ad9c56321fdab5b41197068e1d8c25883b3fea29dd361f9b55116869/dill-0.4.0-py3-none-any.whl", hash = "sha256:44f54bf6412c2c8464c14e8243eb163690a9800dbe2c367330883b19c7561049", size = 119668, upload-time = "2025-04-16T00:41:47.671Z" }, ] +[[package]] +name = "distlib" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/8e/709914eb2b5749865801041647dc7f4e6d00b549cfe88b65ca192995f07c/distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d", size = 614605, upload-time = "2025-07-17T16:52:00.465Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, +] + [[package]] name = "docker" version = "7.1.0" @@ -1239,6 +1307,67 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/4c/e0ce1ef95d4000ebc1c11801f9b944fa5910ecc15b5e351865763d8657f8/graphviz-0.21-py3-none-any.whl", hash = "sha256:54f33de9f4f911d7e84e4191749cac8cc5653f815b06738c54db9a15ab8b1e42", size = 47300, upload-time = "2025-06-15T09:35:04.433Z" }, ] +[[package]] +name = "grpcio" +version = "1.76.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b6/e0/318c1ce3ae5a17894d5791e87aea147587c9e702f24122cc7a5c8bbaeeb1/grpcio-1.76.0.tar.gz", hash = "sha256:7be78388d6da1a25c0d5ec506523db58b18be22d9c37d8d3a32c08be4987bd73", size = 12785182, upload-time = "2025-10-21T16:23:12.106Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/17/ff4795dc9a34b6aee6ec379f1b66438a3789cd1315aac0cbab60d92f74b3/grpcio-1.76.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:65a20de41e85648e00305c1bb09a3598f840422e522277641145a32d42dcefcc", size = 5840037, upload-time = "2025-10-21T16:20:25.069Z" }, + { url = "https://files.pythonhosted.org/packages/4e/ff/35f9b96e3fa2f12e1dcd58a4513a2e2294a001d64dec81677361b7040c9a/grpcio-1.76.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:40ad3afe81676fd9ec6d9d406eda00933f218038433980aa19d401490e46ecde", size = 11836482, upload-time = "2025-10-21T16:20:30.113Z" }, + { url = "https://files.pythonhosted.org/packages/3e/1c/8374990f9545e99462caacea5413ed783014b3b66ace49e35c533f07507b/grpcio-1.76.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:035d90bc79eaa4bed83f524331d55e35820725c9fbb00ffa1904d5550ed7ede3", size = 6407178, upload-time = "2025-10-21T16:20:32.733Z" }, + { url = "https://files.pythonhosted.org/packages/1e/77/36fd7d7c75a6c12542c90a6d647a27935a1ecaad03e0ffdb7c42db6b04d2/grpcio-1.76.0-cp310-cp310-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:4215d3a102bd95e2e11b5395c78562967959824156af11fa93d18fdd18050990", size = 7075684, upload-time = "2025-10-21T16:20:35.435Z" }, + { url = "https://files.pythonhosted.org/packages/38/f7/e3cdb252492278e004722306c5a8935eae91e64ea11f0af3437a7de2e2b7/grpcio-1.76.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:49ce47231818806067aea3324d4bf13825b658ad662d3b25fada0bdad9b8a6af", size = 6611133, upload-time = "2025-10-21T16:20:37.541Z" }, + { url = "https://files.pythonhosted.org/packages/7e/20/340db7af162ccd20a0893b5f3c4a5d676af7b71105517e62279b5b61d95a/grpcio-1.76.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8cc3309d8e08fd79089e13ed4819d0af72aa935dd8f435a195fd152796752ff2", size = 7195507, upload-time = "2025-10-21T16:20:39.643Z" }, + { url = "https://files.pythonhosted.org/packages/10/f0/b2160addc1487bd8fa4810857a27132fb4ce35c1b330c2f3ac45d697b106/grpcio-1.76.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:971fd5a1d6e62e00d945423a567e42eb1fa678ba89072832185ca836a94daaa6", size = 8160651, upload-time = "2025-10-21T16:20:42.492Z" }, + { url = "https://files.pythonhosted.org/packages/2c/2c/ac6f98aa113c6ef111b3f347854e99ebb7fb9d8f7bb3af1491d438f62af4/grpcio-1.76.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9d9adda641db7207e800a7f089068f6f645959f2df27e870ee81d44701dd9db3", size = 7620568, upload-time = "2025-10-21T16:20:45.995Z" }, + { url = "https://files.pythonhosted.org/packages/90/84/7852f7e087285e3ac17a2703bc4129fafee52d77c6c82af97d905566857e/grpcio-1.76.0-cp310-cp310-win32.whl", hash = "sha256:063065249d9e7e0782d03d2bca50787f53bd0fb89a67de9a7b521c4a01f1989b", size = 3998879, upload-time = "2025-10-21T16:20:48.592Z" }, + { url = "https://files.pythonhosted.org/packages/10/30/d3d2adcbb6dd3ff59d6ac3df6ef830e02b437fb5c90990429fd180e52f30/grpcio-1.76.0-cp310-cp310-win_amd64.whl", hash = "sha256:a6ae758eb08088d36812dd5d9af7a9859c05b1e0f714470ea243694b49278e7b", size = 4706892, upload-time = "2025-10-21T16:20:50.697Z" }, + { url = "https://files.pythonhosted.org/packages/a0/00/8163a1beeb6971f66b4bbe6ac9457b97948beba8dd2fc8e1281dce7f79ec/grpcio-1.76.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:2e1743fbd7f5fa713a1b0a8ac8ebabf0ec980b5d8809ec358d488e273b9cf02a", size = 5843567, upload-time = "2025-10-21T16:20:52.829Z" }, + { url = "https://files.pythonhosted.org/packages/10/c1/934202f5cf335e6d852530ce14ddb0fef21be612ba9ecbbcbd4d748ca32d/grpcio-1.76.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:a8c2cf1209497cf659a667d7dea88985e834c24b7c3b605e6254cbb5076d985c", size = 11848017, upload-time = "2025-10-21T16:20:56.705Z" }, + { url = "https://files.pythonhosted.org/packages/11/0b/8dec16b1863d74af6eb3543928600ec2195af49ca58b16334972f6775663/grpcio-1.76.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:08caea849a9d3c71a542827d6df9d5a69067b0a1efbea8a855633ff5d9571465", size = 6412027, upload-time = "2025-10-21T16:20:59.3Z" }, + { url = "https://files.pythonhosted.org/packages/d7/64/7b9e6e7ab910bea9d46f2c090380bab274a0b91fb0a2fe9b0cd399fffa12/grpcio-1.76.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:f0e34c2079d47ae9f6188211db9e777c619a21d4faba6977774e8fa43b085e48", size = 7075913, upload-time = "2025-10-21T16:21:01.645Z" }, + { url = "https://files.pythonhosted.org/packages/68/86/093c46e9546073cefa789bd76d44c5cb2abc824ca62af0c18be590ff13ba/grpcio-1.76.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8843114c0cfce61b40ad48df65abcfc00d4dba82eae8718fab5352390848c5da", size = 6615417, upload-time = "2025-10-21T16:21:03.844Z" }, + { url = "https://files.pythonhosted.org/packages/f7/b6/5709a3a68500a9c03da6fb71740dcdd5ef245e39266461a03f31a57036d8/grpcio-1.76.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8eddfb4d203a237da6f3cc8a540dad0517d274b5a1e9e636fd8d2c79b5c1d397", size = 7199683, upload-time = "2025-10-21T16:21:06.195Z" }, + { url = "https://files.pythonhosted.org/packages/91/d3/4b1f2bf16ed52ce0b508161df3a2d186e4935379a159a834cb4a7d687429/grpcio-1.76.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:32483fe2aab2c3794101c2a159070584e5db11d0aa091b2c0ea9c4fc43d0d749", size = 8163109, upload-time = "2025-10-21T16:21:08.498Z" }, + { url = "https://files.pythonhosted.org/packages/5c/61/d9043f95f5f4cf085ac5dd6137b469d41befb04bd80280952ffa2a4c3f12/grpcio-1.76.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:dcfe41187da8992c5f40aa8c5ec086fa3672834d2be57a32384c08d5a05b4c00", size = 7626676, upload-time = "2025-10-21T16:21:10.693Z" }, + { url = "https://files.pythonhosted.org/packages/36/95/fd9a5152ca02d8881e4dd419cdd790e11805979f499a2e5b96488b85cf27/grpcio-1.76.0-cp311-cp311-win32.whl", hash = "sha256:2107b0c024d1b35f4083f11245c0e23846ae64d02f40b2b226684840260ed054", size = 3997688, upload-time = "2025-10-21T16:21:12.746Z" }, + { url = "https://files.pythonhosted.org/packages/60/9c/5c359c8d4c9176cfa3c61ecd4efe5affe1f38d9bae81e81ac7186b4c9cc8/grpcio-1.76.0-cp311-cp311-win_amd64.whl", hash = "sha256:522175aba7af9113c48ec10cc471b9b9bd4f6ceb36aeb4544a8e2c80ed9d252d", size = 4709315, upload-time = "2025-10-21T16:21:15.26Z" }, + { url = "https://files.pythonhosted.org/packages/bf/05/8e29121994b8d959ffa0afd28996d452f291b48cfc0875619de0bde2c50c/grpcio-1.76.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:81fd9652b37b36f16138611c7e884eb82e0cec137c40d3ef7c3f9b3ed00f6ed8", size = 5799718, upload-time = "2025-10-21T16:21:17.939Z" }, + { url = "https://files.pythonhosted.org/packages/d9/75/11d0e66b3cdf998c996489581bdad8900db79ebd83513e45c19548f1cba4/grpcio-1.76.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:04bbe1bfe3a68bbfd4e52402ab7d4eb59d72d02647ae2042204326cf4bbad280", size = 11825627, upload-time = "2025-10-21T16:21:20.466Z" }, + { url = "https://files.pythonhosted.org/packages/28/50/2f0aa0498bc188048f5d9504dcc5c2c24f2eb1a9337cd0fa09a61a2e75f0/grpcio-1.76.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d388087771c837cdb6515539f43b9d4bf0b0f23593a24054ac16f7a960be16f4", size = 6359167, upload-time = "2025-10-21T16:21:23.122Z" }, + { url = "https://files.pythonhosted.org/packages/66/e5/bbf0bb97d29ede1d59d6588af40018cfc345b17ce979b7b45424628dc8bb/grpcio-1.76.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:9f8f757bebaaea112c00dba718fc0d3260052ce714e25804a03f93f5d1c6cc11", size = 7044267, upload-time = "2025-10-21T16:21:25.995Z" }, + { url = "https://files.pythonhosted.org/packages/f5/86/f6ec2164f743d9609691115ae8ece098c76b894ebe4f7c94a655c6b03e98/grpcio-1.76.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:980a846182ce88c4f2f7e2c22c56aefd515daeb36149d1c897f83cf57999e0b6", size = 6573963, upload-time = "2025-10-21T16:21:28.631Z" }, + { url = "https://files.pythonhosted.org/packages/60/bc/8d9d0d8505feccfdf38a766d262c71e73639c165b311c9457208b56d92ae/grpcio-1.76.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f92f88e6c033db65a5ae3d97905c8fea9c725b63e28d5a75cb73b49bda5024d8", size = 7164484, upload-time = "2025-10-21T16:21:30.837Z" }, + { url = "https://files.pythonhosted.org/packages/67/e6/5d6c2fc10b95edf6df9b8f19cf10a34263b7fd48493936fffd5085521292/grpcio-1.76.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4baf3cbe2f0be3289eb68ac8ae771156971848bb8aaff60bad42005539431980", size = 8127777, upload-time = "2025-10-21T16:21:33.577Z" }, + { url = "https://files.pythonhosted.org/packages/3f/c8/dce8ff21c86abe025efe304d9e31fdb0deaaa3b502b6a78141080f206da0/grpcio-1.76.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:615ba64c208aaceb5ec83bfdce7728b80bfeb8be97562944836a7a0a9647d882", size = 7594014, upload-time = "2025-10-21T16:21:41.882Z" }, + { url = "https://files.pythonhosted.org/packages/e0/42/ad28191ebf983a5d0ecef90bab66baa5a6b18f2bfdef9d0a63b1973d9f75/grpcio-1.76.0-cp312-cp312-win32.whl", hash = "sha256:45d59a649a82df5718fd9527ce775fd66d1af35e6d31abdcdc906a49c6822958", size = 3984750, upload-time = "2025-10-21T16:21:44.006Z" }, + { url = "https://files.pythonhosted.org/packages/9e/00/7bd478cbb851c04a48baccaa49b75abaa8e4122f7d86da797500cccdd771/grpcio-1.76.0-cp312-cp312-win_amd64.whl", hash = "sha256:c088e7a90b6017307f423efbb9d1ba97a22aa2170876223f9709e9d1de0b5347", size = 4704003, upload-time = "2025-10-21T16:21:46.244Z" }, + { url = "https://files.pythonhosted.org/packages/fc/ed/71467ab770effc9e8cef5f2e7388beb2be26ed642d567697bb103a790c72/grpcio-1.76.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:26ef06c73eb53267c2b319f43e6634c7556ea37672029241a056629af27c10e2", size = 5807716, upload-time = "2025-10-21T16:21:48.475Z" }, + { url = "https://files.pythonhosted.org/packages/2c/85/c6ed56f9817fab03fa8a111ca91469941fb514e3e3ce6d793cb8f1e1347b/grpcio-1.76.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:45e0111e73f43f735d70786557dc38141185072d7ff8dc1829d6a77ac1471468", size = 11821522, upload-time = "2025-10-21T16:21:51.142Z" }, + { url = "https://files.pythonhosted.org/packages/ac/31/2b8a235ab40c39cbc141ef647f8a6eb7b0028f023015a4842933bc0d6831/grpcio-1.76.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:83d57312a58dcfe2a3a0f9d1389b299438909a02db60e2f2ea2ae2d8034909d3", size = 6362558, upload-time = "2025-10-21T16:21:54.213Z" }, + { url = "https://files.pythonhosted.org/packages/bd/64/9784eab483358e08847498ee56faf8ff6ea8e0a4592568d9f68edc97e9e9/grpcio-1.76.0-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:3e2a27c89eb9ac3d81ec8835e12414d73536c6e620355d65102503064a4ed6eb", size = 7049990, upload-time = "2025-10-21T16:21:56.476Z" }, + { url = "https://files.pythonhosted.org/packages/2b/94/8c12319a6369434e7a184b987e8e9f3b49a114c489b8315f029e24de4837/grpcio-1.76.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:61f69297cba3950a524f61c7c8ee12e55c486cb5f7db47ff9dcee33da6f0d3ae", size = 6575387, upload-time = "2025-10-21T16:21:59.051Z" }, + { url = "https://files.pythonhosted.org/packages/15/0f/f12c32b03f731f4a6242f771f63039df182c8b8e2cf8075b245b409259d4/grpcio-1.76.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6a15c17af8839b6801d554263c546c69c4d7718ad4321e3166175b37eaacca77", size = 7166668, upload-time = "2025-10-21T16:22:02.049Z" }, + { url = "https://files.pythonhosted.org/packages/ff/2d/3ec9ce0c2b1d92dd59d1c3264aaec9f0f7c817d6e8ac683b97198a36ed5a/grpcio-1.76.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:25a18e9810fbc7e7f03ec2516addc116a957f8cbb8cbc95ccc80faa072743d03", size = 8124928, upload-time = "2025-10-21T16:22:04.984Z" }, + { url = "https://files.pythonhosted.org/packages/1a/74/fd3317be5672f4856bcdd1a9e7b5e17554692d3db9a3b273879dc02d657d/grpcio-1.76.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:931091142fd8cc14edccc0845a79248bc155425eee9a98b2db2ea4f00a235a42", size = 7589983, upload-time = "2025-10-21T16:22:07.881Z" }, + { url = "https://files.pythonhosted.org/packages/45/bb/ca038cf420f405971f19821c8c15bcbc875505f6ffadafe9ffd77871dc4c/grpcio-1.76.0-cp313-cp313-win32.whl", hash = "sha256:5e8571632780e08526f118f74170ad8d50fb0a48c23a746bef2a6ebade3abd6f", size = 3984727, upload-time = "2025-10-21T16:22:10.032Z" }, + { url = "https://files.pythonhosted.org/packages/41/80/84087dc56437ced7cdd4b13d7875e7439a52a261e3ab4e06488ba6173b0a/grpcio-1.76.0-cp313-cp313-win_amd64.whl", hash = "sha256:f9f7bd5faab55f47231ad8dba7787866b69f5e93bc306e3915606779bbfb4ba8", size = 4702799, upload-time = "2025-10-21T16:22:12.709Z" }, + { url = "https://files.pythonhosted.org/packages/b4/46/39adac80de49d678e6e073b70204091e76631e03e94928b9ea4ecf0f6e0e/grpcio-1.76.0-cp314-cp314-linux_armv7l.whl", hash = "sha256:ff8a59ea85a1f2191a0ffcc61298c571bc566332f82e5f5be1b83c9d8e668a62", size = 5808417, upload-time = "2025-10-21T16:22:15.02Z" }, + { url = "https://files.pythonhosted.org/packages/9c/f5/a4531f7fb8b4e2a60b94e39d5d924469b7a6988176b3422487be61fe2998/grpcio-1.76.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:06c3d6b076e7b593905d04fdba6a0525711b3466f43b3400266f04ff735de0cd", size = 11828219, upload-time = "2025-10-21T16:22:17.954Z" }, + { url = "https://files.pythonhosted.org/packages/4b/1c/de55d868ed7a8bd6acc6b1d6ddc4aa36d07a9f31d33c912c804adb1b971b/grpcio-1.76.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fd5ef5932f6475c436c4a55e4336ebbe47bd3272be04964a03d316bbf4afbcbc", size = 6367826, upload-time = "2025-10-21T16:22:20.721Z" }, + { url = "https://files.pythonhosted.org/packages/59/64/99e44c02b5adb0ad13ab3adc89cb33cb54bfa90c74770f2607eea629b86f/grpcio-1.76.0-cp314-cp314-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:b331680e46239e090f5b3cead313cc772f6caa7d0fc8de349337563125361a4a", size = 7049550, upload-time = "2025-10-21T16:22:23.637Z" }, + { url = "https://files.pythonhosted.org/packages/43/28/40a5be3f9a86949b83e7d6a2ad6011d993cbe9b6bd27bea881f61c7788b6/grpcio-1.76.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2229ae655ec4e8999599469559e97630185fdd53ae1e8997d147b7c9b2b72cba", size = 6575564, upload-time = "2025-10-21T16:22:26.016Z" }, + { url = "https://files.pythonhosted.org/packages/4b/a9/1be18e6055b64467440208a8559afac243c66a8b904213af6f392dc2212f/grpcio-1.76.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:490fa6d203992c47c7b9e4a9d39003a0c2bcc1c9aa3c058730884bbbb0ee9f09", size = 7176236, upload-time = "2025-10-21T16:22:28.362Z" }, + { url = "https://files.pythonhosted.org/packages/0f/55/dba05d3fcc151ce6e81327541d2cc8394f442f6b350fead67401661bf041/grpcio-1.76.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:479496325ce554792dba6548fae3df31a72cef7bad71ca2e12b0e58f9b336bfc", size = 8125795, upload-time = "2025-10-21T16:22:31.075Z" }, + { url = "https://files.pythonhosted.org/packages/4a/45/122df922d05655f63930cf42c9e3f72ba20aadb26c100ee105cad4ce4257/grpcio-1.76.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:1c9b93f79f48b03ada57ea24725d83a30284a012ec27eab2cf7e50a550cbbbcc", size = 7592214, upload-time = "2025-10-21T16:22:33.831Z" }, + { url = "https://files.pythonhosted.org/packages/4a/6e/0b899b7f6b66e5af39e377055fb4a6675c9ee28431df5708139df2e93233/grpcio-1.76.0-cp314-cp314-win32.whl", hash = "sha256:747fa73efa9b8b1488a95d0ba1039c8e2dca0f741612d80415b1e1c560febf4e", size = 4062961, upload-time = "2025-10-21T16:22:36.468Z" }, + { url = "https://files.pythonhosted.org/packages/19/41/0b430b01a2eb38ee887f88c1f07644a1df8e289353b78e82b37ef988fb64/grpcio-1.76.0-cp314-cp314-win_amd64.whl", hash = "sha256:922fa70ba549fce362d2e2871ab542082d66e2aaf0c19480ea453905b01f384e", size = 4834462, upload-time = "2025-10-21T16:22:39.772Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -2149,6 +2278,7 @@ dependencies = [ [package.optional-dependencies] all = [ + { name = "cosmos-xenna" }, { name = "gcsfs" }, { name = "s3fs" }, { name = "sentencepiece" }, @@ -2172,6 +2302,9 @@ sentencepiece = [ wandb = [ { name = "wandb" }, ] +xenna = [ + { name = "cosmos-xenna" }, +] [package.dev-dependencies] dev = [ @@ -2196,6 +2329,8 @@ run = [ [package.metadata] requires-dist = [ { name = "colorama", specifier = ">=0.4.6" }, + { name = "cosmos-xenna", marker = "extra == 'all'" }, + { name = "cosmos-xenna", marker = "extra == 'xenna'" }, { name = "datasets", specifier = ">=2.14.0" }, { name = "fsspec", specifier = ">=2024.0.0" }, { name = "gcsfs", marker = "extra == 'all'", specifier = ">=2024.0.0" }, @@ -2228,7 +2363,7 @@ requires-dist = [ { name = "wandb", marker = "extra == 'wandb'", specifier = ">=0.15.0" }, { name = "xxhash", specifier = ">=3.4.0" }, ] -provides-extras = ["wandb", "s3", "gcs", "sentencepiece", "dev", "all"] +provides-extras = ["wandb", "s3", "gcs", "sentencepiece", "xenna", "dev", "all"] [package.metadata.requires-dev] dev = [{ name = "pytest", specifier = ">=9.0.2" }] @@ -2455,6 +2590,93 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/be/9c/92789c596b8df838baa98fa71844d84283302f7604ed565dafe5a6b5041a/oauthlib-3.3.1-py3-none-any.whl", hash = "sha256:88119c938d2b8fb88561af5f6ee0eec8cc8d552b7bb1f712743136eb7523b7a1", size = 160065, upload-time = "2025-06-19T22:48:06.508Z" }, ] +[[package]] +name = "obstore" +version = "0.8.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/8c/9ec984edd0f3b72226adfaa19b1c61b15823b35b52f311ca4af36d009d15/obstore-0.8.2.tar.gz", hash = "sha256:a467bc4e97169e2ba749981b4fd0936015428d9b8f3fb83a5528536b1b6f377f", size = 168852, upload-time = "2025-09-16T15:34:55.786Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/e9/0a1e340ef262f225ad71f556ccba257896f85ca197f02cd228fe5e20b45a/obstore-0.8.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:49104c0d72688c180af015b02c691fbb6cf6a45b03a9d71b84059ed92dbec704", size = 3622821, upload-time = "2025-09-16T15:32:53.79Z" }, + { url = "https://files.pythonhosted.org/packages/24/86/2b53e8b0a838dbbf89ef5dfddde888770bc1a993c691698dae411a407228/obstore-0.8.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c49776abd416e4d80d003213522d82ad48ed3517bee27a6cf8ce0f0cf4e6337e", size = 3356349, upload-time = "2025-09-16T15:32:55.715Z" }, + { url = "https://files.pythonhosted.org/packages/e8/79/1ba6dc854d7de7704a2c474d723ffeb01b6884f72eea7cbe128efc472f4a/obstore-0.8.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1636372b5e171a98369612d122ea20b955661daafa6519ed8322f4f0cb43ff74", size = 3454842, upload-time = "2025-09-16T15:32:57.072Z" }, + { url = "https://files.pythonhosted.org/packages/ca/03/ca67ccc9b9e63cfc0cd069b84437807fed4ef880be1e445b3f29d11518e0/obstore-0.8.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2efed0d86ad4ebffcbe3d0c4d84f26c2c6b20287484a0a748499c169a8e1f2c4", size = 3688363, upload-time = "2025-09-16T15:32:58.164Z" }, + { url = "https://files.pythonhosted.org/packages/a7/2f/c78eb4352d8be64a072934fe3ff2af79a1d06f4571af7c70d96f9741766b/obstore-0.8.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:00c5542616dc5608de82ab6f6820633c9dbab6ff048e770fb8a5fcd1d30cd656", size = 3960133, upload-time = "2025-09-16T15:32:59.614Z" }, + { url = "https://files.pythonhosted.org/packages/4f/34/9e828d19194e227fd9f1d2dd70710da99c2bd2cd728686d59ea80be10b7c/obstore-0.8.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4d9df46aaf25ce80fff48c53382572adc67b6410611660b798024450281a3129", size = 3925493, upload-time = "2025-09-16T15:33:00.923Z" }, + { url = "https://files.pythonhosted.org/packages/5f/7d/9ec5967f3e2915fbc441f72c3892a7f0fb3618e3ae5c8a44181ce4aa641c/obstore-0.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ccf0f03a7fe453fb8640611c922bce19f021c6aaeee6ee44d6d8fb57db6be48", size = 3769401, upload-time = "2025-09-16T15:33:02.373Z" }, + { url = "https://files.pythonhosted.org/packages/85/bf/00b65013068bde630a7369610a2dae4579315cd6ce82d30e3d23315cf308/obstore-0.8.2-cp310-cp310-manylinux_2_24_aarch64.whl", hash = "sha256:ddfbfadc88c5e9740b687ef0833384329a56cea07b34f44e1c4b00a0e97d94a9", size = 3534383, upload-time = "2025-09-16T15:33:03.903Z" }, + { url = "https://files.pythonhosted.org/packages/52/39/1b684fd96c9a33974fc52f417c52b42c1d50df40b44e588853c4a14d9ab1/obstore-0.8.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:53ad53bb16e64102f39559ec470efd78a5272b5e3b84c53aa0423993ac5575c1", size = 3697939, upload-time = "2025-09-16T15:33:05.355Z" }, + { url = "https://files.pythonhosted.org/packages/85/58/93a2c78935f17fde7e22842598a6373e46a9c32d0243ec3b26b5da92df27/obstore-0.8.2-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:b0b905b46354db0961ab818cad762b9c1ac154333ae5d341934c90635a6bd7ab", size = 3681746, upload-time = "2025-09-16T15:33:09.344Z" }, + { url = "https://files.pythonhosted.org/packages/38/90/225c2972338d18f92e7a56f71e34df6935b0b1bd7458bb6a0d2bd4d48f92/obstore-0.8.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fee235694406ebb2dc4178752cf5587f471d6662659b082e9786c716a0a9465c", size = 3765156, upload-time = "2025-09-16T15:33:10.457Z" }, + { url = "https://files.pythonhosted.org/packages/79/eb/aca27e895bfcbbcd2bf05ea6a2538a94b718e6f6d72986e16ab158b753ec/obstore-0.8.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:6c36faf7ace17dd0832aa454118a63ea21862e3d34f71b9297d0c788d00f4985", size = 3941190, upload-time = "2025-09-16T15:33:11.59Z" }, + { url = "https://files.pythonhosted.org/packages/33/ce/c8251a397e7507521768f05bc355b132a0daaff3739e861e51fa6abd821e/obstore-0.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:948a1db1d34f88cfc7ab7e0cccdcfd84cf3977365634599c95ba03b4ef80d1c4", size = 3970041, upload-time = "2025-09-16T15:33:13.035Z" }, + { url = "https://files.pythonhosted.org/packages/2f/c4/018f90701f1e5ea3fbd57f61463f42e1ef5218e548d3adcf12b6be021c34/obstore-0.8.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:2edaa97687c191c5324bb939d72f6fe86a7aa8191c410f1648c14e8296d05c1c", size = 3622568, upload-time = "2025-09-16T15:33:14.196Z" }, + { url = "https://files.pythonhosted.org/packages/a8/62/72dd1e7d52fc554bb1fdb1a9499bda219cf3facea5865a1d97fdc00b3a1b/obstore-0.8.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c4fb7ef8108f08d14edc8bec9e9a6a2e5c4d14eddb8819f5d0da498aff6e8888", size = 3356109, upload-time = "2025-09-16T15:33:15.315Z" }, + { url = "https://files.pythonhosted.org/packages/e0/ae/089fe5b9207091252fe5ce352551214f04560f85eb8f2cc4f716a6a1a57e/obstore-0.8.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fda8f658c0edf799ab1e264f9b12c7c184cd09a5272dc645d42e987810ff2772", size = 3454588, upload-time = "2025-09-16T15:33:16.421Z" }, + { url = "https://files.pythonhosted.org/packages/ea/10/1865ae2d1ba45e8ae85fb0c1aada2dc9533baf60c4dfe74dab905348d74a/obstore-0.8.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:87fe2bc15ce4051ecb56abd484feca323c2416628beb62c1c7b6712114564d6e", size = 3688627, upload-time = "2025-09-16T15:33:17.604Z" }, + { url = "https://files.pythonhosted.org/packages/a6/09/5d7ba6d0aeac563ea5f5586401c677bace4f782af83522b1fdf15430e152/obstore-0.8.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2482aa2562ab6a4ca40250b26bea33f8375b59898a9b5615fd412cab81098123", size = 3959896, upload-time = "2025-09-16T15:33:18.789Z" }, + { url = "https://files.pythonhosted.org/packages/16/15/2b3eda59914761a9ff4d840e2daec5697fd29b293bd18d3dc11c593aed06/obstore-0.8.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4153b928f5d2e9c6cb645e83668a53e0b42253d1e8bcb4e16571fc0a1434599a", size = 3933162, upload-time = "2025-09-16T15:33:19.935Z" }, + { url = "https://files.pythonhosted.org/packages/14/7a/5fc63b41526587067537fb1498c59a210884664c65ccf0d1f8f823b0875a/obstore-0.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbfa9c38620cc191be98c8b5558c62071e495dc6b1cc724f38293ee439aa9f92", size = 3769605, upload-time = "2025-09-16T15:33:21.389Z" }, + { url = "https://files.pythonhosted.org/packages/77/4e/2208ab6e1fc021bf8b7e117249a10ab75d0ed24e0f2de1a8d7cd67d885b5/obstore-0.8.2-cp311-cp311-manylinux_2_24_aarch64.whl", hash = "sha256:0822836eae8d52499f10daef17f26855b4c123119c6eb984aa4f2d525ec2678d", size = 3534396, upload-time = "2025-09-16T15:33:22.574Z" }, + { url = "https://files.pythonhosted.org/packages/1d/8f/a0e2882edd6bd285c82b8a5851c4ecf386c93fe75b6e340d5d9d30e809fc/obstore-0.8.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8ef6435dfd586d83b4f778e7927a5d5b0d8b771e9ba914bc809a13d7805410e6", size = 3697777, upload-time = "2025-09-16T15:33:23.723Z" }, + { url = "https://files.pythonhosted.org/packages/94/78/ebf0c33bed5c9a8eed3b00eefafbcc0a687eeb1e05451c76fcf199d29ff8/obstore-0.8.2-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:0f2cba91f4271ca95a932a51aa8dda1537160342b33f7836c75e1eb9d40621a2", size = 3681546, upload-time = "2025-09-16T15:33:24.935Z" }, + { url = "https://files.pythonhosted.org/packages/af/21/9bf4fb9e53fd5f01af580b6538de2eae857e31d24b0ebfc4d916c306a1e4/obstore-0.8.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:23c876d603af0627627808d19a58d43eb5d8bfd02eecd29460bc9a58030fed55", size = 3765336, upload-time = "2025-09-16T15:33:26.069Z" }, + { url = "https://files.pythonhosted.org/packages/dd/3c/7f6895c23719482d231b2d6ed328e3223fdf99785f6850fba8d2fc5a86ee/obstore-0.8.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ff3c4b5d07629b70b9dee494cd6b94fff8465c3864752181a1cb81a77190fe42", size = 3941142, upload-time = "2025-09-16T15:33:27.275Z" }, + { url = "https://files.pythonhosted.org/packages/93/a4/56ccdb756161595680a28f4b0def2c04f7048ffacf128029be8394367b26/obstore-0.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:aadb2cb72de7227d07f4570f82729625ffc77522fadca5cf13c3a37fbe8c8de9", size = 3970172, upload-time = "2025-09-16T15:33:28.393Z" }, + { url = "https://files.pythonhosted.org/packages/2b/dc/60fefbb5736e69eab56657bca04ca64dc07fdeccb3814164a31b62ad066b/obstore-0.8.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:bb70ce297a47392b1d9a3e310f18d59cd5ebbb9453428210fef02ed60e4d75d1", size = 3612955, upload-time = "2025-09-16T15:33:29.527Z" }, + { url = "https://files.pythonhosted.org/packages/d2/8b/844e8f382e5a12b8a3796a05d76a03e12c7aedc13d6900419e39207d7868/obstore-0.8.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1619bf618428abf1f607e0b219b2e230a966dcf697b717deccfa0983dd91f646", size = 3346564, upload-time = "2025-09-16T15:33:30.698Z" }, + { url = "https://files.pythonhosted.org/packages/89/73/8537f99e09a38a54a6a15ede907aa25d4da089f767a808f0b2edd9c03cec/obstore-0.8.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a4605c3ed7c9515aeb4c619b5f7f2c9986ed4a79fe6045e536b5e59b804b1476", size = 3460809, upload-time = "2025-09-16T15:33:31.837Z" }, + { url = "https://files.pythonhosted.org/packages/b4/99/7714dec721e43f521d6325a82303a002cddad089437640f92542b84e9cc8/obstore-0.8.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce42670417876dd8668cbb8659e860e9725e5f26bbc86449fd259970e2dd9d18", size = 3692081, upload-time = "2025-09-16T15:33:33.028Z" }, + { url = "https://files.pythonhosted.org/packages/ec/bd/4ac4175fe95a24c220a96021c25c432bcc0c0212f618be0737184eebbaad/obstore-0.8.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4a3e893b2a06585f651c541c1972fe1e3bf999ae2a5fda052ee55eb7e6516f5", size = 3957466, upload-time = "2025-09-16T15:33:34.528Z" }, + { url = "https://files.pythonhosted.org/packages/4e/04/caa288fb735484fc5cb019bdf3d896eaccfae0ac4622e520d05692c46790/obstore-0.8.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:08462b32f95a9948ed56ed63e88406e2e5a4cae1fde198f9682e0fb8487100ed", size = 3951293, upload-time = "2025-09-16T15:33:35.733Z" }, + { url = "https://files.pythonhosted.org/packages/44/2f/d380239da2d6a1fda82e17df5dae600a404e8a93a065784518ff8325d5f6/obstore-0.8.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a0bf7763292a8fc47d01cd66e6f19002c5c6ad4b3ed4e6b2729f5e190fa8a0d", size = 3766199, upload-time = "2025-09-16T15:33:36.904Z" }, + { url = "https://files.pythonhosted.org/packages/28/41/d391be069d3da82969b54266948b2582aeca5dd735abeda4d63dba36e07b/obstore-0.8.2-cp312-cp312-manylinux_2_24_aarch64.whl", hash = "sha256:bcd47f8126cb192cbe86942b8f73b1c45a651ce7e14c9a82c5641dfbf8be7603", size = 3529678, upload-time = "2025-09-16T15:33:38.221Z" }, + { url = "https://files.pythonhosted.org/packages/b9/4c/4862fdd1a3abde459ee8eea699b1797df638a460af235b18ca82c8fffb72/obstore-0.8.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:57eda9fd8c757c3b4fe36cf3918d7e589cc1286591295cc10b34122fa36dd3fd", size = 3698079, upload-time = "2025-09-16T15:33:39.696Z" }, + { url = "https://files.pythonhosted.org/packages/68/ca/014e747bc53b570059c27e3565b2316fbe5c107d4134551f4cd3e24aa667/obstore-0.8.2-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:ea44442aad8992166baa69f5069750979e4c5d9ffce772e61565945eea5774b9", size = 3687154, upload-time = "2025-09-16T15:33:40.92Z" }, + { url = "https://files.pythonhosted.org/packages/6f/89/6db5f8edd93028e5b8bfbeee15e6bd3e56f72106107d31cb208b57659de4/obstore-0.8.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:41496a3ab8527402db4142aaaf0d42df9d7d354b13ba10d9c33e0e48dd49dd96", size = 3773444, upload-time = "2025-09-16T15:33:42.123Z" }, + { url = "https://files.pythonhosted.org/packages/26/e5/c9e2cc540689c873beb61246e1615d6e38301e6a34dec424f5a5c63c1afd/obstore-0.8.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:43da209803f052df96c7c3cbec512d310982efd2407e4a435632841a51143170", size = 3939315, upload-time = "2025-09-16T15:33:43.252Z" }, + { url = "https://files.pythonhosted.org/packages/4d/c9/bb53280ca50103c1ffda373cdc9b0f835431060039c2897cbc87ddd92e42/obstore-0.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:1836f5dcd49f9f2950c75889ab5c51fb290d3ea93cdc39a514541e0be3af016e", size = 3978234, upload-time = "2025-09-16T15:33:44.393Z" }, + { url = "https://files.pythonhosted.org/packages/f0/5d/8c3316cc958d386d5e6ab03e9db9ddc27f8e2141cee4a6777ae5b92f3aac/obstore-0.8.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:212f033e53fe6e53d64957923c5c88949a400e9027f7038c705ec2e9038be563", size = 3612027, upload-time = "2025-09-16T15:33:45.6Z" }, + { url = "https://files.pythonhosted.org/packages/ea/4d/699359774ce6330130536d008bfc32827fab0c25a00238d015a5974a3d1d/obstore-0.8.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bee21fa4ba148d08fa90e47a96df11161661ed31e09c056a373cb2154b0f2852", size = 3344686, upload-time = "2025-09-16T15:33:47.185Z" }, + { url = "https://files.pythonhosted.org/packages/82/37/55437341f10512906e02fd9fa69a8a95ad3f2f6a916d3233fda01763d110/obstore-0.8.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4c66594b59832ff1ced4c72575d9beb8b5f9b4e404ac1150a42bfb226617fd50", size = 3459860, upload-time = "2025-09-16T15:33:48.382Z" }, + { url = "https://files.pythonhosted.org/packages/7a/51/4245a616c94ee4851965e33f7a563ab4090cc81f52cc73227ff9ceca2e46/obstore-0.8.2-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:089f33af5c2fe132d00214a0c1f40601b28f23a38e24ef9f79fb0576f2730b74", size = 3691648, upload-time = "2025-09-16T15:33:49.524Z" }, + { url = "https://files.pythonhosted.org/packages/4e/f1/4e2fb24171e3ca3641a4653f006be826e7e17634b11688a5190553b00b83/obstore-0.8.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d87f658dfd340d5d9ea2d86a7c90d44da77a0db9e00c034367dca335735110cf", size = 3956867, upload-time = "2025-09-16T15:33:51.082Z" }, + { url = "https://files.pythonhosted.org/packages/42/f5/b703115361c798c9c1744e1e700d5908d904a8c2e2bd38bec759c9ffb469/obstore-0.8.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6e2e4fa92828c4fbc2d487f3da2d3588701a1b67d9f6ca3c97cc2afc912e9c63", size = 3950599, upload-time = "2025-09-16T15:33:52.173Z" }, + { url = "https://files.pythonhosted.org/packages/53/20/08c6dc0f20c1394e2324b9344838e4e7af770cdcb52c30757a475f50daeb/obstore-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab440e89c5c37a8ec230857dd65147d4b923e0cada33297135d05e0f937d696a", size = 3765865, upload-time = "2025-09-16T15:33:53.291Z" }, + { url = "https://files.pythonhosted.org/packages/77/20/77907765e29b2eba6bd8821872284d91170d7084f670855b2dfcb249ea14/obstore-0.8.2-cp313-cp313-manylinux_2_24_aarch64.whl", hash = "sha256:b9beed107c5c9cd995d4a73263861fcfbc414d58773ed65c14f80eb18258a932", size = 3529807, upload-time = "2025-09-16T15:33:54.535Z" }, + { url = "https://files.pythonhosted.org/packages/a5/f5/f629d39cc30d050f52b1bf927e4d65c1cc7d7ffbb8a635cd546b5c5219a0/obstore-0.8.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b75b4e7746292c785e31edcd5aadc8b758238372a19d4c5e394db5c305d7d175", size = 3693629, upload-time = "2025-09-16T15:33:56.016Z" }, + { url = "https://files.pythonhosted.org/packages/30/ff/106763fd10f2a1cb47f2ef1162293c78ad52f4e73223d8d43fc6b755445d/obstore-0.8.2-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:f33e6c366869d05ab0b7f12efe63269e631c5450d95d6b4ba4c5faf63f69de70", size = 3686176, upload-time = "2025-09-16T15:33:57.247Z" }, + { url = "https://files.pythonhosted.org/packages/ce/0c/d2ccb6f32feeca906d5a7c4255340df5262af8838441ca06c9e4e37b67d5/obstore-0.8.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:12c885a9ce5ceb09d13cc186586c0c10b62597eff21b985f6ce8ff9dab963ad3", size = 3773081, upload-time = "2025-09-16T15:33:58.475Z" }, + { url = "https://files.pythonhosted.org/packages/fa/79/40d1cc504cefc89c9b3dd8874287f3fddc7d963a8748d6dffc5880222013/obstore-0.8.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4accc883b93349a81c9931e15dd318cc703b02bbef2805d964724c73d006d00e", size = 3938589, upload-time = "2025-09-16T15:33:59.734Z" }, + { url = "https://files.pythonhosted.org/packages/14/dd/916c6777222db3271e9fb3cf9a97ed92b3a9b3e465bdeec96de9ab809d53/obstore-0.8.2-cp313-cp313-win_amd64.whl", hash = "sha256:ec850adf9980e5788a826ccfd5819989724e2a2f712bfa3258e85966c8d9981e", size = 3977768, upload-time = "2025-09-16T15:34:01.25Z" }, + { url = "https://files.pythonhosted.org/packages/f1/61/66f8dc98bbf5613bbfe5bf21747b4c8091442977f4bd897945895ab7325c/obstore-0.8.2-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:1431e40e9bb4773a261e51b192ea6489d0799b9d4d7dbdf175cdf813eb8c0503", size = 3623364, upload-time = "2025-09-16T15:34:02.957Z" }, + { url = "https://files.pythonhosted.org/packages/1a/66/6d527b3027e42f625c8fc816ac7d19b0d6228f95bfe7666e4d6b081d2348/obstore-0.8.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:ddb39d4da303f50b959da000aa42734f6da7ac0cc0be2d5a7838b62c97055bb9", size = 3347764, upload-time = "2025-09-16T15:34:04.236Z" }, + { url = "https://files.pythonhosted.org/packages/0d/79/c00103302b620192ea447a948921ad3fed031ce3d19e989f038e1183f607/obstore-0.8.2-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e01f4e13783db453e17e005a4a3ceff09c41c262e44649ba169d253098c775e8", size = 3460981, upload-time = "2025-09-16T15:34:05.595Z" }, + { url = "https://files.pythonhosted.org/packages/3d/d9/bfe4ed4b1aebc45b56644dd5b943cf8e1673505cccb352e66878a457e807/obstore-0.8.2-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:df0fc2d0bc17caff9b538564ddc26d7616f7e8b7c65b1a3c90b5048a8ad2e797", size = 3692711, upload-time = "2025-09-16T15:34:06.796Z" }, + { url = "https://files.pythonhosted.org/packages/13/47/cd6c2cbb18e1f40c77e7957a4a03d2d83f1859a2e876a408f1ece81cad4c/obstore-0.8.2-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e439d06c99a140348f046c9f598ee349cc2dcd9105c15540a4b231f9cc48bbae", size = 3958362, upload-time = "2025-09-16T15:34:08.277Z" }, + { url = "https://files.pythonhosted.org/packages/3d/ea/5ee82bf23abd71c7d6a3f2d008197ae8f8f569d41314c26a8f75318245be/obstore-0.8.2-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0e37d9046669fcc59522d0faf1d105fcbfd09c84cccaaa1e809227d8e030f32c", size = 3957082, upload-time = "2025-09-16T15:34:09.477Z" }, + { url = "https://files.pythonhosted.org/packages/cb/ee/46650405e50fdaa8d95f30375491f9c91fac9517980e8a28a4a6af66927f/obstore-0.8.2-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2646fdcc4bbe92dc2bb5bcdff15574da1211f5806c002b66d514cee2a23c7cb8", size = 3775539, upload-time = "2025-09-16T15:34:10.726Z" }, + { url = "https://files.pythonhosted.org/packages/35/d6/348a7ebebe2ca3d94dfc75344ea19675ae45472823e372c1852844078307/obstore-0.8.2-cp314-cp314-manylinux_2_24_aarch64.whl", hash = "sha256:e31a7d37675056d93dfc244605089dee67f5bba30f37c88436623c8c5ad9ba9d", size = 3535048, upload-time = "2025-09-16T15:34:12.076Z" }, + { url = "https://files.pythonhosted.org/packages/41/07/b7a16cc0da91a4b902d47880ad24016abfe7880c63f7cdafda45d89a2f91/obstore-0.8.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:656313dd8170dde0f0cd471433283337a63912e8e790a121f7cc7639c83e3816", size = 3699035, upload-time = "2025-09-16T15:34:13.331Z" }, + { url = "https://files.pythonhosted.org/packages/7f/74/3269a3a58347e0b019742d888612c4b765293c9c75efa44e144b1e884c0d/obstore-0.8.2-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:329038c9645d6d1741e77fe1a53e28a14b1a5c1461cfe4086082ad39ebabf981", size = 3687307, upload-time = "2025-09-16T15:34:14.501Z" }, + { url = "https://files.pythonhosted.org/packages/01/f9/4fd4819ad6a49d2f462a45be453561f4caebded0dc40112deeffc34b89b1/obstore-0.8.2-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:1e4df99b369790c97c752d126b286dc86484ea49bff5782843a265221406566f", size = 3776076, upload-time = "2025-09-16T15:34:16.207Z" }, + { url = "https://files.pythonhosted.org/packages/14/dd/7c4f958fa0b9fc4778fb3d232e38b37db8c6b260f641022fbba48b049d7e/obstore-0.8.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:9e1c65c65e20cc990414a8a9af88209b1bbc0dd9521b5f6b0293c60e19439bb7", size = 3947445, upload-time = "2025-09-16T15:34:17.423Z" }, + { url = "https://files.pythonhosted.org/packages/c3/37/14bae1f5bf4369027abc5315cdba2428ad4c16e2fd3bd5d35b7ee584aa0c/obstore-0.8.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6ea04118980a9c22fc8581225ff4507b6a161baf8949d728d96e68326ebaab59", size = 3624857, upload-time = "2025-09-16T15:34:35.601Z" }, + { url = "https://files.pythonhosted.org/packages/1a/c4/8cba91629aa20479ba86a57c2c2b3bc0a54fc6a31a4594014213603efae6/obstore-0.8.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:5f33a7570b6001b54252260fbec18c3f6d21e25d3ec57e9b6c5e7330e8290eb2", size = 3355999, upload-time = "2025-09-16T15:34:36.954Z" }, + { url = "https://files.pythonhosted.org/packages/f2/10/3e40557d6d9c38c5a0f7bac1508209b9dbb8c4da918ddfa9326ba9a1de3f/obstore-0.8.2-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:11fa78dfb749edcf5a041cd6db20eae95b3e8b09dfdd9b38d14939da40e7c115", size = 3457322, upload-time = "2025-09-16T15:34:38.143Z" }, + { url = "https://files.pythonhosted.org/packages/1d/01/dcf7988350c286683698cbdd8c15498aec43cbca72eaabad06fd77f0f34a/obstore-0.8.2-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:872bc0921ff88305884546ba05e258ccd95672a03d77db123f0d0563fd3c000b", size = 3689452, upload-time = "2025-09-16T15:34:39.638Z" }, + { url = "https://files.pythonhosted.org/packages/97/02/643eb2ede58933e47bdbc92786058c83d9aa569826d5bf6e83362d24a27a/obstore-0.8.2-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:72556a2fbf018edd921286283e5c7eec9f69a21c6d12516d8a44108eceaa526a", size = 3961171, upload-time = "2025-09-16T15:34:41.232Z" }, + { url = "https://files.pythonhosted.org/packages/d8/5d/c0b515df6089d0f54109de8031a6f6ed31271361948bee90ab8271d22f79/obstore-0.8.2-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:75fa1abf21499dfcfb0328941a175f89a9aa58245bf00e3318fe928e4b10d297", size = 3935988, upload-time = "2025-09-16T15:34:42.501Z" }, + { url = "https://files.pythonhosted.org/packages/7b/97/114d7bc172bb846472181d6fa3e950172ee1b1ccd11291777303c499dbdd/obstore-0.8.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f54f72f30cd608c4399679781c884bf8a0e816c1977a2fac993bf5e1fb30609f", size = 3771781, upload-time = "2025-09-16T15:34:44.405Z" }, + { url = "https://files.pythonhosted.org/packages/c3/43/4aa6de6dc406ef5e109b21a5614c34999575de638254deb456703fae24aa/obstore-0.8.2-pp310-pypy310_pp73-manylinux_2_24_aarch64.whl", hash = "sha256:b044ebf1bf7b8f7b0ca309375c1cd9e140be79e072ae8c70bbd5d9b2ad1f7678", size = 3536689, upload-time = "2025-09-16T15:34:45.649Z" }, + { url = "https://files.pythonhosted.org/packages/06/a5/870ce541aa1a9ee1d9c3e99c2187049bf5a4d278ee9678cc449aae0a4e68/obstore-0.8.2-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:b1326cd2288b64d6fe8857cc22d3a8003b802585fc0741eff2640a8dc35e8449", size = 3700560, upload-time = "2025-09-16T15:34:47.252Z" }, + { url = "https://files.pythonhosted.org/packages/7d/93/76a5fc3833aaa833b4152950d9cdfd328493a48316c24e32ddefe9b8870f/obstore-0.8.2-pp310-pypy310_pp73-musllinux_1_2_armv7l.whl", hash = "sha256:ba6863230648a9b0e11502d2745d881cf74262720238bc0093c3eabd22a3b24c", size = 3683450, upload-time = "2025-09-16T15:34:49.589Z" }, + { url = "https://files.pythonhosted.org/packages/15/3c/4c389362c187630c42f61ef9214e67fc336e44b8aafc47cf49ba9ab8007d/obstore-0.8.2-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:887615da9eeefeb2df849d87c380e04877487aa29dbeb367efc3f17f667470d3", size = 3766628, upload-time = "2025-09-16T15:34:51.937Z" }, + { url = "https://files.pythonhosted.org/packages/03/12/08547e63edf2239ec6660af434602208ab6f394955ef660a6edda13a0bee/obstore-0.8.2-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:4eec1fb32ffa4fb9fe9ad584611ff031927a5c22732b56075ee7204f0e35ebdf", size = 3944069, upload-time = "2025-09-16T15:34:54.108Z" }, +] + [[package]] name = "omegaconf" version = "2.3.0" @@ -2468,6 +2690,95 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b", size = 79500, upload-time = "2022-12-08T20:59:19.686Z" }, ] +[[package]] +name = "opencensus" +version = "0.11.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "opencensus-context" }, + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/15/a7/a46dcffa1b63084f9f17fe3c8cb20724c4c8f91009fd0b2cfdb27d5d2b35/opencensus-0.11.4.tar.gz", hash = "sha256:cbef87d8b8773064ab60e5c2a1ced58bbaa38a6d052c41aec224958ce544eff2", size = 64966, upload-time = "2024-01-03T18:04:07.085Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/ed/9fbdeb23a09e430d87b7d72d430484b88184633dc50f6bfb792354b6f661/opencensus-0.11.4-py2.py3-none-any.whl", hash = "sha256:a18487ce68bc19900336e0ff4655c5a116daf10c1b3685ece8d971bddad6a864", size = 128225, upload-time = "2024-01-03T18:04:05.127Z" }, +] + +[[package]] +name = "opencensus-context" +version = "0.1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/96/3b6f638f6275a8abbd45e582448723bffa29c1fb426721dedb5c72f7d056/opencensus-context-0.1.3.tar.gz", hash = "sha256:a03108c3c10d8c80bb5ddf5c8a1f033161fa61972a9917f9b9b3a18517f0088c", size = 4066, upload-time = "2022-08-03T22:20:22.359Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/68/162c97ea78c957d68ecf78a5c5041d2e25bd5562bdf5d89a6cbf7f8429bf/opencensus_context-0.1.3-py2.py3-none-any.whl", hash = "sha256:073bb0590007af276853009fac7e4bab1d523c3f03baf4cb4511ca38967c6039", size = 5060, upload-time = "2022-08-03T22:20:20.352Z" }, +] + +[[package]] +name = "opentelemetry-api" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/97/b9/3161be15bb8e3ad01be8be5a968a9237c3027c5be504362ff800fca3e442/opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c", size = 65767, upload-time = "2025-12-11T13:32:39.182Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/df/d3f1ddf4bb4cb50ed9b1139cc7b1c54c34a1e7ce8fd1b9a37c0d1551a6bd/opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950", size = 66356, upload-time = "2025-12-11T13:32:17.304Z" }, +] + +[[package]] +name = "opentelemetry-exporter-prometheus" +version = "0.60b1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-sdk" }, + { name = "prometheus-client" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/14/39/7dafa6fff210737267bed35a8855b6ac7399b9e582b8cf1f25f842517012/opentelemetry_exporter_prometheus-0.60b1.tar.gz", hash = "sha256:a4011b46906323f71724649d301b4dc188aaa068852e814f4df38cc76eac616b", size = 14976, upload-time = "2025-12-11T13:32:42.944Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/0d/4be6bf5477a3eb3d917d2f17d3c0b6720cd6cb97898444a61d43cc983f5c/opentelemetry_exporter_prometheus-0.60b1-py3-none-any.whl", hash = "sha256:49f59178de4f4590e3cef0b8b95cf6e071aae70e1f060566df5546fad773b8fd", size = 13019, upload-time = "2025-12-11T13:32:23.974Z" }, +] + +[[package]] +name = "opentelemetry-proto" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/49/1d/f25d76d8260c156c40c97c9ed4511ec0f9ce353f8108ca6e7561f82a06b2/opentelemetry_proto-1.39.1.tar.gz", hash = "sha256:6c8e05144fc0d3ed4d22c2289c6b126e03bcd0e6a7da0f16cedd2e1c2772e2c8", size = 46152, upload-time = "2025-12-11T13:32:48.681Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/95/b40c96a7b5203005a0b03d8ce8cd212ff23f1793d5ba289c87a097571b18/opentelemetry_proto-1.39.1-py3-none-any.whl", hash = "sha256:22cdc78efd3b3765d09e68bfbd010d4fc254c9818afd0b6b423387d9dee46007", size = 72535, upload-time = "2025-12-11T13:32:33.866Z" }, +] + +[[package]] +name = "opentelemetry-sdk" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/fb/c76080c9ba07e1e8235d24cdcc4d125ef7aa3edf23eb4e497c2e50889adc/opentelemetry_sdk-1.39.1.tar.gz", hash = "sha256:cf4d4563caf7bff906c9f7967e2be22d0d6b349b908be0d90fb21c8e9c995cc6", size = 171460, upload-time = "2025-12-11T13:32:49.369Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/98/e91cf858f203d86f4eccdf763dcf01cf03f1dae80c3750f7e635bfa206b6/opentelemetry_sdk-1.39.1-py3-none-any.whl", hash = "sha256:4d5482c478513ecb0a5d938dcc61394e647066e0cc2676bee9f3af3f3f45f01c", size = 132565, upload-time = "2025-12-11T13:32:35.069Z" }, +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.60b1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/df/553f93ed38bf22f4b999d9be9c185adb558982214f33eae539d3b5cd0858/opentelemetry_semantic_conventions-0.60b1.tar.gz", hash = "sha256:87c228b5a0669b748c76d76df6c364c369c28f1c465e50f661e39737e84bc953", size = 137935, upload-time = "2025-12-11T13:32:50.487Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/5e/5958555e09635d09b75de3c4f8b9cae7335ca545d77392ffe7331534c402/opentelemetry_semantic_conventions-0.60b1-py3-none-any.whl", hash = "sha256:9fa8c8b0c110da289809292b0591220d3a7b53c1526a23021e977d68597893fb", size = 219982, upload-time = "2025-12-11T13:32:36.955Z" }, +] + [[package]] name = "packaging" version = "25.0" @@ -2688,6 +2999,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "portpicker" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "psutil" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4d/d0/cda2fc582f09510c84cd6b7d7b9e22a02d4e45dbad2b2ef1c6edd7847e00/portpicker-1.6.0.tar.gz", hash = "sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa", size = 25676, upload-time = "2023-08-15T04:37:08.865Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/2d/440e4d7041fff89f28f483733eb617127aa866135c2dc719e05893f089e1/portpicker-1.6.0-py3-none-any.whl", hash = "sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755", size = 16613, upload-time = "2023-08-15T04:37:07.327Z" }, +] + [[package]] name = "prometheus-client" version = "0.23.1" @@ -2863,6 +3186,58 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/15/4f02896cc3df04fc465010a4c6a0cd89810f54617a32a70ef531ed75d61c/protobuf-6.33.2-py3-none-any.whl", hash = "sha256:7636aad9bb01768870266de5dc009de2d1b936771b38a793f73cbbf279c91c5c", size = 170501, upload-time = "2025-12-06T00:17:52.211Z" }, ] +[[package]] +name = "psutil" +version = "7.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/73/cb/09e5184fb5fc0358d110fc3ca7f6b1d033800734d34cac10f4136cfac10e/psutil-7.2.1.tar.gz", hash = "sha256:f7583aec590485b43ca601dd9cea0dcd65bd7bb21d30ef4ddbf4ea6b5ed1bdd3", size = 490253, upload-time = "2025-12-29T08:26:00.169Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/8e/f0c242053a368c2aa89584ecd1b054a18683f13d6e5a318fc9ec36582c94/psutil-7.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:ba9f33bb525b14c3ea563b2fd521a84d2fa214ec59e3e6a2858f78d0844dd60d", size = 129624, upload-time = "2025-12-29T08:26:04.255Z" }, + { url = "https://files.pythonhosted.org/packages/26/97/a58a4968f8990617decee234258a2b4fc7cd9e35668387646c1963e69f26/psutil-7.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:81442dac7abfc2f4f4385ea9e12ddf5a796721c0f6133260687fec5c3780fa49", size = 130132, upload-time = "2025-12-29T08:26:06.228Z" }, + { url = "https://files.pythonhosted.org/packages/db/6d/ed44901e830739af5f72a85fa7ec5ff1edea7f81bfbf4875e409007149bd/psutil-7.2.1-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ea46c0d060491051d39f0d2cff4f98d5c72b288289f57a21556cc7d504db37fc", size = 180612, upload-time = "2025-12-29T08:26:08.276Z" }, + { url = "https://files.pythonhosted.org/packages/c7/65/b628f8459bca4efbfae50d4bf3feaab803de9a160b9d5f3bd9295a33f0c2/psutil-7.2.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:35630d5af80d5d0d49cfc4d64c1c13838baf6717a13effb35869a5919b854cdf", size = 183201, upload-time = "2025-12-29T08:26:10.622Z" }, + { url = "https://files.pythonhosted.org/packages/fb/23/851cadc9764edcc18f0effe7d0bf69f727d4cf2442deb4a9f78d4e4f30f2/psutil-7.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:923f8653416604e356073e6e0bccbe7c09990acef442def2f5640dd0faa9689f", size = 139081, upload-time = "2025-12-29T08:26:12.483Z" }, + { url = "https://files.pythonhosted.org/packages/59/82/d63e8494ec5758029f31c6cb06d7d161175d8281e91d011a4a441c8a43b5/psutil-7.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:cfbe6b40ca48019a51827f20d830887b3107a74a79b01ceb8cc8de4ccb17b672", size = 134767, upload-time = "2025-12-29T08:26:14.528Z" }, + { url = "https://files.pythonhosted.org/packages/05/c2/5fb764bd61e40e1fe756a44bd4c21827228394c17414ade348e28f83cd79/psutil-7.2.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:494c513ccc53225ae23eec7fe6e1482f1b8a44674241b54561f755a898650679", size = 129716, upload-time = "2025-12-29T08:26:16.017Z" }, + { url = "https://files.pythonhosted.org/packages/c9/d2/935039c20e06f615d9ca6ca0ab756cf8408a19d298ffaa08666bc18dc805/psutil-7.2.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:3fce5f92c22b00cdefd1645aa58ab4877a01679e901555067b1bd77039aa589f", size = 130133, upload-time = "2025-12-29T08:26:18.009Z" }, + { url = "https://files.pythonhosted.org/packages/77/69/19f1eb0e01d24c2b3eacbc2f78d3b5add8a89bf0bb69465bc8d563cc33de/psutil-7.2.1-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:93f3f7b0bb07711b49626e7940d6fe52aa9940ad86e8f7e74842e73189712129", size = 181518, upload-time = "2025-12-29T08:26:20.241Z" }, + { url = "https://files.pythonhosted.org/packages/e1/6d/7e18b1b4fa13ad370787626c95887b027656ad4829c156bb6569d02f3262/psutil-7.2.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d34d2ca888208eea2b5c68186841336a7f5e0b990edec929be909353a202768a", size = 184348, upload-time = "2025-12-29T08:26:22.215Z" }, + { url = "https://files.pythonhosted.org/packages/98/60/1672114392dd879586d60dd97896325df47d9a130ac7401318005aab28ec/psutil-7.2.1-cp314-cp314t-win_amd64.whl", hash = "sha256:2ceae842a78d1603753561132d5ad1b2f8a7979cb0c283f5b52fb4e6e14b1a79", size = 140400, upload-time = "2025-12-29T08:26:23.993Z" }, + { url = "https://files.pythonhosted.org/packages/fb/7b/d0e9d4513c46e46897b46bcfc410d51fc65735837ea57a25170f298326e6/psutil-7.2.1-cp314-cp314t-win_arm64.whl", hash = "sha256:08a2f175e48a898c8eb8eace45ce01777f4785bc744c90aa2cc7f2fa5462a266", size = 135430, upload-time = "2025-12-29T08:26:25.999Z" }, + { url = "https://files.pythonhosted.org/packages/c5/cf/5180eb8c8bdf6a503c6919f1da28328bd1e6b3b1b5b9d5b01ae64f019616/psutil-7.2.1-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:b2e953fcfaedcfbc952b44744f22d16575d3aa78eb4f51ae74165b4e96e55f42", size = 128137, upload-time = "2025-12-29T08:26:27.759Z" }, + { url = "https://files.pythonhosted.org/packages/c5/2c/78e4a789306a92ade5000da4f5de3255202c534acdadc3aac7b5458fadef/psutil-7.2.1-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:05cc68dbb8c174828624062e73078e7e35406f4ca2d0866c272c2410d8ef06d1", size = 128947, upload-time = "2025-12-29T08:26:29.548Z" }, + { url = "https://files.pythonhosted.org/packages/29/f8/40e01c350ad9a2b3cb4e6adbcc8a83b17ee50dd5792102b6142385937db5/psutil-7.2.1-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5e38404ca2bb30ed7267a46c02f06ff842e92da3bb8c5bfdadbd35a5722314d8", size = 154694, upload-time = "2025-12-29T08:26:32.147Z" }, + { url = "https://files.pythonhosted.org/packages/06/e4/b751cdf839c011a9714a783f120e6a86b7494eb70044d7d81a25a5cd295f/psutil-7.2.1-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ab2b98c9fc19f13f59628d94df5cc4cc4844bc572467d113a8b517d634e362c6", size = 156136, upload-time = "2025-12-29T08:26:34.079Z" }, + { url = "https://files.pythonhosted.org/packages/44/ad/bbf6595a8134ee1e94a4487af3f132cef7fce43aef4a93b49912a48c3af7/psutil-7.2.1-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:f78baafb38436d5a128f837fab2d92c276dfb48af01a240b861ae02b2413ada8", size = 148108, upload-time = "2025-12-29T08:26:36.225Z" }, + { url = "https://files.pythonhosted.org/packages/1c/15/dd6fd869753ce82ff64dcbc18356093471a5a5adf4f77ed1f805d473d859/psutil-7.2.1-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:99a4cd17a5fdd1f3d014396502daa70b5ec21bf4ffe38393e152f8e449757d67", size = 147402, upload-time = "2025-12-29T08:26:39.21Z" }, + { url = "https://files.pythonhosted.org/packages/34/68/d9317542e3f2b180c4306e3f45d3c922d7e86d8ce39f941bb9e2e9d8599e/psutil-7.2.1-cp37-abi3-win_amd64.whl", hash = "sha256:b1b0671619343aa71c20ff9767eced0483e4fc9e1f489d50923738caf6a03c17", size = 136938, upload-time = "2025-12-29T08:26:41.036Z" }, + { url = "https://files.pythonhosted.org/packages/3e/73/2ce007f4198c80fcf2cb24c169884f833fe93fbc03d55d302627b094ee91/psutil-7.2.1-cp37-abi3-win_arm64.whl", hash = "sha256:0d67c1822c355aa6f7314d92018fb4268a76668a536f133599b91edd48759442", size = 133836, upload-time = "2025-12-29T08:26:43.086Z" }, +] + +[[package]] +name = "pulp" +version = "3.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/16/1c/d880b739b841a8aa81143091c9bdda5e72e226a660aa13178cb312d4b27f/pulp-3.3.0.tar.gz", hash = "sha256:7eb99b9ce7beeb8bbb7ea9d1c919f02f003ab7867e0d1e322f2f2c26dd31c8ba", size = 16301847, upload-time = "2025-09-18T08:14:57.552Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/6c/64cafaceea3f99927e84b38a362ec6a8f24f33061c90bda77dfe1cd4c3c6/pulp-3.3.0-py3-none-any.whl", hash = "sha256:dd6ad2d63f196d1254eddf9dcff5cd224912c1f046120cb7c143c5b0eda63fae", size = 16387700, upload-time = "2025-09-18T08:14:53.368Z" }, +] + +[[package]] +name = "py-spy" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/19/e2/ff811a367028b87e86714945bb9ecb5c1cc69114a8039a67b3a862cef921/py_spy-0.4.1.tar.gz", hash = "sha256:e53aa53daa2e47c2eef97dd2455b47bb3a7e7f962796a86cc3e7dbde8e6f4db4", size = 244726, upload-time = "2025-07-31T19:33:25.172Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/e3/3a32500d845bdd94f6a2b4ed6244982f42ec2bc64602ea8fcfe900678ae7/py_spy-0.4.1-py2.py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:809094208c6256c8f4ccadd31e9a513fe2429253f48e20066879239ba12cd8cc", size = 3682508, upload-time = "2025-07-31T19:33:13.753Z" }, + { url = "https://files.pythonhosted.org/packages/4f/bf/e4d280e9e0bec71d39fc646654097027d4bbe8e04af18fb68e49afcff404/py_spy-0.4.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:1fb8bf71ab8df95a95cc387deed6552934c50feef2cf6456bc06692a5508fd0c", size = 1796395, upload-time = "2025-07-31T19:33:15.325Z" }, + { url = "https://files.pythonhosted.org/packages/df/79/9ed50bb0a9de63ed023aa2db8b6265b04a7760d98c61eb54def6a5fddb68/py_spy-0.4.1-py2.py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee776b9d512a011d1ad3907ed53ae32ce2f3d9ff3e1782236554e22103b5c084", size = 2034938, upload-time = "2025-07-31T19:33:17.194Z" }, + { url = "https://files.pythonhosted.org/packages/53/a5/36862e3eea59f729dfb70ee6f9e14b051d8ddce1aa7e70e0b81d9fe18536/py_spy-0.4.1-py2.py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:532d3525538254d1859b49de1fbe9744df6b8865657c9f0e444bf36ce3f19226", size = 2658968, upload-time = "2025-07-31T19:33:18.916Z" }, + { url = "https://files.pythonhosted.org/packages/08/f8/9ea0b586b065a623f591e5e7961282ec944b5fbbdca33186c7c0296645b3/py_spy-0.4.1-py2.py3-none-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4972c21890b6814017e39ac233c22572c4a61fd874524ebc5ccab0f2237aee0a", size = 2147541, upload-time = "2025-07-31T19:33:20.565Z" }, + { url = "https://files.pythonhosted.org/packages/68/fb/bc7f639aed026bca6e7beb1e33f6951e16b7d315594e7635a4f7d21d63f4/py_spy-0.4.1-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:6a80ec05eb8a6883863a367c6a4d4f2d57de68466f7956b6367d4edd5c61bb29", size = 2763338, upload-time = "2025-07-31T19:33:22.202Z" }, + { url = "https://files.pythonhosted.org/packages/e1/da/fcc9a9fcd4ca946ff402cff20348e838b051d69f50f5d1f5dca4cd3c5eb8/py_spy-0.4.1-py2.py3-none-win_amd64.whl", hash = "sha256:d92e522bd40e9bf7d87c204033ce5bb5c828fca45fa28d970f58d71128069fdc", size = 1818784, upload-time = "2025-07-31T19:33:23.802Z" }, +] + [[package]] name = "pyarrow" version = "22.0.0" @@ -3369,6 +3744,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b0/0c/9e49c3da7502f18483e4deb3273a3104d501c5e9cf1664a136b8ea36df48/ray-2.49.2-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:b7d8214cff86df044fec727eeeabccc3bfc9b0271d28d61ba92c09f0d127d01d", size = 70027331, upload-time = "2025-09-19T19:16:12.968Z" }, ] +[package.optional-dependencies] +default = [ + { name = "aiohttp" }, + { name = "aiohttp-cors" }, + { name = "colorful" }, + { name = "grpcio" }, + { name = "opencensus" }, + { name = "opentelemetry-exporter-prometheus" }, + { name = "opentelemetry-proto" }, + { name = "opentelemetry-sdk" }, + { name = "prometheus-client" }, + { name = "py-spy" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "smart-open" }, + { name = "virtualenv" }, +] + [[package]] name = "referencing" version = "0.37.0" @@ -3847,6 +4240,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "smart-open" +version = "7.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/67/9a/0a7acb748b86e2922982366d780ca4b16c33f7246fa5860d26005c97e4f3/smart_open-7.5.0.tar.gz", hash = "sha256:f394b143851d8091011832ac8113ea4aba6b92e6c35f6e677ddaaccb169d7cb9", size = 53920, upload-time = "2025-11-08T21:38:40.698Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ad/95/bc978be7ea0babf2fb48a414b6afaad414c6a9e8b1eafc5b8a53c030381a/smart_open-7.5.0-py3-none-any.whl", hash = "sha256:87e695c5148bbb988f15cec00971602765874163be85acb1c9fb8abc012e6599", size = 63940, upload-time = "2025-11-08T21:38:39.024Z" }, +] + [[package]] name = "smmap" version = "5.0.2" @@ -4384,6 +4789,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/d9/d88e73ca598f4f6ff671fb5fde8a32925c2e08a637303a1d12883c7305fa/uvicorn-0.38.0-py3-none-any.whl", hash = "sha256:48c0afd214ceb59340075b4a052ea1ee91c16fbc2a9b1469cca0e54566977b02", size = 68109, upload-time = "2025-10-18T13:46:42.958Z" }, ] +[[package]] +name = "virtualenv" +version = "20.35.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "distlib" }, + { name = "filelock" }, + { name = "platformdirs" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/20/28/e6f1a6f655d620846bd9df527390ecc26b3805a0c5989048c210e22c5ca9/virtualenv-20.35.4.tar.gz", hash = "sha256:643d3914d73d3eeb0c552cbb12d7e82adf0e504dbf86a3182f8771a153a1971c", size = 6028799, upload-time = "2025-10-29T06:57:40.511Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/0c/c05523fa3181fdf0c9c52a6ba91a23fbf3246cc095f26f6516f9c60e6771/virtualenv-20.35.4-py3-none-any.whl", hash = "sha256:c21c9cede36c9753eeade68ba7d523529f228a403463376cf821eaae2b650f1b", size = 6005095, upload-time = "2025-10-29T06:57:37.598Z" }, +] + [[package]] name = "wandb" version = "0.23.1" From 95dd259604a8ab8c80a502209863951c66f8daa1 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Thu, 22 Jan 2026 10:30:53 +0100 Subject: [PATCH 3/3] Integrating cosmos-xenna Signed-off-by: Marc Romeyn --- docs/packed-sft-format-design.md | 482 ++++++++++++ docs/packed-sft-impl-memmap.md | 718 ++++++++++++++++++ ...packed-sft-impl-parquet-megatron-bridge.md | 530 +++++++++++++ docs/packed-sft-impl-parquet-nemotron.md | 626 +++++++++++++++ pyproject.toml | 2 +- src/nemotron/data_prep/__init__.py | 56 +- src/nemotron/data_prep/config.py | 45 ++ src/nemotron/data_prep/console.py | 28 +- src/nemotron/data_prep/pipeline.py | 180 ++++- src/nemotron/data_prep/xenna/__init__.py | 54 +- src/nemotron/data_prep/xenna/executor.py | 99 +++ src/nemotron/data_prep/xenna/observability.py | 207 +++++ .../data_prep/xenna/pipeline_specs.py | 236 ++++++ src/nemotron/data_prep/xenna/runner.py | 662 ---------------- src/nemotron/data_prep/xenna/stages.py | 27 +- src/nemotron/kit/cli/recipe.py | 19 +- src/nemotron/kit/run.py | 11 +- .../data_prep/data_blend_cache_test.json | 11 + .../config/data_prep/default.yaml | 15 +- .../config/data_prep/tiny_xenna.yaml | 66 ++ .../nano3/stage0_pretrain/prep_xenna.py | 2 + uv.lock | 4 +- 22 files changed, 3350 insertions(+), 730 deletions(-) create mode 100644 docs/packed-sft-format-design.md create mode 100644 docs/packed-sft-impl-memmap.md create mode 100644 docs/packed-sft-impl-parquet-megatron-bridge.md create mode 100644 docs/packed-sft-impl-parquet-nemotron.md create mode 100644 src/nemotron/data_prep/xenna/executor.py create mode 100644 src/nemotron/data_prep/xenna/observability.py create mode 100644 src/nemotron/data_prep/xenna/pipeline_specs.py delete mode 100644 src/nemotron/data_prep/xenna/runner.py create mode 100644 src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/data_blend_cache_test.json create mode 100644 src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/tiny_xenna.yaml diff --git a/docs/packed-sft-format-design.md b/docs/packed-sft-format-design.md new file mode 100644 index 000000000..861a58d4e --- /dev/null +++ b/docs/packed-sft-format-design.md @@ -0,0 +1,482 @@ +# RFC: Scalable Packed SFT Data Format + +**Status:** Proposed +**Authors:** Data Infrastructure Team +**Created:** 2025-01-20 +**Target:** Megatron-Bridge + Nemotron Data Prep + +--- + +## Summary + +This RFC proposes adding new memory-efficient formats for packed SFT data in the Megatron-Bridge training pipeline. + +**The problem:** The current pickle-based `.npy` format requires loading the entire dataset into memory—both when writing (Nemotron data_prep) and when reading (Megatron-Bridge training). For a typical 50,000-sample packed dataset, this consumes ~4 GB of RAM. This limits the size of datasets we can process and creates memory pressure during multi-node training where every node loads the full dataset. + +**Modern scale requirements:** State-of-the-art SFT pipelines operate at significantly larger scale. For example, [Nemotron-3 Nano](https://huggingface.co/blog/nvidia/nemotron-3-nano-efficient-open-intelligent-models) uses a 13-million-sample post-training corpus spanning code, math, multi-turn conversations, and tool use. At this scale, the current format would require **~1 TB of RAM** just to load the dataset—clearly infeasible. + +**The proposal:** Add new format options (Parquet, Memmap) that support streaming writes and lazy reads, reducing memory usage by 200-500x for large datasets. The current `.npy` format will remain supported and continue to be the default, ensuring full backward compatibility. Users can opt into new formats when working with larger datasets or cloud storage. + +--- + +## Background + +### Current Packed Sequence Format + +Megatron-Bridge's `GPTSFTPackedDataset` reads packed SFT data from `.npy` files produced by Nemotron's data preparation pipeline. The format stores a Python list of dictionaries: + +```python +# Current format: list of dicts with Python lists +[ + { + "input_ids": [101, 2054, 2003, ...], # Variable-length token ids + "loss_mask": [0, 0, 1, 1, ...], # Variable-length loss mask + "seq_start_id": [0, 45, 128, ...] # Sequence boundaries within the pack + }, + { + "input_ids": [...], + "loss_mask": [...], + "seq_start_id": [...] + }, + # ... thousands more packed samples +] +``` + +This is serialized using `np.save(..., allow_pickle=True)`. + +This format works well for small to medium datasets and will remain supported. However, it has scalability limitations for larger workloads: + +### Scalability Limitations + +#### Problem 1: Write-Side Memory Explosion + +In Nemotron's `process_chat_sft_pack_from_spool_core()`: + +```python +packed_data: list[dict] = [] +for item in materialize_packed_samples(...): + packed_data.append(item) # Accumulates ALL samples in memory + +np.save(f, packed_data, allow_pickle=True) +``` + +The Python object overhead is severe: +- Each token (4-byte int32) becomes a ~28-byte Python int object +- Each list adds 8 bytes per element for pointers + 56 bytes header +- **Result:** 50,000 packed samples with pack_size=2048 → **~4 GB peak memory** + +#### Problem 2: Read-Side Full Load + +In Megatron-Bridge's `GPTSFTPackedDataset._load_dataset()`: + +```python +self.indexed_dataset = np.load(self.file_path, allow_pickle=True) +``` + +This deserializes the entire pickle into RAM before any training iteration. There is no lazy loading or memory mapping—the "random access" only works because everything is already in memory. + +| Dataset Size | Memory Required | +|--------------|-----------------| +| 10k samples | ~800 MB | +| 50k samples | ~4 GB | +| 200k samples | ~16 GB | + +This becomes prohibitive for multi-node training where each node loads the full dataset. + +For reference, modern SFT pipelines like [Nemotron-3 Nano's post-training](https://developer.nvidia.com/blog/inside-nvidia-nemotron-3-techniques-tools-and-data-that-make-it-efficient-and-accurate/) use **13+ million samples** across diverse domains (code, math, multi-turn conversations, tool use). At this scale, the current format is not viable. + +#### Problem 3: No Cloud Storage Support + +The pickle-based format cannot be streamed from cloud storage (S3, GCS). The entire file must be downloaded and deserialized, adding startup latency to training jobs. + +### Contrast: Megatron Pretrain Format + +The Megatron pretrain format (`.bin` + `.idx`) demonstrates the right approach: + +```python +# Data: memory-mapped binary file +mdata = np.memmap(data_path, dtype=np.uint8, mode="r") + +# Index: offsets for O(1) random access +midx = np.load(idx_path, mmap_mode="r") + +# Access: read only what you need +sample = mdata[midx[i]:midx[i+1]] +``` + +**Properties:** +- O(1) memory at load time (just mmap headers) +- O(sample_size) memory per access +- True random access without loading everything + +We need similar properties for packed SFT data. + +--- + +## Goals + +1. **Streaming writes:** New formats should support writing without accumulating in memory +2. **Lazy reads:** New formats should load only the samples needed, not the entire dataset +3. **Cloud-native:** New formats should support direct read/write to S3, GCS, Azure +4. **Backward compatible:** Existing `.npy` files must continue to work unchanged +5. **Opt-in:** New formats are opt-in via configuration; existing pipelines unaffected +6. **Minimal training code changes:** Same `__getitem__` interface regardless of format + +--- + +## New Format Options + +The following formats are proposed as additional options alongside the existing legacy format. + +**Which option is closest to the current format?** Parquet (Option B) is the most similar to the current pickle-based format: +- Both store variable-length lists without padding +- Both use a single file per shard +- Both preserve the same data model (`input_ids`, `loss_mask`, `seq_start_id` as lists) + +The main difference is the serialization format (Parquet columnar vs Python pickle) and that Parquet adds compression. Migration from legacy to Parquet requires no changes to the logical data structure. + +Memmap (Option A) is a larger departure: it pads sequences to fixed length and splits data across multiple files. This trades some disk space (~10%) for true O(1) random access without row group boundaries. + +--- + +### Option A: Padded Fixed-Shape Memmap Arrays + +Use `np.lib.format.open_memmap()` to write fixed-shape numpy arrays that can be memory-mapped for reading. + +**Format:** +``` +shard_000000/ +├── input_ids.npy # int32[num_bins, pack_size] - padded +├── loss_mask.npy # uint8[num_bins, pack_size] - padded +├── packed_len.npy # uint32[num_bins] - actual lengths +├── seq_offsets.npy # uint32[num_bins + 1] - CSR pointers +├── seq_starts.npy # uint32[total_seq_starts] - boundary values +└── manifest.json +``` + +**Pros:** +- True O(1) random access via memmap +- Simplest implementation (direct numpy) +- Zero external dependencies + +**Cons:** +- ~10% padding waste +- Multiple files per shard +- Local filesystem required for writes (cloud needs temp + upload) + +**Implementation:** [packed-sft-impl-memmap.md](./packed-sft-impl-memmap.md) + +--- + +### Option B: Parquet (Recommended) + +Use Apache Parquet with PyArrow for columnar storage with native variable-length support. + +**Format:** +``` +shard_000000.parquet + Schema: + - input_ids: list + - loss_mask: list + - seq_start_id: list + Compression: zstd + Row groups: ~1000 rows +``` + +**Pros:** +- Industry standard with excellent tooling +- Default format for Hugging Face datasets +- Native variable-length lists (no padding) +- 2-3x compression with zstd +- Direct cloud storage support (S3, GCS, Azure) +- Single file per shard +- Queryable with DuckDB, Polars, pandas + +**Cons:** +- Adds PyArrow dependency (already used in pipeline) +- Row-group granularity for access (configurable, typically fine) + +**Implementation:** +- Writer (Nemotron): [packed-sft-impl-parquet-nemotron.md](./packed-sft-impl-parquet-nemotron.md) +- Reader (Megatron-Bridge): [packed-sft-impl-parquet-megatron-bridge.md](./packed-sft-impl-parquet-megatron-bridge.md) + +--- + +### Option C: Megatron bin/idx Style + +Mimic the proven Megatron pretrain format with a flat binary data file and separate index. + +**Current Megatron pretrain format:** +``` +dataset.bin # Concatenated tokens as raw bytes (uint16/uint32) +dataset.idx # Header + document offsets + sequence lengths +``` + +The pretrain format stores only `input_ids` with document boundaries. For packed SFT, we would need to extend this to also store `loss_mask` and `seq_start_id`. + +**Proposed extended format:** +``` +shard_000000.bin # Concatenated input_ids (int32) +shard_000000.loss.bin # Concatenated loss_mask (uint8) [NEW] +shard_000000.idx # Bin offsets + lengths +shard_000000.seq.idx # Sequence boundary offsets (CSR) [NEW] +shard_000000.seq.bin # Sequence start positions [NEW] +``` + +**What needs to be built:** +1. **Extended index format** — Add fields for `loss_mask` offsets and `seq_start_id` CSR pointers +2. **Loss mask storage** — Separate `.loss.bin` file or interleaved with tokens +3. **Sequence boundary storage** — CSR-style index similar to Option A's `seq_offsets.npy` + `seq_starts.npy` +4. **Writer changes** — Extend `IndexedDatasetBuilder` to write additional arrays +5. **Reader changes** — Extend `MMapIndexedDataset` to read loss mask and boundaries + +**Pros:** +- Proven memory-mapping approach in Megatron ecosystem +- Consistent with existing pretrain data loading patterns +- True O(1) random access + +**Cons:** +- Requires extending Megatron's indexed dataset format (non-trivial) +- 5 files per shard (more than Parquet's 1, similar to Option A's 6) +- No compression (larger on disk than Parquet) +- Less ecosystem tooling compared to Parquet + +**Recommendation:** Consider this option if strict consistency with Megatron pretrain tooling is required. Otherwise, Parquet (Option B) provides similar benefits with less implementation effort and better compression. + +--- + +## Comparison + +| Aspect | Legacy (.npy pickle) | Memmap (Option A) | Parquet (Option B) | +|--------|----------------------|-------------------|---------------------| +| Write memory (Python heap) | O(dataset) | O(pack_size)† | O(row_group) | +| Read memory at load | O(dataset) | O(metadata) | O(metadata) | +| Read memory per sample | O(1)* | O(pack_size) | O(row_group) | +| Random access | O(1)* | O(1) true | O(row_group) | +| Disk size | 1x | 1.1x (padding) | 0.4x (compressed) | +| Files per shard | 1 | 6 | 1 | +| Cloud support | None | Local only | Native (local) / Ranged reads (cloud) | +| Variable-length | Via pickle | Padding | Native | +| Tooling | None | numpy | DuckDB, Polars | +| Status | Supported (default) | New option | New option | + +*After full dataset load +†OS page cache may grow with output size during writes; Python heap stays small + +--- + +## When to Use Each Format + +### Legacy (.npy pickle) — Keep using when: +- Existing pipelines that work well at current scale +- Small to medium datasets (< 50k packed samples, ~4 GB memory budget) +- No need to change what already works + +### Parquet (Option B) — Use when: +- Large datasets where memory is a constraint (100k+ samples, millions for production SFT) +- Cloud storage (S3, GCS, Azure) is involved +- You want compression to reduce disk/network I/O +- You need to inspect data with standard tools (DuckDB, pandas) + +**Parquet is the recommended new format** because: +1. **Already a dependency** — PyArrow is used for reading input data +2. **Ecosystem standard** — Default format for Hugging Face datasets; widely adopted for ML data +3. **Single file** — Simpler than multi-file memmap directory +4. **Compression** — 2-3x smaller on disk reduces I/O +5. **Cloud-native** — Direct S3/GCS support without temp files +6. **No padding waste** — Native variable-length arrays + +### Memmap (Option A) — Use when: +- Maximum raw access speed is critical +- You want zero dependencies beyond numpy +- All storage is local POSIX filesystem +- You need true O(1) random access without row group boundaries + +--- + +## Megatron-Bridge Changes Required + +### New Dataset Classes + +```python +# For Parquet format +class GPTSFTPackedParquetDataset(GPTSFTDataset): + def _load_dataset(self): + self._pf = pq.ParquetFile(self.path, memory_map=True) + # No data loaded yet + + def __getitem__(self, idx): + # Read only the row group containing idx + rg = self._pf.read_row_group(self._get_rg(idx)) + return rg[idx % rg_size] + +# For Memmap format +class GPTSFTPackedMemmapDataset(GPTSFTDataset): + def _load_dataset(self): + self.input_ids = np.load(..., mmap_mode='r') + # No data loaded yet + + def __getitem__(self, idx): + L = self.packed_len[idx] + return self.input_ids[idx, :L] +``` + +### Factory Function Update + +```python +def create_sft_dataset(path: Path, tokenizer, ...) -> GPTSFTDataset: + if path.suffix == ".npy": + return GPTSFTPackedDataset(...) # Legacy + elif path.suffix == ".parquet": + return GPTSFTPackedParquetDataset(...) # New + elif path.is_dir() and (path / "manifest.json").exists(): + return GPTSFTPackedMemmapDataset(...) # New +``` + +### Backward Compatibility + +- Existing `.npy` files continue to work unchanged +- Format detected automatically by file extension/structure +- Training code uses same `__getitem__` interface + +--- + +## Format Invariants + +All formats must maintain these invariants for each packed sample. Implementations should assert these during writes and verify during tests: + +```python +# For each packed bin: +assert 0 < packed_len <= pack_size +assert len(input_ids) == packed_len +assert len(loss_mask) == packed_len + +# seq_start_id contains START positions of each sequence in the pack +# First sequence always starts at 0 +assert seq_start_id[0] == 0 +assert all(seq_start_id[i] < seq_start_id[i+1] for i in range(len(seq_start_id)-1)) +assert seq_start_id[-1] < packed_len # Last start is before end + +# To reconstruct sequence boundaries for attention masking: +# boundaries = list(seq_start_id) + [packed_len] +# seq_i spans input_ids[boundaries[i]:boundaries[i+1]] +``` + +**Key invariant:** `seq_start_id` stores **start positions** `[0, start_1, start_2, ...]`, not end positions. The final boundary (`packed_len`) is reconstructed at read time. + +--- + +## DataLoader Multi-Worker Considerations + +PyTorch DataLoader with `num_workers > 0` spawns worker processes. File handles and memory-mapped arrays do not survive pickling across process boundaries. + +**Requirements for new dataset classes:** + +1. **Lazy open pattern:** Store paths in `__init__`, open files in a separate `_load_dataset()` method +2. **Per-worker initialization:** Call `_load_dataset()` inside each worker, not in the main process +3. **Use `worker_init_fn`:** Reopen mmaps/parquet files per worker + +```python +class GPTSFTPackedParquetDataset(GPTSFTDataset): + def __init__(self, path: Path, ...): + self.path = path + self._pf = None # Opened lazily + + def _ensure_loaded(self): + if self._pf is None: + self._pf = pq.ParquetFile(self.path, memory_map=True) + + def __getitem__(self, idx): + self._ensure_loaded() # Opens on first access in each worker + ... +``` + +**Testing requirement:** Verify dataset works with `num_workers=4` and `persistent_workers=True`. + +--- + +## Cloud Storage Behavior + +The new formats behave differently for local vs cloud storage: + +### Parquet +- **Local:** Uses true memory-mapped reads via `memory_map=True` +- **Cloud (S3/GCS):** Uses ranged HTTP reads; `memory_map=True` is ignored. PyArrow's filesystem layer handles buffering. Row group size affects read efficiency—smaller groups = more requests but finer granularity. + +### Memmap +- **Local:** True O(1) random access via OS virtual memory +- **Cloud:** Not directly supported. Requires download-to-local or write-local-then-upload pattern. For cloud outputs, write to local temp directory, then upload atomically. + +**Recommendation:** Use Parquet for cloud workloads; use Memmap only for local high-performance scenarios. + +--- + +## Migration Path + +### Phase 1: Add New Formats + +1. Implement Parquet writer in Nemotron data_prep +2. Implement Parquet reader in Megatron-Bridge +3. Add `packed_storage` config option (default: `legacy_npy_pickle`) +4. Existing pipelines continue to work without changes + +### Phase 2: Validation + +1. Run training with new format, verify identical loss curves +2. Benchmark memory usage and throughput +3. Test cloud storage integration + +### Phase 3: Adoption + +1. Document when to use new formats vs legacy +2. Provide conversion tool for users who want to migrate existing datasets +3. Legacy format remains fully supported for backward compatibility + +--- + +## Memory Impact (New Formats) + +For workloads that benefit from the new formats: + +| Metric | Legacy | Parquet | Memmap | +|--------|--------|---------|--------| +| Write peak (Python heap, 50k bins) | ~4 GB | ~20 MB | ~16 KB† | +| Read at load (RSS) | ~4 GB | ~few MB (metadata) | ~few KB (metadata) | +| Read per batch (bs=8) | ~4 GB* | ~8 MB (row group) | ~128 KB | + +*Already loaded +†OS page cache grows with output size but Python heap stays minimal + +**Note on memory estimates:** These are back-of-envelope calculations based on Python object overhead (~28 bytes per int vs 4 bytes in numpy). Actual memory depends on CPython version, allocator behavior, and workload. The relative improvements (200-500x) are consistent across configurations. + +**New formats provide ~200x reduction in write memory and ~500x reduction in read memory at load**, enabling larger datasets and more efficient multi-node training. + +--- + +## Open Questions + +1. **Row group size tuning:** What's the optimal row group size for training access patterns? (Proposed: 1000 rows, ~2-10 MB per group) + +2. **Multi-shard handling:** Should we use a single Parquet dataset with partitions or multiple files with a wrapper dataset? + +3. **Conversion tooling:** Should we provide a conversion tool for users who want to migrate existing datasets to new formats? + +4. **Default format:** Should the default remain `legacy_npy_pickle` indefinitely, or switch to Parquet after validation? + +--- + +## References + +### Implementation Plans +- [Parquet Writer (Nemotron)](./packed-sft-impl-parquet-nemotron.md) +- [Parquet Reader (Megatron-Bridge)](./packed-sft-impl-parquet-megatron-bridge.md) +- [Memmap Format](./packed-sft-impl-memmap.md) + +### Existing Code +- [Megatron-Bridge GPTSFTPackedDataset](../../../Megatron-Bridge/src/megatron/bridge/data/datasets/sft.py) +- [Nemotron chat_sft_shard_core.py](../src/nemotron/data_prep/chat_sft_shard_core.py) + +### External Documentation +- [Apache Parquet Format](https://parquet.apache.org/docs/file-format/) +- [NumPy Memory-Mapped Files](https://numpy.org/doc/stable/reference/generated/numpy.memmap.html) +- [Hugging Face Datasets (Parquet backend)](https://huggingface.co/docs/datasets/about_arrow) diff --git a/docs/packed-sft-impl-memmap.md b/docs/packed-sft-impl-memmap.md new file mode 100644 index 000000000..30ae8f702 --- /dev/null +++ b/docs/packed-sft-impl-memmap.md @@ -0,0 +1,718 @@ +# Implementation Plan: Memmap Packed SFT Format + +This document details the implementation for Option A (Padded Fixed-Shape Memmap Arrays) from the [Packed SFT Format Design](./packed-sft-format-design.md). + +## Format Specification + +``` +shard_000000/ +├── input_ids.npy # int32[num_bins, pack_size] - padded to pack_size +├── loss_mask.npy # uint8[num_bins, pack_size] - padded to pack_size +├── packed_len.npy # uint32[num_bins] - actual length per bin +├── seq_offsets.npy # uint32[num_bins + 1] - CSR pointers +├── seq_starts.npy # uint32[total_seq_starts] - seq_start_id values +└── manifest.json # schema version, metadata +``` + +--- + +## Phase 1: Nemotron Writer Changes + +### 1.1 Writer Module + +Create `src/nemotron/data_prep/packing/writers.py`: + +```python +from __future__ import annotations + +import json +import os +from typing import Any, Protocol + +import numpy as np + + +class PackedShardWriter(Protocol): + """Abstract interface for packed shard output.""" + + def write_bin( + self, + bin_id: int, + input_ids: np.ndarray, + loss_mask: np.ndarray, + seq_start_id: np.ndarray, + ) -> None: ... + + def finalize(self) -> dict[str, Any]: ... + + +class MemmapShardWriter: + """Memory-efficient writer using numpy memmap. + + Uses np.lib.format.open_memmap to pre-allocate arrays on disk, + then writes bins incrementally without accumulating in memory. + + Peak memory: O(pack_size) for scratch buffers only. + """ + + def __init__( + self, + output_dir: str, + num_bins: int, + pack_size: int, + total_seq_starts: int, + dtype: np.dtype = np.int32, + ): + self.output_dir = output_dir + self.num_bins = num_bins + self.pack_size = pack_size + self.dtype = dtype + + os.makedirs(output_dir, exist_ok=True) + + # Pre-allocate memmap arrays with known sizes + self.input_ids = np.lib.format.open_memmap( + f"{output_dir}/input_ids.npy.tmp", + mode='w+', + dtype=dtype, + shape=(num_bins, pack_size), + ) + self.loss_mask = np.lib.format.open_memmap( + f"{output_dir}/loss_mask.npy.tmp", + mode='w+', + dtype=np.uint8, + shape=(num_bins, pack_size), + ) + self.packed_len = np.lib.format.open_memmap( + f"{output_dir}/packed_len.npy.tmp", + mode='w+', + dtype=np.uint32, + shape=(num_bins,), + ) + # CSR format for variable-length seq_start_id + self.seq_offsets = np.lib.format.open_memmap( + f"{output_dir}/seq_offsets.npy.tmp", + mode='w+', + dtype=np.uint32, + shape=(num_bins + 1,), + ) + self.seq_starts = np.lib.format.open_memmap( + f"{output_dir}/seq_starts.npy.tmp", + mode='w+', + dtype=np.uint32, + shape=(total_seq_starts,), + ) + + self._seq_write_pos = 0 + self.seq_offsets[0] = 0 + self._bins_written = 0 + + def write_bin( + self, + bin_id: int, + input_ids: np.ndarray, + loss_mask: np.ndarray, + seq_start_id: np.ndarray, + ) -> None: + """Write a single packed bin to the memmap arrays.""" + L = len(input_ids) + + # Write padded input_ids and loss_mask (rest stays zero) + self.input_ids[bin_id, :L] = input_ids + self.loss_mask[bin_id, :L] = loss_mask + self.packed_len[bin_id] = L + + # Write seq_start_id in CSR format + S = len(seq_start_id) + self.seq_starts[self._seq_write_pos:self._seq_write_pos + S] = seq_start_id + self._seq_write_pos += S + self.seq_offsets[bin_id + 1] = self._seq_write_pos + + self._bins_written += 1 + + def finalize(self) -> dict[str, Any]: + """Flush memmaps, rename to final paths, write manifest.""" + # Flush and close memmaps + del self.input_ids + del self.loss_mask + del self.packed_len + del self.seq_offsets + del self.seq_starts + + # Atomic rename from .tmp to final + for name in ['input_ids', 'loss_mask', 'packed_len', 'seq_offsets', 'seq_starts']: + tmp_path = f"{self.output_dir}/{name}.npy.tmp" + final_path = f"{self.output_dir}/{name}.npy" + os.rename(tmp_path, final_path) + + # Write manifest with explicit dtype including endianness + manifest = { + "version": "1.0", + "format": "memmap_padded_v1", + "num_bins": self.num_bins, + "pack_size": self.pack_size, + "dtype": self.dtype.str, # e.g., " None: + self.packed_data.append({ + "input_ids": input_ids.tolist(), + "loss_mask": loss_mask.tolist(), + "seq_start_id": seq_start_id.tolist(), + }) + + def finalize(self) -> dict[str, Any]: + tmp_path = self.output_path + ".tmp" + with open(tmp_path, "wb") as f: + np.save(f, self.packed_data, allow_pickle=True) + os.rename(tmp_path, self.output_path) + return {"format": "legacy_pickle", "num_bins": len(self.packed_data)} +``` + +### 1.2 Materialize Function + +Add to `src/nemotron/data_prep/packing/materialize.py`: + +```python +def materialize_bin_arrays( + spool_reader: SequenceSpoolReader, + assignment: BinAssignment, + bin_id: int, + pack_size: int, + scratch_input_ids: np.ndarray, + scratch_loss_mask: np.ndarray, +) -> tuple[int, np.ndarray]: + """ + Materialize a single bin directly to numpy arrays. + + Avoids Python list conversion - writes directly to preallocated buffers. + + Args: + spool_reader: Reader for tokenized sequence spool + assignment: Bin assignment from packing algorithm + bin_id: Which bin to materialize + pack_size: Maximum packed sequence length + scratch_input_ids: Preallocated buffer of shape (pack_size,) + scratch_loss_mask: Preallocated buffer of shape (pack_size,) + + Returns: + packed_len: Actual length of packed tokens (excluding padding) + seq_start_id: Array of sequence START positions within the bin. + Invariant: seq_start_id[0] == 0, strictly increasing, + seq_start_id[-1] < packed_len. + To get boundaries: list(seq_start_id) + [packed_len] + """ + seq_indices = assignment.bin_indices(bin_id) + + # Zero the scratch buffers (for padding) + scratch_input_ids[:] = 0 + scratch_loss_mask[:] = 0 + + pos = 0 + seq_start_ids = [] # Collect START positions + + for seq_index in seq_indices: + input_ids_arr, loss_mask_arr = spool_reader.read_sequence(int(seq_index)) + + # Truncate if needed + seq_len = min(len(input_ids_arr), pack_size) + if pos + seq_len > pack_size: + seq_len = pack_size - pos + + if seq_len <= 0: + break + + # Record start position BEFORE writing + seq_start_ids.append(pos) + + # Write directly to scratch buffers (no Python list!) + scratch_input_ids[pos:pos + seq_len] = input_ids_arr[:seq_len] + scratch_loss_mask[pos:pos + seq_len] = loss_mask_arr[:seq_len] + pos += seq_len + + # Apply loss_mask roll (shift right by 1, first position is 0) + # This ensures loss is computed on predicting token[i+1] from token[i] + if pos > 0: + scratch_loss_mask[1:pos] = scratch_loss_mask[:pos-1].copy() + scratch_loss_mask[0] = 0 + + return pos, np.array(seq_start_ids, dtype=np.uint32) +``` + +### 1.3 Update Central Pack Function + +Update `src/nemotron/data_prep/chat_sft_shard_core.py`: + +```python +from nemotron.data_prep.packing.writers import ( + MemmapShardWriter, + LegacyPickleWriter, +) +from nemotron.data_prep.packing.materialize import materialize_bin_arrays + + +def process_chat_sft_pack_from_spool_core( + *, + spool_dir: str, + output_dir: str, + shard_id: str, + pack_size: int, + packer: Packer, + output_fs: AbstractFileSystem, + dtype: np.dtype = np.int32, + packed_storage: str = "legacy_npy_pickle", # NEW PARAM +) -> dict[str, Any]: + """Process spool files into packed output.""" + + # ... existing setup code (load spool, run packer) ... + + # After packing, we know all sizes + num_bins = assignment.num_bins + num_sequences = int(lengths.shape[0]) + total_seq_starts = num_sequences # Each sequence contributes one entry + + # Choose writer based on config + if packed_storage == "memmap_v1": + shard_output_dir = f"{output_dir}/{shard_id}" + writer = MemmapShardWriter( + output_dir=shard_output_dir, + num_bins=num_bins, + pack_size=pack_size, + total_seq_starts=total_seq_starts, + dtype=dtype, + ) + output_path = shard_output_dir + else: + npy_path = f"{output_dir}/{shard_id}.npy" + writer = LegacyPickleWriter(npy_path) + output_path = npy_path + + # Preallocate scratch buffers (reused for every bin!) + scratch_input_ids = np.zeros(pack_size, dtype=dtype) + scratch_loss_mask = np.zeros(pack_size, dtype=np.uint8) + + # Stream bins without accumulating + for bin_id in range(num_bins): + packed_len, seq_start_id = materialize_bin_arrays( + spool_reader=reader, + assignment=assignment, + bin_id=bin_id, + pack_size=pack_size, + scratch_input_ids=scratch_input_ids, + scratch_loss_mask=scratch_loss_mask, + ) + + writer.write_bin( + bin_id=bin_id, + input_ids=scratch_input_ids[:packed_len].copy(), + loss_mask=scratch_loss_mask[:packed_len].copy(), + seq_start_id=seq_start_id, + ) + + writer.finalize() + + # Build receipt + receipt = { + "shard_id": shard_id, + "output_path": output_path, + "format": packed_storage, + "num_bins": num_bins, + "num_sequences": num_sequences, + # ... other receipt fields ... + } + + return receipt +``` + +### 1.4 Config Changes + +Update `src/nemotron/data_prep/config.py`: + +```python +from typing import Literal + +class ChatSftOutputConfig(BaseModel): + # ... existing fields ... + packed_storage: Literal["legacy_npy_pickle", "memmap_v1"] = "legacy_npy_pickle" +``` + +--- + +## Phase 2: Megatron-Bridge Reader Changes + +### 2.1 Memmap Dataset Class + +Add to `src/megatron/bridge/data/datasets/sft.py`: + +```python +import json +from pathlib import Path + +import numpy as np + + +class GPTSFTPackedMemmapDataset(GPTSFTDataset): + """Memory-efficient packed dataset using memmap arrays. + + Reads from the memmap_padded_v1 format: + - input_ids.npy: int32[num_bins, pack_size] + - loss_mask.npy: uint8[num_bins, pack_size] + - packed_len.npy: uint32[num_bins] + - seq_offsets.npy: uint32[num_bins + 1] + - seq_starts.npy: uint32[total_seq_starts] + - manifest.json: metadata + + Memory usage: O(metadata) at load time, O(pack_size) per sample access. + + DataLoader compatibility: Uses lazy-open pattern. Memmaps are opened on + first access in each worker process, avoiding pickling issues with + num_workers > 0. + """ + + def __init__( + self, + file_path: str, + tokenizer: MegatronTokenizer, + **kwargs, + ): + self.shard_dir = file_path + self._loaded = False + self._arrays = None + super().__init__(file_path, tokenizer, **kwargs) + + def _load_dataset(self): + """Load manifest only. Memmaps opened lazily on first access.""" + manifest_path = Path(self.shard_dir) / "manifest.json" + with open(manifest_path) as f: + self.manifest = json.load(f) + + self._num_bins = self.manifest["num_bins"] + self._pack_size = self.manifest["pack_size"] + + def _ensure_memmaps_open(self): + """Open memmaps on first access (per-worker). Thread-safe.""" + if self._loaded: + return + + # Memory-map arrays (NOT loaded into RAM!) + self.input_ids = np.load( + f"{self.shard_dir}/input_ids.npy", mmap_mode='r' + ) + self.loss_mask = np.load( + f"{self.shard_dir}/loss_mask.npy", mmap_mode='r' + ) + self.packed_len = np.load( + f"{self.shard_dir}/packed_len.npy", mmap_mode='r' + ) + self.seq_offsets = np.load( + f"{self.shard_dir}/seq_offsets.npy", mmap_mode='r' + ) + self.seq_starts = np.load( + f"{self.shard_dir}/seq_starts.npy", mmap_mode='r' + ) + self._loaded = True + + def __len__(self): + return self._num_bins + + def __getitem__(self, idx): + """Read a single packed sample with O(1) access.""" + self._ensure_memmaps_open() # Lazy open per worker + + if self.samples_mapping is not None: + idx = self.samples_mapping[idx] + + # Read only the data we need + L = int(self.packed_len[idx]) + input_ids = self.input_ids[idx, :L] + loss_mask = self.loss_mask[idx, :L] + + # Reconstruct seq_start_id from CSR format + start = int(self.seq_offsets[idx]) + end = int(self.seq_offsets[idx + 1]) + seq_start_id = self.seq_starts[start:end].tolist() + + # Boundaries = start positions + final length + # Invariant: seq_start_id contains starts, we add packed_len as final boundary + seq_boundaries = seq_start_id + [L] + + if idx < 0: + loss_mask = np.zeros_like(loss_mask) + + return { + "input_ids": input_ids, + "seq_boundaries": seq_boundaries, + "loss_mask": loss_mask, + } + + def __getstate__(self): + """For pickling across DataLoader workers - exclude memmaps.""" + state = self.__dict__.copy() + # Remove unpicklable memmaps + state['_loaded'] = False + for key in ['input_ids', 'loss_mask', 'packed_len', 'seq_offsets', 'seq_starts']: + state.pop(key, None) + return state + + def __setstate__(self, state): + """Restore from pickle - memmaps will be reopened on first access.""" + self.__dict__.update(state) +``` + +### 2.2 Multi-Shard Support + +```python +from torch.utils.data import Dataset + + +class GPTSFTPackedMultiShardDataset(Dataset): + """Combines multiple packed memmap shards into one dataset.""" + + def __init__( + self, + shard_dirs: list[str], + tokenizer: MegatronTokenizer, + **kwargs, + ): + self.shards = [ + GPTSFTPackedMemmapDataset(d, tokenizer, **kwargs) + for d in shard_dirs + ] + + # Build cumulative index for O(1) shard lookup + self.shard_offsets = np.cumsum([0] + [len(s) for s in self.shards]) + + def __len__(self): + return int(self.shard_offsets[-1]) + + def __getitem__(self, idx): + # Find which shard contains this index + shard_id = int(np.searchsorted(self.shard_offsets[1:], idx, side='right')) + local_idx = idx - int(self.shard_offsets[shard_id]) + return self.shards[shard_id][local_idx] +``` + +### 2.3 Update Factory Function + +Update `create_sft_dataset` in `sft.py`: + +```python +def create_sft_dataset(path: Path, tokenizer, ...) -> GPTSFTDataset: + """Factory function to create appropriate dataset based on format.""" + + if path.suffix == ".npy": + # Legacy pickle format + return GPTSFTPackedDataset( + file_path=str(path), + tokenizer=tokenizer, + **gpt_sft_dataset_kwargs, + ) + elif path.is_dir() and (path / "manifest.json").exists(): + # New memmap format + return GPTSFTPackedMemmapDataset( + file_path=str(path), + tokenizer=tokenizer, + **gpt_sft_dataset_kwargs, + ) + # ... rest of existing logic ... +``` + +--- + +## Phase 3: Cloud Storage Support + +Memmap requires local POSIX paths. For cloud outputs, use write-local-then-upload: + +```python +import tempfile + + +def write_memmap_to_cloud( + output_fs: AbstractFileSystem, + output_dir: str, + num_bins: int, + pack_size: int, + total_seq_starts: int, + write_fn: callable, +) -> dict[str, Any]: + """Write memmap format to cloud storage.""" + + if output_fs.protocol == "file": + # Local filesystem - write directly + writer = MemmapShardWriter(output_dir, num_bins, pack_size, total_seq_starts) + write_fn(writer) + return writer.finalize() + + # Cloud storage - write to temp, then upload + with tempfile.TemporaryDirectory() as tmpdir: + writer = MemmapShardWriter(tmpdir, num_bins, pack_size, total_seq_starts) + write_fn(writer) + manifest = writer.finalize() + + # Upload all files + output_fs.makedirs(output_dir, exist_ok=True) + for filename in os.listdir(tmpdir): + local_path = f"{tmpdir}/{filename}" + remote_path = f"{output_dir}/{filename}" + output_fs.put(local_path, remote_path) + + return manifest +``` + +--- + +## Conversion Tool + +```python +def convert_legacy_to_memmap(legacy_npy_path: str, output_dir: str) -> dict[str, Any]: + """Convert existing pickle .npy to memmap format.""" + data = np.load(legacy_npy_path, allow_pickle=True) + + num_bins = len(data) + pack_size = max(len(d["input_ids"]) for d in data) + total_seq_starts = sum(len(d["seq_start_id"]) for d in data) + + writer = MemmapShardWriter(output_dir, num_bins, pack_size, total_seq_starts) + + for i, d in enumerate(data): + writer.write_bin( + bin_id=i, + input_ids=np.array(d["input_ids"], dtype=np.int32), + loss_mask=np.array(d["loss_mask"], dtype=np.uint8), + seq_start_id=np.array(d["seq_start_id"], dtype=np.uint32), + ) + + return writer.finalize() +``` + +--- + +## Testing + +### Unit Tests + +```python +import tempfile +import numpy as np +import pytest + +from nemotron.data_prep.packing.writers import MemmapShardWriter + + +def test_memmap_writer_roundtrip(): + """Test write and read produce identical data.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Write + writer = MemmapShardWriter( + output_dir=tmpdir, + num_bins=3, + pack_size=10, + total_seq_starts=6, + ) + + test_data = [ + (np.array([1, 2, 3]), np.array([0, 1, 1]), np.array([0])), + (np.array([4, 5, 6, 7]), np.array([0, 0, 1, 1]), np.array([0, 2])), + (np.array([8, 9]), np.array([0, 1]), np.array([0, 1, 2])), + ] + + for i, (ids, mask, starts) in enumerate(test_data): + writer.write_bin(i, ids, mask, starts) + + writer.finalize() + + # Read back + input_ids = np.load(f"{tmpdir}/input_ids.npy", mmap_mode='r') + packed_len = np.load(f"{tmpdir}/packed_len.npy", mmap_mode='r') + + for i, (expected_ids, _, _) in enumerate(test_data): + L = packed_len[i] + actual_ids = input_ids[i, :L] + np.testing.assert_array_equal(actual_ids, expected_ids) + + +def test_memmap_writer_memory_efficiency(): + """Verify peak memory stays constant regardless of data size.""" + import tracemalloc + + with tempfile.TemporaryDirectory() as tmpdir: + tracemalloc.start() + + writer = MemmapShardWriter( + output_dir=tmpdir, + num_bins=10000, + pack_size=2048, + total_seq_starts=50000, + ) + + # Write many bins + for i in range(10000): + writer.write_bin( + i, + np.random.randint(0, 50000, size=2000, dtype=np.int32), + np.random.randint(0, 2, size=2000, dtype=np.uint8), + np.array([0, 500, 1000, 1500], dtype=np.uint32), + ) + + writer.finalize() + + current, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + + # Peak should be well under 100MB (actual data would be ~400MB) + assert peak < 100 * 1024 * 1024, f"Peak memory {peak / 1e6:.1f} MB too high" + + +def test_dataloader_multiworker(): + """Verify dataset works with DataLoader num_workers > 0.""" + from torch.utils.data import DataLoader + + with tempfile.TemporaryDirectory() as tmpdir: + # Setup: write test data + writer = MemmapShardWriter(tmpdir, num_bins=100, pack_size=64, total_seq_starts=200) + for i in range(100): + writer.write_bin( + i, + np.arange(50, dtype=np.int32), + np.ones(50, dtype=np.uint8), + np.array([0, 25], dtype=np.uint32), + ) + writer.finalize() + + # Create dataset and dataloader with workers + dataset = GPTSFTPackedMemmapDataset(tmpdir, tokenizer=None) + loader = DataLoader( + dataset, + batch_size=8, + num_workers=4, + persistent_workers=True, + ) + + # Iterate through entire dataset + total_samples = 0 + for batch in loader: + total_samples += len(batch["input_ids"]) + + assert total_samples == 100, f"Expected 100 samples, got {total_samples}" +``` diff --git a/docs/packed-sft-impl-parquet-megatron-bridge.md b/docs/packed-sft-impl-parquet-megatron-bridge.md new file mode 100644 index 000000000..0c9b41b81 --- /dev/null +++ b/docs/packed-sft-impl-parquet-megatron-bridge.md @@ -0,0 +1,530 @@ +# Implementation Plan: Parquet Reader (Megatron-Bridge) + +This document details the Megatron-Bridge changes for Option B (Parquet) from the [Packed SFT Format Design](./packed-sft-format-design.md). + +For the Nemotron writer implementation, see [packed-sft-impl-parquet-nemotron.md](./packed-sft-impl-parquet-nemotron.md). + +## Format Specification + +``` +shard_000000.parquet + Schema: + - input_ids: list # Variable-length token ids + - loss_mask: list # Variable-length loss mask + - seq_start_id: list # Variable-length sequence start positions + + Compression: zstd (default) + Row groups: ~1000 rows each +``` + +**Invariant:** `seq_start_id` contains START positions `[0, start_1, start_2, ...]`. To reconstruct boundaries for attention masking: `boundaries = list(seq_start_id) + [len(input_ids)]`. + +--- + +## Files to Modify + +| File | Change | +|------|--------| +| `src/megatron/bridge/data/datasets/sft.py` | Add `GPTSFTPackedParquetDataset` class (after line 1003) | +| `src/megatron/bridge/data/datasets/sft.py` | Update `create_sft_dataset()` factory (lines 173-179) | +| `tests/functional_tests/data/datasets/test_sft.py` | Add `TestDataGPTSFTPackedParquetDataset` class | + +--- + +## Existing Class Structure + +The new class must follow the existing `GPTSFTPackedDataset` pattern (lines 738-1003 in `sft.py`): + +``` +GPTSFTDataset (base class, lines 194-736) +├── __init__(file_path, tokenizer, **kwargs) +├── __len__() +├── __getitem__(idx) → dict +├── _load_dataset() +├── _build_samples_mapping() +├── _build_loss_mask(processed_example) +├── collate_fn(batch) → dict of tensors +└── _collate_item(item, max_length, pad_id) + +GPTSFTPackedDataset (current .npy reader, lines 738-1003) +├── Inherits all base methods +├── Overrides _load_dataset() to use np.load(..., allow_pickle=True) +├── Overrides __getitem__() to return {input_ids, seq_boundaries, loss_mask} +├── Overrides collate_fn() for packed sequence handling with cu_seqlens +└── Adds return_cu_seqlen, pad_cu_seqlens parameters +``` + +--- + +## Step 1: Add Parquet Dataset Class + +Add after `GPTSFTPackedDataset` class (after line 1003) in `sft.py`: + +```python +import pyarrow.parquet as pq + + +class GPTSFTPackedParquetDataset(GPTSFTPackedDataset): + """Memory-efficient packed dataset using Parquet format. + + Identical interface to GPTSFTPackedDataset but reads from .parquet files + instead of .npy files. Uses lazy loading with row group caching for + memory efficiency. + + Memory usage: O(metadata) at load, O(row_group_size * pack_size) during access. + + DataLoader compatibility: Uses lazy-open pattern. ParquetFile is opened on + first access in each worker process, avoiding pickling issues with + num_workers > 0. + + Cloud behavior: For local files, uses true memory-mapping. For cloud URIs + (s3://, gs://), falls back to ranged HTTP reads. + """ + + def __init__( + self, + file_path: str, + tokenizer: MegatronTokenizer, + return_cu_seqlen: bool = True, + pad_cu_seqlens: bool = False, + pack_metadata_file_path: str | None = None, + **kwargs, + ): + # Store path for lazy loading + self._parquet_path = file_path + self._pq_loaded = False + + # Call parent init (will call _load_dataset) + super().__init__( + file_path=file_path, + tokenizer=tokenizer, + return_cu_seqlen=return_cu_seqlen, + pad_cu_seqlens=pad_cu_seqlens, + pack_metadata_file_path=pack_metadata_file_path, + **kwargs, + ) + + def _load_dataset(self): + """Override: Load only metadata, defer full load to first access. + + This enables pickling across DataLoader workers - the ParquetFile + handle is not picklable, so we open it lazily per-worker. + """ + # Don't load yet - will be opened in _ensure_parquet_open() + self._pq_loaded = False + + # Create a minimal indexed_dataset proxy for __len__ to work + # We need to peek at the parquet metadata to get row count + try: + if MultiStorageClientFeature.is_enabled(): + # For cloud storage, we still need to read metadata + msc = MultiStorageClientFeature.import_package() + pf = pq.ParquetFile(self._parquet_path, filesystem=msc.get_filesystem(self._parquet_path)) + else: + pf = pq.ParquetFile(self._parquet_path) + + self._num_rows = pf.metadata.num_rows + self._num_row_groups = pf.metadata.num_row_groups + + # Build row group offset index + self._row_group_offsets = [0] + for i in range(self._num_row_groups): + rg_rows = pf.metadata.row_group(i).num_rows + self._row_group_offsets.append(self._row_group_offsets[-1] + rg_rows) + self._row_group_offsets = np.array(self._row_group_offsets) + + # Close the file - will reopen lazily + pf = None + + except Exception as e: + logger.error( + f"Failed to load packed Parquet dataset. The dataset should be a `.parquet` file. " + f"Please check if the packed dataset was prepared correctly. The original error was:\n {e}", + ) + exit(1) + + # Create a proxy object that supports len() for _build_samples_mapping + class _ParquetProxy: + def __init__(self, num_rows): + self._num_rows = num_rows + def __len__(self): + return self._num_rows + + self.indexed_dataset = _ParquetProxy(self._num_rows) + + def _ensure_parquet_open(self): + """Open ParquetFile on first access (per-worker).""" + if self._pq_loaded: + return + + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + self._parquet_file = pq.ParquetFile( + self._parquet_path, + filesystem=msc.get_filesystem(self._parquet_path), + memory_map=True, + ) + else: + self._parquet_file = pq.ParquetFile( + self._parquet_path, + memory_map=True, # Only effective for local files + ) + + # Row group cache + self._cached_rg_id = -1 + self._cached_data = None + + self._pq_loaded = True + + def _get_row_group_for_idx(self, idx: int) -> int: + """Find which row group contains the given index.""" + return int(np.searchsorted(self._row_group_offsets[1:], idx, side='right')) + + def __len__(self): + return self._num_rows + + def __getitem__(self, idx): + """Read a single packed sample. + + Returns same format as GPTSFTPackedDataset: + { + "input_ids": list[int], + "seq_boundaries": list[int], # seq_start_id + [len(input_ids)] + "loss_mask": list[int], + } + """ + self._ensure_parquet_open() # Lazy open per worker + + if self.samples_mapping is not None: + idx = self.samples_mapping[idx] + + # Find row group and local index within row group + rg_id = self._get_row_group_for_idx(idx) + local_idx = idx - self._row_group_offsets[rg_id] + + # Load row group if not cached + if rg_id != self._cached_rg_id: + table = self._parquet_file.read_row_group( + rg_id, + columns=['input_ids', 'loss_mask', 'seq_start_id'], + ) + # Convert to Python lists once per row group + self._cached_data = { + 'input_ids': table['input_ids'].to_pylist(), + 'loss_mask': table['loss_mask'].to_pylist(), + 'seq_start_id': table['seq_start_id'].to_pylist(), + } + self._cached_rg_id = rg_id + + input_ids = self._cached_data['input_ids'][local_idx] + loss_mask = self._cached_data['loss_mask'][local_idx] + seq_start_id = self._cached_data['seq_start_id'][local_idx] + + # Reconstruct seq_boundaries from seq_start_id (same as GPTSFTPackedDataset) + seq_boundaries = seq_start_id + [len(input_ids)] + + if idx < 0: + loss_mask = [0] * len(loss_mask) + + return { + "input_ids": input_ids, + "seq_boundaries": seq_boundaries, + "loss_mask": loss_mask, + } + + def __getstate__(self): + """For pickling across DataLoader workers - exclude file handle.""" + state = self.__dict__.copy() + state['_pq_loaded'] = False + for key in ['_parquet_file', '_cached_rg_id', '_cached_data']: + state.pop(key, None) + return state + + def __setstate__(self, state): + """Restore from pickle - file will be reopened on first access.""" + self.__dict__.update(state) + + # Note: collate_fn is inherited from GPTSFTPackedDataset - no changes needed + # since __getitem__ returns the same format +``` + +--- + +## Step 2: Update Factory Function + +Modify `create_sft_dataset()` (around line 173) to detect `.parquet` files: + +```python +def create_sft_dataset( + path: Path, + tokenizer: "MegatronTokenizer", + # ... existing parameters ... +) -> "GPTSFTDataset": + # ... existing docstring and gpt_sft_dataset_kwargs setup ... + + if path.suffix == ".npy": + return GPTSFTPackedDataset( + pack_metadata_file_path=pack_metadata_file_path, + pad_cu_seqlens=pad_cu_seqlens, + **gpt_sft_dataset_kwargs, + **kwargs, + ) + elif path.suffix == ".parquet": + # NEW: Parquet packed format + return GPTSFTPackedParquetDataset( + pack_metadata_file_path=pack_metadata_file_path, + pad_cu_seqlens=pad_cu_seqlens, + **gpt_sft_dataset_kwargs, + **kwargs, + ) + elif chat: + return GPTSFTChatDataset( + # ... existing code ... + ) + else: + return GPTSFTDataset( + # ... existing code ... + ) +``` + +--- + +## Step 3: Add Import + +Add PyArrow import at top of `sft.py` (around line 20): + +```python +try: + import pyarrow.parquet as pq + PYARROW_AVAILABLE = True +except ImportError: + PYARROW_AVAILABLE = False + pq = None +``` + +And add a check in `GPTSFTPackedParquetDataset.__init__`: + +```python +if not PYARROW_AVAILABLE: + raise ImportError( + "PyArrow is required for Parquet dataset support. " + "Install with: pip install pyarrow" + ) +``` + +--- + +## Step 4: Multi-Shard Support (Optional) + +For datasets split across multiple parquet files: + +```python +class GPTSFTPackedParquetMultiShardDataset(Dataset): + """Combines multiple Parquet shards into one dataset. + + Usage: + paths = sorted(glob.glob("data/shard_*.parquet")) + dataset = GPTSFTPackedParquetMultiShardDataset(paths, tokenizer, **kwargs) + """ + + def __init__( + self, + parquet_paths: list[str], + tokenizer: MegatronTokenizer, + **kwargs, + ): + self.shards = [ + GPTSFTPackedParquetDataset(p, tokenizer, **kwargs) + for p in parquet_paths + ] + + # Build cumulative index for O(log n) shard lookup + self.shard_offsets = np.cumsum([0] + [len(s) for s in self.shards]) + + # Store collate_fn from first shard + self.collate_fn = self.shards[0].collate_fn + + def __len__(self): + return int(self.shard_offsets[-1]) + + def __getitem__(self, idx): + shard_id = int(np.searchsorted(self.shard_offsets[1:], idx, side='right')) + local_idx = idx - int(self.shard_offsets[shard_id]) + return self.shards[shard_id][local_idx] +``` + +--- + +## Testing + +Add tests to `tests/functional_tests/data/datasets/test_sft.py`: + +```python +class TestDataGPTSFTPackedParquetDataset: + """Tests for GPTSFTPackedParquetDataset.""" + + @pytest.fixture(autouse=True) + def setup_and_teardown_parallel_state(self): + """Initialize distributed state for tests.""" + if not dist.is_initialized(): + dist.init_process_group("gloo", rank=0, world_size=1) + parallel_state.initialize_model_parallel() + yield + parallel_state.destroy_model_parallel() + + def test_parquet_dataset_basic(self, tmp_path, get_tokenizer): + """Test basic read functionality.""" + import pyarrow as pa + import pyarrow.parquet as pq + + # Create test parquet file + parquet_path = tmp_path / "test.parquet" + schema = pa.schema([ + ('input_ids', pa.list_(pa.int32())), + ('loss_mask', pa.list_(pa.uint8())), + ('seq_start_id', pa.list_(pa.int32())), + ]) + + table = pa.Table.from_pydict({ + 'input_ids': [[1, 2, 3, 4, 5], [10, 20, 30, 40]], + 'loss_mask': [[0, 0, 1, 1, 1], [0, 1, 1, 1]], + 'seq_start_id': [[0, 2], [0]], + }, schema=schema) + + pq.write_table(table, str(parquet_path), compression='zstd') + + # Load dataset + tokenizer = get_tokenizer() + dataset = GPTSFTPackedParquetDataset( + file_path=str(parquet_path), + tokenizer=tokenizer, + max_seq_length=2048, + ) + + assert len(dataset) == 2 + + sample = dataset[0] + assert sample['input_ids'] == [1, 2, 3, 4, 5] + assert sample['seq_boundaries'] == [0, 2, 5] + assert sample['loss_mask'] == [0, 0, 1, 1, 1] + + def test_parquet_dataset_collate_fn(self, tmp_path, get_tokenizer): + """Test that collate_fn works correctly (inherited from parent).""" + import pyarrow as pa + import pyarrow.parquet as pq + + parquet_path = tmp_path / "test.parquet" + schema = pa.schema([ + ('input_ids', pa.list_(pa.int32())), + ('loss_mask', pa.list_(pa.uint8())), + ('seq_start_id', pa.list_(pa.int32())), + ]) + + # Create batch of samples + table = pa.Table.from_pydict({ + 'input_ids': [[1, 2, 3, 4], [5, 6, 7, 8]], + 'loss_mask': [[0, 1, 1, 1], [0, 0, 1, 1]], + 'seq_start_id': [[0], [0]], + }, schema=schema) + + pq.write_table(table, str(parquet_path)) + + tokenizer = get_tokenizer() + dataset = GPTSFTPackedParquetDataset( + file_path=str(parquet_path), + tokenizer=tokenizer, + max_seq_length=2048, + ) + + batch = [dataset[0], dataset[1]] + collated = dataset.collate_fn(batch) + + assert 'tokens' in collated + assert 'labels' in collated + assert 'loss_mask' in collated + assert 'position_ids' in collated + + def test_parquet_dataset_with_samples_mapping(self, tmp_path, get_tokenizer): + """Test that max_num_samples and shuffling work.""" + import pyarrow as pa + import pyarrow.parquet as pq + + parquet_path = tmp_path / "test.parquet" + schema = pa.schema([ + ('input_ids', pa.list_(pa.int32())), + ('loss_mask', pa.list_(pa.uint8())), + ('seq_start_id', pa.list_(pa.int32())), + ]) + + # Create 10 samples + rows = [{'input_ids': [i], 'loss_mask': [1], 'seq_start_id': [0]} for i in range(10)] + table = pa.Table.from_pylist(rows, schema=schema) + pq.write_table(table, str(parquet_path)) + + tokenizer = get_tokenizer() + dataset = GPTSFTPackedParquetDataset( + file_path=str(parquet_path), + tokenizer=tokenizer, + max_seq_length=2048, + max_num_samples=5, # Limit to 5 samples + seed=42, + ) + + assert len(dataset) == 5 +``` + +--- + +## Cloud Storage Support + +The implementation automatically supports cloud storage through `MultiStorageClientFeature` (existing Megatron-Bridge pattern): + +```python +# Local file +dataset = GPTSFTPackedParquetDataset("data/shard.parquet", tokenizer) + +# S3 (if MultiStorageClient is configured) +dataset = GPTSFTPackedParquetDataset("s3://bucket/data/shard.parquet", tokenizer) + +# GCS +dataset = GPTSFTPackedParquetDataset("gs://bucket/data/shard.parquet", tokenizer) +``` + +--- + +## Memory Characteristics + +| Metric | Value | +|--------|-------| +| Load time memory | O(metadata) - parquet footer + row group index | +| Per-sample access | O(row_group) - cached row group converted to Python lists | +| Row group cache | 1 row group per dataset instance per worker | + +**Performance notes:** +- Row group cache uses `to_pylist()` which reintroduces Python object overhead +- For training with sequential access, this is efficient (cache hit rate ~99%) +- For random access patterns, consider smaller row groups (100-500 rows) + +--- + +## Debugging & Inspection + +```python +# Quick inspection with PyArrow +import pyarrow.parquet as pq + +pf = pq.ParquetFile('shard.parquet') +print(f"Rows: {pf.metadata.num_rows}") +print(f"Row groups: {pf.metadata.num_row_groups}") +print(f"Schema: {pf.schema_arrow}") + +# Read first row group +df = pf.read_row_group(0).to_pandas() +print(df.head()) +``` + +```bash +# Query with DuckDB +duckdb -c "SELECT COUNT(*) FROM 'shard.parquet'" +duckdb -c "SELECT * FROM 'shard.parquet' LIMIT 5" +``` diff --git a/docs/packed-sft-impl-parquet-nemotron.md b/docs/packed-sft-impl-parquet-nemotron.md new file mode 100644 index 000000000..84188bfcb --- /dev/null +++ b/docs/packed-sft-impl-parquet-nemotron.md @@ -0,0 +1,626 @@ +# Implementation Plan: Parquet Writer (Nemotron) + +This document details the Nemotron data_prep changes for Option B (Parquet) from the [Packed SFT Format Design](./packed-sft-format-design.md). + +For the Megatron-Bridge reader implementation, see [packed-sft-impl-parquet-megatron-bridge.md](./packed-sft-impl-parquet-megatron-bridge.md). + +## Format Specification + +``` +shard_000000.parquet + Schema: + - input_ids: list # Variable-length token ids + - loss_mask: list # Variable-length loss mask + - seq_start_id: list # Variable-length sequence start positions + + Compression: zstd (default) + Row groups: ~1000 rows each (tunable) +``` + +**Key advantages:** +- Single file per shard +- Native variable-length support (no padding waste) +- 2-3x compression with zstd +- Direct cloud storage support (S3, GCS, Azure) + +--- + +## Files to Modify + +| File | Change | +|------|--------| +| `src/nemotron/data_prep/packing/writers.py` | New file with `ParquetShardWriter` class | +| `src/nemotron/data_prep/packing/materialize.py` | Add `materialize_bin_arrays()` function | +| `src/nemotron/data_prep/chat_sft_shard_core.py` | Integrate writer selection | +| `src/nemotron/data_prep/config.py` | Add `packed_storage` config option | + +--- + +## Step 1: Writer Module + +Create `src/nemotron/data_prep/packing/writers.py`: + +```python +from __future__ import annotations + +import os +from typing import Any + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq + + +class ParquetShardWriter: + """Memory-efficient writer using Parquet with streaming row groups. + + Writes bins incrementally, flushing to disk every `row_group_size` bins. + Supports direct cloud writes via PyArrow filesystem integration. + + Peak memory: O(row_group_size * pack_size) for the row group buffer. + """ + + SCHEMA = pa.schema([ + ('input_ids', pa.list_(pa.int32())), + ('loss_mask', pa.list_(pa.uint8())), + ('seq_start_id', pa.list_(pa.int32())), + ]) + + def __init__( + self, + output_path: str, + row_group_size: int = 1000, + compression: str = 'zstd', + filesystem: pa.fs.FileSystem | None = None, + ): + """ + Args: + output_path: Path to output .parquet file (local or cloud URI) + row_group_size: Number of bins per row group (tune for access patterns) + compression: Compression codec ('zstd', 'snappy', 'gzip', 'none') + filesystem: Optional PyArrow filesystem for cloud storage + """ + self.output_path = output_path + self.tmp_path = output_path + '.tmp' + self.row_group_size = row_group_size + self.compression = compression + self.filesystem = filesystem + + # Accumulate numpy arrays for current row group + self._input_ids_values: list[np.ndarray] = [] + self._loss_mask_values: list[np.ndarray] = [] + self._seq_start_values: list[np.ndarray] = [] + self._count = 0 + self._total_bins = 0 + + # Open writer + self._writer = pq.ParquetWriter( + self.tmp_path, + self.SCHEMA, + compression=compression, + filesystem=filesystem, + ) + + def write_bin( + self, + bin_id: int, + input_ids: np.ndarray, + loss_mask: np.ndarray, + seq_start_id: np.ndarray, + ) -> None: + """Buffer a single bin, flushing to disk when row group is full.""" + # Store numpy arrays directly (no .tolist() conversion!) + self._input_ids_values.append(input_ids.astype(np.int32, copy=False)) + self._loss_mask_values.append(loss_mask.astype(np.uint8, copy=False)) + self._seq_start_values.append(seq_start_id.astype(np.int32, copy=False)) + self._count += 1 + self._total_bins += 1 + + # Flush row group when buffer is full + if self._count >= self.row_group_size: + self._flush_buffer() + + def _flush_buffer(self) -> None: + """Write buffered bins as a row group.""" + if self._count == 0: + return + + # Build Arrow arrays from numpy (efficient, minimal copying) + input_ids_arr = pa.array(self._input_ids_values, type=pa.list_(pa.int32())) + loss_mask_arr = pa.array(self._loss_mask_values, type=pa.list_(pa.uint8())) + seq_start_arr = pa.array(self._seq_start_values, type=pa.list_(pa.int32())) + + table = pa.Table.from_arrays( + [input_ids_arr, loss_mask_arr, seq_start_arr], + schema=self.SCHEMA, + ) + self._writer.write_table(table) + + # Clear buffers + self._input_ids_values.clear() + self._loss_mask_values.clear() + self._seq_start_values.clear() + self._count = 0 + + def finalize(self) -> dict[str, Any]: + """Flush remaining data, close writer, rename to final path.""" + self._flush_buffer() + self._writer.close() + + # Atomic rename + if self.filesystem: + self.filesystem.move(self.tmp_path, self.output_path) + else: + os.rename(self.tmp_path, self.output_path) + + return { + 'format': 'parquet', + 'compression': self.compression, + 'num_bins': self._total_bins, + 'row_group_size': self.row_group_size, + } + + +class LegacyPickleWriter: + """Backward-compatible writer using pickle-of-dicts (current format).""" + + def __init__(self, output_path: str): + self.output_path = output_path + self.packed_data: list[dict] = [] + + def write_bin( + self, + bin_id: int, + input_ids: np.ndarray, + loss_mask: np.ndarray, + seq_start_id: np.ndarray, + ) -> None: + self.packed_data.append({ + "input_ids": input_ids.tolist(), + "loss_mask": loss_mask.tolist(), + "seq_start_id": seq_start_id.tolist(), + }) + + def finalize(self) -> dict[str, Any]: + tmp_path = self.output_path + ".tmp" + with open(tmp_path, "wb") as f: + np.save(f, self.packed_data, allow_pickle=True) + os.rename(tmp_path, self.output_path) + return {"format": "legacy_pickle", "num_bins": len(self.packed_data)} +``` + +--- + +## Step 2: Materialize Function + +Add to `src/nemotron/data_prep/packing/materialize.py`: + +```python +def materialize_bin_arrays( + spool_reader: SequenceSpoolReader, + assignment: BinAssignment, + bin_id: int, + pack_size: int, + scratch_input_ids: np.ndarray, + scratch_loss_mask: np.ndarray, +) -> tuple[int, np.ndarray]: + """ + Materialize a single bin directly to numpy arrays. + + Avoids Python list conversion - writes directly to preallocated buffers. + + Args: + spool_reader: Reader for tokenized sequence spool + assignment: Bin assignment from packing algorithm + bin_id: Which bin to materialize + pack_size: Maximum packed sequence length + scratch_input_ids: Preallocated buffer of shape (pack_size,) + scratch_loss_mask: Preallocated buffer of shape (pack_size,) + + Returns: + packed_len: Actual length of packed tokens (excluding padding) + seq_start_id: Array of sequence START positions within the bin. + Invariant: seq_start_id[0] == 0, strictly increasing, + seq_start_id[-1] < packed_len. + To get boundaries: list(seq_start_id) + [packed_len] + """ + seq_indices = assignment.bin_indices(bin_id) + + # Zero the scratch buffers (for padding) + scratch_input_ids[:] = 0 + scratch_loss_mask[:] = 0 + + pos = 0 + seq_start_ids = [] # Collect START positions + + for seq_index in seq_indices: + input_ids_arr, loss_mask_arr = spool_reader.read_sequence(int(seq_index)) + + # Truncate if needed + seq_len = min(len(input_ids_arr), pack_size) + if pos + seq_len > pack_size: + seq_len = pack_size - pos + + if seq_len <= 0: + break + + # Record start position BEFORE writing + seq_start_ids.append(pos) + + # Write directly to scratch buffers (no Python list!) + scratch_input_ids[pos:pos + seq_len] = input_ids_arr[:seq_len] + scratch_loss_mask[pos:pos + seq_len] = loss_mask_arr[:seq_len] + pos += seq_len + + # Apply loss_mask roll (shift right by 1, first position is 0) + # This ensures loss is computed on predicting token[i+1] from token[i] + if pos > 0: + scratch_loss_mask[1:pos] = scratch_loss_mask[:pos-1].copy() + scratch_loss_mask[0] = 0 + + return pos, np.array(seq_start_ids, dtype=np.uint32) +``` + +--- + +## Step 3: Update Central Pack Function + +Update `src/nemotron/data_prep/chat_sft_shard_core.py`: + +```python +from nemotron.data_prep.packing.writers import ( + ParquetShardWriter, + LegacyPickleWriter, +) +from nemotron.data_prep.packing.materialize import materialize_bin_arrays + + +def process_chat_sft_pack_from_spool_core( + *, + spool_dir: str, + output_dir: str, + shard_id: str, + pack_size: int, + packer: Packer, + output_fs: AbstractFileSystem, + dtype: np.dtype = np.int32, + packed_storage: str = "legacy_npy_pickle", # NEW PARAM + parquet_row_group_size: int = 1000, + parquet_compression: str = "zstd", +) -> dict[str, Any]: + """Process spool files into packed output.""" + + # ... existing setup code (load spool, run packer) ... + + # Choose writer based on config + if packed_storage == "parquet": + parquet_path = f"{output_dir}/{shard_id}.parquet" + + # Get PyArrow filesystem for cloud support + pa_filesystem = _get_pyarrow_filesystem(output_fs) + + writer = ParquetShardWriter( + output_path=parquet_path, + row_group_size=parquet_row_group_size, + compression=parquet_compression, + filesystem=pa_filesystem, + ) + output_path = parquet_path + else: + npy_path = f"{output_dir}/{shard_id}.npy" + writer = LegacyPickleWriter(npy_path) + output_path = npy_path + + # Preallocate scratch buffers + scratch_input_ids = np.zeros(pack_size, dtype=dtype) + scratch_loss_mask = np.zeros(pack_size, dtype=np.uint8) + + # Stream bins + for bin_id in range(num_bins): + packed_len, seq_start_id = materialize_bin_arrays( + spool_reader=reader, + assignment=assignment, + bin_id=bin_id, + pack_size=pack_size, + scratch_input_ids=scratch_input_ids, + scratch_loss_mask=scratch_loss_mask, + ) + + writer.write_bin( + bin_id=bin_id, + input_ids=scratch_input_ids[:packed_len].copy(), + loss_mask=scratch_loss_mask[:packed_len].copy(), + seq_start_id=seq_start_id, + ) + + result = writer.finalize() + + # Build receipt + receipt = { + "shard_id": shard_id, + "output_path": output_path, + "format": packed_storage, + "num_bins": result["num_bins"], + # ... other receipt fields ... + } + + return receipt + + +def _get_pyarrow_filesystem(fsspec_fs: AbstractFileSystem) -> pa.fs.FileSystem | None: + """Convert fsspec filesystem to PyArrow filesystem.""" + if fsspec_fs.protocol == "file": + return None # Use default local filesystem + + if fsspec_fs.protocol == "s3": + return pa.fs.S3FileSystem() + elif fsspec_fs.protocol == "gs" or fsspec_fs.protocol == "gcs": + return pa.fs.GcsFileSystem() + elif fsspec_fs.protocol == "az" or fsspec_fs.protocol == "abfs": + return pa.fs.AzureFileSystem() + else: + # Fallback: use PyArrow's fsspec wrapper + from pyarrow.fs import PyFileSystem, FSSpecHandler + return PyFileSystem(FSSpecHandler(fsspec_fs)) +``` + +--- + +## Step 4: Config Changes + +Update `src/nemotron/data_prep/config.py`: + +```python +from typing import Literal + + +class ChatSftOutputConfig(BaseModel): + # ... existing fields ... + packed_storage: Literal["legacy_npy_pickle", "parquet"] = "legacy_npy_pickle" + parquet_row_group_size: int = 1000 + parquet_compression: Literal["zstd", "snappy", "gzip", "none"] = "zstd" +``` + +--- + +## Cloud Storage Support + +Parquet has native cloud storage support via PyArrow: + +### Direct Cloud Writes + +```python +import pyarrow.fs as pafs + + +def get_filesystem_for_uri(uri: str) -> tuple[pa.fs.FileSystem, str]: + """Parse URI and return appropriate filesystem + path.""" + if uri.startswith("s3://"): + return pafs.S3FileSystem(), uri[5:] + elif uri.startswith("gs://"): + return pafs.GcsFileSystem(), uri[5:] + elif uri.startswith("az://") or uri.startswith("abfs://"): + return pafs.AzureFileSystem(), uri.split("://", 1)[1] + else: + return None, uri # Local filesystem + + +# Usage in writer +def write_to_cloud(output_uri: str, ...): + filesystem, path = get_filesystem_for_uri(output_uri) + + writer = ParquetShardWriter( + output_path=path, + filesystem=filesystem, + compression='zstd', + ) + # ... write bins ... + writer.finalize() +``` + +--- + +## Performance Tuning + +### Row Group Size + +The `row_group_size` parameter affects both write and read performance: + +| Row Group Size | Write Memory | Read Latency | Best For | +|----------------|--------------|--------------|----------| +| 100 | ~2 MB | Lower | Random access heavy | +| 1000 (default) | ~20 MB | Medium | Balanced workloads | +| 10000 | ~200 MB | Higher | Sequential scans | + +```python +# For training (sequential access): larger row groups +writer = ParquetShardWriter(output_path, row_group_size=5000) + +# For inference (random access): smaller row groups +writer = ParquetShardWriter(output_path, row_group_size=100) +``` + +### Compression + +| Codec | Ratio | Write Speed | Read Speed | +|-------|-------|-------------|------------| +| zstd | ~2.5x | Medium | Fast | +| snappy | ~1.5x | Fast | Fast | +| gzip | ~3x | Slow | Medium | +| none | 1x | Fastest | Fastest | + +--- + +## Conversion Tool + +Convert existing legacy `.npy` files to Parquet: + +```python +def convert_legacy_to_parquet( + legacy_npy_path: str, + output_path: str, + row_group_size: int = 1000, + compression: str = 'zstd', +) -> dict[str, Any]: + """Convert existing pickle .npy to Parquet format. + + WARNING: Only use with trusted .npy files - pickle loading is unsafe + with untrusted inputs. + """ + import os + + data = np.load(legacy_npy_path, allow_pickle=True) + + writer = ParquetShardWriter( + output_path=output_path, + row_group_size=row_group_size, + compression=compression, + ) + + for i, d in enumerate(data): + writer.write_bin( + bin_id=i, + input_ids=np.array(d["input_ids"], dtype=np.int32), + loss_mask=np.array(d["loss_mask"], dtype=np.uint8), + seq_start_id=np.array(d["seq_start_id"], dtype=np.int32), + ) + + result = writer.finalize() + + # Print size comparison + original_size = os.path.getsize(legacy_npy_path) + new_size = os.path.getsize(output_path) + print(f"Converted {result['num_bins']} bins to {output_path}") + print(f"Size: {original_size / 1e6:.1f} MB -> {new_size / 1e6:.1f} MB " + f"({new_size / original_size * 100:.1f}%)") + + return result +``` + +--- + +## Testing + +### Unit Tests + +```python +import tempfile +import numpy as np +import pytest + +from nemotron.data_prep.packing.writers import ParquetShardWriter + + +def test_parquet_writer_roundtrip(): + """Test write and read produce identical data.""" + with tempfile.NamedTemporaryFile(suffix='.parquet', delete=False) as f: + output_path = f.name + + try: + # Write + writer = ParquetShardWriter( + output_path=output_path, + row_group_size=2, + compression='zstd', + ) + + test_data = [ + (np.array([1, 2, 3]), np.array([0, 1, 1]), np.array([0])), + (np.array([4, 5, 6, 7]), np.array([0, 0, 1, 1]), np.array([0, 2])), + (np.array([8, 9]), np.array([0, 1]), np.array([0, 1, 2])), + ] + + for i, (ids, mask, starts) in enumerate(test_data): + writer.write_bin(i, ids, mask, starts) + + writer.finalize() + + # Read back + import pyarrow.parquet as pq + table = pq.read_table(output_path) + + for i, (expected_ids, expected_mask, expected_starts) in enumerate(test_data): + actual_ids = table['input_ids'][i].as_py() + actual_mask = table['loss_mask'][i].as_py() + actual_starts = table['seq_start_id'][i].as_py() + + np.testing.assert_array_equal(actual_ids, expected_ids) + np.testing.assert_array_equal(actual_mask, expected_mask) + np.testing.assert_array_equal(actual_starts, expected_starts) + + finally: + import os + os.unlink(output_path) + + +def test_parquet_memory_efficiency(): + """Verify peak memory stays reasonable.""" + import tracemalloc + + with tempfile.NamedTemporaryFile(suffix='.parquet', delete=False) as f: + output_path = f.name + + try: + tracemalloc.start() + + writer = ParquetShardWriter( + output_path=output_path, + row_group_size=100, # Small row groups + compression='zstd', + ) + + for i in range(10000): + writer.write_bin( + i, + np.random.randint(0, 50000, size=2000, dtype=np.int32), + np.random.randint(0, 2, size=2000, dtype=np.uint8), + np.array([0, 500, 1000, 1500], dtype=np.int32), + ) + + writer.finalize() + + current, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + + print(f"Peak memory: {peak / 1e6:.1f} MB") + # With row_group_size=100, peak should be well under 50MB + assert peak < 50 * 1024 * 1024 + + finally: + import os + os.unlink(output_path) + + +def test_parquet_compression_ratio(): + """Verify compression achieves expected ratio.""" + with tempfile.NamedTemporaryFile(suffix='.parquet', delete=False) as f: + output_path = f.name + + try: + writer = ParquetShardWriter( + output_path=output_path, + row_group_size=1000, + compression='zstd', + ) + + for i in range(1000): + writer.write_bin( + i, + np.random.randint(0, 50000, size=2000, dtype=np.int32), + np.random.randint(0, 2, size=2000, dtype=np.uint8), + np.array([0, 500, 1000, 1500], dtype=np.int32), + ) + + writer.finalize() + + import os + file_size = os.path.getsize(output_path) + raw_size = 1000 * (2000 * 4 + 2000 * 1 + 4 * 4) # Uncompressed estimate + + compression_ratio = raw_size / file_size + print(f"Compression ratio: {compression_ratio:.1f}x") + assert compression_ratio > 1.5 + + finally: + os.unlink(output_path) +``` diff --git a/pyproject.toml b/pyproject.toml index 6cda1e63f..23572abf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "omegaconf>=2.3.0", "rich>=13.0.0", "textual>=0.70.0", - "ray==2.49.2", + "ray[default]==2.49.2", "fsspec>=2024.0.0", "numpy>=1.24.0", "pyarrow>=14.0.0", diff --git a/src/nemotron/data_prep/__init__.py b/src/nemotron/data_prep/__init__.py index ecb1accb0..96de6e44b 100644 --- a/src/nemotron/data_prep/__init__.py +++ b/src/nemotron/data_prep/__init__.py @@ -212,6 +212,31 @@ class DataPrepConfig: """Max retries for HF downloads before giving up (Xenna path only).""" +def _ensure_driver_hf_home() -> None: + """Ensure HF_HOME is set for the driver process. + + In nemo-run Ray job mode, runtime_env_yaml env_vars apply to Ray workers, + but not to the driver script. This function derives HF_HOME from NEMO_RUN_DIR + (which IS set for the driver) so that: + 1. The driver's HF cache goes to shared storage (e.g., Lustre) + 2. The value propagates to Ray workers via ray.init(runtime_env=...) + 3. Xenna stages get it via _get_hf_runtime_env() -> env_info -> actor_pool + + This prevents "No space left on device" errors from HF downloads filling + local node storage instead of shared Lustre. + """ + if os.environ.get("HF_HOME"): + return # Already set, respect user's explicit setting + + nemo_run_dir = os.environ.get("NEMO_RUN_DIR") + if not nemo_run_dir: + return # Not running via nemo-run, fall back to HF defaults + + # Use same convention as nemo-run's worker-side: /hf + hf_home = str(Path(nemo_run_dir) / "hf") + os.environ["HF_HOME"] = hf_home + + def run_data_prep( config: DataPrepConfig, *, artifact_class: type = PretrainBlendsArtifact ) -> DataBlendsArtifact | PretrainBlendsArtifact: @@ -236,6 +261,12 @@ def run_data_prep( >>> artifact = run_data_prep(config) >>> print(f"Blend path: {artifact.path}") """ + # Ensure HF_HOME is set for the driver process early. + # In nemo-run Ray job mode, runtime_env_yaml env_vars apply to Ray workers only, + # not the driver. We derive HF_HOME from NEMO_RUN_DIR (which IS set for the driver) + # so that all downstream code sees a consistent cache directory on shared storage. + _ensure_driver_hf_home() + # Load data blend specification blend = DataBlend.load(config.blend_path) @@ -264,6 +295,9 @@ def run_data_prep( # Initialize Ray for download tasks (Xenna and Ray executors both use Ray) if config.execution_engine in ("ray", "xenna"): + # Enable uv integration for Ray workers (Ray 2.43+) + # Must be set BEFORE importing ray + os.environ.setdefault("RAY_RUNTIME_ENV_HOOK", "ray._private.runtime_env.uv_runtime_env_hook.hook") import ray if not ray.is_initialized(): @@ -291,7 +325,27 @@ def run_data_prep( if os.environ.get("HF_TOKEN"): runtime_env["env_vars"]["HF_TOKEN"] = os.environ["HF_TOKEN"] - ray.init(address="auto", ignore_reinit_error=True, runtime_env=runtime_env) + # Set environment variables required by cosmos-xenna monitoring + # These must be set BEFORE ray.init() + os.environ.setdefault("RAY_MAX_LIMIT_FROM_API_SERVER", "40000") + os.environ.setdefault("RAY_MAX_LIMIT_FROM_DATA_SOURCE", "40000") + + # Try connecting to existing cluster, fall back to local mode + # include_dashboard=True is required for cosmos-xenna's State API monitoring + try: + ray.init( + address="auto", + ignore_reinit_error=True, + runtime_env=runtime_env, + include_dashboard=True, + ) + except ConnectionError: + # No cluster found - start Ray locally + ray.init( + ignore_reinit_error=True, + runtime_env=runtime_env, + include_dashboard=True, + ) # Build Ray Data config if enabled, auto-detecting cluster resources ray_data_config = None diff --git a/src/nemotron/data_prep/config.py b/src/nemotron/data_prep/config.py index 2e70eed0f..aeb72a174 100644 --- a/src/nemotron/data_prep/config.py +++ b/src/nemotron/data_prep/config.py @@ -194,6 +194,33 @@ def __post_init__(self) -> None: OutputFormat = BinIdxOutputConfig | JsonlOutputConfig | PackedOutputConfig | ChatSftOutputConfig +@dataclass(frozen=True) +class XennaConfig: + """Configuration for Xenna pipeline execution. + + Attributes: + max_concurrent_downloads: Maximum parallel HuggingFace file downloads + max_shard_workers: Maximum workers for shard processing stage. + Each worker uses ~4GB memory. Set based on node memory. + None means auto-scale (cosmos-xenna default). + wandb_log_downloads: Log download progress to wandb + wandb_log_pipeline_stats: Log pipeline stats (actors, queues, progress) to wandb + wandb_download_log_interval_sec: Interval for download progress logging + hf_download_timeout_sec: Timeout for HuggingFace downloads + hf_download_max_retries: Max retries for HuggingFace downloads + pipeline_logging_interval_s: Interval for pipeline stats logging + """ + + max_concurrent_downloads: int = 64 + max_shard_workers: int | None = None + wandb_log_downloads: bool = False + wandb_log_pipeline_stats: bool = False + wandb_download_log_interval_sec: int = 30 + hf_download_timeout_sec: int = 300 + hf_download_max_retries: int = 3 + pipeline_logging_interval_s: int = 30 + + @dataclass(frozen=True) class RayDataConfig: """Configuration for Ray Data shard-task execution. @@ -314,12 +341,30 @@ class PipelineConfig: console_mode: str = "simple" simple_log_interval_sec: int = 30 execution_engine: Literal["ray", "xenna"] = "ray" + xenna: XennaConfig | None = None + # Legacy fields for backward compatibility (prefer xenna.* instead) max_concurrent_downloads: int = 64 wandb_log_downloads: bool = False wandb_download_log_interval_sec: int = 30 hf_download_timeout_sec: int = 300 hf_download_max_retries: int = 3 num_actors: int | None = None + xenna_max_shard_workers: int | None = None # Max workers for xenna shard processing + + def effective_xenna(self) -> XennaConfig: + """Get effective XennaConfig, merging legacy fields if xenna is not set.""" + if self.xenna is not None: + return self.xenna + return XennaConfig( + max_concurrent_downloads=self.max_concurrent_downloads, + max_shard_workers=self.xenna_max_shard_workers, + wandb_log_downloads=self.wandb_log_downloads, + wandb_log_pipeline_stats=False, # New field, no legacy equivalent + wandb_download_log_interval_sec=self.wandb_download_log_interval_sec, + hf_download_timeout_sec=self.hf_download_timeout_sec, + hf_download_max_retries=self.hf_download_max_retries, + pipeline_logging_interval_s=30, # New field, default + ) # ============================================================================ diff --git a/src/nemotron/data_prep/console.py b/src/nemotron/data_prep/console.py index 833f8a7eb..680ac5685 100644 --- a/src/nemotron/data_prep/console.py +++ b/src/nemotron/data_prep/console.py @@ -717,8 +717,19 @@ def start(self) -> None: console.print("\n[bold]Starting data preparation...[/bold]") self._print_simple_status() - def stop(self) -> None: - """Stop the live display.""" + def stop(self, success: bool | None = None) -> None: + """Stop the live display. + + Args: + success: Whether the pipeline completed successfully. If None (default), + auto-detects based on whether all datasets are complete/cached. + """ + # Auto-detect completion if not explicitly provided + if success is None: + done, cached, pending, processing = self._get_summary_counts() + total = len(self.datasets) + success = (done + cached) == total and total > 0 + if self.console_mode == "rich": if self._live: self._live.stop() @@ -728,7 +739,20 @@ def stop(self) -> None: else: # Simple mode: Print final status self._print_simple_status() + + # Print completion message based on actual status + if success: console.print("[bold green]✓ Data preparation complete[/bold green]\n") + else: + total_completed, total_shards = self._get_total_shards_progress() + if total_shards > 0: + pct = total_completed / total_shards * 100 + console.print( + f"[bold yellow]⚠ Data preparation interrupted " + f"({total_completed}/{total_shards} shards, {pct:.1f}%)[/bold yellow]\n" + ) + else: + console.print("[bold yellow]⚠ Data preparation interrupted[/bold yellow]\n") def refresh(self) -> None: """Refresh the live display and cycle pages.""" diff --git a/src/nemotron/data_prep/pipeline.py b/src/nemotron/data_prep/pipeline.py index 793f6c0cf..328b93d73 100644 --- a/src/nemotron/data_prep/pipeline.py +++ b/src/nemotron/data_prep/pipeline.py @@ -861,21 +861,89 @@ def _process_all_shards_parallel( # Dispatch to Xenna executor if requested if execution_engine == "xenna": - from nemotron.data_prep.xenna.runner import run_xenna_pipeline + from dataclasses import asdict - run_xenna_pipeline( - execution_plans=execution_plans, - output_config=output_config, - output_root=output_root, - fs=fs, - live_status=live_status, - results=results, + from nemotron.data_prep.config import XennaConfig + from nemotron.data_prep.xenna.executor import run_xenna + from nemotron.data_prep.xenna.pipeline_specs import build_pretrain_pipeline_spec + from nemotron.data_prep.xenna.work_items import ShardWorkItem + + # Build XennaConfig from individual parameters (legacy compatibility) + xenna_cfg = XennaConfig( max_concurrent_downloads=max_concurrent_downloads, wandb_log_downloads=wandb_log_downloads, + wandb_log_pipeline_stats=True, # Enable pipeline stats logging wandb_download_log_interval_sec=wandb_download_log_interval_sec, hf_download_timeout_sec=hf_download_timeout_sec, hf_download_max_retries=hf_download_max_retries, ) + + # Get resolved tokenizer from first plan (should be uniform) + resolved_tokenizer = execution_plans[0].plan.resolved_tokenizer + + # Build work items + tasks: list[ShardWorkItem] = [] + dataset_receipt_dirs: dict[str, str] = {} + + for ep in execution_plans: + live_status.start_dataset(ep.name) + live_status.report_phase(ep.name, "processing", "xenna") + dataset_receipt_dirs[ep.name] = ep.receipts_dir + + assignment_dicts = {} + for a in ep.plan.file_assignments: + assignment_dicts[a.shard_index] = { + "shard_index": a.shard_index, + "files": [asdict(f) for f in a.files], + "total_bytes": a.total_bytes, + } + + for shard_idx in ep.pending_indices: + tasks.append( + ShardWorkItem( + dataset_name=ep.name, + plan_hash=ep.plan.plan_hash, + shard_index=shard_idx, + assignment=assignment_dicts[shard_idx], + output_dir=ep.dataset_dir, + receipts_dir=ep.receipts_dir, + text_field=ep.config.text_field, + dtype=output_config.dtype, + min_doc_chars=output_config.min_doc_chars, + max_doc_tokens=output_config.max_doc_tokens, + max_rows=output_config.max_rows, + ) + ) + + if tasks: + # Build pipeline spec + pipeline_spec = build_pretrain_pipeline_spec( + tasks=tasks, + resolved_tokenizer=resolved_tokenizer, + output_root=output_root, + xenna_cfg=xenna_cfg, + ) + + # Run pipeline + run_xenna( + pipeline_spec=pipeline_spec, + dataset_receipt_dirs=dataset_receipt_dirs, + output_root=output_root, + fs=fs, + live_status=live_status, + xenna_cfg=xenna_cfg, + ) + + # Aggregate results + for ep in execution_plans: + results[ep.name] = _aggregate_stats_from_receipts(ep.receipts_dir, ep.plan, fs) + live_status.report_metrics( + ep.name, + rows=results[ep.name].get("total_sequences", 0), + tokens=results[ep.name].get("total_tokens", 0), + ) + live_status.complete_dataset(ep.name) + return # Dispatch to Ray Data executor if enabled @@ -1587,10 +1655,14 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe if has_work and config.execution_engine == "xenna": from dataclasses import asdict - from nemotron.data_prep.xenna.runner import run_xenna_jsonl_pipeline + from nemotron.data_prep.xenna.executor import run_xenna + from nemotron.data_prep.xenna.pipeline_specs import build_jsonl_pipeline_spec from nemotron.data_prep.xenna.work_items import JsonlShardWorkItem + xenna_cfg = config.effective_xenna() + tasks: list[JsonlShardWorkItem] = [] + dataset_receipt_dirs: dict[str, str] = {} dataset_infos: list[dict] = [] live_status = con.create_live_status( @@ -1628,6 +1700,7 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe ) ) + dataset_receipt_dirs[dataset.name] = receipts_dir dataset_infos.append( { "name": dataset.name, @@ -1641,29 +1714,38 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe live_status.start_dataset(info["name"]) if tasks: - run_xenna_jsonl_pipeline( + # Build pipeline spec + pipeline_spec = build_jsonl_pipeline_spec( tasks=tasks, - dataset_infos=dataset_infos, output_root=str(config.output.dir), - fs=fs, - live_status=live_status, - results=results, text_field=dataset_plans[0][0].text_field if dataset_plans else "text", transform=format_config.transform, compression=format_config.compression, max_rows=config.output.max_rows, resolve_hf_placeholders=format_config.resolve_hf_placeholders, - max_concurrent_downloads=config.max_concurrent_downloads, - wandb_log_downloads=config.wandb_log_downloads, - wandb_download_log_interval_sec=config.wandb_download_log_interval_sec, - hf_download_timeout_sec=config.hf_download_timeout_sec, - hf_download_max_retries=config.hf_download_max_retries, + xenna_cfg=xenna_cfg, ) - else: - for info in dataset_infos: - stats = _aggregate_jsonl_stats(info["dataset_dir"], num_shards, fs) - results[info["name"]] = stats - live_status.complete_dataset(info["name"]) + + # Run pipeline + run_xenna( + pipeline_spec=pipeline_spec, + dataset_receipt_dirs=dataset_receipt_dirs, + output_root=str(config.output.dir), + fs=fs, + live_status=live_status, + xenna_cfg=xenna_cfg, + ) + + # Aggregate results + for info in dataset_infos: + stats = _aggregate_jsonl_stats(info["dataset_dir"], num_shards, fs) + results[info["name"]] = stats + live_status.report_metrics( + info["name"], + rows=stats.get("num_records", 0), + tokens=0, + ) + live_status.complete_dataset(info["name"]) for dataset, dataset_dir, _, _, _, _ in dataset_plans: weight = dataset.weight @@ -2369,10 +2451,14 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin if config.execution_engine == "xenna": from dataclasses import asdict - from nemotron.data_prep.xenna.runner import run_xenna_chat_sft_pipeline + from nemotron.data_prep.xenna.executor import run_xenna + from nemotron.data_prep.xenna.pipeline_specs import build_chat_sft_pipeline_spec from nemotron.data_prep.xenna.work_items import ChatSftShardWorkItem + xenna_cfg = config.effective_xenna() + tasks: list[ChatSftShardWorkItem] = [] + dataset_receipt_dirs: dict[str, str] = {} dataset_infos: list[dict] = [] live_status = con.create_live_status( @@ -2407,6 +2493,7 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin ) ) + dataset_receipt_dirs[dataset.name] = receipts_dir dataset_infos.append( { "name": dataset.name, @@ -2419,13 +2506,10 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin live_status.start_dataset(info["name"]) if tasks: - run_xenna_chat_sft_pipeline( + # Build pipeline spec + pipeline_spec = build_chat_sft_pipeline_spec( tasks=tasks, - dataset_infos=dataset_infos, output_root=str(config.output.dir), - fs=fs, - live_status=live_status, - results=results, resolved_tokenizer=resolved_tokenizer, messages_field=format_config.messages_field, tools_field=format_config.tools_field, @@ -2438,19 +2522,31 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin seed=42, used_in_filter=format_config.used_in_filter, used_in_field=format_config.used_in_field, - max_concurrent_downloads=config.max_concurrent_downloads, - wandb_log_downloads=config.wandb_log_downloads, - wandb_download_log_interval_sec=config.wandb_download_log_interval_sec, - hf_download_timeout_sec=config.hf_download_timeout_sec, - hf_download_max_retries=config.hf_download_max_retries, + xenna_cfg=xenna_cfg, ) - else: - for info in dataset_infos: - stats = _aggregate_packed_stats( - info["dataset_dir"], info["receipts_dir"], fs - ) - results[info["name"]] = stats - live_status.complete_dataset(info["name"]) + + # Run pipeline + run_xenna( + pipeline_spec=pipeline_spec, + dataset_receipt_dirs=dataset_receipt_dirs, + output_root=str(config.output.dir), + fs=fs, + live_status=live_status, + xenna_cfg=xenna_cfg, + ) + + # Aggregate results + for info in dataset_infos: + stats = _aggregate_packed_stats( + info["dataset_dir"], info["receipts_dir"], fs + ) + results[info["name"]] = stats + live_status.report_metrics( + info["name"], + rows=stats.get("num_sequences", 0), + tokens=stats.get("total_tokens", 0), + ) + live_status.complete_dataset(info["name"]) for dataset, dataset_dir, _, _, _ in dataset_plans: weight = dataset.weight diff --git a/src/nemotron/data_prep/xenna/__init__.py b/src/nemotron/data_prep/xenna/__init__.py index 73ba56ce8..46bea0e12 100644 --- a/src/nemotron/data_prep/xenna/__init__.py +++ b/src/nemotron/data_prep/xenna/__init__.py @@ -12,15 +12,59 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Xenna integration for Nemotron data prep.""" +"""Xenna integration for Nemotron data prep. -from nemotron.data_prep.xenna.runner import run_xenna_pipeline -from nemotron.data_prep.xenna.stages import HfPredownloadStage, PretrainShardStage -from nemotron.data_prep.xenna.work_items import ShardWorkItem +Architecture: +- executor.py: run_xenna() - format-agnostic pipeline executor +- pipeline_specs.py: build_*_pipeline_spec() - PipelineSpec factories +- observability.py: wandb callback and polling helpers +- stages.py: Xenna Stage implementations +- work_items.py: Work item dataclasses +""" + +from nemotron.data_prep.xenna.executor import run_xenna +from nemotron.data_prep.xenna.observability import ( + start_download_poller, + start_receipt_poller, +) +from nemotron.data_prep.xenna.pipeline_specs import ( + build_chat_sft_pipeline_spec, + build_jsonl_pipeline_spec, + build_pretrain_pipeline_spec, +) + +# Stage and work item exports +from nemotron.data_prep.xenna.stages import ( + ChatSftCentralPackStage, + ChatSftSpoolStage, + HfPredownloadStage, + JsonlShardStage, + PretrainShardStage, +) +from nemotron.data_prep.xenna.work_items import ( + ChatSftShardWorkItem, + ChatSftSpoolWorkItem, + JsonlShardWorkItem, + ShardWorkItem, +) __all__ = [ + # Core + "run_xenna", + "build_pretrain_pipeline_spec", + "build_jsonl_pipeline_spec", + "build_chat_sft_pipeline_spec", + "start_receipt_poller", + "start_download_poller", + # Stages "HfPredownloadStage", "PretrainShardStage", + "JsonlShardStage", + "ChatSftSpoolStage", + "ChatSftCentralPackStage", + # Work items "ShardWorkItem", - "run_xenna_pipeline", + "JsonlShardWorkItem", + "ChatSftShardWorkItem", + "ChatSftSpoolWorkItem", ] diff --git a/src/nemotron/data_prep/xenna/executor.py b/src/nemotron/data_prep/xenna/executor.py new file mode 100644 index 000000000..27460ba29 --- /dev/null +++ b/src/nemotron/data_prep/xenna/executor.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Xenna pipeline executor - format-agnostic. + +This module provides a single run_xenna() function that executes any Xenna pipeline +with observability (receipt polling, download logging). It is format-agnostic - the +caller is responsible for building the PipelineSpec and aggregating results. +""" + +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING + +import cosmos_xenna.pipelines.v1 as pipelines_v1 + +from nemotron.data_prep.xenna.observability import ( + start_download_poller, + start_receipt_poller, +) + +if TYPE_CHECKING: + from nemotron.data_prep.config import XennaConfig + + +def run_xenna( + *, + pipeline_spec: pipelines_v1.PipelineSpec, + dataset_receipt_dirs: dict[str, str], + output_root: str, + fs, + live_status, + xenna_cfg: "XennaConfig", +) -> None: + """Run a Xenna pipeline with observability. + + This is a format-agnostic executor. The caller is responsible for: + 1. Building the PipelineSpec (via pipeline_specs.build_*()) + 2. Aggregating results after this function returns + + Args: + pipeline_spec: Pre-built PipelineSpec from pipeline_specs module. + dataset_receipt_dirs: Mapping of dataset_name -> receipts_dir for progress tracking + output_root: Root output directory (used for download progress tracking) + fs: Filesystem abstraction (fsspec-compatible) + live_status: LiveExecutionStatus for progress UI updates + xenna_cfg: Xenna configuration (controls wandb logging) + + Note: + This function does NOT call live_status.start_dataset() or live_status.complete_dataset(). + The caller should handle those for consistent semantics across execution paths. + """ + if not pipeline_spec.input_data: + return + + total_tasks = len(pipeline_spec.input_data) + print(f"[Xenna] Launching pipeline for {total_tasks} task(s)") + + stop_event = threading.Event() + + # Start receipt poller for progress tracking and wandb logging + receipt_thread = start_receipt_poller( + fs=fs, + dataset_receipt_dirs=dataset_receipt_dirs, + live_status=live_status, + stop_event=stop_event, + log_to_wandb=xenna_cfg.wandb_log_pipeline_stats, + total_shards=total_tasks, + ) + + # Optionally start download progress poller + download_thread = None + if xenna_cfg.wandb_log_downloads: + download_thread = start_download_poller( + fs=fs, + output_root=output_root, + stop_event=stop_event, + interval_sec=xenna_cfg.wandb_download_log_interval_sec, + ) + + try: + pipelines_v1.run_pipeline(pipeline_spec) + finally: + stop_event.set() + receipt_thread.join(timeout=2.0) + if download_thread: + download_thread.join(timeout=2.0) diff --git a/src/nemotron/data_prep/xenna/observability.py b/src/nemotron/data_prep/xenna/observability.py new file mode 100644 index 000000000..40be59472 --- /dev/null +++ b/src/nemotron/data_prep/xenna/observability.py @@ -0,0 +1,207 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Observability helpers for Xenna pipelines: wandb logging, receipt/download polling.""" + +from __future__ import annotations + +import json +import threading +import time + + +def start_receipt_poller( + *, + fs, + dataset_receipt_dirs: dict[str, str], + live_status, + stop_event: threading.Event, + interval_sec: float = 10.0, + log_to_wandb: bool = True, + total_shards: int | None = None, +) -> threading.Thread: + """Start a daemon thread that polls receipt directories for progress updates. + + Args: + fs: Filesystem abstraction (fsspec-compatible) + dataset_receipt_dirs: Mapping of dataset_name -> receipts_dir path + live_status: LiveExecutionStatus for progress UI updates + stop_event: Event to signal thread termination + interval_sec: Polling interval in seconds + log_to_wandb: Whether to log metrics to wandb + total_shards: Total number of shards across all datasets (for progress %) + + Returns: + The started daemon thread + """ + + def _poll() -> None: + # Check wandb availability once at start + wandb = None + if log_to_wandb: + try: + import wandb as _wandb + + if _wandb.run is not None: + wandb = _wandb + except ImportError: + pass + + last_counts: dict[str, int] = {name: 0 for name in dataset_receipt_dirs} + seen_receipts: dict[str, set[str]] = {name: set() for name in dataset_receipt_dirs} + tokens_by_dataset: dict[str, int] = {name: 0 for name in dataset_receipt_dirs} + shards_by_dataset: dict[str, int] = {name: 0 for name in dataset_receipt_dirs} + + while not stop_event.is_set(): + metrics_changed = False + + for name, receipts_dir in dataset_receipt_dirs.items(): + try: + if not fs.exists(receipts_dir): + continue + entries = [p for p in fs.ls(receipts_dir, detail=False) if str(p).endswith(".json")] + count = len(entries) + except Exception: + continue + + last = last_counts.get(name, 0) + if count > last: + metrics_changed = True + # Process newly seen receipts + for p in entries: + if p not in seen_receipts[name]: + seen_receipts[name].add(p) + try: + with fs.open(p, "r") as f: + receipt = json.load(f) + tokens_by_dataset[name] += int( + receipt.get("stats", {}).get("total_tokens", 0) + ) + except Exception: + pass + + # Advance progress by delta + for _ in range(count - last): + live_status.advance_dataset(name) + last_counts[name] = count + shards_by_dataset[name] = count + live_status.report_tokens(name, tokens_by_dataset[name]) + + # Log to wandb if metrics changed + if wandb is not None and metrics_changed: + total_shards_completed = sum(shards_by_dataset.values()) + total_tokens = sum(tokens_by_dataset.values()) + + metrics = { + "data_prep/shards_completed": total_shards_completed, + "data_prep/tokens_total": total_tokens, + } + + # Add progress percentage if we know the total + if total_shards is not None and total_shards > 0: + metrics["data_prep/progress"] = total_shards_completed / total_shards + + # Per-dataset metrics + for name in dataset_receipt_dirs: + safe_name = name.replace("-", "_").replace(" ", "_") + metrics[f"data_prep/datasets/{safe_name}/shards"] = shards_by_dataset[name] + metrics[f"data_prep/datasets/{safe_name}/tokens"] = tokens_by_dataset[name] + + wandb.log(metrics) + + stop_event.wait(interval_sec) + + thread = threading.Thread(target=_poll, daemon=True) + thread.start() + return thread + + +def start_download_poller( + *, + fs, + output_root: str, + stop_event: threading.Event, + interval_sec: int = 30, +) -> threading.Thread: + """Start a daemon thread that polls download progress and logs to wandb. + + Args: + fs: Filesystem abstraction (fsspec-compatible) + output_root: Root output directory (download progress is at {output_root}/.xenna/downloads/) + stop_event: Event to signal thread termination + interval_sec: Polling interval in seconds + + Returns: + The started daemon thread + """ + + def _poll() -> None: + try: + import wandb + except ImportError: + return + + if wandb.run is None: + return + + progress_dir = f"{output_root.rstrip('/')}/.xenna/downloads" + last_logged = 0.0 + + while not stop_event.is_set(): + now = time.time() + if now - last_logged < interval_sec: + stop_event.wait(1.0) + continue + last_logged = now + + try: + if not fs.exists(progress_dir): + continue + entries = fs.ls(progress_dir, detail=False) + except Exception: + continue + + total_completed = 0 + total_files = 0 + max_elapsed = 0.0 + max_rate = 0.0 + + for path in entries: + if not str(path).endswith(".json"): + continue + try: + with fs.open(path, "r") as f: + data = json.load(f) + except Exception: + continue + total_completed += int(data.get("completed", 0)) + total_files += int(data.get("total", 0)) + max_elapsed = max(max_elapsed, float(data.get("elapsed_sec", 0.0))) + max_rate = max(max_rate, float(data.get("rate", 0.0))) + + if total_files == 0: + continue + + wandb.log( + { + "data_prep/hf_download_completed": total_completed, + "data_prep/hf_download_total": total_files, + "data_prep/hf_download_rate": max_rate, + "data_prep/hf_download_elapsed_sec": max_elapsed, + } + ) + + thread = threading.Thread(target=_poll, daemon=True) + thread.start() + return thread diff --git a/src/nemotron/data_prep/xenna/pipeline_specs.py b/src/nemotron/data_prep/xenna/pipeline_specs.py new file mode 100644 index 000000000..035048c79 --- /dev/null +++ b/src/nemotron/data_prep/xenna/pipeline_specs.py @@ -0,0 +1,236 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PipelineSpec factories for Xenna pipeline types. + +This module provides factory functions to build cosmos_xenna PipelineSpec objects +for each pipeline type (pretrain, jsonl, chat_sft). The factories are pure functions +that do not run pipelines or handle observability - that's the executor's job. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import cosmos_xenna.pipelines.v1 as pipelines_v1 + +from nemotron.data_prep.xenna.stages import ( + ChatSftCentralPackStage, + ChatSftSpoolStage, + HfPredownloadStage, + JsonlShardStage, + PretrainShardStage, +) +from nemotron.data_prep.xenna.work_items import ( + ChatSftShardWorkItem, + JsonlShardWorkItem, + ShardWorkItem, +) + +if TYPE_CHECKING: + from nemotron.data_prep.config import XennaConfig + + +def _make_pipeline_config(*, logging_interval_s: float) -> pipelines_v1.PipelineConfig: + """Create standard PipelineConfig for Xenna pipelines.""" + return pipelines_v1.PipelineConfig( + execution_mode=pipelines_v1.ExecutionMode.STREAMING, + return_last_stage_outputs=False, + logging_interval_s=logging_interval_s, + monitoring_verbosity_level=pipelines_v1.VerbosityLevel.INFO, + ) + + +def build_pretrain_pipeline_spec( + *, + tasks: Sequence[ShardWorkItem], + resolved_tokenizer: dict, + output_root: str, + xenna_cfg: "XennaConfig", +) -> pipelines_v1.PipelineSpec: + """Build PipelineSpec for pretrain binidx processing. + + Args: + tasks: List of ShardWorkItem to process + resolved_tokenizer: Resolved tokenizer configuration dict + output_root: Root output directory + xenna_cfg: Xenna configuration + + Returns: + PipelineSpec ready to be executed by run_pipeline() + """ + return pipelines_v1.PipelineSpec( + input_data=list(tasks), + stages=[ + pipelines_v1.StageSpec( + HfPredownloadStage( + max_concurrent_downloads=xenna_cfg.max_concurrent_downloads, + output_root=output_root, + download_timeout_sec=xenna_cfg.hf_download_timeout_sec, + max_retries=xenna_cfg.hf_download_max_retries, + ), + num_workers_per_node=1, + ), + pipelines_v1.StageSpec( + PretrainShardStage( + resolved_tokenizer=resolved_tokenizer, + output_root=output_root, + ), + # Limit workers to prevent OOM (each worker ~4GB) + **({"num_workers": xenna_cfg.max_shard_workers} if xenna_cfg.max_shard_workers else {}), + ), + ], + config=_make_pipeline_config(logging_interval_s=xenna_cfg.pipeline_logging_interval_s), + ) + + +def build_jsonl_pipeline_spec( + *, + tasks: Sequence[JsonlShardWorkItem], + output_root: str, + text_field: str, + transform, + compression: str, + max_rows: int | None, + resolve_hf_placeholders: bool, + xenna_cfg: "XennaConfig", +) -> pipelines_v1.PipelineSpec: + """Build PipelineSpec for JSONL processing. + + Args: + tasks: List of JsonlShardWorkItem to process + output_root: Root output directory + text_field: Field name containing text in input records + transform: Optional transform function for records + compression: Output compression ("none" or "zstd") + max_rows: Maximum rows per shard (None for unlimited) + resolve_hf_placeholders: Whether to resolve HuggingFace placeholders + xenna_cfg: Xenna configuration + + Returns: + PipelineSpec ready to be executed by run_pipeline() + """ + return pipelines_v1.PipelineSpec( + input_data=list(tasks), + stages=[ + pipelines_v1.StageSpec( + HfPredownloadStage( + max_concurrent_downloads=xenna_cfg.max_concurrent_downloads, + output_root=output_root, + download_timeout_sec=xenna_cfg.hf_download_timeout_sec, + max_retries=xenna_cfg.hf_download_max_retries, + ), + num_workers_per_node=1, + ), + pipelines_v1.StageSpec( + JsonlShardStage( + output_root=output_root, + text_field=text_field, + transform=transform, + compression=compression, + max_rows=max_rows, + resolve_hf_placeholders=resolve_hf_placeholders, + ), + # Limit workers to prevent OOM on large datasets + num_workers=4, + ), + ], + config=_make_pipeline_config(logging_interval_s=xenna_cfg.pipeline_logging_interval_s), + ) + + +def build_chat_sft_pipeline_spec( + *, + tasks: Sequence[ChatSftShardWorkItem], + output_root: str, + resolved_tokenizer: dict, + messages_field: str, + tools_field: str, + pack_size: int, + algorithm: str, + dtype: str, + chat_template: str | None, + max_doc_tokens: int | None, + max_rows: int | None, + seed: int | None, + used_in_filter: str | None, + used_in_field: str, + xenna_cfg: "XennaConfig", +) -> pipelines_v1.PipelineSpec: + """Build PipelineSpec for Chat SFT processing. + + Args: + tasks: List of ChatSftShardWorkItem to process + output_root: Root output directory + resolved_tokenizer: Resolved tokenizer configuration dict + messages_field: Field name for messages in input records + tools_field: Field name for tools in input records + pack_size: Maximum tokens per packed sequence + algorithm: Packing algorithm + dtype: Token dtype + chat_template: Chat template (name, path, or inline) + max_doc_tokens: Maximum tokens per document + max_rows: Maximum rows per shard + seed: Random seed for packing + used_in_filter: Filter for used_in field + used_in_field: Field name for used_in filtering + xenna_cfg: Xenna configuration + + Returns: + PipelineSpec ready to be executed by run_pipeline() + """ + return pipelines_v1.PipelineSpec( + input_data=list(tasks), + stages=[ + pipelines_v1.StageSpec( + HfPredownloadStage( + max_concurrent_downloads=xenna_cfg.max_concurrent_downloads, + output_root=output_root, + download_timeout_sec=xenna_cfg.hf_download_timeout_sec, + max_retries=xenna_cfg.hf_download_max_retries, + ), + num_workers_per_node=1, + ), + pipelines_v1.StageSpec( + ChatSftSpoolStage( + resolved_tokenizer=resolved_tokenizer, + output_root=output_root, + messages_field=messages_field, + tools_field=tools_field, + pack_size=pack_size, + algorithm=algorithm, + dtype=dtype, + chat_template=chat_template, + max_doc_tokens=max_doc_tokens, + max_rows=max_rows, + seed=seed, + used_in_filter=used_in_filter, + used_in_field=used_in_field, + ), + # Let Xenna auto-scale - spool stage is memory-efficient + ), + pipelines_v1.StageSpec( + ChatSftCentralPackStage( + output_root=output_root, + pack_size=pack_size, + algorithm=algorithm, + dtype=dtype, + seed=seed, + ), + num_workers=1, # Must be single-worker for centralized packing + ), + ], + config=_make_pipeline_config(logging_interval_s=xenna_cfg.pipeline_logging_interval_s), + ) diff --git a/src/nemotron/data_prep/xenna/runner.py b/src/nemotron/data_prep/xenna/runner.py deleted file mode 100644 index aafdb045f..000000000 --- a/src/nemotron/data_prep/xenna/runner.py +++ /dev/null @@ -1,662 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Xenna pipeline runner for Nemotron data prep.""" - -from __future__ import annotations - -from dataclasses import asdict -import gc -import json -import threading -import time -from typing import TYPE_CHECKING - -import cosmos_xenna.pipelines.v1 as pipelines_v1 - - -def _log_memory_status(label: str) -> None: - """Log memory usage for debugging OOM issues.""" - try: - import psutil - process = psutil.Process() - rss_gb = process.memory_info().rss / (1024**3) - print(f"[Memory] {label}: RSS={rss_gb:.2f} GB") - except ImportError: - print(f"[Memory] {label}: psutil not available") - - try: - import ray - if ray.is_initialized(): - resources = ray.available_resources() - obj_store = resources.get("object_store_memory", 0) / (1024**3) - print(f"[Ray] {label}: object_store_available={obj_store:.2f} GB") - except Exception as e: - print(f"[Ray] {label}: error getting status - {e}") - -if TYPE_CHECKING: - from cosmos_xenna.pipelines.private.monitoring_types import PipelineStats - -from nemotron.data_prep.xenna.stages import ( - ChatSftCentralPackStage, - ChatSftSpoolStage, - HfPredownloadStage, - JsonlShardStage, - PretrainShardStage, -) -from nemotron.data_prep.xenna.work_items import ( - ChatSftShardWorkItem, - JsonlShardWorkItem, - ShardWorkItem, -) - - -def run_xenna_pipeline( - *, - execution_plans: list, - output_config, - output_root: str, - fs, - live_status, - results: dict, - max_concurrent_downloads: int = 64, - wandb_log_downloads: bool = False, - wandb_log_pipeline_stats: bool = False, - wandb_download_log_interval_sec: int = 30, - hf_download_timeout_sec: int = 300, - hf_download_max_retries: int = 3, -) -> None: - """Run shard processing via Xenna pipeline.""" - if not execution_plans: - return - - resolved_tokenizer = execution_plans[0].plan.resolved_tokenizer - for ep in execution_plans[1:]: - if ep.plan.resolved_tokenizer != resolved_tokenizer: - raise ValueError( - f"Tokenizer mismatch: dataset '{ep.name}' uses different tokenizer. " - "Xenna executor requires uniform tokenizer across datasets in v1." - ) - - tasks: list[ShardWorkItem] = [] - for ep in execution_plans: - live_status.start_dataset(ep.name) - live_status.report_phase(ep.name, "processing", "xenna") - - assignment_dicts = {} - for a in ep.plan.file_assignments: - assignment_dicts[a.shard_index] = { - "shard_index": a.shard_index, - "files": [asdict(f) for f in a.files], - "total_bytes": a.total_bytes, - } - - for shard_idx in ep.pending_indices: - tasks.append( - ShardWorkItem( - dataset_name=ep.name, - plan_hash=ep.plan.plan_hash, - shard_index=shard_idx, - assignment=assignment_dicts[shard_idx], - output_dir=ep.dataset_dir, - receipts_dir=ep.receipts_dir, - text_field=ep.config.text_field, - dtype=output_config.dtype, - min_doc_chars=output_config.min_doc_chars, - max_doc_tokens=output_config.max_doc_tokens, - max_rows=output_config.max_rows, - ) - ) - - if not tasks: - return - - print(f"[Xenna] Launching pipeline for {len(tasks)} shard(s)") - - pipeline_spec = pipelines_v1.PipelineSpec( - input_data=tasks, - stages=[ - pipelines_v1.StageSpec( - HfPredownloadStage( - max_concurrent_downloads=max_concurrent_downloads, - output_root=output_root, - download_timeout_sec=hf_download_timeout_sec, - max_retries=hf_download_max_retries, - ), - num_workers_per_node=1, - ), - pipelines_v1.StageSpec( - PretrainShardStage( - resolved_tokenizer=resolved_tokenizer, - output_root=output_root, - ), - ) - ], - config=pipelines_v1.PipelineConfig( - execution_mode=pipelines_v1.ExecutionMode.STREAMING, - return_last_stage_outputs=False, - ), - ) - - dataset_pending_counts = {ep.name: len(ep.pending_indices) for ep in execution_plans} - - stop_event = threading.Event() - - def _poll_receipts() -> None: - last_counts: dict[str, int] = {ep.name: 0 for ep in execution_plans} - seen_receipts: dict[str, set[str]] = {ep.name: set() for ep in execution_plans} - tokens_by_dataset: dict[str, int] = {ep.name: 0 for ep in execution_plans} - while not stop_event.is_set(): - for ep in execution_plans: - try: - if not fs.exists(ep.receipts_dir): - continue - entries = [p for p in fs.ls(ep.receipts_dir, detail=False) if str(p).endswith(".json")] - count = len(entries) - except Exception: - continue - - last = last_counts.get(ep.name, 0) - if count > last: - new_entries = [] - for p in entries: - if p not in seen_receipts[ep.name]: - seen_receipts[ep.name].add(p) - new_entries.append(p) - - for receipt_path in new_entries: - try: - with fs.open(receipt_path, "r") as f: - receipt = json.load(f) - tokens_by_dataset[ep.name] += _extract_tokens(receipt) - except Exception: - pass - - for _ in range(count - last): - live_status.advance_dataset(ep.name) - last_counts[ep.name] = count - live_status.report_tokens(ep.name, tokens_by_dataset[ep.name]) - stop_event.wait(10.0) - - poll_thread = threading.Thread(target=_poll_receipts, daemon=True) - poll_thread.start() - - wandb_thread = None - if wandb_log_downloads: - wandb_thread = threading.Thread( - target=_poll_download_stats, - args=(fs, output_root, stop_event, wandb_download_log_interval_sec), - daemon=True, - ) - wandb_thread.start() - - try: - pipelines_v1.run_pipeline(pipeline_spec) - finally: - stop_event.set() - poll_thread.join(timeout=2.0) - if wandb_thread is not None: - wandb_thread.join(timeout=2.0) - - for ep in execution_plans: - results[ep.name] = _aggregate_stats_from_receipts(ep.receipts_dir, ep.plan, fs) - live_status.report_metrics( - ep.name, - rows=results[ep.name].get("total_sequences", 0), - tokens=results[ep.name].get("total_tokens", 0), - ) - live_status.complete_dataset(ep.name) - - -def _aggregate_stats_from_receipts(receipts_dir: str, plan, fs) -> dict: - """Import-free wrapper; actual implementation lives in pipeline.py.""" - from nemotron.data_prep.pipeline import _aggregate_stats_from_receipts as _agg - - return _agg(receipts_dir, plan, fs) - - -def run_xenna_jsonl_pipeline( - *, - tasks: list[JsonlShardWorkItem], - dataset_infos: list[dict], - output_root: str, - fs, - live_status, - results: dict, - text_field: str, - transform, - compression: str, - max_rows: int | None, - resolve_hf_placeholders: bool, - max_concurrent_downloads: int = 64, - wandb_log_downloads: bool = False, - wandb_log_pipeline_stats: bool = False, - wandb_download_log_interval_sec: int = 30, - hf_download_timeout_sec: int = 300, - hf_download_max_retries: int = 3, -) -> None: - if not tasks: - return - - pipeline_spec = pipelines_v1.PipelineSpec( - input_data=tasks, - stages=[ - pipelines_v1.StageSpec( - HfPredownloadStage( - max_concurrent_downloads=max_concurrent_downloads, - output_root=output_root, - download_timeout_sec=hf_download_timeout_sec, - max_retries=hf_download_max_retries, - ), - num_workers_per_node=1, - ), - pipelines_v1.StageSpec( - JsonlShardStage( - output_root=output_root, - text_field=text_field, - transform=transform, - compression=compression, - max_rows=max_rows, - resolve_hf_placeholders=resolve_hf_placeholders, - ), - # Limit workers to prevent OOM on large datasets - num_workers=4, - ), - ], - config=pipelines_v1.PipelineConfig( - execution_mode=pipelines_v1.ExecutionMode.STREAMING, - return_last_stage_outputs=False, - ), - ) - - stop_event = threading.Event() - - def _poll_receipts() -> None: - last_counts: dict[str, int] = {info["name"]: 0 for info in dataset_infos} - seen_receipts: dict[str, set[str]] = {info["name"]: set() for info in dataset_infos} - tokens_by_dataset: dict[str, int] = {info["name"]: 0 for info in dataset_infos} - while not stop_event.is_set(): - for info in dataset_infos: - name = info["name"] - receipts_dir = info["receipts_dir"] - try: - if not fs.exists(receipts_dir): - continue - entries = [p for p in fs.ls(receipts_dir, detail=False) if str(p).endswith(".json")] - count = len(entries) - except Exception: - continue - - last = last_counts.get(name, 0) - if count > last: - new_entries = [] - for p in entries: - if p not in seen_receipts[name]: - seen_receipts[name].add(p) - new_entries.append(p) - - for receipt_path in new_entries: - try: - with fs.open(receipt_path, "r") as f: - receipt = json.load(f) - tokens_by_dataset[name] += _extract_tokens(receipt) - except Exception: - pass - - for _ in range(count - last): - live_status.advance_dataset(name) - last_counts[name] = count - live_status.report_tokens(name, tokens_by_dataset[name]) - stop_event.wait(10.0) - - poll_thread = threading.Thread(target=_poll_receipts, daemon=True) - poll_thread.start() - - wandb_thread = None - if wandb_log_downloads: - wandb_thread = threading.Thread( - target=_poll_download_stats, - args=(fs, output_root, stop_event, wandb_download_log_interval_sec), - daemon=True, - ) - wandb_thread.start() - - try: - pipelines_v1.run_pipeline(pipeline_spec) - finally: - stop_event.set() - poll_thread.join(timeout=2.0) - if wandb_thread is not None: - wandb_thread.join(timeout=2.0) - - for info in dataset_infos: - name = info["name"] - stats = _aggregate_jsonl_stats_from_receipts( - dataset_dir=info["dataset_dir"], - num_shards=info["num_shards"], - fs=fs, - ) - results[name] = stats - live_status.report_metrics( - name, - rows=stats.get("num_records", 0), - tokens=stats.get("total_tokens", 0), - ) - live_status.complete_dataset(name) - - -def run_xenna_chat_sft_pipeline( - *, - tasks: list[ChatSftShardWorkItem], - dataset_infos: list[dict], - output_root: str, - fs, - live_status, - results: dict, - resolved_tokenizer: dict, - messages_field: str, - tools_field: str, - pack_size: int, - algorithm: str, - dtype: str, - chat_template: str | None, - max_doc_tokens: int | None, - max_rows: int | None, - seed: int | None, - used_in_filter: str | None, - used_in_field: str, - max_concurrent_downloads: int = 64, - wandb_log_downloads: bool = False, - wandb_log_pipeline_stats: bool = False, - wandb_download_log_interval_sec: int = 30, - hf_download_timeout_sec: int = 300, - hf_download_max_retries: int = 3, -) -> None: - if not tasks: - return - - stages: list[pipelines_v1.StageSpec] = [ - pipelines_v1.StageSpec( - HfPredownloadStage( - max_concurrent_downloads=max_concurrent_downloads, - output_root=output_root, - download_timeout_sec=hf_download_timeout_sec, - max_retries=hf_download_max_retries, - ), - num_workers_per_node=1, - ), - pipelines_v1.StageSpec( - ChatSftSpoolStage( - resolved_tokenizer=resolved_tokenizer, - output_root=output_root, - messages_field=messages_field, - tools_field=tools_field, - pack_size=pack_size, - algorithm=algorithm, - dtype=dtype, - chat_template=chat_template, - max_doc_tokens=max_doc_tokens, - max_rows=max_rows, - seed=seed, - used_in_filter=used_in_filter, - used_in_field=used_in_field, - ), - # Let Xenna auto-scale - spool stage is memory-efficient - ), - pipelines_v1.StageSpec( - ChatSftCentralPackStage( - output_root=output_root, - pack_size=pack_size, - algorithm=algorithm, - dtype=dtype, - seed=seed, - ), - num_workers=1, # Must be single-worker for centralized packing - ), - ] - - pipeline_spec = pipelines_v1.PipelineSpec( - input_data=tasks, - stages=stages, - config=pipelines_v1.PipelineConfig( - execution_mode=pipelines_v1.ExecutionMode.STREAMING, - return_last_stage_outputs=False, - ), - ) - - stop_event = threading.Event() - - def _poll_receipts() -> None: - last_counts: dict[str, int] = {info["name"]: 0 for info in dataset_infos} - seen_receipts: dict[str, set[str]] = {info["name"]: set() for info in dataset_infos} - tokens_by_dataset: dict[str, int] = {info["name"]: 0 for info in dataset_infos} - while not stop_event.is_set(): - for info in dataset_infos: - name = info["name"] - receipts_dir = info["receipts_dir"] - try: - if not fs.exists(receipts_dir): - continue - entries = [p for p in fs.ls(receipts_dir, detail=False) if str(p).endswith(".json")] - count = len(entries) - except Exception: - continue - - last = last_counts.get(name, 0) - if count > last: - new_entries = [] - for p in entries: - if p not in seen_receipts[name]: - seen_receipts[name].add(p) - new_entries.append(p) - - for receipt_path in new_entries: - try: - with fs.open(receipt_path, "r") as f: - receipt = json.load(f) - tokens_by_dataset[name] += _extract_tokens(receipt) - except Exception: - pass - - for _ in range(count - last): - live_status.advance_dataset(name) - last_counts[name] = count - live_status.report_tokens(name, tokens_by_dataset[name]) - stop_event.wait(10.0) - - poll_thread = threading.Thread(target=_poll_receipts, daemon=True) - poll_thread.start() - - wandb_thread = None - if wandb_log_downloads: - wandb_thread = threading.Thread( - target=_poll_download_stats, - args=(fs, output_root, stop_event, wandb_download_log_interval_sec), - daemon=True, - ) - wandb_thread.start() - - _log_memory_status("Before run_pipeline") - try: - pipelines_v1.run_pipeline(pipeline_spec) - finally: - _log_memory_status("After run_pipeline (in finally)") - stop_event.set() - poll_thread.join(timeout=2.0) - if wandb_thread is not None: - wandb_thread.join(timeout=2.0) - - _log_memory_status("After thread cleanup") - - # Force garbage collection to release memory from pipeline - gc.collect() - _log_memory_status("After gc.collect()") - - for info in dataset_infos: - name = info["name"] - _log_memory_status(f"Before aggregating {name}") - stats = _aggregate_packed_stats_from_receipts( - dataset_dir=info["dataset_dir"], - receipts_dir=info["receipts_dir"], - fs=fs, - ) - results[name] = stats - live_status.report_metrics( - name, - rows=stats.get("num_sequences", 0), - tokens=stats.get("total_tokens", 0), - ) - live_status.complete_dataset(name) - - _log_memory_status("After all aggregation - pipeline complete") - - -def _aggregate_jsonl_stats_from_receipts(*, dataset_dir: str, num_shards: int, fs) -> dict: - from nemotron.data_prep.pipeline import _aggregate_jsonl_stats as _agg - - return _agg(dataset_dir, num_shards, fs) - - -def _aggregate_packed_stats_from_receipts(*, dataset_dir: str, receipts_dir: str, fs) -> dict: - from nemotron.data_prep.pipeline import _aggregate_packed_stats as _agg - - return _agg(dataset_dir, receipts_dir, fs) - - -def _extract_tokens(receipt: dict) -> int: - return int(receipt.get("stats", {}).get("total_tokens", 0)) - - -def _make_wandb_stats_callback(): - """Create a callback function for logging pipeline stats to wandb. - - Returns a callback if wandb is active, None otherwise. - """ - try: - import wandb - except ImportError: - return None - - if wandb.run is None: - return None - - def _log_stats(stats: "PipelineStats") -> None: - """Log PipelineStats to wandb.""" - metrics = { - # Overall pipeline progress - "data_prep/pipeline_duration_min": stats.pipeline_duration_s / 60, - "data_prep/inputs_initial": stats.num_initial_input_tasks, - "data_prep/inputs_remaining": stats.num_input_tasks_remaining, - "data_prep/outputs_total": stats.num_outputs, - "data_prep/main_loop_rate_hz": stats.main_loop_rate_hz, - # Cluster resources - "data_prep/cluster_cpus_total": stats.cluster.total.num_cpus, - "data_prep/cluster_cpus_available": stats.cluster.available.num_cpus, - "data_prep/cluster_gpus_total": stats.cluster.total.num_gpus, - "data_prep/cluster_gpus_available": stats.cluster.available.num_gpus, - "data_prep/cluster_memory_total_gb": stats.cluster.total.memory / 1e9, - "data_prep/cluster_memory_available_gb": stats.cluster.available.memory / 1e9, - } - - # Progress percentage - if stats.num_initial_input_tasks > 0: - progress = 1.0 - (stats.num_input_tasks_remaining / stats.num_initial_input_tasks) - metrics["data_prep/pipeline_progress"] = progress - - # Per-stage resource usage - for stage_name, usage in stats.resource_usage_per_stage.items(): - safe_name = stage_name.replace(" ", "_").replace("-", "_") - metrics[f"data_prep/stage_{safe_name}_cpu_pct"] = usage.cpu_utilization - metrics[f"data_prep/stage_{safe_name}_memory_gb"] = usage.memory_usage / 1e9 - metrics[f"data_prep/stage_{safe_name}_actor_count"] = usage.actor_count - - # Per-stage state from actor pools - for pool_stats in stats.actor_pools: - safe_name = pool_stats.name.replace(" ", "_").replace("-", "_") - # Actor counts - metrics[f"data_prep/stage_{safe_name}_actors_target"] = pool_stats.actor_stats.target - metrics[f"data_prep/stage_{safe_name}_actors_ready"] = pool_stats.actor_stats.ready - metrics[f"data_prep/stage_{safe_name}_actors_running"] = pool_stats.actor_stats.running - metrics[f"data_prep/stage_{safe_name}_actors_idle"] = pool_stats.actor_stats.idle - # Task stats - metrics[f"data_prep/stage_{safe_name}_tasks_completed"] = pool_stats.task_stats.total_completed - metrics[f"data_prep/stage_{safe_name}_input_queue_size"] = pool_stats.task_stats.input_queue_size - metrics[f"data_prep/stage_{safe_name}_output_queue_size"] = pool_stats.task_stats.output_queue_size - # Slot stats - metrics[f"data_prep/stage_{safe_name}_slots_used"] = pool_stats.slot_stats.num_used - metrics[f"data_prep/stage_{safe_name}_slots_empty"] = pool_stats.slot_stats.num_empty - # Speed - if pool_stats.processing_speed_tasks_per_second is not None: - metrics[f"data_prep/stage_{safe_name}_speed_tasks_per_sec"] = pool_stats.processing_speed_tasks_per_second - - wandb.log(metrics) - - return _log_stats - - -def _poll_download_stats(fs, output_root: str, stop_event: threading.Event, interval_sec: int) -> None: - try: - import wandb - except ImportError: - return - - if wandb.run is None: - return - - progress_dir = f"{output_root.rstrip('/')}/.xenna/downloads" - last_logged = 0.0 - - while not stop_event.is_set(): - now = time.time() - if now - last_logged < interval_sec: - stop_event.wait(1.0) - continue - last_logged = now - - try: - if not fs.exists(progress_dir): - continue - entries = fs.ls(progress_dir, detail=False) - except Exception: - continue - - total_completed = 0 - total_files = 0 - max_elapsed = 0.0 - max_rate = 0.0 - - for path in entries: - if not str(path).endswith(".json"): - continue - try: - with fs.open(path, "r") as f: - data = json.load(f) - except Exception: - continue - total_completed += int(data.get("completed", 0)) - total_files += int(data.get("total", 0)) - max_elapsed = max(max_elapsed, float(data.get("elapsed_sec", 0.0))) - max_rate = max(max_rate, float(data.get("rate", 0.0))) - - if total_files == 0: - continue - - wandb.log( - { - "data_prep/hf_download_completed": total_completed, - "data_prep/hf_download_total": total_files, - "data_prep/hf_download_rate": max_rate, - "data_prep/hf_download_elapsed_sec": max_elapsed, - } - ) diff --git a/src/nemotron/data_prep/xenna/stages.py b/src/nemotron/data_prep/xenna/stages.py index c328c868c..e1f0270bb 100644 --- a/src/nemotron/data_prep/xenna/stages.py +++ b/src/nemotron/data_prep/xenna/stages.py @@ -504,8 +504,10 @@ def process_data(self, tasks: list[ShardWorkItem]) -> list[ShardWorkItem]: last_report = start_time self._write_progress(completed, total_files, start_time) + failed_downloads: list[tuple[dict[str, str], Exception]] = [] + with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [ + futures_to_files = { executor.submit( _download_hf_file, file_info["repo_id"], @@ -514,15 +516,16 @@ def process_data(self, tasks: list[ShardWorkItem]) -> list[ShardWorkItem]: cache_dir, self._download_timeout_sec, self._max_retries, - ) + ): file_info for file_info in unique_files - ] + } - for future in as_completed(futures): + for future in as_completed(futures_to_files): + file_info = futures_to_files[future] try: future.result() - except Exception: - pass + except Exception as exc: + failed_downloads.append((file_info, exc)) completed += 1 now = time.perf_counter() if now - last_report >= 5.0 or completed == total_files: @@ -531,6 +534,18 @@ def process_data(self, tasks: list[ShardWorkItem]) -> list[ShardWorkItem]: last_report = now self._write_progress(completed, total_files, start_time) + # Report and fail on download errors - don't let PretrainShardStage run on missing files + if failed_downloads: + print(f"[Pre-download] ERROR: {len(failed_downloads)} downloads failed:") + for file_info, exc in failed_downloads[:10]: + print(f" - {file_info['repo_id']}/{file_info['filename']}: {type(exc).__name__}: {exc}") + if len(failed_downloads) > 10: + print(f" ... and {len(failed_downloads) - 10} more") + raise RuntimeError( + f"Pre-download failed: {len(failed_downloads)} files could not be downloaded. " + "Cannot proceed with tokenization - files would be missing from cache." + ) + self._write_progress(completed, total_files, start_time) return tasks diff --git a/src/nemotron/kit/cli/recipe.py b/src/nemotron/kit/cli/recipe.py index 866261677..00435ee69 100644 --- a/src/nemotron/kit/cli/recipe.py +++ b/src/nemotron/kit/cli/recipe.py @@ -20,6 +20,7 @@ from __future__ import annotations +import os import shutil import subprocess import sys @@ -229,6 +230,8 @@ def wrapper(ctx: typer.Context) -> None: # Execute based on mode if global_ctx.mode == "local": + # Set env vars so subprocess inherits them (wandb, HF tokens, etc.) + os.environ.update(env_vars) _execute_local(script_path, train_path, passthrough, torchrun=torchrun) else: _execute_nemo_run( @@ -637,6 +640,16 @@ def _build_executor( else: partition = env_config.get("batch_partition") or env_config.get("partition") + # Build container mounts, adding /lustre and Ray temp directory + mounts = list(env_config.get("mounts") or []) + # Mount /lustre for access to shared storage (HF cache, data, etc.) + mounts.append("/lustre:/lustre") + remote_job_dir = env_config.get("remote_job_dir") + if remote_job_dir: + # Ray temp directory mount (avoids filling container storage with Ray logs) + ray_temp_path = f"{remote_job_dir}/ray_temp" + mounts.append(f"{ray_temp_path}:/ray-cluster") + # Build executor kwargs, only including exclusive if True executor_kwargs: dict[str, Any] = { "account": env_config.get("account"), @@ -647,7 +660,7 @@ def _build_executor( "cpus_per_task": env_config.get("cpus_per_task"), "time": env_config.get("time", "04:00:00"), "container_image": container_image, - "container_mounts": env_config.get("mounts") or [], + "container_mounts": mounts, "tunnel": tunnel, "packager": packager, "mem": env_config.get("mem"), @@ -692,11 +705,9 @@ def _build_env_vars(job_config: Any, env_config: dict | None = None) -> dict: # Set NEMO_RUN_DIR to actual lustre path for output paths # This ensures artifacts store the real path, not /nemo_run container mount + # Only set for remote execution - local execution uses default paths if env_config and env_config.get("remote_job_dir"): env_vars["NEMO_RUN_DIR"] = env_config["remote_job_dir"] - else: - # Fallback to container mount if remote_job_dir not configured - env_vars["NEMO_RUN_DIR"] = "/nemo_run" # Set HF_HOME to remote_job_dir/hf if not explicitly set by user # This ensures HuggingFace downloads go to Lustre storage with sufficient space diff --git a/src/nemotron/kit/run.py b/src/nemotron/kit/run.py index 7edcc075e..9135dfa66 100644 --- a/src/nemotron/kit/run.py +++ b/src/nemotron/kit/run.py @@ -540,6 +540,15 @@ def build_executor(config: RunConfig, env_vars: dict[str, str] | None = None) -> tunnel = _build_tunnel(config) packager = _build_packager() + # Build container mounts, adding /lustre and Ray temp directory + mounts = list(config.mounts) + # Mount /lustre for access to shared storage (HF cache, data, etc.) + mounts.append("/lustre:/lustre") + if config.remote_job_dir: + # Ray temp directory mount (avoids filling container storage with Ray logs) + ray_temp_path = f"{config.remote_job_dir}/ray_temp" + mounts.append(f"{ray_temp_path}:/ray-cluster") + return run.SlurmExecutor( account=config.account, partition=config.partition, @@ -560,7 +569,7 @@ def build_executor(config: RunConfig, env_vars: dict[str, str] | None = None) -> gres=config.gres, array=config.array, container_image=config.container_image, - container_mounts=config.mounts, + container_mounts=mounts, tunnel=tunnel, packager=packager, env_vars=merged_env, diff --git a/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/data_blend_cache_test.json b/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/data_blend_cache_test.json new file mode 100644 index 000000000..a6147f787 --- /dev/null +++ b/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/data_blend_cache_test.json @@ -0,0 +1,11 @@ +{ + "datasets": [ + { + "name": "nemotron-math-3", + "path": "hf://nvidia/Nemotron-CC-Math-v1", + "subset": "3", + "text_field": "text", + "weight": 1.0 + } + ] +} diff --git a/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/default.yaml b/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/default.yaml index 5fd34caf8..a40271415 100644 --- a/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/default.yaml +++ b/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/default.yaml @@ -14,8 +14,10 @@ run: # Path to data blend JSON file blend_path: ${oc.env:PWD}/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/data_blend_raw.json -# Output directory for tokenized data (outputs to job directory for persistence) -output_dir: ${oc.env:PWD}/../output/stage0_pretrain +# Output directory for tokenized data +# Using fixed path to enable caching across job submissions +# (previously used ${oc.env:PWD}/../output which created new dir per job) +output_dir: /lustre/fsw/portfolios/coreai/users/mromeijn/output/nano3/stage0_pretrain # Number of output shards for parallel loading num_shards: 128 @@ -48,8 +50,13 @@ max_doc_tokens: null sample: null # Ray Data executor settings -# Set to 48 as balance between parallelism and memory on 172GB nodes -ray_data_max_actors: 48 +# Reduced from 48 to 32 to avoid OOM (each worker ~3.7GB) +ray_data_max_actors: 32 + +# Xenna pipeline settings +# Maximum workers for shard processing (each worker ~4GB memory) +# Set based on node memory. null = auto-scale (cosmos-xenna default) +xenna_max_shard_workers: 24 # Console output mode: 'simple' for periodic text updates, 'rich' for animated progress console_mode: simple diff --git a/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/tiny_xenna.yaml b/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/tiny_xenna.yaml new file mode 100644 index 000000000..b5f7e68d2 --- /dev/null +++ b/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/tiny_xenna.yaml @@ -0,0 +1,66 @@ +run: + env: + container: anyscale/ray:2.49.2-py312 + +# Tiny config for pretrain data preparation with Xenna execution engine +# +# Usage: +# python data_prep.py --config tiny_xenna + +# Path to data blend JSON file +blend_path: ${oc.env:PWD}/src/nemotron/recipes/nano3/stage0_pretrain/config/data_prep/data_blend_raw_small.json + +# Output directory for tokenized data +output_dir: ${oc.env:PWD}/../output/stage0_pretrain_tiny_xenna + +# Number of output shards - smaller for tiny config +num_shards: 4 + +# Number of shards for validation split +valid_shards: 1 + +# Number of shards for test split +test_shards: 1 + +# HuggingFace tokenizer model name +tokenizer_model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16 + +# Prepend BOS token to documents +add_bos: false + +# Append EOS token to documents +add_eos: true + +# Default text field name in datasets +text_field: text + +# Skip documents shorter than this (null = no limit) +min_doc_chars: null + +# Truncate documents longer than this (null = no limit) +max_doc_tokens: null + +# Limit rows per dataset for quick tests - small sample for tiny +sample: 100 + +# Console output mode +console_mode: simple + +# Interval in seconds for simple mode status updates +simple_log_interval_sec: 5 + +# Force new run, ignoring cache +force: true + +# Config name for artifact naming +config_name: tiny_xenna + +# Use Xenna execution engine for pipeline processing +execution_engine: xenna + +# Xenna-specific configuration +xenna: + wandb_log_pipeline_stats: true + wandb_log_downloads: true + wandb_download_log_interval_sec: 5 + max_concurrent_downloads: 8 diff --git a/src/nemotron/recipes/nano3/stage0_pretrain/prep_xenna.py b/src/nemotron/recipes/nano3/stage0_pretrain/prep_xenna.py index 5e42cd77c..eebc81c0f 100644 --- a/src/nemotron/recipes/nano3/stage0_pretrain/prep_xenna.py +++ b/src/nemotron/recipes/nano3/stage0_pretrain/prep_xenna.py @@ -64,6 +64,7 @@ class PreTrainDataPrepConfig: sample: int | None = None num_actors: int | None = None ray_data_max_actors: int | None = None + xenna_max_shard_workers: int | None = None force: bool = False config_name: str = "default" @@ -117,6 +118,7 @@ def run_data_prep_main(cfg: PreTrainDataPrepConfig) -> PretrainBlendsArtifact: console_mode=getattr(cfg, "console_mode", "simple"), simple_log_interval_sec=getattr(cfg, "simple_log_interval_sec", 30), ray_data_max_actors=cfg.ray_data_max_actors, + xenna_max_shard_workers=cfg.xenna_max_shard_workers, execution_engine="xenna", ) artifact = run_data_prep(data_prep_config) diff --git a/uv.lock b/uv.lock index 5cd06c598..b1237089d 100644 --- a/uv.lock +++ b/uv.lock @@ -2265,7 +2265,7 @@ dependencies = [ { name = "pyarrow" }, { name = "pydantic" }, { name = "pyyaml" }, - { name = "ray" }, + { name = "ray", extra = ["default"] }, { name = "rich" }, { name = "textual" }, { name = "tomli", marker = "python_full_version < '3.11'" }, @@ -2346,7 +2346,7 @@ requires-dist = [ { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0.0" }, { name = "pyyaml", specifier = ">=6.0" }, - { name = "ray", specifier = "==2.49.2" }, + { name = "ray", extras = ["default"], specifier = "==2.49.2" }, { name = "rich", specifier = ">=13.0.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" }, { name = "s3fs", marker = "extra == 'all'", specifier = ">=2024.0.0" },