@@ -667,12 +667,12 @@ def _apply_turboquant(model, config):
667667# ---------------------------------------------------------------------------
668668
669669
670- def _set_batched_moe (model , enabled , moe_moe_moe_moe_activation_dtype = "bf16" ):
670+ def _set_batched_moe (model , enabled , moe_activation_dtype = "bf16" ):
671671 """Toggle batched tensor-core MoE kernel for all MoE layers."""
672672 for layer in model .layers :
673673 if hasattr (layer , "mlp" ) and hasattr (layer .mlp , "experts" ):
674674 layer .mlp .experts .use_batched_moe = enabled
675- layer .mlp .experts .moe_moe_moe_moe_activation_dtype = moe_moe_moe_moe_activation_dtype
675+ layer .mlp .experts .moe_activation_dtype = moe_activation_dtype
676676
677677
678678def export_and_lower (model , config , args ):
@@ -916,8 +916,8 @@ def _export_cuda(model, config, args):
916916 # chunk_gated_delta_rule with CHUNK_SIZE=64) for the full range of sequence
917917 # lengths. Smaller examples cause AOTI to bake in intermediate buffer sizes
918918 # that reject longer prompts at runtime.
919- moe_moe_moe_moe_activation_dtype = getattr (args , "moe_moe_moe_moe_activation_dtype " , "bf16" )
920- _set_batched_moe (model , True , moe_moe_moe_moe_activation_dtype = moe_moe_moe_moe_activation_dtype )
919+ moe_activation_dtype = getattr (args , "moe_activation_dtype " , "bf16" )
920+ _set_batched_moe (model , True , moe_activation_dtype = moe_activation_dtype )
921921 dense_prefill = getattr (args , "dense_prefill" , "tinygemm" )
922922 _set_dequant_prefill (model , dense_prefill == "dequant" )
923923 print ("Exporting prefill method..." )
@@ -1087,14 +1087,15 @@ def main(): # noqa: C901
10871087 "--moe-activation-dtype" ,
10881088 choices = ["bf16" , "int8" ],
10891089 default = "bf16" ,
1090- help = "MoE activation dtype for prefill only. Decode always uses bf16. bf16 (default): W4A16 batched GEMM. int8: W4A8 with INT8 tensor cores (~1.5x faster prefill) ." ,
1090+ help = "MoE activation dtype for prefill only. Decode always uses bf16. bf16 (default): W4A16 batched GEMM. int8: W4A8 with INT8 tensor cores." ,
10911091 )
10921092 parser .add_argument (
10931093 "--dense-prefill" ,
10941094 choices = ["tinygemm" , "dequant" ],
10951095 default = "tinygemm" ,
1096- help = "Dense linear kernel: tinygemm (default W4A16 INT4 kernel) or "
1097- "dequant (dequant W4→BF16 + Inductor mm for prefill, int4_matvec for decode)." ,
1096+ help = "Dense linear prefill kernel. Decode always uses int4_matvec (Triton W4A16 vec-mat). "
1097+ "tinygemm (default): W4A16 _weight_int4pack_mm. "
1098+ "dequant: dequant W4→BF16 + cuBLAS GEMM." ,
10981099 )
10991100 args = parser .parse_args ()
11001101
@@ -1139,7 +1140,7 @@ def main(): # noqa: C901
11391140 "(dense weights must be W4 quantized)"
11401141 )
11411142
1142- if args .moe_moe_moe_activation_dtype != "bf16" and args .backend != "cuda" :
1143+ if args .moe_activation_dtype != "bf16" and args .backend != "cuda" :
11431144 parser .error ("--moe-activation-dtype int8 requires --backend cuda" )
11441145
11451146 model , config = load_and_quantize (args )
0 commit comments