diff --git a/docs/train/data-prep.md b/docs/train/data-prep.md index bf385367c..fad3726e2 100644 --- a/docs/train/data-prep.md +++ b/docs/train/data-prep.md @@ -446,6 +446,39 @@ format=JsonlOutputConfig(num_shards=64) Supported size formats: `"256MB"`, `"1G"`, `"500MiB"`, etc. +## Per-Dataset Shard Allocation + +When a blend includes multiple datasets, shard counts are now allocated per dataset +instead of using a single global count. The total shard budget comes from +`num_shards` or `shard_size`, and is distributed proportionally by dataset weight +and estimated size, with at least one shard per dataset and a cap based on the +number of input files (to avoid empty shards). Dataset weights still control +training mixture ratios; shard allocation is only a sizing heuristic. + +> **Note: Weight vs Shard Counts** +> +> - **Dataset.weight** (e.g., 0.7, 0.3): Controls *training-time sampling* in +> Megatron-Bridge. A blend with weights [0.7, 0.3] means 70% of training +> samples come from dataset 1 during training. +> +> - **Shard counts**: Controlled by `shard_size` (recommended) or explicit +> `num_shards`. These determine how many physical output files are created +> during data preparation, independent of weights. +> +> For blends with datasets of different sizes, use `shard_size="256MB"` instead +> of explicit `num_shards` to let each dataset get an appropriate shard count +> based on its size. + +`blend.json` now includes a `num_shards` map with the effective per-dataset counts: + +```json +{ + "data_paths": ["1.0", "/path/to/ds1/shard", "0.3", "/path/to/ds2/shard"], + "num_shards": {"ds1": 120, "ds2": 8}, + "split": "99990,8,2" +} +``` + ## Per-Split Output Generate separate train/valid/test outputs using `PerSplitConfig`: @@ -485,7 +518,12 @@ output/ { "train": [["1.0", "/path/to/train/shard_000000"], ["1.0", "/path/to/train/shard_000001"]], "valid": [["1.0", "/path/to/valid/shard_000000"]], - "test": [["1.0", "/path/to/test/shard_000000"]] + "test": [["1.0", "/path/to/test/shard_000000"]], + "num_shards": { + "train": {"train_ds": 128}, + "valid": {"valid_ds": 2}, + "test": {"test_ds": 2} + } } ``` diff --git a/src/nemotron/cli/nano3/data/import_/pretrain.py b/src/nemotron/cli/nano3/data/import_/pretrain.py index 39fa63f46..499798d3b 100644 --- a/src/nemotron/cli/nano3/data/import_/pretrain.py +++ b/src/nemotron/cli/nano3/data/import_/pretrain.py @@ -16,6 +16,7 @@ from __future__ import annotations +import json from pathlib import Path import typer @@ -61,12 +62,21 @@ def pretrain( # Initialize W&B init_wandb_if_configured(wandb_config, job_type="data-import", tags=["pretrain", "import"]) + dataset_shards = None + try: + with open(data_path) as f: + blend_data = json.load(f) + dataset_shards = blend_data.get("num_shards") + except Exception: + dataset_shards = None + # Create artifact with minimal required fields artifact_name = name or "nano3/pretrain/data" artifact = DataBlendsArtifact( path=data_path, total_tokens=0, total_sequences=0, + dataset_shards=dataset_shards, name=artifact_name, ) diff --git a/src/nemotron/cli/nano3/data/import_/sft.py b/src/nemotron/cli/nano3/data/import_/sft.py index 58bd7e060..79978d925 100644 --- a/src/nemotron/cli/nano3/data/import_/sft.py +++ b/src/nemotron/cli/nano3/data/import_/sft.py @@ -16,6 +16,7 @@ from __future__ import annotations +import json from pathlib import Path import typer @@ -71,12 +72,21 @@ def sft( # Initialize W&B init_wandb_if_configured(wandb_config, job_type="data-import", tags=["sft", "import"]) + dataset_shards = None + try: + with open(blend_path) as f: + blend_data = json.load(f) + dataset_shards = blend_data.get("num_shards") + except Exception: + dataset_shards = None + # Create artifact with minimal required fields artifact_name = name or "nano3/sft/data" artifact = DataBlendsArtifact( path=blend_path, total_tokens=0, total_sequences=0, + dataset_shards=dataset_shards, name=artifact_name, ) diff --git a/src/nemotron/data_prep/__init__.py b/src/nemotron/data_prep/__init__.py index ae22c5fb7..19bb63e4c 100644 --- a/src/nemotron/data_prep/__init__.py +++ b/src/nemotron/data_prep/__init__.py @@ -364,17 +364,25 @@ def run_data_prep( # Build output artifact - path points to output directory, blend_path to blend.json blend_json_path = result.output_dir / "blend.json" - artifact = artifact_class( - path=result.output_dir, - blend_path=str(blend_json_path), - total_tokens=result.total_tokens, - total_sequences=result.total_sequences, - elapsed_sec=result.elapsed_sec, - num_shards=num_shards, - source_datasets=source_datasets, - tokenizer_uri=tok_uri, - name=config.artifact_name, # Semantic name for W&B artifact naming - ) + artifact_kwargs = { + "path": result.output_dir, + "blend_path": str(blend_json_path), + "total_tokens": result.total_tokens, + "total_sequences": result.total_sequences, + "elapsed_sec": result.elapsed_sec, + "num_shards": num_shards, + "source_datasets": source_datasets, + "tokenizer_uri": tok_uri, + "name": config.artifact_name, # Semantic name for W&B artifact naming + } + # Optionally include per-dataset shard counts if supported by the artifact schema + if hasattr(artifact_class, "model_fields") and "dataset_shards" in artifact_class.model_fields: + all_split = result.splits.get("all") if result.splits else None + artifact_kwargs["dataset_shards"] = ( + all_split.dataset_shards if all_split is not None else None + ) + + artifact = artifact_class(**artifact_kwargs) artifact.save() # Mark wandb run as successful (before Ray shutdown to avoid socket noise) diff --git a/src/nemotron/data_prep/blend.py b/src/nemotron/data_prep/blend.py index 8781bb821..0b45aae57 100644 --- a/src/nemotron/data_prep/blend.py +++ b/src/nemotron/data_prep/blend.py @@ -28,7 +28,11 @@ class Dataset(BaseModel): Attributes: name: Unique identifier for this dataset path: Data location (hf://repo/name, s3://bucket/prefix, /local/path) - weight: Relative weight in the blend (default: 1.0) + weight: Training-time sampling weight (default: 1.0). Controls how + Megatron-Bridge samples from datasets during training, NOT how + many shards are created during data prep. For example, weights + [0.7, 0.3] mean 70% of training samples come from dataset 1. + Shard counts are determined by dataset size and shard_size config. split: HuggingFace split name (required for hf:// paths) subset: HuggingFace config/subset name text_field: Field containing text to tokenize (default: "text") diff --git a/src/nemotron/data_prep/config.py b/src/nemotron/data_prep/config.py index 60e839b9f..0afba35ac 100644 --- a/src/nemotron/data_prep/config.py +++ b/src/nemotron/data_prep/config.py @@ -67,8 +67,13 @@ class BinIdxOutputConfig: Attributes: format: Format identifier (always "binidx") - shard_size: Target size per shard (e.g., "256MB"). Mutually exclusive with num_shards. - num_shards: Exact number of output shards. Mutually exclusive with shard_size. + shard_size: Target size per shard (e.g., "256MB"). When set, shard count + is computed per-dataset based on individual dataset sizes. This is + recommended for blends with datasets of varying sizes, as it prevents + empty shard files for small datasets. Mutually exclusive with num_shards. + num_shards: Exact number of output shards applied to ALL datasets. + Use shard_size instead when datasets have very different sizes to + avoid empty shards. Mutually exclusive with shard_size. dtype: Token dtype (int32, int64, uint16) """ diff --git a/src/nemotron/data_prep/pipeline.py b/src/nemotron/data_prep/pipeline.py index 61ea18399..6fe527be1 100644 --- a/src/nemotron/data_prep/pipeline.py +++ b/src/nemotron/data_prep/pipeline.py @@ -35,6 +35,7 @@ InternalOutputConfig, InternalTokenizerConfig, JsonlOutputConfig, + OutputFormat, OutputConfig, PackedOutputConfig, PipelineConfig, @@ -149,9 +150,12 @@ class SplitResult: run_hash: Unique hash for this processing run output_dir: Directory containing tokenized shards data_paths: Megatron-Bridge format ["weight", "path", ...] - num_shards: Number of shards produced + num_shards: Total number of shards produced across datasets total_tokens: Total tokens across all shards total_sequences: Total sequences (documents) processed + dataset_shards: Per-dataset shard counts {dataset_name: num_shards}. + When using shard_size config, each dataset gets an appropriate + shard count based on its size, avoiding empty shards for small datasets. """ name: str @@ -161,6 +165,7 @@ class SplitResult: num_shards: int total_tokens: int total_sequences: int + dataset_shards: dict[str, int] | None = None @dataclass @@ -355,12 +360,13 @@ def _tokenize_single(blend: DataBlend, config: PipelineConfig) -> PipelineResult split_name="all", config=config, ) + dataset_shards = split_result.dataset_shards or {} # Check if per-split output mode is enabled if config.per_split is not None and config.per_split.enabled: blend_data = _distribute_shards_to_splits( data_paths=split_result.data_paths, - num_shards=split_result.num_shards, + dataset_shards=dataset_shards, valid_shards=config.per_split.valid_shards, test_shards=config.per_split.test_shards, ) @@ -374,6 +380,9 @@ def _tokenize_single(blend: DataBlend, config: PipelineConfig) -> PipelineResult is_per_split = False split_ratio = config.split + if dataset_shards: + blend_data["num_shards"] = dataset_shards + blend_path = config.output.dir / "blend.json" _write_json(blend_path, blend_data) @@ -389,7 +398,7 @@ def _tokenize_single(blend: DataBlend, config: PipelineConfig) -> PipelineResult def _distribute_shards_to_splits( data_paths: list[str], - num_shards: int, + dataset_shards: dict[str, int], valid_shards: int = 1, test_shards: int = 1, seed: int = 42, @@ -407,7 +416,7 @@ def _distribute_shards_to_splits( Args: data_paths: Megatron-Bridge format path list ["weight", "path", ...] - num_shards: Total number of shards per dataset + dataset_shards: Per-dataset shard counts {dataset_name: num_shards} valid_shards: Number of shards for validation (total, not per-dataset) test_shards: Number of shards for test (total, not per-dataset) seed: Random seed for reproducible shard selection @@ -429,8 +438,20 @@ def _distribute_shards_to_splits( # Collect ALL shards from ALL datasets into one pool # Each entry is (weight, shard_path) where shard_path has the _XXXX suffix all_shards: list[tuple[str, str]] = [] - for weight, prefix in pairs: - for shard_idx in range(num_shards): + dataset_names = list(dataset_shards.keys()) + if len(dataset_names) != len(pairs): + # Fallback: attempt to infer dataset names from path prefixes + dataset_names = [] + for _, prefix in pairs: + if "/datasets/" in prefix: + remainder = prefix.split("/datasets/", 1)[1] + dataset_names.append(remainder.split("/", 1)[0]) + else: + dataset_names.append("") + + for (weight, prefix), name in zip(pairs, dataset_names, strict=False): + shard_count = dataset_shards.get(name, 0) + for shard_idx in range(shard_count): all_shards.append((weight, f"{prefix}_{shard_idx:06d}")) # Use seeded RNG for reproducibility @@ -473,7 +494,8 @@ def _tokenize_per_split(blend: DataBlend, config: PipelineConfig) -> PipelineRes format compatible with Megatron-Bridge's per_split_data_args_path parameter. """ splits: dict[str, SplitResult] = {} - blend_data: dict[str, list[str]] = {} + blend_data: dict[str, list[str] | dict] = {} + dataset_shards_by_split: dict[str, dict[str, int]] = {} for split_name, datasets in blend.splits.items(): # Create split-specific output config (preserve format from parent config) @@ -504,6 +526,11 @@ def _tokenize_per_split(blend: DataBlend, config: PipelineConfig) -> PipelineRes splits[split_name] = split_result # Use simple key names for Megatron-Bridge compatibility blend_data[split_name] = split_result.data_paths + if split_result.dataset_shards: + dataset_shards_by_split[split_name] = split_result.dataset_shards + + if dataset_shards_by_split: + blend_data["num_shards"] = dataset_shards_by_split # Generate combined blend.json blend_path = config.output.dir / "blend.json" @@ -535,6 +562,16 @@ def _process_split( # Get filesystem fs, base_path = get_filesystem(str(config.output.dir)) + # Compute per-dataset shard counts (size + weight aware) + from nemotron.data_prep.blend import DataBlend + + dataset_shards = compute_dataset_shard_counts( + DataBlend.from_datasets(*datasets), + config.output.format, + fs, + ) + total_shards = sum(dataset_shards.values()) if dataset_shards else 0 + # Build internal config dict for planning/processing pipeline_dict = { "datasets": [ @@ -557,6 +594,8 @@ def _process_split( }, "output": { "num_shards": config.output.format.num_shards, + "shard_size": getattr(config.output.format, "shard_size", None), + "dataset_shards": dataset_shards, "dtype": config.output.format.dtype, "min_doc_chars": config.output.min_doc_chars, "max_doc_tokens": config.output.max_doc_tokens, @@ -579,7 +618,13 @@ def _process_split( write_json(fs, f"{run_dir}/config.json", run_config) tokenizer_config = InternalTokenizerConfig(**pipeline_dict["tokenizer"]) - output_config = InternalOutputConfig(**pipeline_dict["output"]) + output_config = InternalOutputConfig( + num_shards=max(dataset_shards.values(), default=1), + dtype=pipeline_dict["output"]["dtype"], + min_doc_chars=pipeline_dict["output"]["min_doc_chars"], + max_doc_tokens=pipeline_dict["output"]["max_doc_tokens"], + max_rows=pipeline_dict["output"]["max_rows"], + ) # Planning phase con.planning_header() @@ -593,10 +638,18 @@ def _process_split( dataset_config = DatasetConfig(**dataset_entry) name = dataset_config.name + dataset_output_config = InternalOutputConfig( + num_shards=dataset_shards.get(name, output_config.num_shards), + dtype=output_config.dtype, + min_doc_chars=output_config.min_doc_chars, + max_doc_tokens=output_config.max_doc_tokens, + max_rows=output_config.max_rows, + ) + # Create or load plan plan = _load_or_create_plan( dataset_config=dataset_config, - output_config=output_config, + output_config=dataset_output_config, tokenizer_config=tokenizer_config, config_hash=config_hash, run_dir=run_dir, @@ -659,6 +712,7 @@ def _process_split( receipts_dir=receipts_dir, pending_indices=pending_indices, cached_stats=cached_stats, + num_shards=plan.num_shards, ) ) @@ -708,7 +762,14 @@ def _process_split( # Generate outputs _generate_manifest( - run_dir, pipeline_dict, results, plan_hashes, run_hash, resolved_tokenizer, fs + run_dir, + pipeline_dict, + results, + plan_hashes, + run_hash, + resolved_tokenizer, + dataset_shards, + fs, ) # Build data_paths in Megatron-Bridge format @@ -728,7 +789,8 @@ def _process_split( run_hash=run_hash, output_dir=Path(config.output.dir), data_paths=data_paths, - num_shards=config.output.format.num_shards, + num_shards=total_shards, + dataset_shards=dataset_shards, total_tokens=sum(r.get("total_tokens", 0) for r in results.values()), total_sequences=sum(r.get("total_sequences", 0) for r in results.values()), ) @@ -759,6 +821,7 @@ class _DatasetExecutionPlan: receipts_dir: str pending_indices: list[int] cached_stats: dict + num_shards: int class _PlanDriftError(Exception): @@ -1370,6 +1433,7 @@ def _generate_manifest( plan_hashes: dict[str, str], run_hash: str, resolved_tokenizer: dict | None, + dataset_shards: dict[str, int], fs, ): """Generate manifest summary.""" @@ -1381,11 +1445,13 @@ def _generate_manifest( "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "run_hash": run_hash, "tokenizer": tokenizer_info, + "dataset_shards": dataset_shards, + "plan_hashes": plan_hashes, "datasets": {}, } for name, stats in results.items(): - num_shards = config["output"]["num_shards"] + num_shards = dataset_shards.get(name, config["output"].get("num_shards")) completed = stats.get("num_shards_completed", 0) status = "completed" if completed == num_shards else "in_progress" @@ -1417,6 +1483,8 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe # Get filesystem fs, base_path = get_filesystem(str(config.output.dir)) + dataset_shards = compute_dataset_shard_counts(blend, format_config, fs) + # Compute run hash (different from tokenization - no tokenizer info) run_config = { "datasets": [ @@ -1432,6 +1500,9 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe "output": { "format": "jsonl", "compression": format_config.compression, + "num_shards": format_config.num_shards, + "shard_size": format_config.shard_size, + "dataset_shards": dataset_shards, }, } if config.sample is not None: @@ -1446,9 +1517,6 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe # Freeze config write_json(fs, f"{run_dir}/config.json", run_config) - # Determine num_shards from format config - num_shards = _resolve_num_shards(format_config, blend, fs) - # For JSONL, we use a simpler processing model: # Each dataset's files are distributed across shards and written directly results = {} @@ -1459,11 +1527,12 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe # Planning phase: discover files and check cache for all datasets from nemotron.data_prep.discovery import discover_input_files, get_dataset_metadata - dataset_plans: list[tuple] = [] # (dataset, dataset_dir, files, cached_stats) + dataset_plans: list[tuple] = [] # (dataset, dataset_dir, files, cached_stats, num_shards) plan_infos = [] for dataset in blend.datasets: name = dataset.name + num_shards = dataset_shards.get(name, 1) # Create dataset directory dataset_dir = f"{run_dir}/datasets/{name}" @@ -1504,7 +1573,7 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe ) ) - dataset_plans.append((dataset, dataset_dir, files, cached_stats)) + dataset_plans.append((dataset, dataset_dir, files, cached_stats, num_shards)) # Show plan summary (auto-detect workers from cluster) con.plan_summary(plan_infos, run_hash) @@ -1512,13 +1581,13 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe # Execution phase has_work = any( num_shards - cached_stats.get("num_shards_completed", 0) > 0 - for _, _, _, cached_stats in dataset_plans + for _, _, _, cached_stats, num_shards in dataset_plans ) if has_work: con.execution_header() - for dataset, dataset_dir, files, cached_stats in dataset_plans: + for dataset, dataset_dir, files, cached_stats, num_shards in dataset_plans: name = dataset.name # Process with actors @@ -1547,6 +1616,8 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe # Generate blend.json blend_data: dict = {"data_paths": data_paths} + if dataset_shards: + blend_data["num_shards"] = dataset_shards if config.split: blend_data["split"] = config.split @@ -1562,7 +1633,8 @@ def _process_jsonl_blend(blend: DataBlend, config: PipelineConfig) -> PipelineRe run_hash=run_hash, output_dir=config.output.dir, data_paths=data_paths, - num_shards=num_shards, + num_shards=sum(dataset_shards.values()) if dataset_shards else 0, + dataset_shards=dataset_shards, total_tokens=0, # No tokenization total_sequences=sum(r.get("num_records", 0) for r in results.values()), ) @@ -1590,6 +1662,145 @@ def _resolve_num_shards(format_config, blend: DataBlend, fs) -> int: return 128 +def compute_dataset_shard_counts( + blend: DataBlend, + format_config: OutputFormat, + fs, +) -> dict[str, int]: + """Compute per-dataset shard counts from a total shard budget. + + Weights still control sampling/mixture ratios at training time. Here they are + only used as a heuristic to allocate the shard budget across datasets. + """ + from nemotron.data_prep.discovery import discover_input_files + from nemotron.data_prep.utils.size import compute_num_shards + + if blend.datasets is None: + datasets = [] + for split_datasets in blend.splits.values(): + datasets.extend(split_datasets) + else: + datasets = blend.datasets + + if not datasets: + return {} + + dataset_infos: list[dict] = [] + total_bytes = 0 + any_size = False + + for dataset in datasets: + try: + dataset_config = DatasetConfig( + name=dataset.name, + path=dataset.path, + split=dataset.split, + subset=dataset.subset, + text_field=dataset.text_field, + ) + files = discover_input_files(dataset_config, fs) + size_bytes = sum(f.size for f in files) + file_count = len(files) + except Exception: + size_bytes = 0 + file_count = 0 + + dataset_infos.append( + { + "name": dataset.name, + "weight": dataset.weight, + "size_bytes": size_bytes, + "file_count": file_count, + } + ) + total_bytes += size_bytes + if size_bytes > 0: + any_size = True + + # Total shard budget + requested_num_shards = getattr(format_config, "num_shards", None) + shard_size = getattr(format_config, "shard_size", None) + if requested_num_shards is not None: + total_budget = requested_num_shards + elif shard_size is not None: + total_budget = compute_num_shards(total_bytes, shard_size) + else: + total_budget = 128 + + # Ensure at least one shard per dataset + total_budget = max(total_budget, len(dataset_infos)) + + # Allocation weights: size-aware if available, else fallback to weights + if any_size and total_bytes > 0: + values = { + info["name"]: info["weight"] * info["size_bytes"] for info in dataset_infos + } + else: + values = {info["name"]: info["weight"] for info in dataset_infos} + + total_value = sum(values.values()) + if total_value <= 0: + values = {info["name"]: 1.0 for info in dataset_infos} + total_value = float(len(dataset_infos)) + + remaining = total_budget - len(dataset_infos) + if remaining < 0: + remaining = 0 + + # Initial allocation: 1 per dataset + proportional extras + ideal_extras = { + name: (remaining * values[name] / total_value) if remaining else 0.0 + for name in values + } + extras_floor = {name: int(extra) for name, extra in ideal_extras.items()} + remainder = remaining - sum(extras_floor.values()) + + assigned = {name: 1 + extras_floor[name] for name in values} + + # Distribute remainder by largest fractional part (deterministic) + fractional_order = sorted( + dataset_infos, + key=lambda info: ( + -(ideal_extras[info["name"]] - extras_floor[info["name"]]), + info["name"], + ), + ) + for info in fractional_order: + if remainder <= 0: + break + name = info["name"] + assigned[name] += 1 + remainder -= 1 + + # Cap by file count to avoid empty shard assignments + capacities = { + info["name"]: max(1, info["file_count"]) for info in dataset_infos + } + freed = 0 + for name, cap in capacities.items(): + if assigned[name] > cap: + freed += assigned[name] - cap + assigned[name] = cap + + # Redistribute freed shards to datasets with remaining capacity + if freed > 0: + while freed > 0: + progressed = False + for info in fractional_order: + if freed <= 0: + break + name = info["name"] + if assigned[name] < capacities[name]: + assigned[name] += 1 + freed -= 1 + progressed = True + if not progressed: + break + + # Preserve dataset order for determinism + return {info["name"]: assigned[info["name"]] for info in dataset_infos} + + def _estimate_blend_bytes(blend: DataBlend, fs) -> int: """Estimate total bytes in blend for shard planning.""" from nemotron.data_prep.discovery import discover_input_files @@ -1738,6 +1949,8 @@ def _process_packed_blend(blend: DataBlend, config: PipelineConfig) -> PipelineR # Get filesystem fs, base_path = get_filesystem(str(config.output.dir)) + dataset_shards = compute_dataset_shard_counts(blend, format_config, fs) + # Compute run hash (includes tokenizer, pack_size, algorithm) run_config = { "datasets": [ @@ -1763,6 +1976,9 @@ def _process_packed_blend(blend: DataBlend, config: PipelineConfig) -> PipelineR "pack_size": format_config.pack_size, "algorithm": format_config.algorithm, "dtype": format_config.dtype, + "num_shards": format_config.num_shards, + "shard_size": format_config.shard_size, + "dataset_shards": dataset_shards, }, } if config.sample is not None: @@ -1777,9 +1993,6 @@ def _process_packed_blend(blend: DataBlend, config: PipelineConfig) -> PipelineR # Freeze config write_json(fs, f"{run_dir}/config.json", run_config) - # Determine num_shards from format config - num_shards = _resolve_num_shards(format_config, blend, fs) - # Resolve tokenizer to get SHA for determinism from nemotron.data_prep.planning import resolve_tokenizer @@ -1794,6 +2007,7 @@ def _process_packed_blend(blend: DataBlend, config: PipelineConfig) -> PipelineR for dataset in blend.datasets: name = dataset.name + num_shards = dataset_shards.get(name, 1) # Create dataset directory structure dataset_dir = f"{run_dir}/datasets/{name}" @@ -1846,6 +2060,8 @@ def _process_packed_blend(blend: DataBlend, config: PipelineConfig) -> PipelineR # Generate blend.json blend_data: dict = {"data_paths": data_paths} + if dataset_shards: + blend_data["num_shards"] = dataset_shards if config.split: blend_data["split"] = config.split @@ -1861,7 +2077,8 @@ def _process_packed_blend(blend: DataBlend, config: PipelineConfig) -> PipelineR run_hash=run_hash, output_dir=config.output.dir, data_paths=data_paths, - num_shards=num_shards, + num_shards=sum(dataset_shards.values()) if dataset_shards else 0, + dataset_shards=dataset_shards, total_tokens=sum(r.get("total_tokens", 0) for r in results.values()), total_sequences=sum(r.get("num_sequences", 0) for r in results.values()), ) @@ -2015,6 +2232,8 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin # Get filesystem fs, base_path = get_filesystem(str(config.output.dir)) + dataset_shards = compute_dataset_shard_counts(blend, format_config, fs) + # Compute run hash (includes tokenizer, pack_size, algorithm, chat_template) run_config = { "datasets": [ @@ -2042,6 +2261,9 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin "chat_template": format_config.chat_template, "messages_field": format_config.messages_field, "tools_field": format_config.tools_field, + "num_shards": format_config.num_shards, + "shard_size": format_config.shard_size, + "dataset_shards": dataset_shards, }, } if config.sample is not None: @@ -2056,9 +2278,6 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin # Freeze config write_json(fs, f"{run_dir}/config.json", run_config) - # Determine num_shards from format config - num_shards = _resolve_num_shards(format_config, blend, fs) - # Resolve tokenizer to get SHA for determinism from nemotron.data_prep.planning import resolve_tokenizer @@ -2072,9 +2291,10 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin con.planning_header() # Discover files for all datasets - dataset_plans: list[tuple] = [] # (dataset, dataset_dir, receipts_dir, files) + dataset_plans: list[tuple] = [] # (dataset, dataset_dir, receipts_dir, files, num_shards) for dataset in blend.datasets: name = dataset.name + num_shards = dataset_shards.get(name, 1) # Create dataset directory structure dataset_dir = f"{run_dir}/datasets/{name}" @@ -2095,11 +2315,11 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin # Display discovered info logger.info(f"Discovered dataset '{name}' with {len(files)} files") - dataset_plans.append((dataset, dataset_dir, receipts_dir, files)) + dataset_plans.append((dataset, dataset_dir, receipts_dir, files, num_shards)) # Build plan info for display plan_infos = [] - for dataset, dataset_dir, receipts_dir, files in dataset_plans: + for dataset, dataset_dir, receipts_dir, files, num_shards in dataset_plans: # Check cached stats cached_stats = _aggregate_packed_stats(dataset_dir, receipts_dir, fs) cached_shards = cached_stats.get("num_shards_completed", 0) @@ -2124,7 +2344,7 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin con.plan_summary(plan_infos, run_hash) # Execution phase - has_work = any(len(files) > 0 for _, _, _, files in dataset_plans) + has_work = any(len(files) > 0 for _, _, _, files, _ in dataset_plans) if has_work: con.execution_header() @@ -2158,7 +2378,9 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin # Create live status panel with all datasets live_status = con.create_live_status( datasets=[ - (dataset.name, num_shards) for dataset, _, _, files in dataset_plans if files + (dataset.name, num_shards) + for dataset, _, _, files, num_shards in dataset_plans + if files ], run_hash=run_hash, ) @@ -2175,14 +2397,36 @@ def _process_chat_sft_blend(blend: DataBlend, config: PipelineConfig) -> Pipelin all_tasks: list[ tuple[str, str, str, int, list] ] = [] # (name, dataset_dir, receipts_dir, shard_idx, files) - for dataset, dataset_dir, receipts_dir, files in dataset_plans: + dataset_pending_counts: dict[str, int] = {} + dataset_completed_counts: dict[str, int] = {} + + for dataset, dataset_dir, receipts_dir, files, num_shards in dataset_plans: if not files: continue - # Each dataset gets 1 shard (since num_shards is computed per-dataset with 1 file) + + live_status.start_dataset(dataset.name) + dataset_pending_counts[dataset.name] = num_shards + dataset_completed_counts[dataset.name] = 0 + # Convert files to dicts for Ray serialization files_as_dicts = [asdict(f) if hasattr(f, "__dict__") else f for f in files] - all_tasks.append((dataset.name, dataset_dir, receipts_dir, 0, files_as_dicts)) - live_status.start_dataset(dataset.name) + + # Distribute files across shards (round-robin) + shard_assignments: dict[int, list] = {i: [] for i in range(num_shards)} + for i, file_info in enumerate(files_as_dicts): + shard_idx = i % num_shards + shard_assignments[shard_idx].append(file_info) + + for shard_idx in range(num_shards): + all_tasks.append( + ( + dataset.name, + dataset_dir, + receipts_dir, + shard_idx, + shard_assignments[shard_idx], + ) + ) # Submit all tasks with backpressure num_actors = len(actors) @@ -2226,23 +2470,25 @@ def submit_task(task: tuple) -> None: # Update progress live_status.advance_dataset(name) - - # Aggregate stats for this dataset - stats = _aggregate_packed_stats(dataset_dir, receipts_dir, fs) - results[name] = stats - live_status.report_metrics( - name, - rows=stats.get("num_sequences", 0), - tokens=stats.get("total_tokens", 0), - ) - live_status.complete_dataset(name) + dataset_completed_counts[name] = dataset_completed_counts.get(name, 0) + 1 + + if dataset_completed_counts[name] >= dataset_pending_counts.get(name, 0): + # Aggregate stats for this dataset + stats = _aggregate_packed_stats(dataset_dir, receipts_dir, fs) + results[name] = stats + live_status.report_metrics( + name, + rows=stats.get("num_sequences", 0), + tokens=stats.get("total_tokens", 0), + ) + live_status.complete_dataset(name) # Submit next task if available if task_queue: submit_task(task_queue.pop(0)) # Build data_paths for all completed datasets - for dataset, dataset_dir, receipts_dir, files in dataset_plans: + for dataset, dataset_dir, receipts_dir, files, _ in dataset_plans: if not files: continue weight = dataset.weight @@ -2257,7 +2503,7 @@ def submit_task(task: tuple) -> None: ray.kill(actor) else: # No work to do - all datasets empty or cached - for dataset, dataset_dir, receipts_dir, files in dataset_plans: + for dataset, dataset_dir, receipts_dir, files, _ in dataset_plans: stats = _aggregate_packed_stats(dataset_dir, receipts_dir, fs) results[dataset.name] = stats weight = dataset.weight @@ -2271,7 +2517,7 @@ def submit_task(task: tuple) -> None: if config.per_split is not None and config.per_split.enabled: blend_data = _distribute_shards_to_splits( data_paths=data_paths, - num_shards=num_shards, + dataset_shards=dataset_shards, valid_shards=config.per_split.valid_shards, test_shards=config.per_split.test_shards, ) @@ -2284,6 +2530,9 @@ def submit_task(task: tuple) -> None: is_per_split = False split_ratio = config.split + if dataset_shards: + blend_data["num_shards"] = dataset_shards + blend_path = config.output.dir / "blend.json" _write_json(blend_path, blend_data) @@ -2296,7 +2545,8 @@ def submit_task(task: tuple) -> None: run_hash=run_hash, output_dir=config.output.dir, data_paths=data_paths, - num_shards=num_shards, + num_shards=sum(dataset_shards.values()) if dataset_shards else 0, + dataset_shards=dataset_shards, total_tokens=sum(r.get("total_tokens", 0) for r in results.values()), total_sequences=sum(r.get("num_sequences", 0) for r in results.values()), ) diff --git a/src/nemotron/data_prep/planning.py b/src/nemotron/data_prep/planning.py index be793a256..d918447c1 100644 --- a/src/nemotron/data_prep/planning.py +++ b/src/nemotron/data_prep/planning.py @@ -208,6 +208,20 @@ def create_shard_plan( if not files: raise ValueError(f"No input files found for {dataset_config.name}") + # Clamp requested shards to avoid empty assignments + # (each shard should have at least one file assigned) + requested_shards = output_config.num_shards + effective_shards = max(1, min(requested_shards, len(files))) + + if effective_shards < requested_shards: + logger = logging.getLogger(__name__) + logger.warning( + f"Dataset '{dataset_config.name}' has {len(files)} files but " + f"{requested_shards} shards requested. Using {effective_shards} shards " + f"to avoid empty outputs. Consider using 'shard_size' instead of " + f"explicit 'num_shards' for blends with varied dataset sizes." + ) + # Resolve tokenizer to immutable revision resolved_tokenizer = resolve_tokenizer(tokenizer_config) @@ -215,7 +229,7 @@ def create_shard_plan( source_fingerprint = compute_source_fingerprint(files, dataset_config) # Create size-balanced assignments - assignments = create_size_balanced_assignments(files, output_config.num_shards) + assignments = create_size_balanced_assignments(files, effective_shards) # Determinism constraints determinism_constraints = { @@ -230,7 +244,7 @@ def create_shard_plan( plan_content = json.dumps( { "dataset_name": dataset_config.name, - "num_shards": output_config.num_shards, + "num_shards": effective_shards, "source_fingerprint": source_fingerprint, "resolved_tokenizer": resolved_tokenizer, "determinism_constraints": determinism_constraints, @@ -246,7 +260,7 @@ def create_shard_plan( created_at=datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), plan_hash=plan_hash, dataset_name=dataset_config.name, - num_shards=output_config.num_shards, + num_shards=effective_shards, source_fingerprint=source_fingerprint, config_hash=config_hash, determinism_constraints=determinism_constraints, diff --git a/src/nemotron/kit/artifacts/data_blends.py b/src/nemotron/kit/artifacts/data_blends.py index 6e48f6ace..09814e7cf 100644 --- a/src/nemotron/kit/artifacts/data_blends.py +++ b/src/nemotron/kit/artifacts/data_blends.py @@ -52,6 +52,12 @@ class DataBlendsArtifact(Artifact): int | None, Field(default=None, ge=0, description="Tokens in test split") ] + # Per-dataset shard counts (optional, populated when available) + dataset_shards: Annotated[ + dict[str, int] | None, + Field(default=None, description="Per-dataset shard counts"), + ] + # Source datasets for lineage tracking # Accepts InputDatasetInfo (with metadata) or str (URI only, for backwards compat) source_datasets: Annotated[ diff --git a/tests/recipes/nano3/stage0_pretrain/test_data_prep_train_integration.py b/tests/recipes/nano3/stage0_pretrain/test_data_prep_train_integration.py index d25c95fd1..21c281cda 100644 --- a/tests/recipes/nano3/stage0_pretrain/test_data_prep_train_integration.py +++ b/tests/recipes/nano3/stage0_pretrain/test_data_prep_train_integration.py @@ -43,11 +43,14 @@ class TestNano3DataPrepTrainIntegration: def test_distribute_shards_produces_valid_per_split_format(self): """Test _distribute_shards_to_splits produces correct format.""" - data_paths = ["1.0", "/path/to/shard", "0.5", "/path/to/other"] + data_paths = [ + "1.0", "/path/to/datasets/ds1/plan/shard", + "0.5", "/path/to/datasets/ds2/plan/shard", + ] result = _distribute_shards_to_splits( data_paths=data_paths, - num_shards=4, + dataset_shards={"ds1": 2, "ds2": 2}, # 4 total shards valid_shards=1, test_shards=1, ) @@ -67,11 +70,12 @@ def test_distribute_shards_produces_valid_per_split_format(self): def test_distribute_shards_respects_shard_counts(self): """Test that valid_shards and test_shards control split sizes.""" - data_paths = ["1.0", "/path/to/shard"] + # Path must contain /datasets/{name}/ for dataset name extraction + data_paths = ["1.0", "/path/to/datasets/my_dataset/plan123/shard"] result = _distribute_shards_to_splits( data_paths=data_paths, - num_shards=10, + dataset_shards={"my_dataset": 10}, valid_shards=2, test_shards=3, ) @@ -83,6 +87,35 @@ def test_distribute_shards_respects_shard_counts(self): # train should have remaining 5 shards (10 elements) assert len(result["train"]) == 10 + def test_distribute_shards_multi_dataset(self): + """Test shard distribution with multiple datasets of different sizes.""" + # Two datasets with different shard counts + data_paths = [ + "0.7", "/path/to/datasets/general/plan123/shard", + "0.3", "/path/to/datasets/domain/plan456/shard", + ] + + result = _distribute_shards_to_splits( + data_paths=data_paths, + dataset_shards={"general": 10, "domain": 3}, # 13 total shards + valid_shards=2, + test_shards=2, + ) + + # Total shards: 13 (10 + 3) + # valid: 2 shards = 4 elements + assert len(result["valid"]) == 4 + # test: 2 shards = 4 elements + assert len(result["test"]) == 4 + # train: remaining 9 shards = 18 elements + assert len(result["train"]) == 18 + + # Verify shard indices are correct (should have _XXXXXX suffix) + for split_data in result.values(): + for i in range(1, len(split_data), 2): + path = split_data[i] + assert "_" in path, f"Path {path} should have shard suffix" + def test_blend_json_format_matches_train_expectation(self): """Test blend.json format is compatible with train.py config.""" with tempfile.TemporaryDirectory() as tmpdir: @@ -172,11 +205,11 @@ def test_pretrain_blends_artifact_with_blend_path(self): def test_shard_path_naming_convention(self): """Test that shard paths follow the expected naming convention.""" - data_paths = ["1.0", "/output/shard"] + data_paths = ["1.0", "/output/datasets/my_ds/plan/shard"] result = _distribute_shards_to_splits( data_paths=data_paths, - num_shards=10, + dataset_shards={"my_ds": 10}, valid_shards=1, test_shards=1, ) @@ -192,11 +225,11 @@ def test_shard_path_naming_convention(self): def test_distribute_shards_deterministic_with_seed(self): """Test that shard distribution is deterministic with same seed.""" - data_paths = ["1.0", "/path/to/shard"] + data_paths = ["1.0", "/path/to/datasets/ds/plan/shard"] result1 = _distribute_shards_to_splits( data_paths=data_paths, - num_shards=10, + dataset_shards={"ds": 10}, valid_shards=2, test_shards=2, seed=42, @@ -204,7 +237,7 @@ def test_distribute_shards_deterministic_with_seed(self): result2 = _distribute_shards_to_splits( data_paths=data_paths, - num_shards=10, + dataset_shards={"ds": 10}, valid_shards=2, test_shards=2, seed=42, @@ -214,11 +247,11 @@ def test_distribute_shards_deterministic_with_seed(self): def test_distribute_shards_different_with_different_seed(self): """Test that different seeds produce different distributions.""" - data_paths = ["1.0", "/path/to/shard"] + data_paths = ["1.0", "/path/to/datasets/ds/plan/shard"] result1 = _distribute_shards_to_splits( data_paths=data_paths, - num_shards=10, + dataset_shards={"ds": 10}, valid_shards=2, test_shards=2, seed=42, @@ -226,7 +259,7 @@ def test_distribute_shards_different_with_different_seed(self): result2 = _distribute_shards_to_splits( data_paths=data_paths, - num_shards=10, + dataset_shards={"ds": 10}, valid_shards=2, test_shards=2, seed=123,