Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions penzai/models/transformer/variants/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,13 +405,18 @@ def gpt_neox_from_huggingface_model(
"rotary_pct",
"vocab_size",
# Ignored by conversion:
"max_position_embeddings",
"torch_dtype",
"_attn_implementation_autoset",
"_name_or_path",
"architectures",
"attention_probs_dropout_prob",
"bos_token_id",
"eos_token_id",
"_attn_implementation_autoset",
"head_dim",
"hidden_dropout_prob",
"is_decoder",
"max_position_embeddings",
"torch_dtype",
"type_vocab_size",
}
bad_attributes = {}
for k, v in hf_config_attributes.items():
Expand Down
14 changes: 11 additions & 3 deletions penzai/models/transformer/variants/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def llama_from_huggingface_model(
reference_attributes = transformers.LlamaConfig().to_dict()
handled_or_ignored_attributes = {
# Handled during conversion:
"hidden_act",
"hidden_size",
"intermediate_size",
"num_attention_heads",
Expand All @@ -75,13 +76,20 @@ def llama_from_huggingface_model(
"rope_theta",
"vocab_size",
# Ignored by conversion:
"max_position_embeddings",
"torch_dtype",
"_attn_implementation_autoset",
"_name_or_path",
"architectures",
"attention_probs_dropout_prob",
"bos_token_id",
"eos_token_id",
"_attn_implementation_autoset",
"head_dim",
"hidden_dropout_prob",
"is_decoder",
"max_position_embeddings",
"pad_token_id",
"torch_dtype",
"type_vocab_size",
"use_cache",
}
bad_attributes = {}
for k, v in hf_config_attributes.items():
Expand Down
16 changes: 13 additions & 3 deletions penzai/models/transformer/variants/llamalike_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class LlamalikeTransformerConfig:
mlp_hidden_dim: int
num_decoder_blocks: int
vocab_size: int
mlp_variant: Literal["geglu_approx", "swiglu"]
mlp_variant: Literal["geglu_exact", "geglu_approx", "swiglu"]
tie_embedder_and_logits: bool
rope_wavelength: float = 10_000
rms_norm_eps: float = 1e-6
Expand Down Expand Up @@ -157,7 +157,9 @@ def build_llamalike_feedforward(
Returns:
An instance of TransformerFeedForward containing the GELU MLP blocks.
"""
if config.mlp_variant == "geglu_approx":
if config.mlp_variant == "geglu_exact":
act_fn = functools.partial(jax.nn.gelu, approximate=False)
elif config.mlp_variant == "geglu_approx":
# Approximate is already the default in JAX, but we specify it explicitly
# because defaults differ between JAX and PyTorch.
act_fn = functools.partial(jax.nn.gelu, approximate=True)
Expand Down Expand Up @@ -641,6 +643,14 @@ def llamalike_from_huggingface_model(
else:
activation_dtype = param_dtype

# Map HuggingFace hidden_act to Penzai mlp_variant
hidden_act_to_mlp_variant = {
"silu": "swiglu",
"gelu": "geglu_exact",
"gelu_new": "geglu_approx",
}
mlp_variant = hidden_act_to_mlp_variant[hf_config.hidden_act]

pz_config = LlamalikeTransformerConfig(
num_kv_heads=num_kv_heads,
query_head_multiplier=query_head_multiplier,
Expand All @@ -649,7 +659,7 @@ def llamalike_from_huggingface_model(
mlp_hidden_dim=hf_config.intermediate_size,
num_decoder_blocks=hf_config.num_hidden_layers,
vocab_size=hf_config.vocab_size,
mlp_variant="swiglu",
mlp_variant=mlp_variant,
rope_wavelength=hf_config.rope_theta,
tie_embedder_and_logits=False,
attention_type=attention_type,
Expand Down
14 changes: 11 additions & 3 deletions penzai/models/transformer/variants/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def mistral_from_huggingface_model(
reference_attributes = transformers.MistralConfig().to_dict()
handled_or_ignored_attributes = {
# Handled during conversion:
"hidden_act",
"hidden_size",
"intermediate_size",
"num_attention_heads",
Expand All @@ -81,11 +82,18 @@ def mistral_from_huggingface_model(
"vocab_size",
"sliding_window",
# Ignored by conversion:
"max_position_embeddings",
"torch_dtype",
"architectures",
"_attn_implementation_autoset",
"_name_or_path",
"architectures",
"attention_probs_dropout_prob",
"head_dim",
"hidden_dropout_prob",
"is_decoder",
"max_position_embeddings",
"pad_token_id",
"torch_dtype",
"type_vocab_size",
"use_cache",
}
bad_attributes = {}
for k, v in hf_config_attributes.items():
Expand Down
42 changes: 41 additions & 1 deletion tests/models/transformer_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,24 @@ class TransformerConsistencyTest(parameterized.TestCase):
)
def test_llama_consistency(self, num_attention_heads, num_key_value_heads):
cfg = transformers.LlamaConfig(
# Adjusted architecture parameters for a smaller version of Llama.
vocab_size=11,
hidden_size=64,
intermediate_size=256,
num_hidden_layers=3,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
# Extra parameters that are set when loading official models from
# HuggingFace but aren't set by default in LlamaConfig.
max_position_embeddings=8192,
rms_norm_eps=1e-05,
rope_theta=500000.0,
torch_dtype="bfloat16",
architectures=["LlamaForCausalLM"],
bos_token_id=128000,
eos_token_id=128001,
_name_or_path="meta-llama/Meta-Llama-3-8B",
_attn_implementation_autoset=True,
)

torch.manual_seed(0)
Expand Down Expand Up @@ -73,15 +85,35 @@ def test_llama_consistency(self, num_attention_heads, num_key_value_heads):
dict(testcase_name="full", num_attention_heads=4, num_key_value_heads=4),
dict(testcase_name="mqa", num_attention_heads=4, num_key_value_heads=1),
dict(testcase_name="gqa", num_attention_heads=4, num_key_value_heads=2),
dict(
testcase_name="act_gelu",
num_attention_heads=4,
num_key_value_heads=4,
hidden_act="gelu",
),
)
def test_mistral_consistency(self, num_attention_heads, num_key_value_heads):
def test_mistral_consistency(
self, num_attention_heads, num_key_value_heads, hidden_act="silu"
):
cfg = transformers.MistralConfig(
# Adjusted architecture parameters for a smaller version of Mistral.
vocab_size=11,
hidden_size=64,
intermediate_size=256,
num_hidden_layers=3,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
# Extra parameters that are set when loading official models from
# HuggingFace but aren't set by default in MistralConfig.
max_position_embeddings=32768,
sliding_window=None,
rms_norm_eps=1e-05,
rope_theta=1000000.0,
torch_dtype="bfloat16",
architectures=["MistralForCausalLM"],
_name_or_path="fake_org/fake-Mistral-version",
_attn_implementation_autoset=True,
)

torch.manual_seed(0)
Expand Down Expand Up @@ -110,11 +142,19 @@ def test_mistral_consistency(self, num_attention_heads, num_key_value_heads):

def test_gpt_neox_consistency(self):
cfg = transformers.GPTNeoXConfig(
# Adjusted architecture parameters for a smaller version of GPT-NeoX.
vocab_size=11,
hidden_size=64,
intermediate_size=256,
num_hidden_layers=3,
num_attention_heads=4,
# Extra parameters that are set when loading official models from
# HuggingFace but aren't set by default in GPTNeoXConfig.
torch_dtype="float16",
architectures=["GPTNeoXForCausalLM"],
eos_token_id=0,
_name_or_path="fake_org/fake-GPTNeoX-version",
_attn_implementation_autoset=True,
)

torch.manual_seed(0)
Expand Down