Skip to content
Open
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
<div align="center">

test
Megatron-LM & Megatron Core
===========================

Expand Down
2 changes: 1 addition & 1 deletion megatron/core/dist_checkpointing/strategies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def load_common(self, checkpoint_dir: Union[str, Path]):
msc = MultiStorageClientFeature.import_package()
return msc.torch.load(load_path, map_location='cpu')
else:
return torch.load(load_path, map_location='cpu')
return torch.load(load_path, map_location='cpu', weights_only=False)
except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist'
if MultiStorageClientFeature.is_enabled():
Expand Down
13 changes: 8 additions & 5 deletions megatron/core/dist_checkpointing/strategies/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,12 @@ def _expected_shape(sh_ten):
def _validate_global_shapes(self, metadata, sharded_tensors):
for sh_ten in sharded_tensors:
if sh_ten.key not in metadata.state_dict_metadata:
raise KeyError(
f"{sh_ten.key} from model not in state dict:"
f" {sorted(metadata.state_dict_metadata.keys())}"
)
# raise KeyError(
# f"{sh_ten.key} from model not in state dict:"
# f" {sorted(metadata.state_dict_metadata.keys())}"
# )
print(f"{sh_ten.key} from model not in state dict, will skip")
continue
loaded_shape = metadata.state_dict_metadata[sh_ten.key].size
expected_shape = self._expected_shape(sh_ten)
if loaded_shape != expected_shape:
Expand Down Expand Up @@ -630,7 +632,7 @@ def _temporarily_bypass_shape_validation(self):
tensor_metadata = self.metadata.state_dict_metadata
metadata_with_sizes = [
(tensor_metadata[key], tensor_metadata[key].size, sharded_tensor)
for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items()
for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata
]
try:
# Temporarily set sizes to expected shapes
Expand Down Expand Up @@ -959,6 +961,7 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors,
allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors,
allow_partial_load=True,
),
)

Expand Down
70 changes: 70 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def __init__(
)

for param in self.parameters():
setattr(param, "parallel_mode", parallel_mode)
if is_expert:
# Reduce the gradient on the expert_data_parallel group for expert linear layers
setattr(param, "allreduce", not self.expert_parallel)
Expand Down Expand Up @@ -1161,6 +1162,61 @@ def sharded_state_dict(


if HAVE_TE and is_te_min_version("1.9.0.dev0"):
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y

class _FakeInt4QuantizationSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x, group_size):
m, n = x.shape
block_size_m, block_size_n = 1, group_size


m_padded = ceil_div(m, block_size_m) * block_size_m
n_padded = ceil_div(n, block_size_n) * block_size_n

x_padded = torch.zeros(
(m_padded, n_padded),
dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x

x_view = x_padded.view(
m_padded // block_size_m,
block_size_m,
n_padded // block_size_n,
block_size_n
)

x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True)
q_max = 7
x_scale = x_max / q_max

x_scale = x_scale.clamp(min=1e-5)

x_div = x_view / x_scale
x_round = torch.round(x_div)

x_q_clamped = x_round.clamp(-q_max, q_max)

x_dequant_view = x_q_clamped * x_scale

x_dequant_full = x_dequant_view.view_as(x_padded)
x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype)

return x_out

@staticmethod
def backward(ctx, grad_output):
return grad_output, None

def fake_int4_quantization_ste(x, group_size):
x_out = _FakeInt4QuantizationSTE.apply(x, group_size)

if hasattr(x, 'main_grad'):
x_out.main_grad = x.main_grad

return x_out

