Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 127 additions & 16 deletions tests/generate/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def flat_state(self):
def from_flat_path(self, flat_path):
new_params = {}
for keys, param in flat_path:
new_params[".".join(keys)] = param.value
new_params[".".join(keys)] = (
param if hasattr(param, "value") else MockParam(param)
)
return MockState(new_params)


Expand All @@ -50,6 +52,31 @@ class MockParam:
def __init__(self, value):
self.value = value

@property
def shape(self):
return self.value.shape

@property
def dtype(self):
return self.value.dtype

@property
def ndim(self):
return self.value.ndim

@property
def sharding(self):
return self.value.sharding

def __getitem__(self, item):
return self.value[item]

def __array__(self, dtype=None):
return np.asarray(self.value, dtype=dtype)

def __jax_array__(self):
return self.value


class Logprob:

Expand Down Expand Up @@ -229,14 +256,14 @@ def test_transfer_state_with_mappings_tranpose_and_sharding_device(self):
expected_layer_0_weight = jnp.arange(16).reshape(2, 8).T * 2
self.assertTrue(
jnp.array_equal(
new_tgt_state.params["decoder.layer_0.weight"],
new_tgt_state.params["decoder.layer_0.weight"].value,
expected_layer_0_weight,
)
)
expected_layer_1_weight = jnp.arange(16, 32).reshape(2, 8).T
self.assertTrue(
jnp.array_equal(
new_tgt_state.params["encoder.layer_0.weight"],
new_tgt_state.params["encoder.layer_0.weight"].value,
expected_layer_1_weight,
)
)
Expand Down Expand Up @@ -297,7 +324,7 @@ def test_transfer_state_with_bias_padding_and_reshape(self):
# Verify shape
self.assertEqual(result.params[src_key].shape, (4, 128))
# Verify values are repeated correctly
self.assertTrue(jnp.allclose(result.params[src_key], 1.0))
self.assertTrue(jnp.allclose(result.params[src_key].value, 1.0))

