-
Notifications
You must be signed in to change notification settings - Fork 88
Support magic_init #1075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Support magic_init #1075
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,13 +18,18 @@ | |
| from __future__ import annotations | ||
|
|
||
| import functools | ||
| import math | ||
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING, Literal | ||
|
|
||
| import paddle.nn.functional as F | ||
|
|
||
| from ..model_parallel_config import ModelParallelConfig | ||
| from ..utils import init_method_normal, scaled_init_method_normal | ||
| from ..utils import ( | ||
| get_magic_init_method, | ||
| init_method_normal, | ||
| scaled_init_method_normal, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Callable | ||
|
|
@@ -757,6 +762,9 @@ class TransformerConfig(ModelParallelConfig): | |
| routing_map_fusion: bool = False | ||
| """If True, use Triton fused routing map kernel for MoE routing.""" | ||
|
|
||
| magic_init: bool = False | ||
| """Use the magic initialization method.""" | ||
|
|
||
| # Field name mapping rules: HuggingFace config.json name -> TransformerConfig name | ||
| transform_rules = { | ||
| # DSA field mapping | ||
|
|
@@ -867,7 +875,15 @@ def __post_init__(self): | |
| # init_method is not None | ||
| self.embedding_init_method = self.init_method | ||
|
|
||
| if self.init_method is None: | ||
| if self.magic_init: | ||
| if self.hidden_size == 0: | ||
| raise ValueError( | ||
| "hidden_size must be non-zero when magic_init is True." | ||
| ) | ||
| sigma = math.sqrt(0.3333 / self.hidden_size) | ||
| self.init_method = get_magic_init_method(sigma) | ||
| self.init_method_std = sigma | ||
| elif self.init_method is None: | ||
| self.init_method = init_method_normal(self.init_method_std) | ||
|
|
||
| if ( | ||
|
|
@@ -924,7 +940,9 @@ def __post_init__(self): | |
| "recompute_granularity must be one of full and selective" | ||
| ) | ||
|
|
||
| if self.output_layer_init_method is None: | ||
| if self.magic_init: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 当前逻辑在 建议修复策略:仅在用户未显式指定时覆盖,与 if self.magic_init and self.output_layer_init_method is None:
self.output_layer_init_method = self.init_method
elif self.output_layer_init_method is None:
...如果设计意图就是强制统一,建议在 docstring 中明确说明 |
||
| self.output_layer_init_method = self.init_method | ||
| elif self.output_layer_init_method is None: | ||
| self.output_layer_init_method = scaled_init_method_normal( | ||
| self.init_method_std, | ||
| self.num_hidden_layers, | ||
|
|
@@ -936,7 +954,10 @@ def __post_init__(self): | |
| # By default, use the same init std as you use for every other non-output layer. | ||
| self.embedding_init_method_std = self.init_method_std | ||
|
|
||
| if self.embedding_init_method is None: | ||
| if self.magic_init: | ||
| self.embedding_init_method = self.init_method | ||
| self.embedding_init_method_std = self.init_method_std | ||
| elif self.embedding_init_method is None: | ||
| if self.init_method is None or ( | ||
| self.embedding_init_method_std != self.init_method_std | ||
| ): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.