存在三个device mesh: dp_shard,no_shard,ep_shard,对应register三个pair
register tensor需要提供两个选项:
- 是否量化,如果量化,scale tensor的序列化data也一并传过来。
- 是否转bfloat16,另外一个选项,是对于一般的tensor,是否要在传之前转bfloat16。
| 参数名 |
Shape |
Dtype |
备注 |
| embed_tokens |
(129280, 7168) |
bfloat16 |
dp_shard |
| self_attn.fused_qkv_a_proj.weight |
(2112, 7168) |
float8_e4m3fn |
no_shard |
| self_attn.fused_qkv_a_proj.weight_scale |
(56, 17) |
float32 |
|
| self_attn.q_a_layernorm.weight |
(1536,) |
bfloat16 |
no_shard |
| self_attn.q_b_proj.weight |
(24576, 1536) |
float8_e4m3fn |
dp_shard |
| self_attn.q_b_proj.weight_scale |
(12, 192) |
float32 |
|
| self_attn.kv_a_layernorm.weight |
(512,) |
bfloat16 |
no_shard |
| self_attn.kv_b_proj.weight |
(32768, 512) |
float8_e4m3fn |
dp_shard |
| self_attn.kv_b_proj.weight_scale |
(4, 256) |
float32 |
|
| self_attn.o_proj.weight |
(28672, 4096) |
float8_e4m3fn |
dp_shard |
| self_attn.o_proj.weight_scale |
(32, 224) |
float32 |
|
| mlp.gate.weight |
(256, 7168) |
bfloat16 |
no_shard |
| mlp.gate.e_score_correction_bias |
(256,) |
float32 |
no_shard |
| mlp.shared_experts.gate_up_proj.weight |
(4096, 7168) |
float8_e4m3fn |
no_shard |
| mlp.shared_experts.gate_up_proj.weight_scale |
(56, 32) |
float32 |
|
| mlp.shared_experts.down_proj.weight |
(7168, 2048) |
float8_e4m3fn |
no_shard |
| mlp.shared_experts.down_proj.weight_scale |
(16, 56) |
float32 |
|
| mlp.experts.w13_weight |
(256, 4096, 7168) |
float8_e4m3fn |
ep_shard |
| mlp.experts.w13_weight_scale_inv |
(256, 32, 56) |
float32 |
|
| mlp.experts.w2_weight |
(256, 7168, 2048) |
float8_e4m3fn |
ep_shard |
| mlp.experts.w2_weight_scale_inv |
(256, 56, 16) |
float32 |
|
| input_layernorm.weight |
(7168,) |
bfloat16 |
no_shard |
| post_attention_layernorm.weight |
(7168,) |
bfloat16 |
no_shard |
存在三个device mesh: dp_shard,no_shard,ep_shard,对应register三个pair
register tensor需要提供两个选项: