Skip to content
11 changes: 6 additions & 5 deletions tests/data/test_dynamic_batching_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def build_command(shuffle=True, save_by_idx=True):
"--train.rmpad=false",
"--train.rmpad_with_pos_ids=true",
"--train.dyn_bsz=true",
"--dyn_bsz_in_dataloader=false",
"--train.dyn_bsz_runtime=worker",
f"--save_by_idx={str(save_by_idx).lower()}",
"--train.seed=42",
]
Expand Down Expand Up @@ -453,7 +453,6 @@ def _run_distributed_test():
_parser = argparse.ArgumentParser()
_parser.add_argument("--shuffle", type=lambda x: x.lower() == "true", default=True)
_parser.add_argument("--save_by_idx", type=lambda x: x.lower() == "true", default=True)
_parser.add_argument("--dyn_bsz_in_dataloader", type=lambda x: x.lower() == "true", default=True)
test_args, remaining_argv = _parser.parse_known_args()
sys.argv = [sys.argv[0]] + remaining_argv

Expand Down Expand Up @@ -505,7 +504,7 @@ def _run_distributed_test():
train_steps=train_steps,
rmpad=args.train.rmpad,
dyn_bsz=args.train.dyn_bsz,
dyn_bsz_in_dataloader=test_args.dyn_bsz_in_dataloader,
dyn_bsz_runtime=args.train.dyn_bsz_runtime,
bsz_warmup_ratio=args.train.bsz_warmup_ratio,
rmpad_with_pos_ids=args.train.rmpad_with_pos_ids,
dyn_bsz_buffer_size=READY_FOR_MICRO_BATCH_THRESHOLD,
Expand Down Expand Up @@ -584,6 +583,7 @@ def _run_distributed_test():
"extra_state": {
"curr_epoch": epoch,
"curr_step": local_step,
"global_step": global_step,
"train_dataloader": dataloader.state_dict(),
"environ_meter": environ_meter.state_dict(),
},
Expand All @@ -603,9 +603,10 @@ def _run_distributed_test():
dataloader.load_state_dict(state["extra_state"]["train_dataloader"])
environ_meter.load_state_dict(state["extra_state"]["environ_meter"])
start_epoch = state["extra_state"]["curr_epoch"]
assert start_epoch == 1
assert start_epoch == save_epoch
start_step = state["extra_state"]["curr_step"] + 1
assert start_step == 1
assert start_step == save_step + 1
global_step = state["extra_state"]["global_step"]
dl_state = state["extra_state"]["train_dataloader"]
logger.error(f"[rank{rank}] Loaded dataloader state: {dl_state}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def build_command(dataset_type, dataloader_type):
"--nnodes=1",
"--nproc_per_node=8",
f"--master_port={port}",
"tests/data/test_multisource_datasets.py",
"tests/data/test_interleave_datasets.py",
"--data.enable_multisource=True",
"--model.config_path=test",
"--data.train_path=None",
Expand All @@ -276,7 +276,7 @@ def build_command(dataset_type, dataloader_type):
return command


def test_multisource_data_rmpad_with_pos_ids():
def test_interleave_rmpad_with_pos_ids():
command = build_command(dataset_type="mapping", dataloader_type="rmpad_with_pos_ids")
result = subprocess.run(command, check=True)
assert result.returncode == 0
Expand All @@ -286,7 +286,7 @@ def test_multisource_data_rmpad_with_pos_ids():
assert result.returncode == 0


def test_multisource_data_padding():
def test_interleave_padding():
command = build_command(dataset_type="mapping", dataloader_type="padding")
result = subprocess.run(command, check=True)
assert result.returncode == 0
Expand Down
Loading
Loading