def test_transfer_state_with_scanned_layers(self):
"""Comprehensive test for scanned layers covering multiple scenarios."""
Expand Down Expand Up @@ -400,7 +427,7 @@ def test_transfer_state_with_scanned_layers(self):
self.assertEqual(transferred.shape, (vocab_size, embed_dim))
self.assertTrue(
jnp.allclose(
transferred,
transferred.value,
jnp.full(
(vocab_size, embed_dim), layer_idx + 1, dtype=jnp.float32
),
Expand All @@ -420,7 +447,7 @@ def test_transfer_state_with_scanned_layers(self):

self.assertEqual(transferred.shape, (batch_size, vocab_size))
self.assertTrue(
jnp.allclose(transferred, expected),
jnp.allclose(transferred.value, expected),
f"Scanned bias layer {layer_idx} mismatch",
)

Expand All @@ -430,12 +457,96 @@ def test_transfer_state_with_scanned_layers(self):
self.assertEqual(transferred_embedding.shape, (embed_dim, vocab_size))
self.assertTrue(
jnp.allclose(
transferred_embedding,
transferred_embedding.value,
jnp.full((embed_dim, vocab_size), 99.0, dtype=jnp.float32),
),
"Regular parameter with transpose mismatch",
)

def test_transfer_state_with_mappings_gemma4(self):
"""Test transfer_state_with_mappings for Gemma4."""
from tunix.models.gemma4.mapping_vllm_jax import VLLM_JAX_MAPPING

# Mock source state (Tunix style)
src_params = {
"layers.0.attn.kv_einsum.w": MockParam(
jnp.arange(2 * 2 * 16 * 8, dtype=jnp.float32).reshape(2, 2, 16, 8)
),
"layers.0.moe.gating_einsum": MockParam(
jnp.arange(4 * 2 * 8 * 16, dtype=jnp.float32).reshape(4, 2, 8, 16)
),
"layers.0.moe.linear": MockParam(
jnp.arange(4 * 16 * 8, dtype=jnp.float32).reshape(4, 16, 8)
),
}
src_state = MockState(src_params)

# Mock destination state (vLLM Jax backend style)
dst_params = {
"model.layers.0.self_attn.k_proj.weight": MockParam(
jnp.zeros((16, 2, 8), dtype=jnp.float32)
),
"model.layers.0.self_attn.v_proj.weight": MockParam(
jnp.zeros((16, 2, 8), dtype=jnp.float32)
),
"model.layers.0.experts.kernel_gating_upproj_EDF": MockParam(
jnp.zeros((4, 2, 8, 16), dtype=jnp.float32)
),
"model.layers.0.experts.kernel_down_proj_EFD": MockParam(
jnp.zeros((4, 16, 8), dtype=jnp.float32)
),
}
dst_state = MockState(dst_params)

# Apply preprocessing if it exists in mapping
if 'preprocess_src_state' in VLLM_JAX_MAPPING:
src_state = VLLM_JAX_MAPPING['preprocess_src_state'](src_state)

key_mappings = VLLM_JAX_MAPPING['to_hf_mappings']
transpose_keys = VLLM_JAX_MAPPING['to_hf_transpose_keys']

new_tgt_state = utils.transfer_state_with_mappings(
src_state,
dst_state,
key_mappings=key_mappings,
transpose_keys=transpose_keys,
)

# Assertions
src_val = jnp.arange(2 * 2 * 16 * 8, dtype=jnp.float32).reshape(2, 2, 16, 8)
k_val_src = src_val[0]
v_val_src = src_val[1]

expected_k = jnp.transpose(k_val_src, (1, 0, 2))
expected_v = jnp.transpose(v_val_src, (1, 0, 2))

self.assertTrue(
jnp.array_equal(
new_tgt_state.params["model.layers.0.self_attn.k_proj.weight"],
expected_k,
)
)
self.assertTrue(
jnp.array_equal(
new_tgt_state.params["model.layers.0.self_attn.v_proj.weight"],
expected_v,
)
)

self.assertTrue(
jnp.array_equal(
new_tgt_state.params["model.layers.0.experts.kernel_gating_upproj_EDF"],
src_params["layers.0.moe.gating_einsum"].value,
)
)

self.assertTrue(
jnp.array_equal(
new_tgt_state.params["model.layers.0.experts.kernel_down_proj_EFD"],
src_params["layers.0.moe.linear"].value,
)
)

def test_verify_state_closeness(self):
"""Test verify_state_closeness function with various scenarios."""

Expand Down Expand Up @@ -1001,28 +1112,30 @@ def test_transfer_state_with_interleaved_scanned_layers(self):

self.assertTrue(
jnp.allclose(
new_tgt_state.params["decoder.layer.0.weight"], expected_layer_0
new_tgt_state.params["decoder.layer.0.weight"].value,
expected_layer_0,
),
"Interleaved layer 0 mismatch",
)
self.assertTrue(
jnp.allclose(
new_tgt_state.params["decoder.layer.2.weight"], expected_layer_2
new_tgt_state.params["decoder.layer.2.weight"].value,
expected_layer_2,
),
"Interleaved layer 2 mismatch",
)

# Layers 1 and 3 should remain zero (not mapped)
self.assertTrue(
jnp.allclose(
new_tgt_state.params["decoder.layer.1.weight"],
new_tgt_state.params["decoder.layer.1.weight"].value,
jnp.zeros((vocab_size, embed_dim), dtype=jnp.float32),
),
"Non-interleaved layer 1 should be zero",
)
self.assertTrue(
jnp.allclose(
new_tgt_state.params["decoder.layer.3.weight"],
new_tgt_state.params["decoder.layer.3.weight"].value,
jnp.zeros((vocab_size, embed_dim), dtype=jnp.float32),
),
"Non-interleaved layer 3 should be zero",
Expand Down Expand Up @@ -1401,21 +1514,20 @@ def test_sglang_jax_1d_kv_bias_alignment(self):

self.assertEqual(result.params[src_key].shape, (1024,))
expected = jnp.tile(src_k_bias, 8)
self.assertTrue(jnp.allclose(result.params[src_key], expected))

self.assertTrue(jnp.allclose(result.params[src_key].value, expected))

def test_transfer_state_directly_fuses_moe_weights(self):
"""Tests that wi_0 and wi_1 are fused into wi when target expects it."""
wi_0_val = jnp.array([[1.0, 2.0], [5.0, 6.0]], dtype=jnp.float32)
wi_1_val = jnp.array([[3.0, 4.0], [7.0, 8.0]], dtype=jnp.float32)

src_state = nnx.Dict(
layers=nnx.Dict(
wi_0=nnx.Param(wi_0_val),
wi_1=nnx.Param(wi_1_val),
)
)

dst_state = nnx.Dict(
layers=nnx.Dict(
wi=nnx.Param(jnp.zeros((2, 4), dtype=jnp.float32))
Expand Down Expand Up @@ -1605,7 +1717,6 @@ def test_transfer_state_directly_delete_dst_buffers_scanned_layers(self):
dst_state['layers_1']['weight'][...], scanned[1]
)


def test_transfer_state_directly_fuses_moe_weights_with_padding(self):
"""Tests that wi_0 and wi_1 are fused, padded and interleaved into wi."""
# Source: wi_0, wi_1 each (2 experts, 2 features)
Expand Down
10 changes: 9 additions & 1 deletion tunix/generate/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataclasses import dataclass
import importlib
from typing import Any, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple


class BackendMappingMixin:
Expand Down Expand Up @@ -61,6 +61,10 @@ def lora_to_hf_transpose_keys(cls, backend: str | None = None):
def to_hf_hook_fns(cls, backend: str | None = None):
return cls.mapping_for(backend).get('to_hf_hook_fns')

@classmethod
def preprocess_src_state(cls, backend: str | None = None):
return cls.mapping_for(backend).get('preprocess_src_state')


@dataclass
class MappingConfig:
Expand All @@ -77,6 +81,7 @@ class MappingConfig:
to_hf_hook_fns: Optional[Dict[str, Any]] = None
to_hf_transpose_keys: Optional[Dict[str, Tuple[int, ...]]] = None
lora_to_hf_transpose_keys: Optional[Dict[str, Tuple[int, ...]]] = None
preprocess_src_state: Optional[Callable[[Any], Any]] = None

@classmethod
def build(
Expand All @@ -102,6 +107,7 @@ def build(
'to_hf_hook_fns',
'to_hf_transpose_keys',
'lora_to_hf_transpose_keys',
'preprocess_src_state',
)

values: Dict[str, Any] = {}
Expand Down Expand Up @@ -129,6 +135,7 @@ def build(
to_hf_hook_fns=resolved.get('to_hf_hook_fns'),
to_hf_transpose_keys=resolved.get('to_hf_transpose_keys'),
lora_to_hf_transpose_keys=resolved.get('lora_to_hf_transpose_keys'),
preprocess_src_state=resolved.get('preprocess_src_state'),
)

@classmethod
Expand Down Expand Up @@ -157,6 +164,7 @@ def maybe_call(attr: str):
to_hf_hook_fns=maybe_call('to_hf_hook_fns'),
to_hf_transpose_keys=maybe_call('to_hf_transpose_keys'),
lora_to_hf_transpose_keys=maybe_call('lora_to_hf_transpose_keys'),
preprocess_src_state=maybe_call('preprocess_src_state'),
)

for key, value in overrides.items():
Expand Down
35 changes: 31 additions & 4 deletions tunix/generate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def build_flat_dict(
compiled_mappings.append((src, re.compile(pattern), sharding))

# ITERATE THROUGH ACTUAL PARAMETERS
unmapped_paths = []
for keys, v in flat_state:
# Convert key tuple ('model', 'layers', '0') to string 'model.layers.0'
path = '.'.join(str(key) for key in keys)
Expand Down Expand Up @@ -404,7 +405,10 @@ def build_flat_dict(
break
# There are no mappings for rng related params.
if not mapped:
logging.warning('!!! No mapping for flat state: %s', path)
unmapped_paths.append(path)

if unmapped_paths:
logging.warning('!!! No mapping for flat states: %s', unmapped_paths)

# Sort layers based on layer index to ensure correct order.
for key, (layers, paths, sharding) in new_flat_dict.items():
Expand Down Expand Up @@ -507,6 +511,13 @@ def _apply_transpose(
target_key = last_key
elif all_key in transpose_keys and 'lora' not in all_key:
target_key = all_key
else:
for k, _ in transpose_keys.items():
if '*' in k:
pattern = '^' + re.escape(k).replace('\\*', '.*') + '$'
if re.match(pattern, all_key):
target_key = k
break
if target_key != '':
logging.debug('Applying transpose on %s', src_key)
return jnp.transpose(val, transpose_keys[target_key])
Expand All @@ -519,7 +530,6 @@ def _apply_transpose(
if re.compile(rf'{r_key}').match(all_key):
logging.debug('Applying LoRA transpose on %s', src_key)
return jnp.transpose(val[None, :, :], transpose_keys[r_key])

return val


Expand Down Expand Up @@ -617,6 +627,22 @@ def _align_shape(
padded_dim = (val.shape[-1] + 127) // 128 * 128
repeated_dim = tgt_shape[-1] // padded_dim
new_tgt_shape = tgt_shape[:-1] + (repeated_dim, padded_dim)
elif re.compile(r'layers\..*\.moe\.gating_einsum').match(src_key):
tp_size = kwargs['tp_size']
num_experts, expert_dim, embed_dim = val.shape[0], val.shape[2], val.shape[3]
gate_chunks, up_chunks = val[:, 0, :, :], val[:, 1, :, :]
chunk_size = expert_dim // tp_size
padded_expert_chunk_dim = ((chunk_size + 127)//128)*128
pad_amount = padded_expert_chunk_dim - chunk_size
gate_chunks = gate_chunks.reshape(num_experts, tp_size, -1, embed_dim)
up_chunks = up_chunks.reshape(num_experts, tp_size, -1, embed_dim)
if pad_amount > 0:
gate_chunks = jnp.pad(gate_chunks, ((0, 0), (0, 0), (0, pad_amount), (0, 0)))
up_chunks = jnp.pad(up_chunks, ((0, 0), (0, 0), (0, pad_amount), (0, 0)))
val_chunks = jnp.stack([gate_chunks, up_chunks], axis=2)
val_chunks = val_chunks.reshape(num_experts, -1, embed_dim)
val_chunks = val_chunks.transpose(0, 2, 1)
return val_chunks
else:
raise ShapeMismatchError(
f'Rank mismatch for {src_key}: {val.shape} vs {tgt_shape}'
Expand Down Expand Up @@ -741,9 +767,10 @@ def _sync_tied_lm_head_if_needed(
embed_param = None
lm_head_param = None
for flat_key, tgt_param in tgt_flat_list:
if flat_key[-1:] == ('embedding',):
path = '.'.join(str(k) for k in flat_key)
if path.endswith(('embedding', 'embed_tokens.weight')):
embed_param = tgt_param
elif flat_key[-1:] == ('lm_head',):
elif path.endswith(('lm_head', 'lm_head.weight')):
lm_head_param = tgt_param

if embed_param is None or lm_head_param is None:
Expand Down
Loading
Loading