From 8455dbfefd4740c944fb1876d4a6599dc9622cbc Mon Sep 17 00:00:00 2001 From: yueming-yuan Date: Thu, 23 Apr 2026 20:23:11 -0700 Subject: [PATCH 1/2] DeepSeek V4 RL support Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> --- .../core/distributed/finalize_model_grads.py | 4 +- megatron/core/models/gpt/gpt_model.py | 1 + megatron/core/optimizer/distrib_optimizer.py | 3 +- .../pipeline_parallel/p2p_communication.py | 9 +- megatron/core/pipeline_parallel/schedules.py | 10 +- megatron/core/tensor_parallel/mappings.py | 68 +++++++-- megatron/core/transformer/attention.py | 10 -- .../experimental_attention_variant/dsa.py | 137 +++++++++++++----- megatron/core/transformer/module.py | 12 ++ megatron/core/transformer/moe/moe_layer.py | 17 ++- megatron/core/transformer/moe/moe_utils.py | 21 ++- megatron/core/transformer/moe/router.py | 112 +++++++++++--- .../core/transformer/moe/shared_experts.py | 3 + .../transformer/multi_latent_attention.py | 4 +- .../core/transformer/transformer_block.py | 31 ++++ .../core/transformer/transformer_config.py | 66 ++++++++- .../core/transformer/transformer_layer.py | 106 +++++++++++--- megatron/training/arguments.py | 8 + megatron/training/training.py | 6 - 19 files changed, 492 insertions(+), 136 deletions(-) diff --git a/megatron/core/distributed/finalize_model_grads.py b/megatron/core/distributed/finalize_model_grads.py index a52592bb269..4ccdaf20fb6 100644 --- a/megatron/core/distributed/finalize_model_grads.py +++ b/megatron/core/distributed/finalize_model_grads.py @@ -281,7 +281,7 @@ def reset_model_temporary_tensors(config: TransformerConfig, model: List[torch.n """ for model_chunk in model: for module in get_attr_wrapped_model(model_chunk, 'modules')(): - if config.moe_router_enable_expert_bias and hasattr(module, 'expert_bias'): + if config.moe_router_enable_expert_bias and getattr(module, 'expert_bias', None) is not None: module.local_tokens_per_expert.zero_() if ( config.moe_router_load_balancing_type == "global_aux_loss" @@ -473,7 +473,7 @@ def finalize_model_grads( if config.timers is not None: config.timers('embedding-grads-all-reduce').stop() - if config.moe_router_enable_expert_bias: + if config.moe_router_enable_expert_bias and not config.freeze_e_score_correction_bias: _update_router_expert_bias(model, config) reset_model_temporary_tensors(config, model) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 0915f7e878b..a1c8665833a 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -540,6 +540,7 @@ def forward( packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, padding_mask=padding_mask, + input_ids=input_ids, **(extra_block_kwargs or {}), ) diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index 8fe58f92bbb..bc37e4d0f2d 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -422,8 +422,7 @@ def _build_model_and_main_param_groups( # fp32 params. elif model_param.type() == 'torch.cuda.FloatTensor': - # Keep shard tensors as leaf tensors for torch Optimizer. - shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end] + shard_model_param = model_param.view(-1)[param_range.start : param_range.end] model_fp32_params_this_group.append(model_param) shard_fp32_params_this_group.append(shard_model_param) tensor_parallel.copy_tensor_model_parallel_attributes( diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index f18309217c3..6c7561af548 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -181,17 +181,18 @@ def _communicate_shapes(self, tensor_send_next, tensor_send_prev, recv_prev, rec (recv_prev_shape, recv_next_shape) """ config = self.config + num_dims = 4 if config.dsv4_mode else 3 recv_prev_shape_tensor = None recv_next_shape_tensor = None send_prev_shape_tensor = None send_next_shape_tensor = None if recv_prev: recv_prev_shape_tensor = torch.empty( - (3,), device=torch.cuda.current_device(), dtype=torch.int64 + (num_dims,), device=torch.cuda.current_device(), dtype=torch.int64 ) if recv_next: recv_next_shape_tensor = torch.empty( - (3,), device=torch.cuda.current_device(), dtype=torch.int64 + (num_dims,), device=torch.cuda.current_device(), dtype=torch.int64 ) if tensor_send_prev is not None: send_prev_shape_tensor = torch.tensor( @@ -241,11 +242,11 @@ def _communicate_shapes(self, tensor_send_next, tensor_send_prev, recv_prev, rec # should take this out once the bug with batch_isend_irecv is resolved. torch.cuda.synchronize() - recv_prev_shape = [0, 0, 0] + recv_prev_shape = [0] * num_dims if recv_prev_shape_tensor is not None: recv_prev_shape = recv_prev_shape_tensor.tolist() - recv_next_shape = [0, 0, 0] + recv_next_shape = [0] * num_dims if recv_next_shape_tensor is not None: recv_next_shape = recv_next_shape_tensor.tolist() diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index f15dcd1400b..a7588e83562 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -1149,7 +1149,10 @@ def enable_grad_sync(): model_type = get_model_type(model[0]) - tensor_shape = [seq_length, micro_batch_size, config.hidden_size] + if config.dsv4_mode: + tensor_shape = [seq_length, micro_batch_size, config.dsv4_hc_mult, config.hidden_size] + else: + tensor_shape = [seq_length, micro_batch_size, config.hidden_size] tensor_shape[0] = tensor_shape[0] // cp_group.size() if config.sequence_parallel: tensor_shape[0] = tensor_shape[0] // tp_group.size() @@ -2098,7 +2101,10 @@ def get_tensor_shapes( if config.sequence_parallel: effective_seq_length = effective_seq_length // tp_group.size() - tensor_shapes.append((effective_seq_length, micro_batch_size, config.hidden_size)) + if config.dsv4_mode: + tensor_shapes.append((effective_seq_length, micro_batch_size, config.dsv4_hc_mult, config.hidden_size)) + else: + tensor_shapes.append((effective_seq_length, micro_batch_size, config.hidden_size)) return tensor_shapes diff --git a/megatron/core/tensor_parallel/mappings.py b/megatron/core/tensor_parallel/mappings.py index 9ff69c9dc31..4931bc41ba6 100644 --- a/megatron/core/tensor_parallel/mappings.py +++ b/megatron/core/tensor_parallel/mappings.py @@ -19,8 +19,14 @@ dist_reduce_scatter_func = torch.distributed._reduce_scatter_base -def _reduce(input_, group): - """All-reduce the input tensor across model parallel group.""" +def _reduce(input_, group, fp32=False): + """All-reduce the input tensor across model parallel group. + + Args: + input_: Input tensor. + group: Process group for all-reduce. + fp32: If True, cast to FP32 before all-reduce, then cast back. + """ assert group is not None, "group should not be None" # Bypass the function if we are using only 1 GPU. @@ -28,7 +34,13 @@ def _reduce(input_, group): return input_ # All-reduce. - torch.distributed.all_reduce(input_.contiguous(), group=group) + if fp32: + orig_dtype = input_.dtype + input_fp32 = input_.float().contiguous() + torch.distributed.all_reduce(input_fp32, group=group) + input_.copy_(input_fp32.to(orig_dtype)) + else: + torch.distributed.all_reduce(input_.contiguous(), group=group) return input_ @@ -194,24 +206,56 @@ def _reduce_scatter_along_first_dim(input_, group, input_split_sizes=None, use_g return output +def split_along_nth_dim(input_, dim, group): + """Split the tensor along the specified dimension and keep the + corresponding slice. This is a pure function without autograd. + + Args: + input_: Input tensor to split. + dim: The dimension along which to split. + group: The process group for splitting. + + Returns: + The slice of the input tensor corresponding to the current rank. + """ + assert group is not None, "group should not be None" + + world_size = group.size() + if world_size == 1: + return input_ + + dim_size = input_.size(dim) + assert ( + dim_size % world_size == 0 + ), f"Dimension {dim} of the tensor (size {dim_size}) should be divisible by world size {world_size}" + local_dim_size = dim_size // world_size + rank = group.rank() + dim_offset = rank * local_dim_size + + output = input_.narrow(dim, dim_offset, local_dim_size).contiguous() + + return output + + class _CopyToModelParallelRegion(torch.autograd.Function): """Pass the input to the model parallel region.""" @staticmethod - def symbolic(graph, input_, group): + def symbolic(graph, input_, group, all_reduce_grad_fp32): """Symbolic function for tracing.""" return input_ @staticmethod - def forward(ctx, input_, group): + def forward(ctx, input_, group, all_reduce_grad_fp32): """Forward function.""" ctx.group = group + ctx.all_reduce_grad_fp32 = all_reduce_grad_fp32 return input_ @staticmethod def backward(ctx, grad_output): """Backward function.""" - return _reduce(grad_output, ctx.group), None + return _reduce(grad_output, ctx.group, fp32=ctx.all_reduce_grad_fp32), None, None class _ReduceFromModelParallelRegion(torch.autograd.Function): @@ -466,10 +510,16 @@ def backward(ctx, *grad_output): # ----------------- -def copy_to_tensor_model_parallel_region(input_, group=None): - """Wrapper for autograd function: forward: copy, backward allreduce""" +def copy_to_tensor_model_parallel_region(input_, group=None, all_reduce_grad_fp32=False): + """Wrapper for autograd function: forward: copy, backward allreduce + + Args: + input_: Input tensor. + group: Process group for all-reduce. If None, uses default TP group. + all_reduce_grad_fp32: If True, cast gradients to FP32 before all-reduce, then cast back. + """ group = get_tensor_model_parallel_group_if_none(group) - return _CopyToModelParallelRegion.apply(input_, group) + return _CopyToModelParallelRegion.apply(input_, group, all_reduce_grad_fp32) def reduce_from_tensor_model_parallel_region(input_, group=None): diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 74e52c81977..bc5e4e2ee0d 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -1491,16 +1491,6 @@ def get_query_key_value_tensors( if output_gate: # Gate [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] gate = gate.reshape(*gate.shape[:2], -1, self.hidden_size_per_attention_head) - if self.config.num_query_groups < self.world_size: - # gate has the same head layout as query before slicing. - # Apply the same TP slice so gate matches the per-rank query. - idx = get_tensor_model_parallel_rank() % ( - self.world_size // self.config.num_query_groups - ) - size = self.num_attention_heads_per_partition // ( - self.world_size // self.config.num_query_groups - ) - gate = gate[:, :, idx * size : (idx + 1) * size, :] return query, key, value, gate return query, key, value diff --git a/megatron/core/transformer/experimental_attention_variant/dsa.py b/megatron/core/transformer/experimental_attention_variant/dsa.py index 3734db7043f..97110f1f123 100644 --- a/megatron/core/transformer/experimental_attention_variant/dsa.py +++ b/megatron/core/transformer/experimental_attention_variant/dsa.py @@ -316,7 +316,7 @@ def fused_qk_topk_naive( # ========================================= # Select top-k indices # ========================================= - topk_k = min(index_topk, seqlen) + topk_k = min(index_topk, index_scores.size(-1)) # [batch, seqlen, index_topk] topk_indices = index_scores.topk(topk_k, dim=-1)[1] @@ -687,6 +687,7 @@ def __init__( pg_collection (ProcessGroupCollection, optional): Process groups for the indexer. """ super().__init__(config=config) + self.dsv4_mode = config.dsv4_mode self.hidden_size = self.config.hidden_size self.qk_pos_emb_head_dim = self.config.qk_pos_emb_head_dim self.q_lora_rank = ( @@ -743,26 +744,27 @@ def __init__( parallel_mode="duplicated", ) - self.linear_wk = build_module( - submodules.linear_wk, - self.hidden_size, - self.index_head_dim, - config=self.config, - init_method=self.config.init_method, - bias=False, - skip_bias_add=False, - skip_weight_param_allocation=False, - parallel_mode="duplicated", - ) + if not self.dsv4_mode: + self.linear_wk = build_module( + submodules.linear_wk, + self.hidden_size, + self.index_head_dim, + config=self.config, + init_method=self.config.init_method, + bias=False, + skip_bias_add=False, + skip_weight_param_allocation=False, + parallel_mode="duplicated", + ) - k_norm_config = copy.copy(self.config) - k_norm_config.normalization = "LayerNorm" - self.k_norm = build_module( - submodules.k_norm, - config=k_norm_config, - hidden_size=self.index_head_dim, - eps=self.config.layernorm_epsilon, - ) + k_norm_config = copy.copy(self.config) + k_norm_config.normalization = "LayerNorm" + self.k_norm = build_module( + submodules.k_norm, + config=k_norm_config, + hidden_size=self.index_head_dim, + eps=self.config.layernorm_epsilon, + ) self.linear_weights_proj = build_module( submodules.linear_weights_proj, @@ -776,6 +778,30 @@ def __init__( parallel_mode="duplicated", ) + # V4-specific: compressor for key computation and custom RoPE + if self.dsv4_mode: + from miles_plugins.models.deepseek_v4.ops.compressor import DeepSeekV4Compressor + from miles_plugins.models.deepseek_v4.ops.utils import wrapped_precompute_freqs_cis + from miles_plugins.models.deepseek_v4.ops.qat import fp8_simulate_qat + self._fp8_simulate_qat = fp8_simulate_qat + + self.compress_ratio = 4 + self.compressor = DeepSeekV4Compressor( + config=self.config, + head_dim=self.index_head_dim, + compress_ratio=self.compress_ratio, + rotate=True, + cp_group=pg_collection.cp, + ) + + self.rope_head_dim = config.qk_pos_emb_head_dim + rope_base = config.dsv4_compress_rope_theta if self.compress_ratio else config.rotary_base + freqs_cis = wrapped_precompute_freqs_cis(config, rope_head_dim=self.rope_head_dim, base=rope_base) + self.register_buffer("freqs_cis", freqs_cis, persistent=False) + + from miles.utils.replay_base import indexer_replay_manager + indexer_replay_manager.register_to_module(self, "indexer_replay") + def _apply_rope(self, x: torch.Tensor, rotary_pos_emb: torch.Tensor, mscale: float): """Apply RoPE to the input tensor.""" # x_nope [seqlen, batch, *, index_head_dim - qk_pos_emb_head_dim] @@ -802,14 +828,15 @@ def forward_before_topk( # ========================================= # Prepare RoPE params # ========================================= - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - None, None, x, self.config, packed_seq_params - ) - if self.config.rope_type == "rope": - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=False) - mscale = 1.0 - else: - rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len, packed_seq=False) + if not self.dsv4_mode: + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + None, None, x, self.config, packed_seq_params + ) + if self.config.rope_type == "rope": + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=False) + mscale = 1.0 + else: + rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len, packed_seq=False) # ========================================= # Gather inputs if sp is enabled @@ -831,25 +858,48 @@ def forward_before_topk( # [seqlen, batch, index_n_heads * index_head_dim] # -> [seqlen, batch, index_n_heads, index_head_dim] q = q.reshape(seqlen, bsz, self.index_n_heads, self.index_head_dim) - q = self._apply_rope(q, rotary_pos_emb, mscale) + if self.dsv4_mode: + import einops + from miles_plugins.models.deepseek_v4.ops.cp_utils import get_freqs_cis_for_cp + from miles_plugins.models.deepseek_v4.ops.ref_model import apply_rotary_emb + + rd = self.rope_head_dim + cp_size = parallel_state.get_context_parallel_world_size() + freqs_cis = get_freqs_cis_for_cp( + self.freqs_cis, seqlen, cp_size, self.pg_collection.cp, stride=1 + ) + q = q.clone() + q = einops.rearrange(q, 's b ... -> b s ...') + apply_rotary_emb(q[..., -rd:], freqs_cis) + q = einops.rearrange(q, 'b s ... -> s b ...') + else: + q = self._apply_rope(q, rotary_pos_emb, mscale) # ========================================= # k linear and apply rope to k # ========================================= - # [seqlen, batch, hidden_size] -> [seqlen, batch, index_head_dim] - k, _ = self.linear_wk(x) - k = self.k_norm(k) - # [seqlen, batch, index_head_dim] -> [seqlen, batch, 1, index_head_dim] - k = k.reshape(seqlen, bsz, 1, self.index_head_dim) - k = self._apply_rope(k, rotary_pos_emb, mscale) - # [seqlen, batch, 1, index_head_dim] -> [seqlen, batch, index_head_dim] - k = k.reshape(seqlen, bsz, self.index_head_dim) + if self.dsv4_mode: + k = self.compressor(x) + else: + # [seqlen, batch, hidden_size] -> [seqlen, batch, index_head_dim] + k, _ = self.linear_wk(x) + k = self.k_norm(k) + # [seqlen, batch, index_head_dim] -> [seqlen, batch, 1, index_head_dim] + k = k.reshape(seqlen, bsz, 1, self.index_head_dim) + k = self._apply_rope(k, rotary_pos_emb, mscale) + # [seqlen, batch, 1, index_head_dim] -> [seqlen, batch, index_head_dim] + k = k.reshape(seqlen, bsz, self.index_head_dim) # ========================================= # Rotate activation # ========================================= q = rotate_activation(q) - k = rotate_activation(k) + if hasattr(self, '_fp8_simulate_qat'): + import os + if os.environ.get("MEGATRON_USE_KV_QAT", "0") == "1": + q = self._fp8_simulate_qat(q, block_size=128) + if not self.dsv4_mode: + k = rotate_activation(k) # ========================================= # Prepare weights for index scores @@ -892,6 +942,17 @@ def forward_with_scores( # [batch, seqlen, seqlen], [batch, seqlen, index_topk] index_scores, topk_indices = fused_qk_topk_naive(q, k, weights, self.index_topk, mask) + # V4 mode: apply indexer replay if registered + if self.dsv4_mode: + from miles.utils.replay_base import indexer_replay_manager + + def _original_topk(scores, k, **kwargs): + k = min(k, scores.size(-1)) + return scores.topk(k, dim=-1)[1] + + topk_fn = indexer_replay_manager.get_topk_fn(_original_topk, return_probs=False) + topk_indices = topk_fn(index_scores, self.index_topk) + return index_scores, topk_indices def forward( diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index c30c107e791..056aef6fc1a 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -430,6 +430,13 @@ def __init__(self, config: TransformerConfig, module: torch.nn.Module): self.vp_stage = getattr(module, 'vp_stage', None) self.pg_collection = getattr(module, 'pg_collection', None) + # Snapshot FP32 params marked with _keep_fp32 before precision conversion + fp32_params = {} + for name, param in module.named_parameters(): + if getattr(param, '_keep_fp32', False): + assert param.dtype == torch.float32 + fp32_params[name] = param.data.clone() + if self.fp16: self.add_module('module', module.half()) @@ -445,6 +452,11 @@ def float16_convertor(val): else: raise Exception('Either config.fp16 or config.bf16 should be True.') + # Restore FP32 params after precision conversion + for name, param in self.module.named_parameters(): + if name in fp32_params: + param.data = fp32_params[name] + self.float16_convertor = float16_convertor def set_input_tensor(self, input_tensor): # pylint: disable=missing-function-docstring diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 80d5a04b0f0..c5953a45b1d 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -167,7 +167,7 @@ def __init__( self.tp_group = pg_collection.tp # Initialize router. - self.router = submodules.router(config=self.config, pg_collection=pg_collection) + self.router = submodules.router(config=self.config, pg_collection=pg_collection, layer_number=layer_number) self.tp_group = pg_collection.tp # Initialize latent projections. @@ -247,13 +247,16 @@ def __init__( self.fwd_execution_map = ["route", "expert_compute", "postprocess"] @maybe_skip_or_early_return_by_cudagraph("route") - def route(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): + def route(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None): """Compute token routing for preprocessing. This method uses the router to determine which experts to send each token to, producing routing probabilities and a mapping. """ - probs, routing_map = apply_module(self.router)(hidden_states, padding_mask) + if input_ids is not None and self.config.sequence_parallel: + from megatron.core.tensor_parallel.mappings import split_along_nth_dim + input_ids = split_along_nth_dim(input_ids, dim=1, group=parallel_state.get_tensor_model_parallel_group()) + probs, routing_map = apply_module(self.router)(hidden_states, padding_mask, input_ids=input_ids) return probs, routing_map @maybe_skip_or_early_return_by_cudagraph("preprocess") @@ -363,6 +366,7 @@ def forward( hidden_states: torch.Tensor, intermediate_tensors=None, padding_mask: Optional[torch.Tensor] = None, + input_ids: Optional[torch.Tensor] = None, ): """Forward pass for the MoE layer. @@ -377,6 +381,7 @@ def forward( padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens. Shape [seq_length, bsz]. True for valid tokens, False for padding tokens. Defaults to None. + input_ids (torch.Tensor, optional): Input token IDs for hash routing. Returns: A tuple containing the output tensor and the MLP bias, if any. """ @@ -390,11 +395,11 @@ def forward( padding_mask = padding_mask.transpose(0, 1).bool() # MoE forward: route -> dispatch -> compute -> combine - def custom_forward(hidden_states, intermediate_tensors, padding_mask=None): + def custom_forward(hidden_states, intermediate_tensors, padding_mask=None, input_ids=None): try: if "route" in self.fwd_execution_map: shared_expert_output = self.shared_experts_compute(hidden_states) - probs, routing_map = self.route(hidden_states, padding_mask) + probs, routing_map = self.route(hidden_states, padding_mask, input_ids=input_ids) hidden_states, probs = self.preprocess(hidden_states, probs, routing_map) if intermediate_tensors is not None: @@ -448,7 +453,7 @@ def custom_forward(hidden_states, intermediate_tensors, padding_mask=None): custom_forward, False, hidden_states, padding_mask ) else: - outputs = custom_forward(hidden_states, intermediate_tensors, padding_mask) + outputs = custom_forward(hidden_states, intermediate_tensors, padding_mask, input_ids=input_ids) return outputs diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index c6ab213d9c3..0f4316edc77 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -620,6 +620,8 @@ def topk_routing_with_score_function( expert_bias: Optional[torch.Tensor] = None, fused: bool = False, is_mtp: bool = False, + tid2eid: Optional[torch.Tensor] = None, + input_ids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute the routing probabilities and map for top-k selection with score function. @@ -700,7 +702,8 @@ def _compute_topk( from miles.utils.replay_base import routing_replay_manager # MTP layers cannot use rollout routing replay - if not is_mtp: + # Hash-routed layers (tid2eid is not None) also bypass replay since routing is deterministic + if not is_mtp and tid2eid is None: compute_topk = routing_replay_manager.get_topk_fn(_compute_topk, return_probs=True) else: compute_topk = _compute_topk @@ -721,6 +724,22 @@ def _compute_topk( else: scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + elif score_function == "sqrtsoftplus": + assert num_groups is None + assert group_topk is None + scores = torch.nn.functional.softplus(logits.float()).sqrt().type_as(logits) + if tid2eid is not None: + assert not tid2eid.requires_grad + assert input_ids is not None and not input_ids.requires_grad + top_indices = tid2eid[input_ids] + assert torch.all(top_indices >= 0) + else: + assert expert_bias is not None + scores_for_routing = scores + expert_bias + assert len(scores_for_routing.shape) == 2 + _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) + scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits) + probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) else: raise ValueError(f"Invalid score_function: {score_function}") diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 26d29b799a8..98057078127 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -65,6 +65,11 @@ def __init__( self.calculate_per_token_loss = self.config.calculate_per_token_loss self.reset_parameters() + if self.config.moe_router_freeze_gate: + self.weight.requires_grad = False + if self.bias is not None: + self.bias.requires_grad = False + def reset_parameters(self): """Reset the router parameters.""" if self.config.perform_initialization: @@ -92,6 +97,11 @@ def gating(self, input: torch.Tensor): if self.bias is not None and self.bias.device.type == 'cpu': self.bias.data = self.bias.data.to(device=torch.cuda.current_device()) + if self.config.moe_router_freeze_gate: + assert not self.weight.requires_grad + if self.bias is not None: + assert not self.bias.requires_grad + # Convert to specified datatype for routing computation if enabled router_dtype = input.dtype if self.config.moe_router_dtype == 'fp32': @@ -146,21 +156,66 @@ class TopKRouter(Router): """ def __init__( - self, config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None + self, config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None, layer_number: int = None ) -> None: """Initialize the zero token dropping router. Args: config (TransformerConfig): The configuration for the transformer model. pg_collection (ProcessGroupCollection, optional): Process groups for MoE operations. + layer_number (int, optional): The layer number for this router. """ super().__init__(config=config, pg_collection=pg_collection) + self.layer_number = layer_number self.topk = self.config.moe_router_topk self.routing_type = self.config.moe_router_load_balancing_type self.score_function = self.config.moe_router_score_function self.input_jitter = None - self.enable_expert_bias = self.config.moe_router_enable_expert_bias + # Routing mode (expert_bias vs tid2eid) depends on layer_number, which + # may not be known at construction time — Megatron's TransformerLayer + # calls build_module(MoELayer) without layer_number, then immediately + # calls set_layer_number(). We defer creation to _init_routing_mode() + # and assert it has been called before the first forward pass. + self._routing_mode_initialized = False + self.enable_expert_bias = False + self.tid2eid = None + self._frozen_expert_bias_snapshot = None + if layer_number is not None: + self._init_routing_mode(layer_number) + + # Initialize global tokens per expert for global aux loss + if self.get_aux_loss_coeff("global_aux_loss") > 0: + self.register_buffer( + 'global_tokens_per_expert', + torch.zeros( + self.config.num_moe_experts, + dtype=torch.float32, + device=torch.cuda.current_device(), + ), + persistent=False, + ) + self.register_buffer( + 'ga_steps', + torch.tensor(0, dtype=torch.float32, device=torch.cuda.current_device()), + persistent=False, + ) + else: + self.global_tokens_per_expert = None + self.ga_steps = None + + from miles.utils.replay_base import routing_replay_manager + routing_replay_manager.register_to_module(self, "routing_replay") + + def _init_routing_mode(self, layer_number): + assert not self._routing_mode_initialized + self._routing_mode_initialized = True + + mode_hash = layer_number <= self.config.dsv4_n_hash_layers + + self.enable_expert_bias = ( + self.config.moe_router_enable_expert_bias and not mode_hash + ) if self.enable_expert_bias: self.register_buffer( 'local_tokens_per_expert', @@ -183,28 +238,23 @@ def __init__( self.local_tokens_per_expert = None self.expert_bias = None - # Initialize global tokens per expert for global aux loss - if self.get_aux_loss_coeff("global_aux_loss") > 0: - self.register_buffer( - 'global_tokens_per_expert', - torch.zeros( - self.config.num_moe_experts, - dtype=torch.float32, - device=torch.cuda.current_device(), + if self.config.freeze_e_score_correction_bias and self.enable_expert_bias: + self._frozen_expert_bias_snapshot = None + + if mode_hash: + self.tid2eid = torch.nn.Parameter( + torch.full( + (self.config.vocab_size, self.topk), + fill_value=-1, + dtype=torch.int32, ), - persistent=False, + requires_grad=False, ) - self.register_buffer( - 'ga_steps', - torch.tensor(0, dtype=torch.float32, device=torch.cuda.current_device()), - persistent=False, - ) - else: - self.global_tokens_per_expert = None - self.ga_steps = None - from miles.utils.replay_base import routing_replay_manager - routing_replay_manager.register_to_module(self, "routing_replay") + def set_layer_number(self, layer_number: int): + self.layer_number = layer_number + if not self._routing_mode_initialized: + self._init_routing_mode(layer_number) def _maintain_float32_expert_bias(self): """ @@ -543,7 +593,7 @@ def _apply_expert_bias( routing_map = routing_map & (~padding_mask) self.local_tokens_per_expert += routing_map.sum(dim=0) - def routing(self, logits: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): + def routing(self, logits: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None): """Top-k routing function Args: @@ -551,12 +601,16 @@ def routing(self, logits: torch.Tensor, padding_mask: Optional[torch.Tensor] = N padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens. Shape [seq_length, bsz]. True for valid tokens, False for padding tokens. Defaults to None. + input_ids (torch.Tensor, optional): Input token IDs for hash routing. Returns: probs (torch.Tensor): The probabilities of token to experts assignment. routing_map (torch.Tensor): The mapping of token to experts assignment, with shape [num_tokens, num_experts]. """ + if self.config.dsv4_mode: + assert self._routing_mode_initialized + seq_length, bsz = logits.shape[:2] logits = logits.view(-1, self.config.num_moe_experts) @@ -582,6 +636,8 @@ def routing(self, logits: torch.Tensor, padding_mask: Optional[torch.Tensor] = N expert_bias=self.expert_bias, fused=self.config.moe_router_fusion, is_mtp=self.is_mtp, + tid2eid=self.tid2eid, + input_ids=input_ids.view(-1) if self.tid2eid is not None and input_ids is not None else None, ) # Apply token dropping to probs and routing_map. @@ -637,7 +693,7 @@ def reset_global_aux_loss_tracker(self): self.global_tokens_per_expert.zero_() self.ga_steps.zero_() - def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): + def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None): """ Forward pass of the router. @@ -646,9 +702,17 @@ def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = No padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens. Shape [seq_length, bsz]. True for valid tokens, False for padding tokens. Defaults to None. + input_ids (torch.Tensor, optional): Input token IDs for hash routing. """ self._maintain_float32_expert_bias() + if self.config.freeze_e_score_correction_bias and self.enable_expert_bias: + if self._frozen_expert_bias_snapshot is None: + self._frozen_expert_bias_snapshot = self.expert_bias.clone() + else: + assert torch.equal(self.expert_bias, self._frozen_expert_bias_snapshot), \ + "expert_bias was modified but freeze_e_score_correction_bias is enabled!" + # Apply input jitter input = self.apply_input_jitter(input) logits = self.gating(input) @@ -663,7 +727,7 @@ def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = No logits, self.config.moe_router_force_biased, self.layer_number ) - probs, routing_map = self.routing(logits, padding_mask=padding_mask) + probs, routing_map = self.routing(logits, padding_mask=padding_mask, input_ids=input_ids) return probs, routing_map diff --git a/megatron/core/transformer/moe/shared_experts.py b/megatron/core/transformer/moe/shared_experts.py index 35066b1a8b0..cee9719c89e 100644 --- a/megatron/core/transformer/moe/shared_experts.py +++ b/megatron/core/transformer/moe/shared_experts.py @@ -47,6 +47,9 @@ def __init__( assert config.add_bias_linear == False, "bias is not supported in the shared experts, " "please set '--disable-bias-linear' instead." + if not config.activation_func_clamp_shared_expert: + config.activation_func_clamp_value = None + config.ffn_hidden_size = config.moe_shared_expert_intermediate_size # TODO(Hepteract): pass pg_collection to MLP after refactoring MLP super().__init__(config=config, submodules=submodules, tp_group=pg_collection.tp) diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index cd3db50a35b..6f5c1f58f3e 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -283,8 +283,8 @@ def forward( else: if inference_context is None or inference_context.is_static_batching(): extra_kwargs = {} - if self.config.experimental_attention_variant == "dsa": - # For dsa we need to pass in the original hidden states and the compressed + if self.config.experimental_attention_variant in ("dsa", "dsv4"): + # For dsa/dsv4 we need to pass in the original hidden states and the compressed # query representation. extra_kwargs["x"] = hidden_states extra_kwargs["qr"] = q_compressed diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index f222a2c3a6b..9105e610480 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -384,6 +384,16 @@ def build_layer(layer_spec, layer_number): else: self.final_layernorm = None # Either this or nn.Identity + # DeepSeek V4 Hyper-Connection + if self.config.dsv4_mode: + from miles_plugins.models.deepseek_v4.ops.hyper_connection import ( + DeepSeekV4HyperConnectionUtil, + HCHeadParams, + ) + self.hc_util = DeepSeekV4HyperConnectionUtil(self.config) + if self.has_final_layernorm_in_this_stage(): + self.hc_head_params = HCHeadParams(self.config) + if self.config.inference_fuse_tp_communication: self._setup_fused_tp_communication() @@ -453,6 +463,7 @@ def _checkpointed_forward( packed_seq_params: PackedSeqParams, use_inner_quantization_context: bool, padding_mask: Optional[Tensor] = None, + input_ids: Optional[Tensor] = None, ): """Forward method with activation checkpointing.""" @@ -464,6 +475,7 @@ def custom_forward( context_mask, rotary_pos_emb, padding_mask=None, + input_ids=None, ): for index in range(start, end): layer = self._get_layer(index) @@ -495,6 +507,7 @@ def custom_forward( inference_context=None, packed_seq_params=packed_seq_params, padding_mask=padding_mask, + input_ids=input_ids, ) return hidden_states, context @@ -515,6 +528,7 @@ def checkpoint_handler(forward_func): context_mask, rotary_pos_emb, padding_mask, + input_ids, ) else: return tensor_parallel.checkpoint( @@ -526,6 +540,7 @@ def checkpoint_handler(forward_func): context_mask, rotary_pos_emb, padding_mask, + input_ids, ) if self.config.recompute_method == 'uniform': @@ -632,6 +647,7 @@ def forward( packed_seq_params: Optional[PackedSeqParams] = None, sequence_len_offset: Optional[Tensor] = None, padding_mask: Optional[Tensor] = None, + input_ids: Optional[Tensor] = None, *, inference_params: Optional[BaseInferenceContext] = None, dynamic_inference_decode_only: Optional[bool] = None, @@ -702,6 +718,10 @@ def forward( # is called here to be future-proof and corner-case-proof. hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + # HC expand: [s, b, d] -> [s, b, hc, d] + if self.config.dsv4_mode and self.pre_process: + hidden_states = self.hc_util.block_expand(hidden_states) + if self.config.sequence_parallel: rng_context = tensor_parallel.get_cuda_rng_tracker().fork() else: @@ -742,6 +762,7 @@ def forward( packed_seq_params=packed_seq_params, use_inner_quantization_context=use_inner_quantization_context, padding_mask=padding_mask, + input_ids=input_ids, ) else: for l_no, layer in enumerate(self.layers): @@ -775,6 +796,7 @@ def forward( packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, padding_mask=padding_mask, + input_ids=input_ids, ) if ( @@ -784,6 +806,15 @@ def forward( ): hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + # HC head: [s, b, hc, d] -> [s, b, d] + if self.config.dsv4_mode and self.post_process and hasattr(self, 'hc_head_params'): + hidden_states = self.hc_util.block_head( + hidden_states, + self.hc_head_params.hc_head_fn, + self.hc_head_params.hc_head_scale, + self.hc_head_params.hc_head_base, + ) + # Final layer norm. if self.final_layernorm is not None: hidden_states = self.final_layernorm(hidden_states) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index de51edaf31f..2d460efb2c0 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -187,6 +187,10 @@ class TransformerConfig(ModelParallelConfig): """Clamp the output of the linear_fc1 in the activation function. Only used when activation_func is quick_gelu.""" + activation_func_clamp_shared_expert: bool = True + """If False, skip activation_func_clamp_value inside SharedExpertMLP so only routed MoE + experts get the clamp.""" + num_moe_experts: Optional[int] = None """Number of experts to use for MoE layer. When set, it replaces MLP with MoE layer. Set to None for no MoE.""" @@ -252,8 +256,8 @@ class TransformerConfig(ModelParallelConfig): #################### # attention variant #################### - experimental_attention_variant: Optional[Literal['gated_delta_net', 'dsa']] = None - """Type of attention variant to use. Currently support gated_delta_net and dsa.""" + experimental_attention_variant: Optional[Literal['gated_delta_net', 'dsa', 'dsv4']] = None + """Type of attention variant to use. Currently support gated_delta_net, dsa, and dsv4.""" #################### # DSA @@ -274,6 +278,48 @@ class TransformerConfig(ModelParallelConfig): """Whether to use sparse DSA indexer loss. If True, the indexer loss will be computed using the top-k indices.""" + #################### + # DeepSeek V4 + #################### + dsv4_mode: bool = False + """Enable DeepSeek V4 mode (MLA + MoE with window sparse attention, topk, hyper-connections).""" + + dsv4_hc_mult: Optional[int] = None + """Hyper-Connection multiplier (number of HC streams).""" + + dsv4_hc_sinkhorn_iters: int = 20 + """Number of Sinkhorn iterations for HC doubly-stochastic normalization.""" + + dsv4_hc_eps: float = 1e-6 + """Epsilon for HC Sinkhorn normalization.""" + + dsv4_compress_ratios: Optional[List[int]] = None + """Per-layer compression ratios for compressor. None or 0 means no compression.""" + + dsv4_compress_rope_theta: float = 40000.0 + """RoPE theta for compressor positional embeddings.""" + + dsv4_o_groups: Optional[int] = None + """Number of output groups for grouped output projection.""" + + dsv4_o_lora_rank: Optional[int] = None + """LoRA rank for output projection.""" + + dsv4_n_hash_layers: int = 0 + """Number of layers using hash routing (from layer 0). Remaining layers use learned routing.""" + + dsv4_window_size: int = 4096 + """Window size for local window attention in sparse attention.""" + + vocab_size: Optional[int] = None + """Vocabulary size, passed through for hash routing tid2eid initialization.""" + + freeze_e_score_correction_bias: bool = False + """Freeze expert score correction bias during training.""" + + moe_router_freeze_gate: bool = False + """Freeze MoE router gate weights during training.""" + #################### # linear attention #################### @@ -631,8 +677,8 @@ class TransformerConfig(ModelParallelConfig): """Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax enabled. Defaults to None, which means no scaling.""" - moe_router_score_function: Literal['softmax', 'sigmoid'] = "softmax" - """Score function for MoE routing. Can be "softmax" or "sigmoid".""" + moe_router_score_function: Literal['softmax', 'sigmoid', 'sqrtsoftplus'] = "softmax" + """Score function for MoE routing. Can be "softmax", "sigmoid", or "sqrtsoftplus".""" moe_router_dtype: Optional[Literal['fp32', 'fp64']] = None """Data type for routing and expert output weighted averaging. Using fp32 or fp64 can @@ -971,6 +1017,9 @@ def __post_init__(self): self.experimental_attention_variant = self.linear_attention_type self.linear_attention_type = None + if self.experimental_attention_variant == "dsv4": + self.dsv4_mode = True + if self.experimental_attention_variant in ["gated_delta_net"]: assert ( self.linear_attention_freq is not None @@ -1619,10 +1668,13 @@ def __post_init__(self): self.expert_tensor_parallel_size == 1 ), "Bias in Moe is only supported when ETP==1" - if self.moe_router_enable_expert_bias and self.moe_router_score_function != "sigmoid": + if self.moe_router_enable_expert_bias and self.moe_router_score_function not in ( + "sigmoid", + "sqrtsoftplus", + ): raise ValueError( - "Expert bias for aux-loss-free routing only supports sigmoid score function." - "Please set --moe-router-score-function sigmoid for sigmoid score function." + "Expert bias for aux-loss-free routing only supports sigmoid or sqrtsoftplus score function. " + "Please set --moe-router-score-function to sigmoid or sqrtsoftplus." ) if self.num_moe_experts and self.fp8: diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 4c755d5a264..f707a17940a 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -392,6 +392,23 @@ def __init__( eps=self.config.layernorm_epsilon ) + # DeepSeek V4 Hyper-Connection per-layer parameters + if self.config.dsv4_mode: + hc_mult = self.config.dsv4_hc_mult + hc_dim = hc_mult * self.config.hidden_size + mix_size = (2 + hc_mult) * hc_mult + # HC attention parameters + self.hc_attn_fn = torch.nn.Parameter(torch.empty(mix_size, hc_dim, dtype=torch.float32)) + self.hc_attn_base = torch.nn.Parameter(torch.empty(mix_size, dtype=torch.float32)) + self.hc_attn_scale = torch.nn.Parameter(torch.empty(3, dtype=torch.float32)) + # HC FFN parameters + self.hc_ffn_fn = torch.nn.Parameter(torch.empty(mix_size, hc_dim, dtype=torch.float32)) + self.hc_ffn_base = torch.nn.Parameter(torch.empty(mix_size, dtype=torch.float32)) + self.hc_ffn_scale = torch.nn.Parameter(torch.empty(3, dtype=torch.float32)) + for p in [self.hc_attn_fn, self.hc_attn_base, self.hc_attn_scale, + self.hc_ffn_fn, self.hc_ffn_base, self.hc_ffn_scale]: + p._keep_fp32 = True + self.recompute_input_layernorm = False self.recompute_pre_mlp_layernorm = False self.recompute_mlp = False @@ -526,11 +543,13 @@ def forward(self, *args, **kwargs): # this is only used to uniquely identify decode and non-decode cuda graph # runners in the cuda graph manager kwargs.pop("dynamic_inference_decode_only", None) + input_ids = kwargs.pop("input_ids", None) hidden_states, context = self._forward_attention(*args, **kwargs) output = self._forward_mlp( hidden_states, kwargs.get("inference_context", None), padding_mask=kwargs.get("padding_mask", None), + input_ids=input_ids, ) return output, context @@ -588,6 +607,14 @@ def _forward_attention( # Residual connection. residual = hidden_states + # HC pre for attention sublayer + if self.config.dsv4_mode: + from miles_plugins.models.deepseek_v4.ops.hyper_connection import DeepSeekV4HyperConnectionUtil + hc_util = DeepSeekV4HyperConnectionUtil(self.config) + hidden_states, hc_attn_post, hc_attn_comb = hc_util.layer_pre( + hidden_states, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base + ) + # Optional Input Layer norm if self.recompute_input_layernorm: self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput() @@ -631,23 +658,29 @@ def _forward_attention( attention_output_with_bias[0] ) - attention_output, attention_output_bias = attention_output_with_bias - attention_output = self.post_self_attn_layernorm(attention_output) - attention_output_with_bias = (attention_output, attention_output_bias) - - # TODO: could we move `bias_dropout_add_exec_handler` itself - # inside the module provided in the `bias_dropout_add_spec` module? nvtx_range_push(suffix="self_attn_bda") - if using_fused_tp_inference_kernel: - # In inference optimized transformer layer, there is no bias and dropout - # The remaining residual add is already handled inside the - # self attention module. - hidden_states = attention_output_with_bias[0] + + if self.config.dsv4_mode: + hidden_states = hc_util.layer_post( + attention_output_with_bias, residual, hc_attn_post, hc_attn_comb + ) else: - with self.bias_dropout_add_exec_handler(): - hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)( - attention_output_with_bias, residual, self.hidden_dropout - ) + attention_output, attention_output_bias = attention_output_with_bias + attention_output = self.post_self_attn_layernorm(attention_output) + attention_output_with_bias = (attention_output, attention_output_bias) + + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + if using_fused_tp_inference_kernel: + # In inference optimized transformer layer, there is no bias and dropout + # The remaining residual add is already handled inside the + # self attention module. + hidden_states = attention_output_with_bias[0] + else: + with self.bias_dropout_add_exec_handler(): + hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)( + attention_output_with_bias, residual, self.hidden_dropout + ) nvtx_range_pop(suffix="self_attn_bda") # Delay the offload of the attention norm until after the self_attn_bda has been computed @@ -700,7 +733,7 @@ def _forward_pre_mlp_layernorm(self, hidden_states): return pre_mlp_layernorm_output - def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None): + def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None, input_ids=None): """ Perform a forward pass through the feed-forward layer. @@ -709,6 +742,7 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) Shape [seq_length, batch_size, hidden_size]. inference_context: Inference context for optimizations. padding_mask (Tensor, optional): Padding mask for MoE routing. + input_ids (Tensor, optional): Input token IDs for hash routing in MoE. Shape [bsz, seq_length]. True = padding (exclude), False = valid (include). Only used for MoE layers to exclude padding tokens from aux loss computations. The MoELayer will internally transform this to [seq_length, bsz] format. @@ -719,6 +753,14 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) # Residual connection. residual = hidden_states + # HC pre for MLP sublayer + if self.config.dsv4_mode: + from miles_plugins.models.deepseek_v4.ops.hyper_connection import DeepSeekV4HyperConnectionUtil + hc_util = DeepSeekV4HyperConnectionUtil(self.config) + hidden_states, hc_ffn_post, hc_ffn_comb = hc_util.layer_pre( + hidden_states, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base + ) + # Optional Layer norm post the cross-attention. pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states) @@ -773,7 +815,11 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) # Set the residual for fused reduce-scatter + add + layer-norm + all-gather # operation in MLP's fc2. self._set_fc2_residual(residual) - mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output, padding_mask=padding_mask) + mlp_output_with_bias = self.mlp( + pre_mlp_layernorm_output, + padding_mask=padding_mask, + **(dict(input_ids=input_ids) if self.is_moe_layer else {}), + ) mlp_output, mlp_output_bias = mlp_output_with_bias mlp_output = self.post_mlp_layernorm(mlp_output) @@ -797,15 +843,21 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute(tensor) return list(mlp_output_with_bias) + [residual] else: - return self._forward_post_mlp(mlp_output_with_bias, residual) + return self._forward_post_mlp( + mlp_output_with_bias, residual, + hc_ffn_post=hc_ffn_post if self.config.dsv4_mode else None, + hc_ffn_comb=hc_ffn_comb if self.config.dsv4_mode else None, + ) - def _forward_post_mlp(self, mlp_output_with_bias, residual): + def _forward_post_mlp(self, mlp_output_with_bias, residual, *, hc_ffn_post=None, hc_ffn_comb=None): """ Perform operations after the MLP computation. Args: mlp_output_with_bias (Tensor): Output tensor of the MLP layer with bias. residual (Tensor): Residual tensor. + hc_ffn_post (Tensor, optional): HC post weights for DSV4 mode. + hc_ffn_comb (Tensor, optional): HC comb weights for DSV4 mode. Returns: output (Tensor): Transformed hidden states of shape [s, b, h]. @@ -828,7 +880,14 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual): # TODO: could we move `bias_dropout_add_exec_handler` itself # inside the module provided in the `bias_dropout_add_spec` module? nvtx_range_push(suffix="mlp_bda") - if using_fused_tp_inference_kernel: + if self.config.dsv4_mode: + # DSV4: skip bias_dropout_add; residual connection is handled by HC layer_post. + from miles_plugins.models.deepseek_v4.ops.hyper_connection import DeepSeekV4HyperConnectionUtil + hc_util = DeepSeekV4HyperConnectionUtil(self.config) + hidden_states = hc_util.layer_post( + mlp_output_with_bias, residual, hc_ffn_post, hc_ffn_comb + ) + elif using_fused_tp_inference_kernel: # In inference optimized transformer layer, there is no bias and dropout # The remaining residual add is already handled inside the # MLP module. @@ -839,6 +898,7 @@ def _forward_post_mlp(self, mlp_output_with_bias, residual): mlp_output_with_bias, residual, self.hidden_dropout ) nvtx_range_pop(suffix="mlp_bda") + # Delay the offload of the mlp norm until after the mlp_bda has been computed # because the residual is needed in the mlp_bda. if self.offload_mlp_norm: @@ -1141,7 +1201,7 @@ def _te_cuda_graph_replay(self, *args, **kwargs): return residual, hidden_states, probs, shared_expert_output # CUDA Graph does not capture the MLP/MoE part at all. - output = self._forward_mlp(*cuda_graph_output) + output = self._forward_mlp(*cuda_graph_output, input_ids=kwargs.get("input_ids", None)) return output, context def _get_te_cuda_graph_replay_args(self, *args, **kwargs): @@ -1364,7 +1424,7 @@ def _forward_mlp_postprocess(self, residual, output, shared_expert_output, mlp_b output = self.mlp(None, intermediate_tensors=(output, shared_expert_output)) return self._forward_post_mlp((output, mlp_bias), residual) - def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None): + def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None, input_ids=None): """ Orchestrates the MLP forward pass, handling partial CUDA graph execution logic. @@ -1414,4 +1474,4 @@ def _forward_mlp_partial_cudagraphs( else: return _forward_mlp_partial_cudagraphs(hidden_states, padding_mask=padding_mask) else: - return super()._forward_mlp(hidden_states, padding_mask=padding_mask) + return super()._forward_mlp(hidden_states, padding_mask=padding_mask, input_ids=input_ids) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 5fc410a76d8..28fea46195d 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1634,6 +1634,8 @@ def _add_network_size_args(parser): "barrier_with_L1_time", # args uses same var with a different name "num_moe_experts", + # defined in _add_tokenizer_args, required by topk router + "vocab_size", "fp8_param", # incompatible defaults in dataclass "gradient_accumulation_fusion", @@ -2779,6 +2781,12 @@ def _add_mla_args(parser): help="Mscale all dimensions for YaRN RoPE in multi-latent attention.") group.add_argument('--cache-mla-latents', action='store_true', default=False, help="If set caches the mla down projected latents with mla flash decode.") + group.add_argument('--original-max-position-embeddings', type=int, default=4096, + help="Original maximum position embeddings for the original model, used by yarn.") + group.add_argument('--beta-fast', type=float, default=32, + help="Beta fast for YaRN RoPE.") + group.add_argument('--beta-slow', type=float, default=1, + help="Beta slow for YaRN RoPE.") return parser diff --git a/megatron/training/training.py b/megatron/training/training.py index 121f6795711..847abba33f1 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1274,12 +1274,6 @@ def build_model(): # After TE2.x: Below function is an empty function and does nothing. correct_amax_history_if_needed(model) - # Miles allows selected parameters to opt out of Float16Module's global - # bf16/fp16 cast. Restore those dtypes immediately after model materialization. - from miles.backends.megatron_utils.fp32_param_utils import enforce_marked_param_dtypes - - enforce_marked_param_dtypes(model) - if wrap_with_ddp: if args.use_torch_fsdp2: assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0" From 1c6e5b7bcde0097ebe193b6258115ff3558f69d6 Mon Sep 17 00:00:00 2001 From: yueming-yuan Date: Tue, 28 Apr 2026 17:47:33 -0700 Subject: [PATCH 2/2] add optimizer fixes for fp32 param --- .../cpu_offloading/hybrid_optimizer.py | 21 +++++++++++++------ megatron/core/optimizer/distrib_optimizer.py | 3 ++- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py b/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py index 28487c3b367..499c17ee214 100644 --- a/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +++ b/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py @@ -122,7 +122,7 @@ def param_copy_back_gpu_hook(optimizer, args, kwargs): for param in _param_generator(optimizer): gpu_param = self.cpu_copys_map_gpu_param[param] gpu_param.data.copy_(param.data, non_blocking=True) - self._d2h_stream.record_event().wait(torch.cuda.current_stream()) + self._h2d_stream.record_event().wait(torch.cuda.current_stream()) return param_copy_back_gpu_hook @@ -370,15 +370,24 @@ def _update_fp32_params_by_new_state(self): if not self.param_update_in_fp32: return for param, v in self.state.items(): - fp32_param = self.param_to_fp32_param[param] - fp32_param.data.copy_(v["master_param"]) + inner_param = self.param_to_inner_param.get(param, param) + if inner_param is param: + continue + # Do the device/dtype conversion inside copy_ so the destination + # tensor owns the synchronization. Creating an intermediate + # non_blocking CPU tensor can race with the following CPU copy. + inner_param.data.copy_(v["master_param"].detach(), non_blocking=False) def update_fp32_param_by_new_param(self): """ - Update the fp32 parameters by the new parameters. + Refresh optimizer-side parameter copies after model weights are loaded + or otherwise changed outside the optimizer. """ - for param, fp32_param in self.param_to_fp32_param.items(): - fp32_param.data.copy_(param) + for param, inner_param in self.param_to_inner_param.items(): + if inner_param is param: + continue + # Blocking direct D2H copy is required here. + inner_param.data.copy_(param.detach(), non_blocking=False) def _register_load_state_dict_hooks(self): def pre_load_state_dict_hook(self, state_dict): diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index bc37e4d0f2d..8fe58f92bbb 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -422,7 +422,8 @@ def _build_model_and_main_param_groups( # fp32 params. elif model_param.type() == 'torch.cuda.FloatTensor': - shard_model_param = model_param.view(-1)[param_range.start : param_range.end] + # Keep shard tensors as leaf tensors for torch Optimizer. + shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end] model_fp32_params_this_group.append(model_param) shard_fp32_params_this_group.append(shard_model_param) tensor_parallel.copy_tensor_model_parallel_attributes(