Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions megatron/core/distributed/finalize_model_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def reset_model_temporary_tensors(config: TransformerConfig, model: List[torch.n
"""
for model_chunk in model:
for module in get_attr_wrapped_model(model_chunk, 'modules')():
if config.moe_router_enable_expert_bias and hasattr(module, 'expert_bias'):
if config.moe_router_enable_expert_bias and getattr(module, 'expert_bias', None) is not None:
module.local_tokens_per_expert.zero_()
if (
config.moe_router_load_balancing_type == "global_aux_loss"
Expand Down Expand Up @@ -473,7 +473,7 @@ def finalize_model_grads(
if config.timers is not None:
config.timers('embedding-grads-all-reduce').stop()

if config.moe_router_enable_expert_bias:
if config.moe_router_enable_expert_bias and not config.freeze_e_score_correction_bias:
_update_router_expert_bias(model, config)

reset_model_temporary_tensors(config, model)
Expand Down
1 change: 1 addition & 0 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def forward(
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
padding_mask=padding_mask,
input_ids=input_ids,
**(extra_block_kwargs or {}),
)

Expand Down
21 changes: 15 additions & 6 deletions megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def param_copy_back_gpu_hook(optimizer, args, kwargs):
for param in _param_generator(optimizer):
gpu_param = self.cpu_copys_map_gpu_param[param]
gpu_param.data.copy_(param.data, non_blocking=True)
self._d2h_stream.record_event().wait(torch.cuda.current_stream())
self._h2d_stream.record_event().wait(torch.cuda.current_stream())

return param_copy_back_gpu_hook

Expand Down Expand Up @@ -370,15 +370,24 @@ def _update_fp32_params_by_new_state(self):
if not self.param_update_in_fp32:
return
for param, v in self.state.items():
fp32_param = self.param_to_fp32_param[param]
fp32_param.data.copy_(v["master_param"])
inner_param = self.param_to_inner_param.get(param, param)
if inner_param is param:
continue
# Do the device/dtype conversion inside copy_ so the destination
# tensor owns the synchronization. Creating an intermediate
# non_blocking CPU tensor can race with the following CPU copy.
inner_param.data.copy_(v["master_param"].detach(), non_blocking=False)

def update_fp32_param_by_new_param(self):
"""
Update the fp32 parameters by the new parameters.
Refresh optimizer-side parameter copies after model weights are loaded
or otherwise changed outside the optimizer.
"""
for param, fp32_param in self.param_to_fp32_param.items():
fp32_param.data.copy_(param)
for param, inner_param in self.param_to_inner_param.items():
if inner_param is param:
continue
# Blocking direct D2H copy is required here.
inner_param.data.copy_(param.detach(), non_blocking=False)

def _register_load_state_dict_hooks(self):
def pre_load_state_dict_hook(self, state_dict):
Expand Down
9 changes: 5 additions & 4 deletions megatron/core/pipeline_parallel/p2p_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,18 @@ def _communicate_shapes(self, tensor_send_next, tensor_send_prev, recv_prev, rec
(recv_prev_shape, recv_next_shape)
"""
config = self.config
num_dims = 4 if config.dsv4_mode else 3
recv_prev_shape_tensor = None
recv_next_shape_tensor = None
send_prev_shape_tensor = None
send_next_shape_tensor = None
if recv_prev:
recv_prev_shape_tensor = torch.empty(
(3,), device=torch.cuda.current_device(), dtype=torch.int64
(num_dims,), device=torch.cuda.current_device(), dtype=torch.int64
)
if recv_next:
recv_next_shape_tensor = torch.empty(
(3,), device=torch.cuda.current_device(), dtype=torch.int64
(num_dims,), device=torch.cuda.current_device(), dtype=torch.int64
)
if tensor_send_prev is not None:
send_prev_shape_tensor = torch.tensor(
Expand Down Expand Up @@ -241,11 +242,11 @@ def _communicate_shapes(self, tensor_send_next, tensor_send_prev, recv_prev, rec
# should take this out once the bug with batch_isend_irecv is resolved.
torch.cuda.synchronize()

recv_prev_shape = [0, 0, 0]
recv_prev_shape = [0] * num_dims
if recv_prev_shape_tensor is not None:
recv_prev_shape = recv_prev_shape_tensor.tolist()

recv_next_shape = [0, 0, 0]
recv_next_shape = [0] * num_dims
if recv_next_shape_tensor is not None:
recv_next_shape = recv_next_shape_tensor.tolist()

Expand Down
10 changes: 8 additions & 2 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,10 @@ def enable_grad_sync():

model_type = get_model_type(model[0])

tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
if config.dsv4_mode:
tensor_shape = [seq_length, micro_batch_size, config.dsv4_hc_mult, config.hidden_size]
else:
tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
tensor_shape[0] = tensor_shape[0] // cp_group.size()
if config.sequence_parallel:
tensor_shape[0] = tensor_shape[0] // tp_group.size()
Expand Down Expand Up @@ -2098,7 +2101,10 @@ def get_tensor_shapes(
if config.sequence_parallel:
effective_seq_length = effective_seq_length // tp_group.size()

tensor_shapes.append((effective_seq_length, micro_batch_size, config.hidden_size))
if config.dsv4_mode:
tensor_shapes.append((effective_seq_length, micro_batch_size, config.dsv4_hc_mult, config.hidden_size))
else:
tensor_shapes.append((effective_seq_length, micro_batch_size, config.hidden_size))
return tensor_shapes


Expand Down
68 changes: 59 additions & 9 deletions megatron/core/tensor_parallel/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,28 @@
dist_reduce_scatter_func = torch.distributed._reduce_scatter_base


def _reduce(input_, group):
"""All-reduce the input tensor across model parallel group."""
def _reduce(input_, group, fp32=False):
"""All-reduce the input tensor across model parallel group.

Args:
input_: Input tensor.
group: Process group for all-reduce.
fp32: If True, cast to FP32 before all-reduce, then cast back.
"""
assert group is not None, "group should not be None"

# Bypass the function if we are using only 1 GPU.
if group.size() == 1:
return input_

# All-reduce.
torch.distributed.all_reduce(input_.contiguous(), group=group)
if fp32:
orig_dtype = input_.dtype
input_fp32 = input_.float().contiguous()
torch.distributed.all_reduce(input_fp32, group=group)
input_.copy_(input_fp32.to(orig_dtype))
else:
torch.distributed.all_reduce(input_.contiguous(), group=group)

return input_

Expand Down Expand Up @@ -194,24 +206,56 @@ def _reduce_scatter_along_first_dim(input_, group, input_split_sizes=None, use_g
return output


def split_along_nth_dim(input_, dim, group):
"""Split the tensor along the specified dimension and keep the
corresponding slice. This is a pure function without autograd.

Args:
input_: Input tensor to split.
dim: The dimension along which to split.
group: The process group for splitting.

Returns:
The slice of the input tensor corresponding to the current rank.
"""
assert group is not None, "group should not be None"

world_size = group.size()
if world_size == 1:
return input_

dim_size = input_.size(dim)
assert (
dim_size % world_size == 0
), f"Dimension {dim} of the tensor (size {dim_size}) should be divisible by world size {world_size}"
local_dim_size = dim_size // world_size
rank = group.rank()
dim_offset = rank * local_dim_size

output = input_.narrow(dim, dim_offset, local_dim_size).contiguous()

return output


class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""

@staticmethod
def symbolic(graph, input_, group):
def symbolic(graph, input_, group, all_reduce_grad_fp32):
"""Symbolic function for tracing."""
return input_

@staticmethod
def forward(ctx, input_, group):
def forward(ctx, input_, group, all_reduce_grad_fp32):
"""Forward function."""
ctx.group = group
ctx.all_reduce_grad_fp32 = all_reduce_grad_fp32
return input_

@staticmethod
def backward(ctx, grad_output):
"""Backward function."""
return _reduce(grad_output, ctx.group), None
return _reduce(grad_output, ctx.group, fp32=ctx.all_reduce_grad_fp32), None, None


class _ReduceFromModelParallelRegion(torch.autograd.Function):
Expand Down Expand Up @@ -466,10 +510,16 @@ def backward(ctx, *grad_output):
# -----------------


def copy_to_tensor_model_parallel_region(input_, group=None):
"""Wrapper for autograd function: forward: copy, backward allreduce"""
def copy_to_tensor_model_parallel_region(input_, group=None, all_reduce_grad_fp32=False):
"""Wrapper for autograd function: forward: copy, backward allreduce

Args:
input_: Input tensor.
group: Process group for all-reduce. If None, uses default TP group.
all_reduce_grad_fp32: If True, cast gradients to FP32 before all-reduce, then cast back.
"""
group = get_tensor_model_parallel_group_if_none(group)
return _CopyToModelParallelRegion.apply(input_, group)
return _CopyToModelParallelRegion.apply(input_, group, all_reduce_grad_fp32)


def reduce_from_tensor_model_parallel_region(input_, group=None):
Expand Down
10 changes: 0 additions & 10 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1491,16 +1491,6 @@ def get_query_key_value_tensors(
if output_gate:
# Gate [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
gate = gate.reshape(*gate.shape[:2], -1, self.hidden_size_per_attention_head)
if self.config.num_query_groups < self.world_size:
# gate has the same head layout as query before slicing.
# Apply the same TP slice so gate matches the per-rank query.
idx = get_tensor_model_parallel_rank() % (
self.world_size // self.config.num_query_groups
)
size = self.num_attention_heads_per_partition // (
self.world_size // self.config.num_query_groups
)
gate = gate[:, :, idx * size : (idx + 1) * size, :]
return query, key, value, gate

return query, key, value
Expand Down
Loading