@@ -41,93 +41,73 @@ class GroupedGemmPrimitive(BasePrimitive):
4141
4242 name = "te_grouped_gemm_ffi"
4343 multiple_results = True
44- impl_static_args = (6 , 7 , 8 , 9 )
44+ impl_static_args = ()
4545 inner_primitive = None
4646 outer_primitive = None
4747
4848 @staticmethod
49- def abstract (
50- lhs_contig_aval ,
51- lhs_scale_contig_aval ,
52- rhs_contig_aval ,
53- rhs_scale_contig_aval ,
54- bias_contig_aval ,
55- dim_list_aval ,
56- * ,
57- num_gemms ,
58- scaling_mode ,
59- out_dtype ,
60- out_flat_size ,
61- ):
62- del lhs_contig_aval , lhs_scale_contig_aval
63- del rhs_contig_aval , rhs_scale_contig_aval
64- del bias_contig_aval , dim_list_aval
65- del num_gemms , scaling_mode
66- out_flat_aval = jax .core .ShapedArray (shape = (out_flat_size ,), dtype = out_dtype )
67- wkspace_size = get_cublas_workspace_size_bytes () * num_cublas_streams
68- wkspace_aval = jax .core .ShapedArray (shape = (wkspace_size ,), dtype = jnp .uint8 )
69- return (out_flat_aval , wkspace_aval )
49+ def abstract (* args , num_gemms , scaling_mode , out_dtype , has_bias ):
50+ """
51+ Args:
52+ *args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias:
53+ args[ 0 : num_gemms] are the lhs tensors,
54+ args[ num_gemms : 2*num_gemms] are the rhs tensors,
55+ args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors,
56+ args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors,
57+ args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True.
58+ num_gemms: Number of GEMM operations to perform.
59+ scaling_mode: Scaling mode for the GEMM operations.
60+ out_dtype: Data type of the output tensors.
61+ has_bias: Boolean indicating if bias tensors are provided.
62+
63+ Returns:
64+ A tuple of ShapedArray objects of size num_gemms+1:
65+ ret[0 : num_gemms]: GEMM output tensors,
66+ ret[num_gemms]:workspace tensor.
67+ """
68+ del scaling_mode
69+ expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms
70+ assert (
71+ len (args ) == expected_num_args
72+ ), f"Expected { expected_num_args } input arguments, but got { len (args )} "
73+ A_list = args [0 :num_gemms ]
74+ B_list = args [num_gemms : 2 * num_gemms ]
75+ # A and B have shapes [1, m, k] and [1, n, k]
76+ out_list_aval = tuple (
77+ jax .core .ShapedArray ((A .shape [1 ], B .shape [1 ]), dtype = out_dtype )
78+ for A , B in zip (A_list , B_list )
79+ )
80+ workspace_size = get_cublas_workspace_size_bytes () * num_cublas_streams
81+ workspace_aval = jax .core .ShapedArray (shape = (workspace_size ,), dtype = jnp .uint8 )
82+ return (* out_list_aval , workspace_aval )
7083
7184 @staticmethod
7285 def outer_abstract (* args , ** kwargs ):
7386 (out_aval , _ ) = GroupedGemmPrimitive .abstract (* args , ** kwargs )
7487 return out_aval
7588
7689 @staticmethod
77- def lowering (
78- ctx ,
79- lhs_contig ,
80- lhs_scale_inv_contig ,
81- rhs_contig ,
82- rhs_scale_inv_contig ,
83- bias_contig ,
84- dim_list ,
85- * ,
86- num_gemms ,
87- scaling_mode ,
88- out_dtype ,
89- out_flat_size ,
90- ) -> jnp .ndarray :
91- del out_dtype , out_flat_size
90+ def lowering (ctx , * args , num_gemms , scaling_mode , out_dtype , has_bias ):
91+ del out_dtype
9292 return jax .ffi .ffi_lowering (GroupedGemmPrimitive .name )(
9393 ctx ,
94- lhs_contig ,
95- lhs_scale_inv_contig ,
96- rhs_contig ,
97- rhs_scale_inv_contig ,
98- bias_contig ,
99- dim_list ,
94+ * args ,
10095 num_gemms = num_gemms ,
101- scaling_mode = scaling_mode .value ,
96+ scaling_mode = int (scaling_mode ),
97+ has_bias = has_bias ,
10298 )
10399
104100 @staticmethod
105- def impl (
106- lhs_contig ,
107- lhs_scale_inv_contig ,
108- rhs_contig ,
109- rhs_scale_inv_contig ,
110- bias_contig ,
111- dim_list ,
112- num_gemms ,
113- scaling_mode ,
114- out_dtype ,
115- out_flat_size ,
116- ) -> jnp .ndarray :
101+ def impl (* args , num_gemms , scaling_mode , out_dtype , has_bias ):
117102 assert GroupedGemmPrimitive .inner_primitive is not None
118103 out = GroupedGemmPrimitive .inner_primitive .bind (
119- lhs_contig ,
120- lhs_scale_inv_contig ,
121- rhs_contig ,
122- rhs_scale_inv_contig ,
123- bias_contig ,
124- dim_list ,
104+ * args ,
125105 num_gemms = num_gemms ,
126- scaling_mode = scaling_mode ,
106+ scaling_mode = scaling_mode . value ,
127107 out_dtype = out_dtype ,
128- out_flat_size = out_flat_size ,
108+ has_bias = has_bias ,
129109 )
130- return out [0 ] # out is [out_flat , wkspace], only return out_flat
110+ return out [: - 1 ] # out is [out_list , wkspace], only return out_list
131111
132112
133113register_primitive (GroupedGemmPrimitive )
@@ -366,6 +346,7 @@ def swizzled_scale(scales):
366346 rows , cols = scales .shape
367347 scales = scales .reshape (rows // 128 , 4 , 32 , cols // 4 , 4 )
368348 scales = jnp .transpose (scales , (0 , 3 , 2 , 1 , 4 ))
349+ scales = scales .reshape (rows , cols )
369350 return scales
370351
371352
@@ -380,18 +361,12 @@ def grouped_gemm(
380361 len (lhs_list ) == len (rhs_list ) == len (contracting_dims_list )
381362 ), "lhs_list, rhs_list, contracting_dims_list must have the same length"
382363
383- # Flatten inputs and save their shapes
384- num_gemms = len (lhs_list )
385- out_flat_size = 0
386- dims = []
387- lhs_contig_ = []
388- rhs_contig_ = []
389- lhs_scale_inv_contig_ = []
390- rhs_scale_inv_contig_ = []
391- bias_contig_ = []
392- out_offsets = []
393- remain_shape_list = []
394364 num_gemms = len (lhs_list )
365+ lhs_list_ = []
366+ rhs_list_ = []
367+ lhs_sinv_list_ = []
368+ rhs_sinv_list_ = []
369+ bias_list_ = []
395370 for i in range (num_gemms ):
396371 lhs = lhs_list [i ]
397372 rhs = rhs_list [i ]
@@ -402,7 +377,7 @@ def grouped_gemm(
402377 lhs_shape = lhs .data .shape
403378 rhs_shape = rhs .data .shape
404379 out_dtype = lhs .dq_dtype
405- # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING , need to handle internal data_layout
380+ # For ScaledTensors and DELAYED_TENSOR_SCALING , need to handle internal data_layout
406381 if lhs .scaling_mode == ScalingMode .DELAYED_TENSOR_SCALING :
407382 assert not (
408383 lhs .data .dtype == jnp .float8_e5m2 and rhs .data .dtype == jnp .float8_e5m2
@@ -427,6 +402,7 @@ def grouped_gemm(
427402 lhs_remain_shape = _calculate_remaining_shape (lhs_shape , lhs_contract )
428403 rhs_remain_shape = _calculate_remaining_shape (rhs_shape , rhs_contract )
429404
405+ # Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy
430406 if scaling_mode == ScalingMode .NO_SCALING :
431407 lhs_3d = _shape_normalization (lhs , lhs_dn )
432408 rhs_3d = _shape_normalization (rhs , rhs_dn )
@@ -438,13 +414,13 @@ def grouped_gemm(
438414 rhs_3d = _shape_normalization (rhs .data , rhs_dn )
439415 lhs_scale_inv = _shape_normalization (lhs .scale_inv , lhs_dn )
440416 rhs_scale_inv = _shape_normalization (rhs .scale_inv , rhs_dn )
417+ # swizzled_scale requires a matrix
441418 lhs_scale_inv = swizzled_scale (lhs_scale_inv .squeeze ())
442419 rhs_scale_inv = swizzled_scale (rhs_scale_inv .squeeze ())
443420 else :
444421 raise NotImplementedError ("Unsupported ScalingMode: {scaling_mode}" )
445422
446- # Note: if _shape_normalization() is updated to support non-TN, need to update here
447- # already_transposed doesn't matter for the output shape
423+ # Note: already_transposed doesn't matter for the output shape
448424 # x.shape = [B, D1, D2]
449425 # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
450426 # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
@@ -455,66 +431,37 @@ def grouped_gemm(
455431 bn = rhs_remain_shape [0 ]
456432 kl = lhs_3d .shape [- 1 ]
457433 kr = rhs_3d .shape [- 1 ]
458- remain_shape_list .append (((bm ,), (bn ,)))
459- assert kl == kr , f"lhs_3d.shape[-1] ({ kl } ) != rhs_3d.shape[-1] ({ kr } )"
460- k = kl
461-
462- if (bm % 16 != 0 ) or (bn % 16 != 0 ) or (k % 16 != 0 ):
463- print (f"grouped_gemm input pair { i } has invalid problem shape for lowering: " )
464- print (
465- f"m = { bm } , n = { bn } , k = { k } ; cuBLAS requires the problem shapes being multiples"
466- " of 16"
467- )
468- assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0
469-
470- dims .append ((bm , bn , k ))
471- lhs_contig_ .append (lhs_3d .reshape (- 1 ))
472- rhs_contig_ .append (rhs_3d .reshape (- 1 ))
434+ assert kl == kr , f"After shape normalization, contracting dim size mismatch: { kl } != { kr } "
435+ if (bm % 16 != 0 ) or (bn % 16 != 0 ) or (kl % 16 != 0 ):
436+ print ("grouped_gemm input pair {i} has invalid problem shape for lowering: " )
437+ print (f"m = { bm } , n = { bn } , k = { kl } ; " )
438+ print ("cuBLAS requires the problem shapes being multiples of 16" )
439+ assert (bm % 16 == 0 ) and (bn % 16 == 0 ) and (kl % 16 == 0 )
440+
441+ lhs_list_ .append (lhs_3d )
442+ rhs_list_ .append (rhs_3d )
473443 if scaling_mode == ScalingMode .NO_SCALING :
474- lhs_scale_inv_contig_ .append (jnp .ones (1 , dtype = jnp .float32 ))
475- rhs_scale_inv_contig_ .append (jnp .ones (1 , dtype = jnp .float32 ))
444+ lhs_sinv_list_ .append (jnp .ones (1 , dtype = jnp .float32 ))
445+ rhs_sinv_list_ .append (jnp .ones (1 , dtype = jnp .float32 ))
476446 if scaling_mode == ScalingMode .DELAYED_TENSOR_SCALING :
477- lhs_scale_inv_contig_ .append (lhs .scale_inv . reshape ( - 1 ) )
478- rhs_scale_inv_contig_ .append (rhs .scale_inv . reshape ( - 1 ) )
447+ lhs_sinv_list_ .append (lhs .scale_inv )
448+ rhs_sinv_list_ .append (rhs .scale_inv )
479449 if scaling_mode == ScalingMode .MXFP8_1D_SCALING :
480- lhs_scale_inv_contig_ .append (lhs_scale_inv . reshape ( - 1 ) )
481- rhs_scale_inv_contig_ .append (rhs_scale_inv . reshape ( - 1 ) )
450+ lhs_sinv_list_ .append (lhs_scale_inv )
451+ rhs_sinv_list_ .append (rhs_scale_inv )
482452 if bias_list is not None :
483- bias_contig_ .append (bias_list [i ].reshape (- 1 ))
484- out_flat_size += bm * bn
485- out_offsets .append (out_flat_size )
486-
487- lhs_contig = jnp .concatenate (lhs_contig_ )
488- rhs_contig = jnp .concatenate (rhs_contig_ )
489- lhs_scale_inv_contig = jnp .concatenate (lhs_scale_inv_contig_ )
490- rhs_scale_inv_contig = jnp .concatenate (rhs_scale_inv_contig_ )
491- bias_contig = jnp .empty (0 ) if bias_list is None else jnp .concatenate (bias_contig_ )
492- dim_list = jnp .array (dims , dtype = jnp .int32 )
493-
494- # TE/common does not support NVTE_NO_SCALING yet
495- # It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16
496- if scaling_mode == ScalingMode .NO_SCALING :
497- scaling_mode = ScalingMode .DELAYED_TENSOR_SCALING
498-
499- # Perform batched GEMM on flattened inputs
500- out_contig = GroupedGemmPrimitive .outer_primitive .bind (
501- lhs_contig ,
502- lhs_scale_inv_contig ,
503- rhs_contig ,
504- rhs_scale_inv_contig ,
505- bias_contig ,
506- dim_list ,
453+ bias_list_ .append (bias_list [i ])
454+
455+ out_list = GroupedGemmPrimitive .outer_primitive .bind (
456+ * lhs_list_ ,
457+ * rhs_list_ ,
458+ * lhs_sinv_list_ ,
459+ * rhs_sinv_list_ ,
460+ * bias_list_ ,
507461 num_gemms = num_gemms ,
508- scaling_mode = scaling_mode . value ,
462+ scaling_mode = scaling_mode ,
509463 out_dtype = out_dtype ,
510- out_flat_size = out_flat_size ,
464+ has_bias = 1 if bias_list is not None else 0 ,
511465 )
512466
513- # Split the output back into tensors
514- out_offsets = jnp .array (out_offsets )
515- out_flat_list = jnp .split (out_contig , out_offsets [:- 1 ])
516- out_tensors = []
517- for out_flat , (lhs_remain_shape , rhs_remain_shape ) in zip (out_flat_list , remain_shape_list ):
518- out_tensors .append (out_flat .reshape (* lhs_remain_shape , * rhs_remain_shape ))
519-
520- return out_tensors
467+ return out_list
0 commit comments