From ab585d79aba42fca41ba0e82f65e5acabde9fcd3 Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Thu, 14 May 2026 19:41:23 +0000 Subject: [PATCH] Add hetero MIMO DDP performance flags --- .../run_hetero_nemotron_54l_hel_train.sh | 30 +++++++++++++- examples/mimo/training/hetero/args.py | 39 ++++++++++++++++++- examples/mimo/training/hetero/optimizer.py | 1 + examples/mimo/training/hetero/runtime.py | 35 ++++++++++++++--- 4 files changed, 95 insertions(+), 10 deletions(-) diff --git a/examples/mimo/scripts/run_hetero_nemotron_54l_hel_train.sh b/examples/mimo/scripts/run_hetero_nemotron_54l_hel_train.sh index 74be1ed2134..ebccea20091 100755 --- a/examples/mimo/scripts/run_hetero_nemotron_54l_hel_train.sh +++ b/examples/mimo/scripts/run_hetero_nemotron_54l_hel_train.sh @@ -58,6 +58,11 @@ LLM_EP="${LLM_EP:-16}" LLM_EXPT_TP="${LLM_EXPT_TP:-1}" ENABLE_EXPERIMENTAL="${ENABLE_EXPERIMENTAL:-1}" MOE_ROUTER_FORCE_LOAD_BALANCING="${MOE_ROUTER_FORCE_LOAD_BALANCING:-0}" +OVERLAP_GRAD_REDUCE="${OVERLAP_GRAD_REDUCE:-1}" +OVERLAP_PARAM_GATHER="${OVERLAP_PARAM_GATHER:-0}" +DDP_NUM_BUCKETS="${DDP_NUM_BUCKETS:-8}" +DDP_BUCKET_SIZE="${DDP_BUCKET_SIZE:-}" +DDP_PAD_BUCKETS_FOR_HIGH_NCCL_BUSBW="${DDP_PAD_BUCKETS_FOR_HIGH_NCCL_BUSBW:-1}" ENCODER_SIZE=$((ENCODER_TP * ENCODER_CP * ENCODER_PP * ENCODER_DP)) LLM_SIZE=$((LLM_TP * LLM_CP * LLM_PP * LLM_DP)) @@ -197,6 +202,9 @@ if [[ "${RANK_ID}" -eq 0 ]]; then echo "enable_experimental=${ENABLE_EXPERIMENTAL}" echo "moe_router_force_load_balancing=${MOE_ROUTER_FORCE_LOAD_BALANCING}" echo "moe_router_fusion=model-provider-default" + echo "overlap_grad_reduce=${OVERLAP_GRAD_REDUCE} overlap_param_gather=${OVERLAP_PARAM_GATHER}" + echo "ddp_num_buckets=${DDP_NUM_BUCKETS:-unset} ddp_bucket_size=${DDP_BUCKET_SIZE:-unset}" + echo "ddp_pad_buckets_for_high_nccl_busbw=${DDP_PAD_BUCKETS_FOR_HIGH_NCCL_BUSBW}" echo "data=${DATA_TRAIN}" echo "tokenizer=${TOKENIZER_MODEL}" echo "run_dir=${RUN_DIR}" @@ -218,6 +226,25 @@ fi if [[ "${MOE_ROUTER_FORCE_LOAD_BALANCING}" == "1" || "${MOE_ROUTER_FORCE_LOAD_BALANCING}" == "true" ]]; then MODEL_ARGS+=(--moe-router-force-load-balancing) fi +DDP_ARGS=() +if [[ "${OVERLAP_GRAD_REDUCE}" == "1" || "${OVERLAP_GRAD_REDUCE}" == "true" ]]; then + DDP_ARGS+=(--overlap-grad-reduce) +else + DDP_ARGS+=(--no-overlap-grad-reduce) +fi +if [[ "${OVERLAP_PARAM_GATHER}" == "1" || "${OVERLAP_PARAM_GATHER}" == "true" ]]; then + DDP_ARGS+=(--overlap-param-gather) +else + DDP_ARGS+=(--no-overlap-param-gather) +fi +if [[ -n "${DDP_NUM_BUCKETS}" ]]; then + DDP_ARGS+=(--ddp-num-buckets "${DDP_NUM_BUCKETS}") +elif [[ -n "${DDP_BUCKET_SIZE}" ]]; then + DDP_ARGS+=(--ddp-bucket-size "${DDP_BUCKET_SIZE}") +fi +if [[ "${DDP_PAD_BUCKETS_FOR_HIGH_NCCL_BUSBW}" == "1" || "${DDP_PAD_BUCKETS_FOR_HIGH_NCCL_BUSBW}" == "true" ]]; then + DDP_ARGS+=(--ddp-pad-buckets-for-high-nccl-busbw) +fi CMD=( "${PYTHON_BIN}" -u examples/mimo/train_hetero.py @@ -257,8 +284,7 @@ CMD=( --adam-beta1 0.9 --adam-beta2 0.95 --clip-grad 1.0 - --no-overlap-grad-reduce - --ddp-bucket-size 0 + "${DDP_ARGS[@]}" --log-interval "${LOG_INTERVAL}" --train-iters "${TRAIN_ITERS}" "$@" diff --git a/examples/mimo/training/hetero/args.py b/examples/mimo/training/hetero/args.py index 12b7f9f041d..5c4380c5029 100644 --- a/examples/mimo/training/hetero/args.py +++ b/examples/mimo/training/hetero/args.py @@ -126,11 +126,38 @@ def parse_args() -> argparse.Namespace: "keeps overlap disabled because actual-data batches may be text-only." ), ) + train.add_argument( + "--overlap-param-gather", + action=argparse.BooleanOptionalAction, + default=False, + help=( + "Enable DDP parameter-gather overlap for the language module. Vision encoder DDP " + "keeps overlap disabled because actual-data batches may be text-only." + ), + ) + train.add_argument( + "--ddp-num-buckets", + type=int, + default=None, + help=( + "Number of language-model DDP buckets. Mutually exclusive with " + "--ddp-bucket-size." + ), + ) train.add_argument( "--ddp-bucket-size", type=int, - default=10000, - help="DDP bucket size. Use 0 for a single unbounded bucket.", + default=None, + help="DDP bucket size. Defaults to 10000. Use 0 for a single unbounded bucket.", + ) + train.add_argument( + "--ddp-pad-buckets-for-high-nccl-busbw", + action="store_true", + default=False, + help=( + "Pad language-model distributed-optimizer buckets to improve NCCL bus bandwidth " + "at large DP sizes." + ), ) train.add_argument("--seed", type=int, default=12345) train.add_argument("--log-interval", type=int, default=1) @@ -150,6 +177,14 @@ def validate_args(args: argparse.Namespace, world_size: int) -> tuple[int, int]: raise ValueError("Phase 2 mock training currently supports CP=1 only") if args.log_interval < 1: raise ValueError("--log-interval must be >= 1") + if args.ddp_num_buckets is not None and args.ddp_num_buckets < 1: + raise ValueError("--ddp-num-buckets must be >= 1") + if args.ddp_bucket_size is not None and args.ddp_bucket_size < 0: + raise ValueError("--ddp-bucket-size must be >= 0") + if args.ddp_num_buckets is not None and args.ddp_bucket_size is not None: + raise ValueError("--ddp-num-buckets and --ddp-bucket-size are mutually exclusive") + if args.overlap_param_gather and not args.overlap_grad_reduce: + raise ValueError("--overlap-param-gather requires --overlap-grad-reduce") if args.timeline_dp_replica < 0: raise ValueError("--timeline-dp-replica must be >= 0") diff --git a/examples/mimo/training/hetero/optimizer.py b/examples/mimo/training/hetero/optimizer.py index fb3bcca8ca1..f6eb7edb0cd 100644 --- a/examples/mimo/training/hetero/optimizer.py +++ b/examples/mimo/training/hetero/optimizer.py @@ -26,6 +26,7 @@ def build_optimizer(args: argparse.Namespace, model: MimoModel): clip_grad=args.clip_grad, bf16=not args.fp32, use_distributed_optimizer=True, + overlap_param_gather=args.overlap_param_gather and model.language_model is not None, log_num_zeros_in_grad=args.log_num_zeros_in_grad, ), ) diff --git a/examples/mimo/training/hetero/runtime.py b/examples/mimo/training/hetero/runtime.py index 10ba4eea0ac..aeeb63904bd 100644 --- a/examples/mimo/training/hetero/runtime.py +++ b/examples/mimo/training/hetero/runtime.py @@ -74,19 +74,21 @@ def wrap_active_modules_with_ddp( args: argparse.Namespace, mimo_model: MimoModel, topology: HeteroTopology ) -> None: """Freeze and DDP-wrap active local MIMO modules.""" - language_ddp_config = DistributedDataParallelConfig( - overlap_grad_reduce=args.overlap_grad_reduce, - bucket_size=args.ddp_bucket_size if args.ddp_bucket_size > 0 else None, - use_distributed_optimizer=True, - ) vision_ddp_config = DistributedDataParallelConfig( overlap_grad_reduce=False, - bucket_size=args.ddp_bucket_size if args.ddp_bucket_size > 0 else None, + bucket_size=resolve_fixed_ddp_bucket_size(args.ddp_bucket_size), use_distributed_optimizer=True, ) if mimo_model.language_model is not None: if args.freeze_lm: set_module_requires_grad(mimo_model.language_model, False) + language_ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=args.overlap_grad_reduce, + overlap_param_gather=args.overlap_param_gather, + bucket_size=resolve_language_ddp_bucket_size(args, mimo_model.language_model), + use_distributed_optimizer=True, + pad_buckets_for_high_nccl_busbw=args.ddp_pad_buckets_for_high_nccl_busbw, + ) debug_rank("wrapping language model in DDP") mimo_model.language_model = DistributedDataParallel( config=mimo_model.language_model.config, @@ -117,6 +119,27 @@ def wrap_active_modules_with_ddp( debug_rank("vision submodule DDP ready") +def resolve_language_ddp_bucket_size( + args: argparse.Namespace, module: torch.nn.Module +) -> Optional[int]: + """Return the configured language DDP bucket size.""" + if args.ddp_num_buckets is not None: + num_trainable_params = sum( + param.numel() for param in module.parameters() if param.requires_grad + ) + return max(1, num_trainable_params // args.ddp_num_buckets) + return resolve_fixed_ddp_bucket_size(args.ddp_bucket_size) + + +def resolve_fixed_ddp_bucket_size(bucket_size: Optional[int]) -> Optional[int]: + """Return the concrete DDP bucket size, preserving the historical default.""" + if bucket_size is None: + return 10000 + if bucket_size == 0: + return None + return bucket_size + + def set_module_requires_grad(module: Optional[torch.nn.Module], requires_grad: bool) -> None: """Set requires_grad for every parameter in a module when the module exists.""" if module is None: