Skip to content

Commit 98b4c0d

Browse files
[JAX] grouped_gemm() uses variadic arguments (#1658)
* New GroupedGemmPrimitive using variadic args * Remove squeeze() to reduce D2D memcpy * Revert to the list append fashion to simplify code --------- Signed-off-by: Hua Huang <huah@nvidia.com> Co-authored-by: Phuong Nguyen <phuonguyen@nvidia.com>
1 parent c8e7cc0 commit 98b4c0d

2 files changed

Lines changed: 173 additions & 239 deletions

File tree

transformer_engine/jax/cpp_extensions/gemm.py

Lines changed: 81 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -41,93 +41,73 @@ class GroupedGemmPrimitive(BasePrimitive):
4141

4242
name = "te_grouped_gemm_ffi"
4343
multiple_results = True
44-
impl_static_args = (6, 7, 8, 9)
44+
impl_static_args = ()
4545
inner_primitive = None
4646
outer_primitive = None
4747

4848
@staticmethod
49-
def abstract(
50-
lhs_contig_aval,
51-
lhs_scale_contig_aval,
52-
rhs_contig_aval,
53-
rhs_scale_contig_aval,
54-
bias_contig_aval,
55-
dim_list_aval,
56-
*,
57-
num_gemms,
58-
scaling_mode,
59-
out_dtype,
60-
out_flat_size,
61-
):
62-
del lhs_contig_aval, lhs_scale_contig_aval
63-
del rhs_contig_aval, rhs_scale_contig_aval
64-
del bias_contig_aval, dim_list_aval
65-
del num_gemms, scaling_mode
66-
out_flat_aval = jax.core.ShapedArray(shape=(out_flat_size,), dtype=out_dtype)
67-
wkspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
68-
wkspace_aval = jax.core.ShapedArray(shape=(wkspace_size,), dtype=jnp.uint8)
69-
return (out_flat_aval, wkspace_aval)
49+
def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias):
50+
"""
51+
Args:
52+
*args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias:
53+
args[ 0 : num_gemms] are the lhs tensors,
54+
args[ num_gemms : 2*num_gemms] are the rhs tensors,
55+
args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors,
56+
args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors,
57+
args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True.
58+
num_gemms: Number of GEMM operations to perform.
59+
scaling_mode: Scaling mode for the GEMM operations.
60+
out_dtype: Data type of the output tensors.
61+
has_bias: Boolean indicating if bias tensors are provided.
62+
63+
Returns:
64+
A tuple of ShapedArray objects of size num_gemms+1:
65+
ret[0 : num_gemms]: GEMM output tensors,
66+
ret[num_gemms]:workspace tensor.
67+
"""
68+
del scaling_mode
69+
expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms
70+
assert (
71+
len(args) == expected_num_args
72+
), f"Expected {expected_num_args} input arguments, but got {len(args)}"
73+
A_list = args[0:num_gemms]
74+
B_list = args[num_gemms : 2 * num_gemms]
75+
# A and B have shapes [1, m, k] and [1, n, k]
76+
out_list_aval = tuple(
77+
jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype)
78+
for A, B in zip(A_list, B_list)
79+
)
80+
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
81+
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
82+
return (*out_list_aval, workspace_aval)
7083

7184
@staticmethod
7285
def outer_abstract(*args, **kwargs):
7386
(out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs)
7487
return out_aval
7588

7689
@staticmethod
77-
def lowering(
78-
ctx,
79-
lhs_contig,
80-
lhs_scale_inv_contig,
81-
rhs_contig,
82-
rhs_scale_inv_contig,
83-
bias_contig,
84-
dim_list,
85-
*,
86-
num_gemms,
87-
scaling_mode,
88-
out_dtype,
89-
out_flat_size,
90-
) -> jnp.ndarray:
91-
del out_dtype, out_flat_size
90+
def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias):
91+
del out_dtype
9292
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
9393
ctx,
94-
lhs_contig,
95-
lhs_scale_inv_contig,
96-
rhs_contig,
97-
rhs_scale_inv_contig,
98-
bias_contig,
99-
dim_list,
94+
*args,
10095
num_gemms=num_gemms,
101-
scaling_mode=scaling_mode.value,
96+
scaling_mode=int(scaling_mode),
97+
has_bias=has_bias,
10298
)
10399

