Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/mimo/training/hetero/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ def parse_args() -> argparse.Namespace:
default=0,
help="Dense data-parallel replica to trace when --timeline-ranks=dp-replica.",
)
runtime.add_argument(
"--timeline-iter-start",
type=int,
default=None,
help="First training iteration to record in timeline traces.",
)
runtime.add_argument(
"--timeline-iter-end",
type=int,
default=None,
help="Last training iteration to record in timeline traces.",
)
runtime.add_argument(
"--timeline-cuda-events",
action="store_true",
Expand Down Expand Up @@ -152,6 +164,16 @@ def validate_args(args: argparse.Namespace, world_size: int) -> tuple[int, int]:
raise ValueError("--log-interval must be >= 1")
if args.timeline_dp_replica < 0:
raise ValueError("--timeline-dp-replica must be >= 0")
if args.timeline_iter_start is not None and args.timeline_iter_start < 1:
raise ValueError("--timeline-iter-start must be >= 1")
if args.timeline_iter_end is not None and args.timeline_iter_end < 1:
raise ValueError("--timeline-iter-end must be >= 1")
if (
args.timeline_iter_start is not None
and args.timeline_iter_end is not None
and args.timeline_iter_end < args.timeline_iter_start
):
raise ValueError("--timeline-iter-end must be >= --timeline-iter-start")

validate_model_provider_args(args)
if args.dataset_provider == "mock":
Expand Down
67 changes: 36 additions & 31 deletions examples/mimo/training/hetero/grad_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from examples.mimo.utils.hetero import debug_rank, is_process_group_member
from megatron.core.distributed.finalize_model_grads import finalize_model_grads
from megatron.core.models.mimo.model.base import MimoModel
from megatron.core.pipeline_parallel.timeline import timeline_event
from megatron.core.pipeline_parallel.utils import is_pp_last_stage


Expand All @@ -34,48 +35,52 @@ def finalize_grads_func(_model_list, num_tokens, force_all_reduce=False, **_kwar
if num_tokens is None:
raise RuntimeError("train_hetero.py expects calculate_per_token_loss=True")

global_num_tokens = torch.zeros(1, dtype=torch.float32, device="cuda")
if is_token_source_rank():
# MCore has already summed loss-mask token counts across microbatches
# for this gradient-accumulation step. Match Megatron's normalization
# domain by reducing the language last-stage count over DP and CP.
token_count = num_tokens.to(device="cuda", dtype=torch.float32).sum().view(1)
dist.all_reduce(token_count, op=dist.ReduceOp.SUM, group=language_pg.dp_cp)
if dist.get_rank(language_pg.dp_cp) == 0:
global_num_tokens.copy_(token_count)
# Publish the already DP/CP-reduced language token count to encoder ranks too.
dist.all_reduce(global_num_tokens, op=dist.ReduceOp.MAX)
global_num_tokens_value = global_num_tokens.item()
with timeline_event("grad_finalize.token_count_reduce"):
global_num_tokens = torch.zeros(1, dtype=torch.float32, device="cuda")
if is_token_source_rank():
# MCore has already summed loss-mask token counts across microbatches
# for this gradient-accumulation step. Match Megatron's normalization
# domain by reducing the language last-stage count over DP and CP.
token_count = num_tokens.to(device="cuda", dtype=torch.float32).sum().view(1)
dist.all_reduce(token_count, op=dist.ReduceOp.SUM, group=language_pg.dp_cp)
if dist.get_rank(language_pg.dp_cp) == 0:
global_num_tokens.copy_(token_count)
# Publish the already DP/CP-reduced language token count to encoder ranks too.
dist.all_reduce(global_num_tokens, op=dist.ReduceOp.MAX)
global_num_tokens_value = global_num_tokens.item()

if mimo_model.language_model is not None:
debug_rank("finalizing language grads")
finalize_model_grads(
[mimo_model.language_model],
num_tokens=None,
pg_collection=language_pg,
force_all_reduce=force_all_reduce,
)
debug_rank("language grads finalized")
for submodule in mimo_model.modality_submodules.values():
if submodule is not None:
debug_rank("finalizing vision grads")
with timeline_event("grad_finalize.language"):
finalize_model_grads(
[submodule],
[mimo_model.language_model],
num_tokens=None,
pg_collection=vision_pg,
pg_collection=language_pg,
force_all_reduce=force_all_reduce,
)
debug_rank("language grads finalized")
for submodule in mimo_model.modality_submodules.values():
if submodule is not None:
debug_rank("finalizing vision grads")
with timeline_event("grad_finalize.vision"):
finalize_model_grads(
[submodule],
num_tokens=None,
pg_collection=vision_pg,
force_all_reduce=force_all_reduce,
)
debug_rank("vision grads finalized")

if global_num_tokens_value > 0:
scale = 1.0 / global_num_tokens_value
if mimo_model.language_model is not None:
debug_rank("scaling language grads")
mimo_model.language_model.scale_gradients(scale)
for submodule in mimo_model.modality_submodules.values():
if submodule is not None:
debug_rank("scaling vision grads")
submodule.scale_gradients(scale)
with timeline_event("grad_finalize.scale"):
if mimo_model.language_model is not None:
debug_rank("scaling language grads")
mimo_model.language_model.scale_gradients(scale)
for submodule in mimo_model.modality_submodules.values():
if submodule is not None:
debug_rank("scaling vision grads")
submodule.scale_gradients(scale)

mimo_model.config.no_sync_func = build_no_sync_func(mimo_model)
mimo_model.config.finalize_model_grads_func = finalize_grads_func
Expand Down
16 changes: 16 additions & 0 deletions examples/mimo/training/hetero/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from megatron.core.pipeline_parallel.timeline import (
close_pipeline_timeline,
flush_pipeline_timeline,
is_pipeline_timeline_active,
set_pipeline_timeline_iteration,
timeline_instant,
)


Expand Down Expand Up @@ -74,6 +76,7 @@ def run_train_loop(args: argparse.Namespace) -> None:
result = train_step(
args, model, topology, optimizer, opt_param_scheduler, communicator, data_iterator
)
record_cuda_memory_snapshot()
flush_pipeline_timeline()
logger.record_step(result)
logger.maybe_log(iteration, optimizer, result)
Expand All @@ -97,3 +100,16 @@ def build_pipeline_communicator(
dim_mapping={"s": 0, "h": 2, "b": 1},
module_output_ndim={topology.encoder_name: 2},
)


def record_cuda_memory_snapshot() -> None:
"""Record CUDA memory usage in the active timeline without cross-rank synchronization."""
if not torch.cuda.is_available() or not is_pipeline_timeline_active():
return
timeline_instant(
"cuda.memory",
memory_allocated_bytes=torch.cuda.memory_allocated(),
max_memory_allocated_bytes=torch.cuda.max_memory_allocated(),
memory_reserved_bytes=torch.cuda.memory_reserved(),
max_memory_reserved_bytes=torch.cuda.max_memory_reserved(),
)
15 changes: 10 additions & 5 deletions examples/mimo/training/hetero/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ def train_step(
data_iterator,
) -> TrainStepResult:
"""Run one Megatron-shaped hetero training step."""
zero_active_grad_buffers(model)
optimizer.zero_grad()
with timeline_event("grad_buffers.zero"):
zero_active_grad_buffers(model)
with timeline_event("optimizer.zero_grad"):
optimizer.zero_grad()

debug_rank("starting forward/backward schedule")
losses = schedule.forward_backward_pipelining_without_interleaving(
Expand All @@ -116,12 +118,15 @@ def train_step(
debug_rank("schedule complete")

debug_rank("optimizer step starting")
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
update_successful = reduce_update_success(update_successful)
with timeline_event("optimizer.step"):
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
with timeline_event("update_success.reduce"):
update_successful = reduce_update_success(update_successful)
debug_rank("optimizer step complete")

if update_successful:
opt_param_scheduler.step(increment=get_global_batch_size(args))
with timeline_event("optimizer_param_scheduler.step"):
opt_param_scheduler.step(increment=get_global_batch_size(args))
skipped_iter = 0
else:
# Match Megatron train_step semantics: failed updates skip LR advancement but
Expand Down
24 changes: 23 additions & 1 deletion examples/mimo/training/hetero/timeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ def configure_hetero_timeline(args: argparse.Namespace, topology: HeteroTopology
scope = os.environ.get("MIMO_TIMELINE_RANKS", args.timeline_ranks)
dp_replica = int(os.environ.get("MIMO_TIMELINE_DP_REPLICA", args.timeline_dp_replica))
output_dir = args.timeline_dir or os.environ.get("MIMO_TIMELINE_DIR", "mimo_timeline")
iteration_start = _optional_env_int("MIMO_TIMELINE_ITER_START", args.timeline_iter_start)
iteration_end = _optional_env_int("MIMO_TIMELINE_ITER_END", args.timeline_iter_end)
if iteration_start is not None and iteration_start < 1:
raise ValueError("timeline iteration start must be >= 1")
if iteration_end is not None and iteration_end < 1:
raise ValueError("timeline iteration end must be >= 1")
if (
iteration_start is not None
and iteration_end is not None
and iteration_end < iteration_start
):
raise ValueError("timeline iteration end must be >= timeline iteration start")
selected_ranks = select_timeline_ranks(scope, dp_replica, topology, world_size)
role, coords = rank_role_and_coords(rank, topology)

Expand All @@ -49,13 +61,16 @@ def configure_hetero_timeline(args: argparse.Namespace, topology: HeteroTopology
},
cuda_events=args.timeline_cuda_events or env_flag_enabled("MIMO_TIMELINE_CUDA_EVENTS"),
nvtx=args.timeline_nvtx or env_flag_enabled("MIMO_TIMELINE_NVTX"),
iteration_start=iteration_start,
iteration_end=iteration_end,
)

if rank != 0:
return None
return (
"Pipeline timeline enabled: "
f"dir={output_dir}, scope={scope}, selected_ranks={len(selected_ranks)}"
f"dir={output_dir}, scope={scope}, selected_ranks={len(selected_ranks)}, "
f"iter_start={iteration_start}, iter_end={iteration_end}"
)


Expand Down Expand Up @@ -108,3 +123,10 @@ def grid_coords(grid: HyperCommGrid, rank: int) -> dict[str, int]:
def env_flag_enabled(name: str) -> bool:
"""Return whether an environment flag is set to a truthy value."""
return os.environ.get(name, "").strip().lower() in {"1", "true", "yes", "on"}


def _optional_env_int(name: str, default: Optional[int]) -> Optional[int]:
value = os.environ.get(name)
if value is None or value.strip() == "":
return default
return int(value)
Loading