Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions examples/mimo/scripts/run_hetero_nemotron_54l_hel_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ LLM_EP="${LLM_EP:-16}"
LLM_EXPT_TP="${LLM_EXPT_TP:-1}"
ENABLE_EXPERIMENTAL="${ENABLE_EXPERIMENTAL:-1}"
MOE_ROUTER_FORCE_LOAD_BALANCING="${MOE_ROUTER_FORCE_LOAD_BALANCING:-0}"
OVERLAP_GRAD_REDUCE="${OVERLAP_GRAD_REDUCE:-1}"
OVERLAP_PARAM_GATHER="${OVERLAP_PARAM_GATHER:-0}"
DDP_NUM_BUCKETS="${DDP_NUM_BUCKETS:-8}"
DDP_BUCKET_SIZE="${DDP_BUCKET_SIZE:-}"
DDP_PAD_BUCKETS_FOR_HIGH_NCCL_BUSBW="${DDP_PAD_BUCKETS_FOR_HIGH_NCCL_BUSBW:-1}"

ENCODER_SIZE=$((ENCODER_TP * ENCODER_CP * ENCODER_PP * ENCODER_DP))
LLM_SIZE=$((LLM_TP * LLM_CP * LLM_PP * LLM_DP))
Expand Down Expand Up @@ -197,6 +202,9 @@ if [[ "${RANK_ID}" -eq 0 ]]; then
echo "enable_experimental=${ENABLE_EXPERIMENTAL}"
echo "moe_router_force_load_balancing=${MOE_ROUTER_FORCE_LOAD_BALANCING}"
echo "moe_router_fusion=model-provider-default"
echo "overlap_grad_reduce=${OVERLAP_GRAD_REDUCE} overlap_param_gather=${OVERLAP_PARAM_GATHER}"
echo "ddp_num_buckets=${DDP_NUM_BUCKETS:-unset} ddp_bucket_size=${DDP_BUCKET_SIZE:-unset}"
echo "ddp_pad_buckets_for_high_nccl_busbw=${DDP_PAD_BUCKETS_FOR_HIGH_NCCL_BUSBW}"
echo "data=${DATA_TRAIN}"
echo "tokenizer=${TOKENIZER_MODEL}"
echo "run_dir=${RUN_DIR}"
Expand All @@ -218,6 +226,25 @@ fi
if [[ "${MOE_ROUTER_FORCE_LOAD_BALANCING}" == "1" || "${MOE_ROUTER_FORCE_LOAD_BALANCING}" == "true" ]]; then
MODEL_ARGS+=(--moe-router-force-load-balancing)
fi
DDP_ARGS=()
if [[ "${OVERLAP_GRAD_REDUCE}" == "1" || "${OVERLAP_GRAD_REDUCE}" == "true" ]]; then
DDP_ARGS+=(--overlap-grad-reduce)
else
DDP_ARGS+=(--no-overlap-grad-reduce)
fi
if [[ "${OVERLAP_PARAM_GATHER}" == "1" || "${OVERLAP_PARAM_GATHER}" == "true" ]]; then
DDP_ARGS+=(--overlap-param-gather)
else
DDP_ARGS+=(--no-overlap-param-gather)
fi
if [[ -n "${DDP_NUM_BUCKETS}" ]]; then
DDP_ARGS+=(--ddp-num-buckets "${DDP_NUM_BUCKETS}")
elif [[ -n "${DDP_BUCKET_SIZE}" ]]; then
DDP_ARGS+=(--ddp-bucket-size "${DDP_BUCKET_SIZE}")
fi
if [[ "${DDP_PAD_BUCKETS_FOR_HIGH_NCCL_BUSBW}" == "1" || "${DDP_PAD_BUCKETS_FOR_HIGH_NCCL_BUSBW}" == "true" ]]; then
DDP_ARGS+=(--ddp-pad-buckets-for-high-nccl-busbw)
fi