104100
@staticmethod
105-
def impl(
106-
lhs_contig,
107-
lhs_scale_inv_contig,
108-
rhs_contig,
109-
rhs_scale_inv_contig,
110-
bias_contig,
111-
dim_list,
112-
num_gemms,
113-
scaling_mode,
114-
out_dtype,
115-
out_flat_size,
116-
) -> jnp.ndarray:
101+
def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias):
117102
assert GroupedGemmPrimitive.inner_primitive is not None
118103
out = GroupedGemmPrimitive.inner_primitive.bind(
119-
lhs_contig,
120-
lhs_scale_inv_contig,
121-
rhs_contig,
122-
rhs_scale_inv_contig,
123-
bias_contig,
124-
dim_list,
104+
*args,
125105
num_gemms=num_gemms,
126-
scaling_mode=scaling_mode,
106+
scaling_mode=scaling_mode.value,
127107
out_dtype=out_dtype,
128-
out_flat_size=out_flat_size,
108+
has_bias=has_bias,
129109
)
130-
return out[0] # out is [out_flat, wkspace], only return out_flat
110+
return out[:-1] # out is [out_list, wkspace], only return out_list
131111

132112

133113
register_primitive(GroupedGemmPrimitive)
@@ -366,6 +346,7 @@ def swizzled_scale(scales):
366346
rows, cols = scales.shape
367347
scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4)
368348
scales = jnp.transpose(scales, (0, 3, 2, 1, 4))
349+
scales = scales.reshape(rows, cols)
369350
return scales
370351

371352

@@ -380,18 +361,12 @@ def grouped_gemm(
380361
len(lhs_list) == len(rhs_list) == len(contracting_dims_list)
381362
), "lhs_list, rhs_list, contracting_dims_list must have the same length"
382363

383-
# Flatten inputs and save their shapes
384-
num_gemms = len(lhs_list)
385-
out_flat_size = 0
386-
dims = []
387-
lhs_contig_ = []
388-
rhs_contig_ = []
389-
lhs_scale_inv_contig_ = []
390-
rhs_scale_inv_contig_ = []
391-
bias_contig_ = []
392-
out_offsets = []
393-
remain_shape_list = []
394364
num_gemms = len(lhs_list)
365+
lhs_list_ = []
366+
rhs_list_ = []
367+
lhs_sinv_list_ = []
368+
rhs_sinv_list_ = []
369+
bias_list_ = []
395370
for i in range(num_gemms):
396371
lhs = lhs_list[i]
397372
rhs = rhs_list[i]
@@ -402,7 +377,7 @@ def grouped_gemm(
402377
lhs_shape = lhs.data.shape
403378
rhs_shape = rhs.data.shape
404379
out_dtype = lhs.dq_dtype
405-
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout
380+
# For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout
406381
if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
407382
assert not (
408383
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
@@ -427,6 +402,7 @@ def grouped_gemm(
427402
lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract)
428403
rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract)
429404

405+
# Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy
430406
if scaling_mode == ScalingMode.NO_SCALING:
431407
lhs_3d = _shape_normalization(lhs, lhs_dn)
432408
rhs_3d = _shape_normalization(rhs, rhs_dn)
@@ -438,13 +414,13 @@ def grouped_gemm(
438414
rhs_3d = _shape_normalization(rhs.data, rhs_dn)
439415
lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn)
440416
rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn)
417+
# swizzled_scale requires a matrix
441418
lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze())
442419
rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze())
443420
else:
444421
raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}")
445422

