Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/paddlefleet/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@
from __future__ import annotations

import functools
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal

import paddle.nn.functional as F

from ..model_parallel_config import ModelParallelConfig
from ..utils import init_method_normal, scaled_init_method_normal
from ..utils import (
get_magic_init_method,
init_method_normal,
scaled_init_method_normal,
)

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -757,6 +762,9 @@ class TransformerConfig(ModelParallelConfig):
routing_map_fusion: bool = False
"""If True, use Triton fused routing map kernel for MoE routing."""

magic_init: bool = False
"""Use the magic initialization method."""

# Field name mapping rules: HuggingFace config.json name -> TransformerConfig name
transform_rules = {
# DSA field mapping
Expand Down Expand Up @@ -867,7 +875,15 @@ def __post_init__(self):
# init_method is not None
self.embedding_init_method = self.init_method

if self.init_method is None:
if self.magic_init:
if self.hidden_size == 0:
raise ValueError(
"hidden_size must be non-zero when magic_init is True."
)
Comment thread
DanielSun11 marked this conversation as resolved.
sigma = math.sqrt(0.3333 / self.hidden_size)
self.init_method = get_magic_init_method(sigma)
self.init_method_std = sigma
elif self.init_method is None:
self.init_method = init_method_normal(self.init_method_std)

if (
Expand Down Expand Up @@ -924,7 +940,9 @@ def __post_init__(self):
"recompute_granularity must be one of full and selective"
)

if self.output_layer_init_method is None:
if self.magic_init:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 magic_init=True 时无条件覆盖用户显式设置的 output_layer_init_method

当前逻辑在 magic_init=True 时直接赋值,即使用户通过构造参数显式传入了自定义的 output_layer_init_method,也会被静默覆盖。同样的问题也存在于下方 embedding_init_method 的处理(line 956-958)。

建议修复策略:仅在用户未显式指定时覆盖,与 elif ... is None 模式保持一致:

if self.magic_init and self.output_layer_init_method is None:
    self.output_layer_init_method = self.init_method
elif self.output_layer_init_method is None:
    ...

如果设计意图就是强制统一,建议在 docstring 中明确说明 magic_init=True 会忽略用户指定的其他 init method。

self.output_layer_init_method = self.init_method
elif self.output_layer_init_method is None:
self.output_layer_init_method = scaled_init_method_normal(
self.init_method_std,
self.num_hidden_layers,
Expand All @@ -936,7 +954,10 @@ def __post_init__(self):
# By default, use the same init std as you use for every other non-output layer.
self.embedding_init_method_std = self.init_method_std

if self.embedding_init_method is None:
if self.magic_init:
self.embedding_init_method = self.init_method
self.embedding_init_method_std = self.init_method_std
elif self.embedding_init_method is None:
if self.init_method is None or (
self.embedding_init_method_std != self.init_method_std
):
Expand Down
11 changes: 11 additions & 0 deletions src/paddlefleet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,17 @@ def scaled_init_method_normal(sigma, num_layers, multiplier=2.0):
return functools.partial(paddle.nn.init.normal_, mean=0.0, std=std)


def get_magic_init_method(sigma):
"""Magic init method: randn(...).scale(sigma) under fp32 default dtype guard."""
Comment thread
DanielSun11 marked this conversation as resolved.

def init_method(weight):
weight.set_value(
paddle.randn(weight.shape, dtype=weight.dtype).scale(sigma)
)

return init_method


def get_pg_size(group=None):
"""Get world size for a distributed group.
Expand Down
133 changes: 133 additions & 0 deletions tests/single_card_tests/test_transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,5 +189,138 @@ def test_hybridep_dispatcher_type_is_preserved(self):
self.assertTrue(config.moe_use_fusion_node)


class TestMagicInit(unittest.TestCase):
"""Tests for the magic_init functionality in TransformerConfig."""

def test_magic_init_false_default_behavior(self):
"""When magic_init is False (default), normal init methods should be used."""
config = TransformerConfig(
num_hidden_layers=12,
hidden_size=768,
magic_init=False,
)
# When False, init_method should be set but not the magic init
self.assertIsNotNone(config.init_method)
self.assertIsNotNone(config.output_layer_init_method)

def test_magic_init_true_sigma_calculation(self):
"""When magic_init is True, sigma should be sqrt(0.3333 / hidden_size)."""
import math

hidden_size = 768
config = TransformerConfig(
num_hidden_layers=12,
hidden_size=hidden_size,
magic_init=True,
)
expected_sigma = math.sqrt(0.3333 / hidden_size)
self.assertAlmostEqual(config.init_method_std, expected_sigma, places=6)

def test_magic_init_true_all_methods_same(self):
"""When magic_init is True, all init methods should be the same."""
config = TransformerConfig(
num_hidden_layers=12,
hidden_size=768,
magic_init=True,
)
# All init methods should be the same function
self.assertIs(config.init_method, config.output_layer_init_method)
self.assertIs(config.init_method, config.embedding_init_method)

def test_magic_init_true_different_hidden_sizes(self):
"""Test sigma calculation with different hidden sizes."""
import math

for hidden_size in [512, 768, 1024, 2048, 4096]:
config = TransformerConfig(
num_hidden_layers=12,
hidden_size=hidden_size,
magic_init=True,
)
expected_sigma = math.sqrt(0.3333 / hidden_size)
self.assertAlmostEqual(
config.init_method_std, expected_sigma, places=6
)

def test_magic_init_true_init_method_matches_get_magic_init_method(self):
"""When magic_init is True, init method should match get_magic_init_method."""
import math

from paddlefleet.utils import get_magic_init_method

hidden_size = 768
config = TransformerConfig(
num_hidden_layers=12,
hidden_size=hidden_size,
magic_init=True,
)

# Create test weight
weight = paddle.randn([100, 100])

# Apply config's init method
config.init_method(weight)

# Calculate expected using get_magic_init_method
expected_sigma = math.sqrt(0.3333 / hidden_size)
magic_init = get_magic_init_method(expected_sigma)
expected_weight = paddle.randn([100, 100])
magic_init(expected_weight)

# Compare results using same random seed
paddle.seed(1234)
weight1 = paddle.randn([100, 100])
config.init_method(weight1)

paddle.seed(1234)
weight2 = paddle.randn([100, 100])
magic_init(weight2)

paddle.testing.assert_close(weight1, weight2, rtol=1e-6, atol=1e-6)

def test_magic_init_false_uses_normal_init(self):
"""When magic_init is False, normal init methods should be used."""
config = TransformerConfig(
num_hidden_layers=12,
hidden_size=768,
magic_init=False,
)
# Should have init_method_std set to normal value
self.assertIsNotNone(config.init_method_std)
# Should be a reasonable value for normal init (not the magic init value)
import math

magic_sigma = math.sqrt(0.3333 / 768)
self.assertNotAlmostEqual(config.init_method_std, magic_sigma, places=6)

def test_magic_init_true_with_moe(self):
"""Test magic_init works correctly with MoE models."""
import math

config = TransformerConfig(
num_hidden_layers=12,
hidden_size=768,
n_routed_experts=8,
magic_init=True,
)
expected_sigma = math.sqrt(0.3333 / 768)
self.assertAlmostEqual(config.init_method_std, expected_sigma, places=6)
# All init methods should still be the same
self.assertIs(config.init_method, config.output_layer_init_method)
self.assertIs(config.init_method, config.embedding_init_method)

def test_magic_init_true_raises_on_zero_hidden_size(self):
"""When magic_init is True and hidden_size is 0, should raise ValueError."""
with self.assertRaises(
ValueError,
msg="hidden_size must be non-zero when magic_init is True.",
):
TransformerConfig(
num_hidden_layers=12,
hidden_size=0,
magic_init=True,
)


if __name__ == "__main__":
unittest.main()
Loading