class TEGroupedLinear(te.pytorch.GroupedLinear):
"""
Expand Down Expand Up @@ -1361,6 +1417,20 @@ def forward(self, x, m_splits):
return out
return out, None

def _get_weight_tensors(self):
"""Get the weight tensors of the module."""
weight_tensors = super()._get_weight_tensors()

if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1":
group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128"))

weight_tensors = [
fake_int4_quantization_ste(w, group_size)
for w in weight_tensors
]

return weight_tensors

def _encode_extra_state(self, state):
# TE 2.0 changed the format of extra_state to be a byte tensor
if is_te_min_version("2.0.0"):
Expand Down
83 changes: 51 additions & 32 deletions megatron/core/fusions/fused_mla_yarn_rope_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel(
SIN,
emb_dim: tl.constexpr,
k_dim: tl.constexpr,
k_dim_ceil: tl.constexpr,
v_dim: tl.constexpr,
head_num: tl.constexpr,
batch_size,
Expand Down Expand Up @@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel(
cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2))
sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2))

KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads
kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads
mask = kv_off < head_num * stride_kv_nheads
k_in_off = kv_off + tl.arange(0, k_dim)[None, :]
v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :]
k = tl.load(KV_ptr + k_in_off, mask=mask)
v = tl.load(KV_ptr + v_in_off, mask=mask)
KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads
ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H
kj_range = tl.arange(0, k_dim_ceil)[None, :]
mask_k = (ki_range < head_num) & (kj_range < k_dim)
mask_v = ki_range < head_num
k_off = ki_range * stride_kv_nheads + kj_range
if v_dim > 0:
v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :]
v = tl.load(KV_ptr + v_off, mask=mask_v)
else:
v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty)
k = tl.load(KV_ptr + k_off, mask=mask_k)

K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads
V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads
K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads
V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads

k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :]
v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :]
tl.store(K_ptr + k_out_off, k, mask=mask)
tl.store(V_ptr + v_out_off, v, mask=mask)
k_out_off = ki_range * stride_k_nheads + kj_range
tl.store(K_ptr + k_out_off, k, mask=mask_k)
if v_dim > 0:
v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :]
tl.store(V_ptr + v_out_off, v, mask=mask_v)

EMB = K_POS_EMB + pid_m * stride_emb_seq
# x1 = t[..., 0::2], x2 = t[..., 1::2]
Expand All @@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel(
x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2)
x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2)

x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H
mask_x = x_range < head_num
x_left_off = (
tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads
x_range * stride_k_nheads
+ k_dim
+ tl.arange(0, emb_dim // 2)[None, :]
)
x_right_off = x_left_off + emb_dim // 2
tl.store(K_ptr + x_left_off, x_left, mask=mask)
tl.store(K_ptr + x_right_off, x_right, mask=mask)
tl.store(K_ptr + x_left_off, x_left, mask=mask_x)
tl.store(K_ptr + x_right_off, x_right, mask=mask_x)


@triton.autotune(
Expand All @@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel(
SIN,
emb_dim: tl.constexpr,
k_dim: tl.constexpr,
k_dim_ceil: tl.constexpr,
v_dim: tl.constexpr,
head_num: tl.constexpr,
batch_size,
Expand Down Expand Up @@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel(
else:
token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size)

dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads
dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads
mask = dkv_off < head_num * stride_dkv_nheads
dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :]
dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :]

dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads
dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads
dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :]
dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :]
dk = tl.load(dK_ptr + dk_in_off, mask=mask)
dv = tl.load(dV_ptr + dv_in_off, mask=mask)
tl.store(dKV_ptr + dk_out_off, dk, mask=mask)
tl.store(dKV_ptr + dv_out_off, dv, mask=mask)
dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads
ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H
kj_range = tl.arange(0, k_dim_ceil)[None, :]
mask_k = (ki_range < head_num) & (kj_range < k_dim)
mask_v = ki_range < head_num
dk_out_off = ki_range * stride_dkv_nheads + kj_range

dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads
dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads
dk_in_off = ki_range * stride_dk_nheads + kj_range

dk = tl.load(dK_ptr + dk_in_off, mask=mask_k)
tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k)

if v_dim > 0:
dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :]
dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :]
dv = tl.load(dV_ptr + dv_in_off, mask=mask_v)
tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v)

if pid_head == 0:
x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32)
x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32)
for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)):
dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads
x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim
dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads
x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads
mask = x_off < head_num * stride_dk_nheads
x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :]
x_right_off = x_left_off + emb_dim // 2
Expand Down Expand Up @@ -632,6 +647,7 @@ def forward(

o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim)
o_value = kv.new_empty(total_seqlen, nheads, v_dim)
k_dim_ceil = triton.next_power_of_2(k_dim)

grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"]))
rotary_fwd_kv_kernel[grid](
Expand All @@ -643,6 +659,7 @@ def forward(
sin,
emb_dim,
k_dim,
k_dim_ceil,
v_dim,
nheads,
batch_size,
Expand Down Expand Up @@ -700,6 +717,7 @@ def backward(ctx, dk, dv):

d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim)
d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim)
k_dim_ceil = triton.next_power_of_2(ctx.k_dim)

grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"]))
rotary_bwd_kv_kernel[grid](
Expand All @@ -711,6 +729,7 @@ def backward(ctx, dk, dv):
sin,
ctx.emb_dim,
ctx.k_dim,
k_dim_ceil,
ctx.v_dim,
nheads,
batch_size,
Expand Down
10 changes: 9 additions & 1 deletion megatron/core/models/common/language_module/language_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,15 @@ def compute_output_layer_and_language_model_loss(
assert (
column_parallel_linear is not None
), "column_parallel_linear cannot be None when not using fused linear cross entropy."
logits, _ = column_parallel_linear(hidden, **col_linear_kwargs)
# output
output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()}
output_layer_buffers = dict(column_parallel_linear.named_buffers())
logits, _ = torch.func.functional_call(
column_parallel_linear,
{**output_layer_params, **output_layer_buffers},
(hidden,),
col_linear_kwargs,
)

return self.compute_language_model_loss(labels, logits)

Expand Down
8 changes: 8 additions & 0 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec(
use_kitchen: bool = False,
use_te_activation_func: bool = False,
fallback_to_eager_attn: bool = False,
post_self_attn_layernorm: bool = False,
post_mlp_layernorm: bool = False,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).

Expand Down Expand Up @@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec(
mlp=mlp,
sharded_state_dict_keys_map=sharded_state_dict_keys_map,
normalization=normalization,
post_self_attn_layernorm=post_self_attn_layernorm,
post_mlp_layernorm=post_mlp_layernorm,
)


Expand Down Expand Up @@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend(
mlp: ModuleSpec,
sharded_state_dict_keys_map: Optional[dict] = None,
normalization: Optional[str] = None,
post_self_attn_layernorm: bool = False,
post_mlp_layernorm: bool = False,
) -> ModuleSpec:
"""Helper function to get module spec for TransformerLayer"""

Expand All @@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend(
input_layernorm=input_layernorm,
self_attention=attention,
self_attn_bda=get_bias_dropout_add,
post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp,
pre_mlp_layernorm=pre_mlp_layernorm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp,
sharded_state_dict_keys_map=sharded_state_dict_keys_map,
),
)
Expand Down
Loading