446-
# Note: if _shape_normalization() is updated to support non-TN, need to update here
447-
# already_transposed doesn't matter for the output shape
423+
# Note: already_transposed doesn't matter for the output shape
448424
# x.shape = [B, D1, D2]
449425
# contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
450426
# contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
@@ -455,66 +431,37 @@ def grouped_gemm(
455431
bn = rhs_remain_shape[0]
456432
kl = lhs_3d.shape[-1]
457433
kr = rhs_3d.shape[-1]
458-
remain_shape_list.append(((bm,), (bn,)))
459-
assert kl == kr, f"lhs_3d.shape[-1] ({kl}) != rhs_3d.shape[-1] ({kr})"
460-
k = kl
461-
462-
if (bm % 16 != 0) or (bn % 16 != 0) or (k % 16 != 0):
463-
print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ")
464-
print(
465-
f"m = {bm}, n = {bn}, k = {k}; cuBLAS requires the problem shapes being multiples"
466-
" of 16"
467-
)
468-
assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0
469-
470-
dims.append((bm, bn, k))
471-
lhs_contig_.append(lhs_3d.reshape(-1))
472-
rhs_contig_.append(rhs_3d.reshape(-1))
434+
assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}"
435+
if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0):
436+
print("grouped_gemm input pair {i} has invalid problem shape for lowering: ")
437+
print(f"m = {bm}, n = {bn}, k = {kl}; ")
438+
print("cuBLAS requires the problem shapes being multiples of 16")
439+
assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0)
440+
441+
lhs_list_.append(lhs_3d)
442+
rhs_list_.append(rhs_3d)
473443
if scaling_mode == ScalingMode.NO_SCALING:
474-
lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
475-
rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
444+
lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
445+
rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
476446
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
477-
lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1))
478-
rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1))
447+
lhs_sinv_list_.append(lhs.scale_inv)
448+
rhs_sinv_list_.append(rhs.scale_inv)
479449
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
480-
lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1))
481-
rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1))
450+
lhs_sinv_list_.append(lhs_scale_inv)
451+
rhs_sinv_list_.append(rhs_scale_inv)
482452
if bias_list is not None:
483-
bias_contig_.append(bias_list[i].reshape(-1))
484-
out_flat_size += bm * bn
485-
out_offsets.append(out_flat_size)
486-
487-
lhs_contig = jnp.concatenate(lhs_contig_)
488-
rhs_contig = jnp.concatenate(rhs_contig_)
489-
lhs_scale_inv_contig = jnp.concatenate(lhs_scale_inv_contig_)
490-
rhs_scale_inv_contig = jnp.concatenate(rhs_scale_inv_contig_)
491-
bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_)
492-
dim_list = jnp.array(dims, dtype=jnp.int32)
493-
494-
# TE/common does not support NVTE_NO_SCALING yet
495-
# It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16
496-
if scaling_mode == ScalingMode.NO_SCALING:
497-
scaling_mode = ScalingMode.DELAYED_TENSOR_SCALING
498-
499-
# Perform batched GEMM on flattened inputs
500-
out_contig = GroupedGemmPrimitive.outer_primitive.bind(
501-
lhs_contig,
502-
lhs_scale_inv_contig,
503-
rhs_contig,
504-
rhs_scale_inv_contig,
505-
bias_contig,
506-
dim_list,
453+
bias_list_.append(bias_list[i])
454+
455+
out_list = GroupedGemmPrimitive.outer_primitive.bind(
456+
*lhs_list_,
457+
*rhs_list_,
458+
*lhs_sinv_list_,
459+
*rhs_sinv_list_,
460+
*bias_list_,
507461
num_gemms=num_gemms,
508-
scaling_mode=scaling_mode.value,
462+
scaling_mode=scaling_mode,
509463
out_dtype=out_dtype,
510-
out_flat_size=out_flat_size,
464+
has_bias=1 if bias_list is not None else 0,
511465
)
512466

513-
# Split the output back into tensors
514-
out_offsets = jnp.array(out_offsets)
515-
out_flat_list = jnp.split(out_contig, out_offsets[:-1])
516-
out_tensors = []
517-
for out_flat, (lhs_remain_shape, rhs_remain_shape) in zip(out_flat_list, remain_shape_list):
518-
out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape))
519-
520-
return out_tensors
467+
return out_list

0 commit comments

Comments
 (0)