CMD=(
"${PYTHON_BIN}" -u examples/mimo/train_hetero.py
Expand Down Expand Up @@ -257,8 +284,7 @@ CMD=(
--adam-beta1 0.9
--adam-beta2 0.95
--clip-grad 1.0
--no-overlap-grad-reduce
--ddp-bucket-size 0
"${DDP_ARGS[@]}"
--log-interval "${LOG_INTERVAL}"
--train-iters "${TRAIN_ITERS}"
"$@"
Expand Down
39 changes: 37 additions & 2 deletions examples/mimo/training/hetero/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,38 @@ def parse_args() -> argparse.Namespace:
"keeps overlap disabled because actual-data batches may be text-only."
),
)
train.add_argument(
"--overlap-param-gather",
action=argparse.BooleanOptionalAction,
default=False,
help=(
"Enable DDP parameter-gather overlap for the language module. Vision encoder DDP "
"keeps overlap disabled because actual-data batches may be text-only."
),
)
train.add_argument(
"--ddp-num-buckets",
type=int,
default=None,
help=(
"Number of language-model DDP buckets. Mutually exclusive with "
"--ddp-bucket-size."
),
)
train.add_argument(
"--ddp-bucket-size",
type=int,
default=10000,
help="DDP bucket size. Use 0 for a single unbounded bucket.",
default=None,
help="DDP bucket size. Defaults to 10000. Use 0 for a single unbounded bucket.",
)
train.add_argument(
"--ddp-pad-buckets-for-high-nccl-busbw",
action="store_true",
default=False,
help=(
"Pad language-model distributed-optimizer buckets to improve NCCL bus bandwidth "
"at large DP sizes."
),
)
train.add_argument("--seed", type=int, default=12345)
train.add_argument("--log-interval", type=int, default=1)
Expand All @@ -150,6 +177,14 @@ def validate_args(args: argparse.Namespace, world_size: int) -> tuple[int, int]:
raise ValueError("Phase 2 mock training currently supports CP=1 only")
if args.log_interval < 1:
raise ValueError("--log-interval must be >= 1")
if args.ddp_num_buckets is not None and args.ddp_num_buckets < 1:
raise ValueError("--ddp-num-buckets must be >= 1")
if args.ddp_bucket_size is not None and args.ddp_bucket_size < 0:
raise ValueError("--ddp-bucket-size must be >= 0")
if args.ddp_num_buckets is not None and args.ddp_bucket_size is not None:
raise ValueError("--ddp-num-buckets and --ddp-bucket-size are mutually exclusive")
if args.overlap_param_gather and not args.overlap_grad_reduce:
raise ValueError("--overlap-param-gather requires --overlap-grad-reduce")
if args.timeline_dp_replica < 0:
raise ValueError("--timeline-dp-replica must be >= 0")

Expand Down
1 change: 1 addition & 0 deletions examples/mimo/training/hetero/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def build_optimizer(args: argparse.Namespace, model: MimoModel):
clip_grad=args.clip_grad,
bf16=not args.fp32,
use_distributed_optimizer=True,
overlap_param_gather=args.overlap_param_gather and model.language_model is not None,
log_num_zeros_in_grad=args.log_num_zeros_in_grad,
),
)
Expand Down
35 changes: 29 additions & 6 deletions examples/mimo/training/hetero/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,21 @@ def wrap_active_modules_with_ddp(
args: argparse.Namespace, mimo_model: MimoModel, topology: HeteroTopology
) -> None:
"""Freeze and DDP-wrap active local MIMO modules."""
language_ddp_config = DistributedDataParallelConfig(
overlap_grad_reduce=args.overlap_grad_reduce,
bucket_size=args.ddp_bucket_size if args.ddp_bucket_size > 0 else None,
use_distributed_optimizer=True,
)
vision_ddp_config = DistributedDataParallelConfig(
overlap_grad_reduce=False,
bucket_size=args.ddp_bucket_size if args.ddp_bucket_size > 0 else None,
bucket_size=resolve_fixed_ddp_bucket_size(args.ddp_bucket_size),
use_distributed_optimizer=True,
)
if mimo_model.language_model is not None:
if args.freeze_lm:
set_module_requires_grad(mimo_model.language_model, False)
language_ddp_config = DistributedDataParallelConfig(
overlap_grad_reduce=args.overlap_grad_reduce,
overlap_param_gather=args.overlap_param_gather,
bucket_size=resolve_language_ddp_bucket_size(args, mimo_model.language_model),
use_distributed_optimizer=True,
pad_buckets_for_high_nccl_busbw=args.ddp_pad_buckets_for_high_nccl_busbw,
)
debug_rank("wrapping language model in DDP")
mimo_model.language_model = DistributedDataParallel(
config=mimo_model.language_model.config,
Expand Down Expand Up @@ -117,6 +119,27 @@ def wrap_active_modules_with_ddp(
debug_rank("vision submodule DDP ready")


def resolve_language_ddp_bucket_size(
args: argparse.Namespace, module: torch.nn.Module
) -> Optional[int]:
"""Return the configured language DDP bucket size."""
if args.ddp_num_buckets is not None:
num_trainable_params = sum(
param.numel() for param in module.parameters() if param.requires_grad
)
return max(1, num_trainable_params // args.ddp_num_buckets)
return resolve_fixed_ddp_bucket_size(args.ddp_bucket_size)


def resolve_fixed_ddp_bucket_size(bucket_size: Optional[int]) -> Optional[int]:
"""Return the concrete DDP bucket size, preserving the historical default."""
if bucket_size is None:
return 10000
if bucket_size == 0:
return None
return bucket_size


def set_module_requires_grad(module: Optional[torch.nn.Module], requires_grad: bool) -> None:
"""Set requires_grad for every parameter in a module when the module exists."""
if module is None:
Expand Down