diff --git a/tests/data/test_dynamic_batching_dataset.py b/tests/data/test_dynamic_batching_dataset.py index 3bcd17b07..907971cca 100644 --- a/tests/data/test_dynamic_batching_dataset.py +++ b/tests/data/test_dynamic_batching_dataset.py @@ -393,7 +393,7 @@ def build_command(shuffle=True, save_by_idx=True): "--train.rmpad=false", "--train.rmpad_with_pos_ids=true", "--train.dyn_bsz=true", - "--dyn_bsz_in_dataloader=false", + "--train.dyn_bsz_runtime=worker", f"--save_by_idx={str(save_by_idx).lower()}", "--train.seed=42", ] @@ -453,7 +453,6 @@ def _run_distributed_test(): _parser = argparse.ArgumentParser() _parser.add_argument("--shuffle", type=lambda x: x.lower() == "true", default=True) _parser.add_argument("--save_by_idx", type=lambda x: x.lower() == "true", default=True) - _parser.add_argument("--dyn_bsz_in_dataloader", type=lambda x: x.lower() == "true", default=True) test_args, remaining_argv = _parser.parse_known_args() sys.argv = [sys.argv[0]] + remaining_argv @@ -505,7 +504,7 @@ def _run_distributed_test(): train_steps=train_steps, rmpad=args.train.rmpad, dyn_bsz=args.train.dyn_bsz, - dyn_bsz_in_dataloader=test_args.dyn_bsz_in_dataloader, + dyn_bsz_runtime=args.train.dyn_bsz_runtime, bsz_warmup_ratio=args.train.bsz_warmup_ratio, rmpad_with_pos_ids=args.train.rmpad_with_pos_ids, dyn_bsz_buffer_size=READY_FOR_MICRO_BATCH_THRESHOLD, @@ -584,6 +583,7 @@ def _run_distributed_test(): "extra_state": { "curr_epoch": epoch, "curr_step": local_step, + "global_step": global_step, "train_dataloader": dataloader.state_dict(), "environ_meter": environ_meter.state_dict(), }, @@ -603,9 +603,10 @@ def _run_distributed_test(): dataloader.load_state_dict(state["extra_state"]["train_dataloader"]) environ_meter.load_state_dict(state["extra_state"]["environ_meter"]) start_epoch = state["extra_state"]["curr_epoch"] - assert start_epoch == 1 + assert start_epoch == save_epoch start_step = state["extra_state"]["curr_step"] + 1 - assert start_step == 1 + assert start_step == save_step + 1 + global_step = state["extra_state"]["global_step"] dl_state = state["extra_state"]["train_dataloader"] logger.error(f"[rank{rank}] Loaded dataloader state: {dl_state}") diff --git a/tests/data/test_multisource_datasets.py b/tests/data/test_interleave_datasets.py similarity index 98% rename from tests/data/test_multisource_datasets.py rename to tests/data/test_interleave_datasets.py index 6fe1ea166..da083f2a6 100644 --- a/tests/data/test_multisource_datasets.py +++ b/tests/data/test_interleave_datasets.py @@ -256,7 +256,7 @@ def build_command(dataset_type, dataloader_type): "--nnodes=1", "--nproc_per_node=8", f"--master_port={port}", - "tests/data/test_multisource_datasets.py", + "tests/data/test_interleave_datasets.py", "--data.enable_multisource=True", "--model.config_path=test", "--data.train_path=None", @@ -276,7 +276,7 @@ def build_command(dataset_type, dataloader_type): return command -def test_multisource_data_rmpad_with_pos_ids(): +def test_interleave_rmpad_with_pos_ids(): command = build_command(dataset_type="mapping", dataloader_type="rmpad_with_pos_ids") result = subprocess.run(command, check=True) assert result.returncode == 0 @@ -286,7 +286,7 @@ def test_multisource_data_rmpad_with_pos_ids(): assert result.returncode == 0 -def test_multisource_data_padding(): +def test_interleave_padding(): command = build_command(dataset_type="mapping", dataloader_type="padding") result = subprocess.run(command, check=True) assert result.returncode == 0 diff --git a/tests/data/test_multisource_dataset.py b/tests/data/test_multisource_dataset.py new file mode 100644 index 000000000..256ffcf05 --- /dev/null +++ b/tests/data/test_multisource_dataset.py @@ -0,0 +1,759 @@ +import os +import random +import subprocess +import sys +import time +from typing import Literal, cast +from unittest.mock import patch + +import numpy as np +import pytest +import torch +import torch.distributed as dist +import yaml +from torch.utils.data import IterableDataset +from transformers import PretrainedConfig +from utils import DummyIterableDataset, DummyMappingDataset, FakeModel, compare_global_batch + +from veomni.arguments import VeOmniArguments, parse_args +from veomni.checkpoint import build_checkpointer +from veomni.data import build_dataloader +from veomni.data.multisource_dataset import MultiSourceDataset +from veomni.distributed.parallel_state import init_parallel_state +from veomni.utils import helper +from veomni.utils.device import get_device_type, get_dist_comm_backend, get_torch_device +from veomni.utils.helper import get_cache_dir + + +logger = helper.create_logger(__name__) + + +# Patch empty_cache to avoid AttributeError on CPU +def _mock_empty_cache(): + """Mock empty_cache that does nothing on CPU.""" + pass + + +def _torch_shm_manager_executable() -> bool: + torch_dir = os.path.dirname(torch.__file__) + shm_manager = os.path.join(torch_dir, "bin", "torch_shm_manager") + return os.path.exists(shm_manager) and os.access(shm_manager, os.X_OK) + + +class MockIterableDataset(IterableDataset): + def __init__(self, data, name="mock"): + self.data = list(data) + self.name = name + self._state = {"consumed": 0} + + def __iter__(self): + for item in self.data: + self._state["consumed"] += 1 + yield item + + def state_dict(self): + return dict(self._state) + + def load_state_dict(self, state): + self._state = dict(state) + + +def run_multisource_dataset_test(): + args = parse_args(VeOmniArguments) + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + device_type = get_device_type() + if device_type != "cpu": + get_torch_device().set_device(f"{device_type}:{args.train.local_rank}") + backend = "gloo" if device_type == "cpu" else get_dist_comm_backend() + dist.init_process_group(backend=backend, world_size=world_size, rank=rank) + + init_parallel_state( + dp_size=args.train.data_parallel_size, + dp_replicate_size=args.train.data_parallel_replicate_size, + dp_shard_size=args.train.data_parallel_shard_size, + tp_size=args.train.tensor_parallel_size, + ep_size=args.train.expert_parallel_size, + pp_size=args.train.pipeline_parallel_size, + cp_size=args.train.context_parallel_size, + ulysses_size=args.train.ulysses_parallel_size, + dp_mode=args.train.data_parallel_mode, + ) + + Checkpointer = build_checkpointer(dist_backend=args.train.data_parallel_mode, ckpt_manager=args.train.ckpt_manager) + + multisource_names = ["dataset_a", "dataset_b"] + multisource_weights = [0.5, 0.5] + + # Build DummyIterableDataset instances directly (bypasses HuggingFace shuffle bug) + iterable_datasets = [ + DummyIterableDataset(DummyMappingDataset(size=100), shuffle=True, seed=args.train.seed + i) + for i in range(len(multisource_names)) + ] + + args.data.enable_multisource = True + train_dataset = MultiSourceDataset( + datasets=iterable_datasets, + weights=multisource_weights, + seed=args.train.seed, + level="token", + source_names=multisource_names, + sharded=True, + stopping_strategy="all_exhausted", + ) + + # YAML config for EnvironMeter's MultiSourceInfoTracker + multisource_config = dict( + sources=multisource_names, + names=multisource_names, + schedule=[dict(schedule_type="const", weights=multisource_weights)], + ) + tmp_yaml_path = os.path.join(get_cache_dir("./tmp_simple_ms.yaml"), "tmp_simple_ms.yaml") + if dist.get_rank() == 0: + with open(tmp_yaml_path, "w") as f: + yaml.safe_dump(multisource_config, f) + dist.barrier() + + state = cast(MultiSourceDataset, train_dataset).state_dict() + assert state["version"] == 0 + assert state["topology"]["stopping_strategy"] == "all_exhausted" + assert state["topology"]["level"] == "token" + assert state["topology"]["source_names"] == multisource_names + source_ids = state["topology"]["source_ids"] + assert len(source_ids) == len(multisource_names) + assert len(set(source_ids)) == len(source_ids) + assert sorted(state["runtime"]["avg_len_sum"].keys()) == sorted(source_ids) + assert sorted(state["runtime"]["avg_len_count"].keys()) == sorted(source_ids) + assert sorted(state["runtime"]["dataset_states"].keys()) == sorted(source_ids) + + dataset_length = None + args.train.compute_train_steps(args.data.max_seq_len, args.data.train_size, dataset_length) + + global_batch_size = cast(int, args.train.global_batch_size) + dataloader = build_dataloader( + dataloader_type="native", + dataset=train_dataset, + micro_batch_size=args.train.micro_batch_size, + global_batch_size=global_batch_size, + dataloader_batch_size=args.train.dataloader_batch_size, + max_seq_len=args.data.max_seq_len, + train_steps=args.train.train_steps, + rmpad=args.train.rmpad, + bsz_warmup_ratio=0.0, + rmpad_with_pos_ids=args.train.rmpad_with_pos_ids, + num_workers=1, + drop_last=args.data.drop_last, + pin_memory=args.data.pin_memory, + prefetch_factor=args.data.prefetch_factor, + dyn_bsz=args.train.dyn_bsz, + dyn_bsz_buffer_size=1, + ) + + config = PretrainedConfig() + environ_meter = helper.EnvironMeter( + config=config, + global_batch_size=global_batch_size, + rmpad=args.train.rmpad, + rmpad_with_pos_ids=args.train.rmpad_with_pos_ids, + empty_cache_steps=args.train.empty_cache_steps, + enable_multisource=args.data.enable_multisource, + dataloader=dataloader, + data_path=tmp_yaml_path, + ) + + gt_global_batch_list = [] + epoch_num = 3 + train_steps = args.train.train_steps + start_epoch, start_step, global_step = 0, 0, 0 + save_epoch, save_step = 1, args.train.train_steps - 2 + + fake_model = FakeModel().to(get_device_type()) + for epoch in range(start_epoch, epoch_num): + dataloader.set_epoch(epoch) + data_iterator = iter(dataloader) + start_time = time.time() + for local_step in range(start_step, args.train.train_steps): + global_step += 1 + try: + micro_batches = next(data_iterator) + except StopIteration: + logger.info(f"epoch:{epoch} Dataloader finished with drop_last {args.data.drop_last}") + break + + if global_step == 1: + helper.print_example(example=micro_batches[0], rank=args.train.local_rank) + for micro_batch in micro_batches: + assert "ds_idx" in micro_batch + assert "source_name" in micro_batch + source_name = micro_batch["source_name"] + if isinstance(source_name, list): + assert all(name in multisource_names for name in source_name) + else: + assert source_name in multisource_names + ds_idx = micro_batch["ds_idx"] + if isinstance(ds_idx, torch.Tensor): + assert torch.all((ds_idx >= 0) & (ds_idx < len(multisource_names))) + elif isinstance(ds_idx, list): + assert all(0 <= int(idx) < len(multisource_names) for idx in ds_idx) + else: + assert 0 <= int(ds_idx) < len(multisource_names) + assert micro_batch["attention_mask"].shape[-1] == micro_batch["input_ids"].shape[-1] + assert micro_batch["labels"].shape[-1] == micro_batch["input_ids"].shape[-1] + assert torch.all(micro_batch["attention_mask"] == 1) + assert torch.all(micro_batch["labels"] == micro_batch["input_ids"]) + + if epoch > save_epoch or (epoch == save_epoch and local_step > save_step): + gt_global_batch_list.append(micro_batches) + + for micro_step, micro_batch in enumerate(micro_batches): + if global_step == 1: + logger.info(f"[rank{rank}] micro step: {micro_step}, {type(micro_batch)}") + + environ_meter.add(micro_batch) + + delta_time = time.time() - start_time + try: + metrics_resume = environ_meter.step(delta_time, global_step=global_step) + except AttributeError as e: + # Skip metrics on CPU + logger.warning(f"[rank{rank}] Skipping metrics: {e}") + metrics_resume = {} + if epoch == save_epoch and local_step == save_step: + state = { + "model": fake_model, + "extra_state": { + "curr_epoch": epoch, + "curr_step": local_step, + "global_step": global_step, + "train_dataloader": dataloader.state_dict(), + "environ_meter": environ_meter.state_dict(), + }, + } + save_checkpoint_path = os.path.join(args.train.save_checkpoint_path, f"global_step_{global_step}") + Checkpointer.save(args.train.save_checkpoint_path, state, global_steps=global_step) + dist.barrier() + state = {"model": fake_model, "extra_state": {}} + Checkpointer.load(save_checkpoint_path, state) + dataloader.load_state_dict(state["extra_state"]["train_dataloader"]) + environ_meter.load_state_dict(state["extra_state"]["environ_meter"]) + start_epoch = state["extra_state"]["curr_epoch"] + assert start_epoch == save_epoch + start_step = state["extra_state"]["curr_step"] + 1 # resume from the next step + assert start_step == save_step + 1 + global_step = state["extra_state"]["global_step"] + + pred_global_batch_list = [] + + for epoch in range(start_epoch, epoch_num): + dataloader.set_epoch(epoch) + data_iterator = iter(dataloader) + for local_step in range(start_step, train_steps): + global_step += 1 + global_batch = next(data_iterator) + + if epoch > save_epoch or (epoch == save_epoch and local_step > save_step): + pred_global_batch_list.append(global_batch) + + start_time = time.time() + for micro_batch in global_batch: + environ_meter.add(micro_batch) + delta_time = time.time() - start_time + try: + metrics = environ_meter.step(delta_time, global_step=global_step) + except AttributeError as e: + # Skip metrics on CPU (torch.cpu has no attribute 'get_device_name') + logger.warning(f"[rank{rank}] Skipping metrics: {e}") + metrics = {} + start_step = 0 + + compare_global_batch(gt_global_batch_list, pred_global_batch_list) + if ( + metrics is not None + and metrics_resume is not None + and "consume_tokens(M)" in metrics + and "consume_tokens(M)" in metrics_resume + ): + assert metrics["consume_tokens(M)"] == metrics_resume["consume_tokens(M)"] + + logger.info_rank0( + f"dataset_a: {metrics.get('multi_source/consumed_chunk_num/dataset_a', 0)} dataset_b: {metrics.get('multi_source/consumed_chunk_num/dataset_b', 0)}" + ) + + if dist.is_initialized(): + dist.barrier() + + if not dist.is_initialized() or dist.get_rank() == 0: + os.remove(tmp_yaml_path) + + if world_size > 1: + dist.destroy_process_group() + + +def _make_simple_dataset( + datasets, + weights, + level="sample", + stopping_strategy: Literal["first_exhausted", "all_exhausted", "never_exhausted"] = "first_exhausted", + source_names=None, + source_ids=None, +): + return MultiSourceDataset( + datasets=datasets, + weights=weights, + seed=123, + level=level, + sample_token_len_fn=None, + source_names=source_names, + source_ids=source_ids, + sharded=False, + stopping_strategy=stopping_strategy, + ) + + +def test_state_dict_structure(): + ds1 = MockIterableDataset([{"input_ids": [1, 2]}], name="a") + ds2 = MockIterableDataset([{"input_ids": [3, 4, 5]}], name="b") + dataset = _make_simple_dataset( + datasets=[ds1, ds2], + weights=[0.5, 0.5], + level="token", + stopping_strategy="all_exhausted", + source_names=["a", "b"], + source_ids=["id_a", "id_b"], + ) + state = dataset.state_dict() + assert state["version"] == 0 + assert state["topology"]["source_ids"] == ["id_a", "id_b"] + assert sorted(state["runtime"]["avg_len_sum"].keys()) == ["id_a", "id_b"] + assert sorted(state["runtime"]["avg_len_count"].keys()) == ["id_a", "id_b"] + assert sorted(state["runtime"]["dataset_states"].keys()) == ["id_a", "id_b"] + assert sorted(state["runtime"]["exhausted"].keys()) == ["id_a", "id_b"] + + +def test_exhausted_state_save_restore_and_elastic(): + """Test exhausted state save/restore with elastic source add/remove scenarios.""" + # Scenario 1: Basic save and restore + ds1 = MockIterableDataset([{"input_ids": [1]}], name="a") + ds2 = MockIterableDataset([{"input_ids": [2]}, {"input_ids": [3]}], name="b") + dataset = _make_simple_dataset( + datasets=[ds1, ds2], + weights=[0.5, 0.5], + stopping_strategy="all_exhausted", + source_ids=["id_a", "id_b"], + ) + dataset._exhausted = [True, False] + state = dataset.state_dict() + assert state["runtime"]["exhausted"] == {"id_a": True, "id_b": False} + + # Restore to same structure + ds1_new = MockIterableDataset([{"input_ids": [1]}], name="a") + ds2_new = MockIterableDataset([{"input_ids": [2]}, {"input_ids": [3]}], name="b") + dataset_new = _make_simple_dataset( + datasets=[ds1_new, ds2_new], + weights=[0.5, 0.5], + stopping_strategy="all_exhausted", + source_ids=["id_a", "id_b"], + ) + dataset_new.load_state_dict(state) + assert dataset_new._exhausted == [True, False] + + # Scenario 2: Add a new source - new source should default to False + ds1_new2 = MockIterableDataset([{"input_ids": [1]}], name="a") + ds2_new2 = MockIterableDataset([{"input_ids": [2]}, {"input_ids": [3]}], name="b") + ds3_new = MockIterableDataset([{"input_ids": [4]}], name="c") + dataset_with_new = _make_simple_dataset( + datasets=[ds1_new2, ds2_new2, ds3_new], + weights=[0.3, 0.3, 0.4], + stopping_strategy="all_exhausted", + source_ids=["id_a", "id_b", "id_c"], + ) + dataset_with_new.load_state_dict(state, reconcile_policy="allow_add") + assert dataset_with_new._exhausted == [True, False, False] + + # Scenario 3: Remove a source - only remaining sources' states preserved + ds1_new3 = MockIterableDataset([{"input_ids": [1]}], name="a") + dataset_removed = _make_simple_dataset( + datasets=[ds1_new3], + weights=[1.0], + stopping_strategy="all_exhausted", + source_ids=["id_a"], + ) + dataset_removed.load_state_dict(state, reconcile_policy="allow_add_remove") + assert dataset_removed._exhausted == [True] + + +def test_exhausted_state_backward_compatible(): + """Test that loading old checkpoint without exhausted field defaults to all False.""" + ds1 = MockIterableDataset([{"input_ids": [1]}], name="a") + ds2 = MockIterableDataset([{"input_ids": [2]}], name="b") + dataset = _make_simple_dataset( + datasets=[ds1, ds2], + weights=[0.5, 0.5], + stopping_strategy="all_exhausted", + source_ids=["id_a", "id_b"], + ) + + # Simulate old checkpoint without exhausted field + old_state = { + "topology": {"source_ids": ["id_a", "id_b"]}, + "runtime": { + "random_state": np.random.RandomState(42).get_state(), + "avg_len_sum": {"id_a": 1.0, "id_b": 2.0}, + "avg_len_count": {"id_a": 1, "id_b": 2}, + "dataset_states": {"id_a": {"consumed": 1}, "id_b": {"consumed": 2}}, + }, + } + + dataset.load_state_dict(old_state) + assert dataset._exhausted == [False, False] + + +def test_elastic_load_add_source(): + ds1 = MockIterableDataset([{"input_ids": [1]}], name="a") + ds2 = MockIterableDataset([{"input_ids": [2]}], name="b") + dataset = _make_simple_dataset( + datasets=[ds1, ds2], + weights=[0.5, 0.5], + source_ids=["id_a", "id_b"], + ) + next(iter(dataset)) + state = dataset.state_dict() + ds3 = MockIterableDataset([{"input_ids": [3]}], name="c") + dataset_new = _make_simple_dataset( + datasets=[ds1, ds2, ds3], + weights=[0.3, 0.3, 0.4], + source_ids=["id_a", "id_b", "id_c"], + ) + dataset_new.load_state_dict(state, reconcile_policy="allow_add") + assert ds1.state_dict()["consumed"] >= 1 + + +def test_elastic_load_remove_source(): + ds1 = MockIterableDataset([{"input_ids": [1]}], name="a") + ds2 = MockIterableDataset([{"input_ids": [2]}], name="b") + dataset = _make_simple_dataset( + datasets=[ds1, ds2], + weights=[0.5, 0.5], + source_ids=["id_a", "id_b"], + ) + next(iter(dataset)) + state = dataset.state_dict() + dataset_new = _make_simple_dataset( + datasets=[ds1], + weights=[1.0], + source_ids=["id_a"], + ) + dataset_new.load_state_dict(state, reconcile_policy="allow_add_remove") + assert ds1.state_dict()["consumed"] >= 1 + + +def test_elastic_load_strict_policy(): + ds1 = MockIterableDataset([{"input_ids": [1]}], name="a") + ds2 = MockIterableDataset([{"input_ids": [2]}], name="b") + dataset = _make_simple_dataset( + datasets=[ds1, ds2], + weights=[0.5, 0.5], + source_ids=["id_a", "id_b"], + ) + state = dataset.state_dict() + dataset_new = _make_simple_dataset( + datasets=[ds1], + weights=[1.0], + source_ids=["id_a"], + ) + with pytest.raises(ValueError): + dataset_new.load_state_dict(state, reconcile_policy="strict") + + +def test_stopping_strategy_first_exhausted(): + ds1 = MockIterableDataset([{"input_ids": [1]}], name="a") + ds2 = MockIterableDataset([{"input_ids": [2]}], name="b") + dataset = _make_simple_dataset( + datasets=[ds1, ds2], + weights=[0.5, 0.5], + source_ids=["id_a", "id_b"], + stopping_strategy="first_exhausted", + ) + dataset._iters = [iter(ds1), iter(ds2)] + dataset._exhausted = [False, False] + first = dataset._next_sample(0) + assert first["input_ids"] == [1] + with pytest.raises(StopIteration): + dataset._next_sample(0) + + +def test_stopping_strategy_all_exhausted(): + ds1 = MockIterableDataset([{"input_ids": [1]}], name="a") + ds2 = MockIterableDataset([{"input_ids": [2]}], name="b") + dataset = _make_simple_dataset( + datasets=[ds1, ds2], + weights=[0.5, 0.5], + source_ids=["id_a", "id_b"], + stopping_strategy="all_exhausted", + ) + dataset._iters = [iter(ds1), iter(ds2)] + dataset._exhausted = [False, False] + first = dataset._next_sample(0) + second = dataset._next_sample(0) + assert first["input_ids"] == [1] + assert second["input_ids"] == [1] + + +def test_stopping_strategy_never_exhausted(): + ds1 = MockIterableDataset([{"input_ids": [1]}], name="a") + ds2 = MockIterableDataset([{"input_ids": [2]}], name="b") + dataset = _make_simple_dataset( + datasets=[ds1, ds2], + weights=[0.5, 0.5], + source_ids=["id_a", "id_b"], + stopping_strategy="never_exhausted", + ) + dataset._iters = [iter(ds1), iter(ds2)] + dataset._exhausted = [False, False] + first = dataset._next_sample(0) + second = dataset._next_sample(0) + assert first["input_ids"] == [1] + assert second["input_ids"] == [1] + + +def test_determinism_with_seed(): + data_a = [{"input_ids": [i]} for i in range(10)] + data_b = [{"input_ids": [i]} for i in range(10, 20)] + ds1_a = MockIterableDataset(data_a, name="a") + ds2_a = MockIterableDataset(data_b, name="b") + ds1_b = MockIterableDataset(data_a, name="a") + ds2_b = MockIterableDataset(data_b, name="b") + dataset1 = _make_simple_dataset( + datasets=[ds1_a, ds2_a], + weights=[0.5, 0.5], + source_ids=["id_a", "id_b"], + stopping_strategy="all_exhausted", + ) + dataset2 = _make_simple_dataset( + datasets=[ds1_b, ds2_b], + weights=[0.5, 0.5], + source_ids=["id_a", "id_b"], + stopping_strategy="all_exhausted", + ) + dataset1.set_epoch(0) + dataset2.set_epoch(0) + it1 = iter(dataset1) + it2 = iter(dataset2) + for _ in range(10): + sample1 = cast(dict, next(it1)) + sample2 = cast(dict, next(it2)) + assert sample1["ds_idx"] == sample2["ds_idx"] + + +def test_level_token_weighting(): + ds1 = MockIterableDataset([{"input_ids": [1, 2, 3, 4]}], name="a") + ds2 = MockIterableDataset([{"input_ids": [5]}], name="b") + dataset = _make_simple_dataset( + datasets=[ds1, ds2], + weights=[1.0, 1.0], + level="token", + source_ids=["id_a", "id_b"], + ) + dataset._avg_len_sum = [4.0, 1.0] + dataset._avg_len_count = [1, 1] + weights = dataset._runtime_weights() + assert weights[0] == 0.2 + assert weights[1] == 0.8 + + +@pytest.mark.parametrize( + ("kwargs", "match"), + [ + ( + { + "datasets": [MockIterableDataset([{"input_ids": [1]}]), MockIterableDataset([{"input_ids": [2]}])], + "weights": [1.0], + }, + "weights length must match datasets length", + ), + ( + { + "datasets": [MockIterableDataset([{"input_ids": [1]}]), MockIterableDataset([{"input_ids": [2]}])], + "weights": [0.5, 0.5], + "source_names": ["only_one"], + }, + "source_names length must match datasets length", + ), + ( + { + "datasets": [MockIterableDataset([{"input_ids": [1]}]), MockIterableDataset([{"input_ids": [2]}])], + "weights": [0.5, 0.5], + "source_ids": ["id_a"], + }, + "source_ids length must match datasets length", + ), + ( + { + "datasets": [MockIterableDataset([{"input_ids": [1]}]), MockIterableDataset([{"input_ids": [2]}])], + "weights": [0.5, 0.5], + "source_ids": ["same_id", "same_id"], + }, + "source_ids must be unique", + ), + ( + {"datasets": [MockIterableDataset([{"input_ids": [1]}])], "weights": [1.0], "level": "invalid"}, + "level must be 'sample' or 'token'", + ), + ( + { + "datasets": [MockIterableDataset([{"input_ids": [1]}])], + "weights": [1.0], + "stopping_strategy": cast(Literal["first_exhausted", "all_exhausted", "never_exhausted"], "invalid"), + }, + "stopping_strategy must be", + ), + ], +) +def test_init_validation(kwargs, match): + with pytest.raises(ValueError, match=match): + MultiSourceDataset(**kwargs, seed=42) + + +@pytest.mark.parametrize( + ("sample", "expected"), + [ + ({"attention_mask": torch.tensor([1, 1, 0])}, 2.0), + ({"attention_mask": [1, 1, 1, 0]}, 3.0), + ({"input_ids": torch.tensor([1, 2, 3])}, 3.0), + ({"input_ids": [1, 2, 3, 4]}, 4.0), + ([{"input_ids": [1, 2]}, {"input_ids": [3, 4, 5]}], 5.0), + ({"other_field": "value"}, 1.0), + (None, 0.0), + ], +) +def test_default_sample_token_len(sample, expected): + ds1 = MockIterableDataset([{"input_ids": [1, 2, 3]}], name="a") + dataset = _make_simple_dataset(datasets=[ds1], weights=[1.0], source_ids=["id_a"]) + assert dataset._default_sample_token_len(sample) == expected + + +class TestLoadStateDictBoundary: + def test_missing_topology(self): + ds1 = MockIterableDataset([{"input_ids": [1]}], name="a") + dataset = _make_simple_dataset(datasets=[ds1], weights=[1.0], source_ids=["id_a"]) + with pytest.raises(ValueError, match="state_dict missing required keys"): + dataset.load_state_dict({"runtime": {}}) + + def test_missing_runtime(self): + ds1 = MockIterableDataset([{"input_ids": [1]}], name="a") + dataset = _make_simple_dataset(datasets=[ds1], weights=[1.0], source_ids=["id_a"]) + with pytest.raises(ValueError, match="state_dict missing required keys"): + dataset.load_state_dict({"topology": {}}) + + def test_missing_source_ids_in_topology(self): + ds1 = MockIterableDataset([{"input_ids": [1]}], name="a") + dataset = _make_simple_dataset(datasets=[ds1], weights=[1.0], source_ids=["id_a"]) + state = { + "topology": {"weights": [1.0], "level": "sample"}, + "runtime": { + "random_state": np.random.RandomState(42).get_state(), + "avg_len_sum": {}, + "avg_len_count": {}, + "dataset_states": {}, + }, + } + with pytest.raises(ValueError, match="state_dict missing topology.source_ids"): + dataset.load_state_dict(state) + + def test_avg_len_not_dict(self): + ds1 = MockIterableDataset([{"input_ids": [1]}], name="a") + dataset = _make_simple_dataset(datasets=[ds1], weights=[1.0], source_ids=["id_a"]) + state = { + "topology": {"source_ids": ["id_a"]}, + "runtime": { + "random_state": np.random.RandomState(42).get_state(), + "avg_len_sum": [1.0], + "avg_len_count": [1], + "dataset_states": {}, + }, + } + with pytest.raises(ValueError, match="must be dicts keyed by source_id"): + dataset.load_state_dict(state) + + def test_dataset_states_not_dict(self): + ds1 = MockIterableDataset([{"input_ids": [1]}], name="a") + dataset = _make_simple_dataset(datasets=[ds1], weights=[1.0], source_ids=["id_a"]) + state = { + "topology": {"source_ids": ["id_a"]}, + "runtime": { + "random_state": np.random.RandomState(42).get_state(), + "avg_len_sum": {"id_a": 1.0}, + "avg_len_count": {"id_a": 1}, + "dataset_states": [], + }, + } + with pytest.raises(ValueError, match="must be a dict keyed by source_id"): + dataset.load_state_dict(state) + + def test_warn_only_policy(self): + ds1 = MockIterableDataset([{"input_ids": [1]}], name="a") + ds2 = MockIterableDataset([{"input_ids": [2]}], name="b") + dataset = _make_simple_dataset( + datasets=[ds1, ds2], + weights=[0.5, 0.5], + source_ids=["id_a", "id_b"], + ) + dataset._avg_len_sum = [2.0, 5.0] + dataset._avg_len_count = [1, 2] + dataset._global_sample_idx = 7 + dataset._random_state = np.random.RandomState(999) + state = dataset.state_dict() + dataset_new = _make_simple_dataset( + datasets=[ds1], + weights=[1.0], + source_ids=["id_a"], + ) + dataset_new.load_state_dict(state, reconcile_policy="warn_only") + assert dataset_new._avg_len_sum == [2.0] + assert dataset_new._avg_len_count == [1] + assert dataset_new._global_sample_idx == 7 + rng = np.random.RandomState() + rng.set_state(state["runtime"]["random_state"]) + assert dataset_new._random_state.randint(0, 2**31 - 1) == rng.randint(0, 2**31 - 1) + + +def build_command(): + port = 12345 + random.randint(0, 100) + command = [ + "torchrun", + "--nnodes=1", + "--nproc_per_node=2", + f"--master_port={port}", + "tests/data/test_multisource_dataset.py", + "--data.enable_multisource=True", + "--model.config_path=test", + "--data.train_path=None", + "--data.train_size=1000", + "--data.max_seq_len=8", + "--data.datasets_type=iterable", + "--train.global_batch_size=8", + "--train.micro_batch_size=2", + "--train.data_parallel_mode=ddp", + "--train.ckpt_manager=dcp", + "--train.ulysses_parallel_size=1", + "--train.bsz_warmup_ratio=0", + "--train.output_dir=.tests/cache", + "--train.rmpad=false", + "--train.rmpad_with_pos_ids=true", + "--train.dyn_bsz=True", + "--train.max_steps=6", + ] + return command + + +def test_multisource_dataset_chain(): + if sys.platform == "darwin": + pytest.skip(f"torch_shm_manager not supported on macOS: executable={_torch_shm_manager_executable()}") + command = build_command() + result = subprocess.run(command, check=True) + assert result.returncode == 0 + + +if __name__ == "__main__": + with patch("veomni.utils.device.empty_cache", _mock_empty_cache): + run_multisource_dataset_test() diff --git a/veomni/arguments/arguments_types.py b/veomni/arguments/arguments_types.py index b41ab3a03..abb84b954 100644 --- a/veomni/arguments/arguments_types.py +++ b/veomni/arguments/arguments_types.py @@ -276,6 +276,10 @@ class DataArguments: default=2, metadata={"help": "Number of workers to load data."}, ) + worker_num_threads: Optional[int] = field( + default=None, + metadata={"help": "Per-worker torch thread count for dataloader subprocesses."}, + ) prefetch_factor: int = field( default=2, metadata={"help": "Number of batches loaded in advance by each worker."}, diff --git a/veomni/data/batching_strategy.py b/veomni/data/batching_strategy.py index 40d28ac97..802f4c3fa 100644 --- a/veomni/data/batching_strategy.py +++ b/veomni/data/batching_strategy.py @@ -22,11 +22,7 @@ class DynBszBuffer: """ def __init__(self): - self._buffer = [] - self._buffer_sample_lens = [] - self.del_idxs = [] - self.cur_idx = 0 - self.all_token_cnt = 0 + self.clear() def append(self, item: Dict[str, Any]): """ @@ -80,6 +76,13 @@ def flush(self): ] self.del_idxs = [] + def clear(self): + self._buffer = [] + self._buffer_sample_lens = [] + self.del_idxs = [] + self.cur_idx = 0 + self.all_token_cnt = 0 + def merge(self, buffer_to_merge: "DynBszBuffer"): """ " Merge the buffer with another buffer. @@ -212,3 +215,6 @@ def get_micro_batch(self, step) -> Any: def empty(self) -> bool: return len(self.buffer) == 0 + + def drop_buffer(self): + self.buffer.clear() diff --git a/veomni/data/data_loader.py b/veomni/data/data_loader.py index 67ecea941..918d27712 100644 --- a/veomni/data/data_loader.py +++ b/veomni/data/data_loader.py @@ -13,8 +13,9 @@ # limitations under the License. -from typing import TYPE_CHECKING, Callable, List, Optional, Union +from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Union +import torch from torch.utils.data import IterableDataset from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler @@ -35,7 +36,7 @@ TextSequenceShardCollator, UnpackDataCollator, ) -from .dynamic_batching import DynamicBatchingSizeDataset, DynamicBatchSizeDataLoader +from .dynamic_batching import DynamicBatchingSizeDataset, DynamicBatchSizeDataLoader, get_length_by_attention_mask_fn if TYPE_CHECKING: @@ -60,6 +61,16 @@ def set_epoch(self, epoch: int) -> None: self.dataset.set_epoch(epoch) +def _build_worker_init_fn(worker_num_threads: Optional[int]) -> Optional[Callable[[int], None]]: + if worker_num_threads is None: + return None + + def worker_init_fn(_worker_id: int) -> None: + torch.set_num_threads(worker_num_threads) + + return worker_init_fn + + @DATALOADER_REGISTRY.register("native") def build_native_dataloader( dataset: "Dataset", @@ -73,18 +84,19 @@ def build_native_dataloader( bsz_warmup_ratio: float = 0.02, bsz_warmup_init_mbtoken: int = 200, dyn_bsz: bool = True, - dyn_bsz_in_dataloader: bool = True, # If True, dynamic batching is handled in the main process via DynamicBatchSizeDataLoader (legacy). - # If False, batching is done inside each DataLoader worker via DynamicBatchingSizeDataset, which supports StatefulDataLoader checkpoint/resume. + dyn_bsz_runtime: Literal["main", "worker"] = "main", pad_packed_to_length: Optional[int] = None, dyn_bsz_buffer_size: int = 500, dyn_bsz_margin: int = 0, - dyn_bsz_dataset_save_by_idx: bool = True, # Whether to save buffer by index for checkpointing when dyn_bsz_in_dataloader is False. + dyn_bsz_dataset_save_by_idx: bool = True, # Whether to save buffer by index for checkpointing when dyn_bsz_runtime is "worker". collate_fn: Optional[Union[Callable, List[Callable]]] = None, num_workers: int = 8, + worker_num_threads: int = 1, drop_last: bool = True, pin_memory: bool = True, prefetch_factor: int = 2, seed: int = 0, + multiprocessing_context=None, ) -> "DistributedDataloader": parallel_state = get_parallel_state() token_micro_bsz = micro_batch_size * max_seq_len @@ -127,7 +139,7 @@ def build_native_dataloader( if use_rmpad and dyn_bsz: dyn_bsz_collate_fn = collate_fn - if dyn_bsz_in_dataloader: + if dyn_bsz_runtime == "main": batching_strategy = TextBatchingStrategy( token_micro_bsz=token_micro_bsz - dyn_bsz_margin * max_seq_len, buffer_size=dyn_bsz_buffer_size, @@ -141,7 +153,7 @@ def build_native_dataloader( dataset=dataset, micro_batch_seq_length=token_micro_bsz, ready_for_micro_batch_threshold=dyn_bsz_buffer_size, - get_length_fn=lambda x: int(x["attention_mask"].sum()), + get_length_fn=get_length_by_attention_mask_fn, dynamic_batching_collate_fn=dyn_bsz_collate_fn, save_by_idx=dyn_bsz_dataset_save_by_idx, ) @@ -164,13 +176,15 @@ def build_native_dataloader( batch_size=dataloader_batch_size, sampler=sampler, num_workers=num_workers, + worker_init_fn=_build_worker_init_fn(worker_num_threads), collate_fn=collate_fn, pin_memory=pin_memory, pin_memory_device=get_device_type(), drop_last=drop_last, prefetch_factor=prefetch_factor, + multiprocessing_context=multiprocessing_context, ) - if use_rmpad and dyn_bsz and dyn_bsz_in_dataloader: + if use_rmpad and dyn_bsz and dyn_bsz_runtime == "main": dataloader = DynamicBatchSizeDataLoader( dataloader, batching_strategy=batching_strategy, diff --git a/veomni/data/dynamic_batching.py b/veomni/data/dynamic_batching.py index e01f73d92..6226d69b0 100644 --- a/veomni/data/dynamic_batching.py +++ b/veomni/data/dynamic_batching.py @@ -30,6 +30,10 @@ logger = logging.get_logger(__name__) +def get_length_by_attention_mask_fn(sample): + return int(sample["attention_mask"].sum()) + + class DynamicBatchingSizeDataset(IterableDataset): """Dynamic batching dataset that yields micro batches based on token count. @@ -69,7 +73,7 @@ def __init__( ready_for_micro_batch_threshold: int, dynamic_batching_collate_fn: Optional[Callable] = None, save_by_idx: bool = True, - get_length_fn: Optional[Callable] = len, + get_length_fn: Optional[Callable] = get_length_by_attention_mask_fn, force_generate_long_sequence: bool = False, ) -> None: """Initialize the DynamicBatchingSizeDataset. @@ -101,26 +105,33 @@ def __init__( self.ready_for_micro_batch_threshold = ready_for_micro_batch_threshold self.micro_batch_seq_length = micro_batch_seq_length self.get_length_fn = get_length_fn + self.save_by_idx = save_by_idx - self._data_iter = None if force_generate_long_sequence: raise ValueError("force_generate_long_sequence is not supported yet.") - self.force_generate_long_sequence = force_generate_long_sequence - if self.save_by_idx and not ( - hasattr(self.dataset, "get_item") and hasattr(self.dataset, "output_refetch_idx") - ): - raise ValueError( - "save_by_idx is True, but dataset does not have get_item method or output_refetch_idx attribute to resume samples in buffers based on idx" - ) - self.dataset.output_refetch_idx = self.save_by_idx - self._buffer = [] self._buffer_of_refetch_idx = [] self._buffer_token_count = 0 + self._just_resumed = False # Flag to indicate if the dataset has just been resumed from a checkpoint, used to skip buffer checks on the first iteration after resume. + + @property + def save_by_idx(self) -> bool: + return self._save_by_idx + + @save_by_idx.setter + def save_by_idx(self, value: bool) -> None: + if value and not (hasattr(self.dataset, "get_item") and hasattr(self.dataset, "output_refetch_idx")): + raise ValueError( + "save_by_idx is True, but dataset does not have get_item method or output_refetch_idx attribute to resume samples in buffers based on idx" + ) + self._save_by_idx = value + if hasattr(self.dataset, "output_refetch_idx"): + self.dataset.output_refetch_idx = value + def __iter__(self): """Iterate over the dataset and yield dynamically batched micro batches. @@ -137,6 +148,15 @@ def __iter__(self): """ self._data_iter = iter(self.dataset) + if not self._just_resumed: + # Clear buffer state on new iteration unless we just resumed from a checkpoint, + # in which case we want to keep the buffer contents. + self._buffer = [] + self._buffer_of_refetch_idx = [] + self._buffer_token_count = 0 + else: + self._just_resumed = False + while True: try: if ( @@ -154,19 +174,25 @@ def __iter__(self): if self.save_by_idx: item, refetch_idx = item[0], item[1] - length = self.get_length_fn(item) - if length > self.micro_batch_seq_length and not self.force_generate_long_sequence: - # TODO: record the count of discarded long examples for monitoring - logger.warning( - f"Sample length {length} exceeds micro batch seq length {self.micro_batch_seq_length}, skipping. If you want to force generate a micro batch with this sample, enable force_generate_long_sequence." - ) - continue - - self._buffer.append((item, length)) - if self.save_by_idx: - self._buffer_of_refetch_idx.append(refetch_idx) - - self._buffer_token_count += self._buffer[-1][1] + samples_to_add = [] + if type(item) is list: + samples_to_add = item + else: + samples_to_add = [item] + for item in samples_to_add: + length = self.get_length_fn(item) + if length > self.micro_batch_seq_length and not self.force_generate_long_sequence: + # TODO: record the count of discarded long examples for monitoring + logger.warning( + f"Sample length {length} exceeds micro batch seq length {self.micro_batch_seq_length}, skipping. If you want to force generate a micro batch with this sample, enable force_generate_long_sequence." + ) + continue + + self._buffer.append((item, length)) + if self.save_by_idx: + self._buffer_of_refetch_idx.append(refetch_idx) + + self._buffer_token_count += self._buffer[-1][1] except Exception as e: if isinstance(e, StopIteration): @@ -316,6 +342,8 @@ def load_state_dict(self, state_dict): if "dynamic_batch_upstream_dataset_state" in state_dict: self.dataset.load_state_dict(state_dict["dynamic_batch_upstream_dataset_state"]) + self._just_resumed = True + def set_epoch(self, epoch: int): """Set the epoch for the upstream dataset. @@ -381,6 +409,7 @@ def __iter__(self) -> Iterator: self.step = 0 self._data_iter = iter(self._dataloader) self._batch_data_iter = self.batch_data_generator() + self.batching_strategy.drop_buffer() self._resume = False return self diff --git a/veomni/data/multisource_dataset.py b/veomni/data/multisource_dataset.py new file mode 100644 index 000000000..48e575371 --- /dev/null +++ b/veomni/data/multisource_dataset.py @@ -0,0 +1,451 @@ +import copy +from typing import Any, Callable, List, Literal, Optional, Sequence + +import numpy as np +import torch +from torch.utils.data import IterableDataset, get_worker_info + +from ..distributed.parallel_state import get_parallel_state +from ..utils import logging + + +logger = logging.get_logger(__name__) + + +class MultiSourceDataset(IterableDataset): + """Multi-source dataset with weighted sampling. + + This dataset samples from multiple upstream iterable datasets according to a + (possibly token-adjusted) weight distribution. + + It supports: + - Per-epoch deterministic randomness (seeded by epoch, dp rank, and worker id). + - Optional distributed sharding behavior controlled by ``sharded``. + - Stopping strategies for how to behave when an upstream source is exhausted. + - Optional refetch-index passthrough for checkpointing buffers by index. + """ + + def __init__( + self, + datasets: Sequence[IterableDataset], + weights: Sequence[float], + seed: int = 42, + level: str = "sample", + sample_token_len_fn: Optional[Callable[[Any], float]] = None, + source_names: Optional[Sequence[str]] = None, + source_ids: Optional[Sequence[str]] = None, + sharded: bool = False, + stopping_strategy: Literal["first_exhausted", "all_exhausted", "never_exhausted"] = "first_exhausted", + output_refetch_idx: bool = False, + ) -> None: + """Initialize a MultiSourceDataset. + + Args: + datasets: Upstream iterable datasets (one per source). + weights: Sampling weights aligned with ``datasets``. + seed: Base random seed. + level: Sampling level. ``sample`` uses ``weights`` directly; ``token`` reweights + by the inverse of the running average token length per source. + sample_token_len_fn: Function that returns the token length of a sample. + If not provided, a default heuristic is used. + source_names: Optional display names for each source (for meta fields). + source_ids: Optional stable IDs for each source (used in checkpoint state). + sharded: If False, performs deterministic modulo-based sharding by dp rank on + the produced samples. If True, assumes upstream datasets already handle + sharding/splitting. + stopping_strategy: + - ``first_exhausted``: Stop the whole dataset once any source is exhausted. + - ``all_exhausted``: Restart an exhausted source until all sources are exhausted. + - ``never_exhausted``: Always restart exhausted sources and never terminate. + output_refetch_idx: If True, yields ``(sample, (source_id, refetch_idx))`` so that + downstream components can checkpoint buffers by indices and reconstruct them. + + Raises: + ValueError: If input arguments are invalid. + """ + self._datasets = list(datasets) + self._weights = np.asarray(weights, dtype=np.float64) + self._seed = seed + self._level = level + self._sample_token_len_fn = sample_token_len_fn or self._default_sample_token_len + self._source_names = list(source_names) if source_names is not None else None + self._source_ids = list(source_ids) if source_ids is not None else [] + self._sharded = sharded + self._stopping_strategy = stopping_strategy + self._ds_num = len(self._datasets) + + if not self._source_names: + self._source_names = [] + for i, dataset in enumerate(self._datasets): + if callable(getattr(dataset, "get_name", None)): + self._source_names.append(dataset.get_name()) + else: + self._source_names.append(f"source_{i}") + + if not self._source_ids: + self._source_ids = copy.deepcopy(self._source_names) + + self._id2dataset = dict(zip(self._source_ids, self._datasets)) + self._avg_len_sum = [0.0 for _ in range(self._ds_num)] + self._avg_len_count = [0 for _ in range(self._ds_num)] + self._global_sample_idx = 0 + self._random_state = np.random.RandomState(seed=self._seed) + self._iters: List[Any] = [] + self._epoch = 0 + self._exhausted = [False for _ in range(self._ds_num)] + if self._weights.shape[0] != self._ds_num: + raise ValueError("weights length must match datasets length") + if self._source_names is not None and len(self._source_names) != self._ds_num: + raise ValueError("source_names length must match datasets length") + if len(self._source_ids) != self._ds_num: + raise ValueError("source_ids length must match datasets length") + if len(set(self._source_ids)) != self._ds_num: + raise ValueError("source_ids must be unique") + if self._level not in ("sample", "token"): + raise ValueError("level must be 'sample' or 'token'") + if self._stopping_strategy not in ("first_exhausted", "all_exhausted", "never_exhausted"): + raise ValueError("stopping_strategy must be 'first_exhausted', 'all_exhausted', or 'never_exhausted'") + + parallel_state = get_parallel_state() + self.dp_rank = max(0, int(getattr(parallel_state, "dp_rank", 0))) + self.dp_size = max(1, int(getattr(parallel_state, "dp_size", 1))) + + self.output_refetch_idx = output_refetch_idx + + self._just_resumed = False + + @property + def output_refetch_idx(self) -> bool: + """Whether to yield refetch indices alongside samples.""" + return self._output_refetch_idx + + @output_refetch_idx.setter + def output_refetch_idx(self, value: bool) -> None: + """Enable or disable refetch-index output. + + When enabled, each upstream dataset must provide: + - ``get_item(idx)`` to fetch a sample by index + - ``output_refetch_idx`` attribute to switch yielding ``(sample, idx)`` + + Args: + value: True to enable refetch indices, False to disable. + + Raises: + ValueError: If any upstream dataset cannot support refetch-by-index. + """ + if value: + for source_id, dataset in self._id2dataset.items(): + if not (callable(getattr(dataset, "get_item", None)) and hasattr(dataset, "output_refetch_idx")): + raise ValueError( + f"output_refetch_idx is True, but dataset '{source_id}' does not have " + f"get_item method or output_refetch_idx attribute to resume samples " + f"in buffers based on idx" + ) + self._output_refetch_idx = value + for dataset in self._datasets: + if hasattr(dataset, "output_refetch_idx"): + setattr(dataset, "output_refetch_idx", value) + + def get_item(self, refetch_idx): + """Fetch a single sample by its source ID and index within that source. + + This is used by downstream checkpoint/resume logic that stores buffer + contents as ``(source_id, idx)`` pairs instead of full samples. + + Args: + refetch_idx: A ``(source_id, idx)`` tuple. ``source_id`` identifies the + sub-dataset, and ``idx`` is the 0-based index within that sub-dataset. + + Returns: + The sample returned by the underlying sub-dataset. + + Raises: + AttributeError: If the underlying sub-dataset does not provide an index-based fetch API. + """ + source_id, idx = refetch_idx + dataset = self._id2dataset[source_id] + get_item_fn = getattr(dataset, "get_item", None) + if callable(get_item_fn): + return get_item_fn(idx) + raise AttributeError(f"dataset '{source_id}' does not implement get_item") + + def set_epoch(self, epoch: int) -> None: + """Set the epoch for deterministic sampling. + + Args: + epoch: Current epoch number. + """ + self._epoch = epoch + for dataset in self._datasets: + set_epoch_fn = getattr(dataset, "set_epoch", None) + if callable(set_epoch_fn): + set_epoch_fn(epoch) + + def __iter__(self): + """Iterate and yield samples from multiple sources. + + Yields: + If ``output_refetch_idx`` is False, yields a sample (typically a dict). + If ``output_refetch_idx`` is True, yields ``(sample, (source_id, refetch_idx))``. + """ + worker_info = get_worker_info() + worker_id = worker_info.id if worker_info is not None else 0 + if not self._just_resumed: + seed_seq = np.random.SeedSequence([self._seed, self._epoch, self.dp_rank, worker_id]) + current_seed = int(seed_seq.generate_state(1, dtype=np.uint32)[0]) + self._random_state = np.random.RandomState(current_seed) + self._exhausted = [False for _ in range(self._ds_num)] + self._avg_len_sum = [0.0 for _ in range(self._ds_num)] + self._avg_len_count = [0 for _ in range(self._ds_num)] + self._global_sample_idx = 0 + else: + self._just_resumed = False + + self._iters = [iter(ds) for ds in self._datasets] + while True: + ds_idx = self._random_state.choice(self._ds_num, p=self._runtime_weights()) + try: + sample = self._next_sample(ds_idx) + except StopIteration: + return + if sample is None: + continue + + if self._output_refetch_idx: + sample, refetch_idx = sample[0], sample[1] + + sample = self._attach_meta(sample, ds_idx) + token_len = self._sample_token_len_fn(sample) + if token_len <= 0: + continue + if self._level == "token": + self._avg_len_sum[ds_idx] += token_len + self._avg_len_count[ds_idx] += 1 + self._global_sample_idx += 1 + if not self._sharded and self._global_sample_idx % self.dp_size != self.dp_rank: + continue + if self._output_refetch_idx: + yield sample, (self._source_ids[ds_idx], refetch_idx) + else: + yield sample + + def _runtime_weights(self) -> np.ndarray: + """Compute the per-source sampling probabilities for the current runtime state. + + Returns: + A probability vector of shape ``(num_sources,)`` that sums to 1. + + Raises: + ValueError: If the weight sum is non-positive. + """ + if self._level == "sample": + weights = self._weights + else: + avg_lens = [] + for idx in range(self._ds_num): + if self._avg_len_count[idx] > 0: + avg_lens.append(self._avg_len_sum[idx] / self._avg_len_count[idx]) + else: + avg_lens.append(1.0) + weights = self._weights / np.asarray(avg_lens, dtype=np.float64) + total = float(np.sum(weights)) + if total <= 0: + raise ValueError("sum of weights must be positive") + return weights / total + + def _next_sample(self, ds_idx: int) -> Any: + """Fetch the next sample from a specific sub-dataset index. + + Args: + ds_idx: Index of the sub-dataset to fetch from. + + Returns: + The next sample from the chosen sub-dataset. + + Raises: + StopIteration: When the dataset terminates under the configured stopping strategy. + """ + while True: + try: + return next(self._iters[ds_idx]) + except StopIteration: + if self._stopping_strategy == "first_exhausted": + raise + if self._stopping_strategy == "all_exhausted": + self._exhausted[ds_idx] = True + if all(self._exhausted): + raise + elif self._stopping_strategy == "never_exhausted": + self._exhausted[ds_idx] = True + if all(self._exhausted): + self._exhausted = [False for _ in range(self._ds_num)] + self._iters[ds_idx] = iter(self._datasets[ds_idx]) + + def _attach_meta(self, sample: Any, ds_idx: int) -> Any: + """Attach per-source metadata fields onto a sample. + + Adds: + - ``ds_idx``: the integer source index + - ``source_name``: optional display name if provided + + Args: + sample: A sample or list of samples. + ds_idx: Source index for this sample. + + Returns: + The updated sample (mutated in place when possible). + """ + source_name = self._source_names[ds_idx] if self._source_names is not None else None + if isinstance(sample, list): + for item in sample: + if isinstance(item, dict): + item["ds_idx"] = ds_idx + if source_name is not None: + item["source_name"] = source_name + return sample + if isinstance(sample, dict): + sample["ds_idx"] = ds_idx + if source_name is not None: + sample["source_name"] = source_name + return sample + + def _default_sample_token_len(self, sample: Any) -> float: + """Default heuristic to estimate token length of a sample. + + Args: + sample: A single sample or a list of samples. + + Returns: + Estimated token length as a float. + """ + if sample is None: + return 0 + if isinstance(sample, list): + return float(sum(self._default_sample_token_len(item) for item in sample)) + if not isinstance(sample, dict): + return 1.0 + if "attention_mask" in sample: + attention_mask = sample["attention_mask"] + if isinstance(attention_mask, torch.Tensor): + return float(attention_mask.sum().item()) + if isinstance(attention_mask, list): + return float(sum(attention_mask)) + if "input_ids" in sample: + input_ids = sample["input_ids"] + if isinstance(input_ids, torch.Tensor): + return float(input_ids.numel()) + if isinstance(input_ids, list): + return float(len(input_ids)) + return 1.0 + + def state_dict(self) -> dict: + """Return a checkpointable state dict for this dataset.""" + dataset_states_by_id = {} + for dataset, source_id in zip(self._datasets, self._source_ids): + state_fn = getattr(dataset, "state_dict", None) + getstate_fn = getattr(dataset, "__getstate__", None) + if callable(state_fn): + ds_state = state_fn() + elif callable(getstate_fn): + ds_state = getstate_fn() + else: + ds_state = None + dataset_states_by_id[source_id] = ds_state + avg_len_sum_by_id = {source_id: self._avg_len_sum[idx] for idx, source_id in enumerate(self._source_ids)} + avg_len_count_by_id = {source_id: self._avg_len_count[idx] for idx, source_id in enumerate(self._source_ids)} + # save _exhausted state + exhausted_by_id = {source_id: self._exhausted[idx] for idx, source_id in enumerate(self._source_ids)} + return { + "version": 0, + "topology": { + "source_ids": list(self._source_ids), + "source_names": list(self._source_names) if self._source_names is not None else None, + "weights": self._weights.tolist(), + "level": self._level, + "stopping_strategy": self._stopping_strategy, + "sharded": self._sharded, + }, + "runtime": { + "random_state": self._random_state.get_state(), + "avg_len_sum": avg_len_sum_by_id, + "avg_len_count": avg_len_count_by_id, + "exhausted": exhausted_by_id, + "global_sample_idx": self._global_sample_idx, + "dataset_states": dataset_states_by_id, + }, + } + + def load_state_dict( + self, + state: dict, + reconcile_policy: Literal["strict", "allow_add", "allow_add_remove", "warn_only"] = "allow_add_remove", + ) -> None: + """Restore state from a previous ``state_dict()``. + + Args: + state: State dict previously produced by ``state_dict()``. + reconcile_policy: Policy for handling source-id changes: + - ``strict``: error on any added/removed source. + - ``allow_add``: allow new sources but error on removed ones. + - ``allow_add_remove``: allow both add and remove. + - ``warn_only``: allow changes and log a warning. + + Raises: + ValueError: If required state fields are missing or incompatible. + """ + if "topology" not in state or "runtime" not in state: + raise ValueError("state_dict missing required keys: topology/runtime") + runtime = state["runtime"] + topology = state["topology"] + if "source_ids" not in topology: + raise ValueError("state_dict missing topology.source_ids") + saved_source_ids = topology["source_ids"] + added = [] + removed = [] + if saved_source_ids is not None: + saved_set = set(saved_source_ids) + added = [source_id for source_id in self._source_ids if source_id not in saved_set] + removed = [source_id for source_id in saved_source_ids if source_id not in set(self._source_ids)] + if added or removed: + if reconcile_policy == "strict": + raise ValueError( + f"source_ids mismatch: added={added} removed={removed} with policy={reconcile_policy}" + ) + if reconcile_policy == "allow_add" and removed: + raise ValueError( + f"source_ids removed not allowed: removed={removed} with policy={reconcile_policy}" + ) + if reconcile_policy == "warn_only": + logger.warning( + f"source_ids changed: added={added} removed={removed} with policy={reconcile_policy}" + ) + random_state = runtime["random_state"] + self._random_state.set_state(random_state) + avg_len_sum = runtime["avg_len_sum"] + avg_len_count = runtime["avg_len_count"] + if not isinstance(avg_len_sum, dict) or not isinstance(avg_len_count, dict): + raise ValueError("runtime.avg_len_sum and runtime.avg_len_count must be dicts keyed by source_id") + self._avg_len_sum = [float(avg_len_sum.get(source_id, 0.0)) for source_id in self._source_ids] + self._avg_len_count = [int(avg_len_count.get(source_id, 0)) for source_id in self._source_ids] + self._global_sample_idx = runtime.get("global_sample_idx", 0) + dataset_states = runtime["dataset_states"] + if not isinstance(dataset_states, dict): + raise ValueError("runtime.dataset_states must be a dict keyed by source_id") + dataset_states_by_id = dataset_states + for dataset, source_id in zip(self._datasets, self._source_ids): + ds_state = dataset_states_by_id.get(source_id) + if ds_state is None: + continue + load_state_fn = getattr(dataset, "load_state_dict", None) + if callable(load_state_fn): + load_state_fn(ds_state) + + # Ensure _exhausted is re-initialized for the current source count + # This is important when sources are added/removed during checkpoint resume + if "exhausted" in runtime and isinstance(runtime["exhausted"], dict): + exhausted_dict = runtime["exhausted"] + self._exhausted = [bool(exhausted_dict.get(source_id, False)) for source_id in self._source_ids] + else: + self._exhausted = [False for _ in range(self._ds_num)] + + self._just_resumed = True