From 3c0f027dd27abaefde2e3638bbf161d185b17704 Mon Sep 17 00:00:00 2001 From: nenomigami Date: Tue, 19 May 2026 22:03:38 +0900 Subject: [PATCH 1/2] Add colvec scale reduce to gemm_dact --- quack/gemm_dact.py | 110 ++++++++++++++++++++++++++++---------- quack/gemm_interface.py | 115 ++++++++++++++++++++++++++++++++++++---- tests/test_linear.py | 31 ++++++++++- 3 files changed, 217 insertions(+), 39 deletions(-) diff --git a/quack/gemm_dact.py b/quack/gemm_dact.py index a5a576a..5cd5779 100644 --- a/quack/gemm_dact.py +++ b/quack/gemm_dact.py @@ -42,7 +42,22 @@ class GemmDActMixin(GemmActMixin): # Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout) # and return 2 arguments (dx, out) - EpilogueArguments = GemmActMixin.EpilogueArguments + _epi_ops = ( + Scalar("sr_seed", dtype=Int32), + ColVecLoad("mColVecBroadcast"), + TileStore("mAuxOut"), + ColVecReduce("mColVecReduce"), + ) + _extra_param_fields = (("act_fn", cutlass.Constexpr, None),) + + @mlir_namedtuple + class EpilogueArguments(NamedTuple): + mAuxOut: cute.Tensor + act_fn: cutlass.Constexpr[Optional[Callable]] = None + mColVecBroadcast: Optional[cute.Tensor] = None + mColVecReduce: Optional[cute.Tensor] = None + rounding_mode: cutlass.Constexpr[int] = RoundingMode.RN + sr_seed: Optional[Int32 | cute.Tensor] = None @cute.jit def epi_visit_subtile( @@ -53,6 +68,8 @@ def epi_visit_subtile( tRS_rC: Optional[cute.Tensor] = None, ) -> Optional[cute.Tensor]: assert tRS_rC is not None + tDrColVec = epi_loop_tensors.get("mColVecBroadcast") + tDrColVecReduce = epi_loop_tensors.get("mColVecReduce") # We don't add C to the accumulator GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None) tRS_rC_acc = cute.make_rmem_tensor_like(tRS_rC, self.acc_dtype) @@ -74,6 +91,40 @@ def epi_visit_subtile( ) else: tRS_rAuxOut = tRS_rC_acc + if const_expr(tDrColVecReduce is not None): + # Accumulate unscaled postact * GEMM output for router-score gradients. + colvec_reduce_accumulate(self, tDrColVecReduce, tRS_rAuxOut, rScale=tRS_rD) + if const_expr(tDrColVec is not None): + if const_expr(self.arch != 100): + tRS_rD.store(tRS_rD.load() * tDrColVec.load().to(tRS_rD.element_type)) + tRS_rAuxOut.store( + tRS_rAuxOut.load() * tDrColVec.load().to(tRS_rAuxOut.element_type) + ) + else: + tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout) + tRS_rD_mn = layout_utils.convert_layout_zero_stride(tRS_rD, tDrColVec.layout) + tRS_rAuxOut_mn = layout_utils.convert_layout_zero_stride( + tRS_rAuxOut, tDrColVec.layout + ) + for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True): + for n in cutlass.range( + cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True + ): + scale = (tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]) + tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1] = ( + cute.arch.mul_packed_f32x2( + (tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1]), scale + ) + ) + tRS_rAuxOut_mn[m, 2 * n], tRS_rAuxOut_mn[m, 2 * n + 1] = ( + cute.arch.mul_packed_f32x2( + ( + tRS_rAuxOut_mn[m, 2 * n], + tRS_rAuxOut_mn[m, 2 * n + 1], + ), + scale, + ) + ) return tRS_rAuxOut @@ -301,30 +352,30 @@ def _compile_gemm_dact( pa_shape = (m, n) if varlen_m else (m, n, l) mAuxOut = fake_tensor(postact_dtype, pa_shape, leading_dim=pa_leading, divisibility=div_pa) + mColVec = None + if colvec_scale_ndim == 2: + mColVec = fake_tensor(colvec_scale_dtype, (l, m), leading_dim=1, divisibility=4) + elif colvec_scale_ndim == 1: + mColVec = fake_tensor(colvec_scale_dtype, (m,), leading_dim=0, divisibility=4) + mColVecReduce = None + n_tiles = cute.sym_int() + if colvec_reduce_ndim == 3: + mColVecReduce = fake_tensor( + colvec_reduce_dtype, + (l, m, n_tiles), + leading_dim=2, + divisibility=1, + ) + elif colvec_reduce_ndim == 2: + mColVecReduce = fake_tensor( + colvec_reduce_dtype, + (m, n_tiles), + leading_dim=1, + divisibility=1, + ) + if is_dgated: act_fn = dgate_fn_map[activation] - - mColVec = None - if colvec_scale_ndim == 2: - mColVec = fake_tensor(colvec_scale_dtype, (l, m), leading_dim=1, divisibility=4) - elif colvec_scale_ndim == 1: - mColVec = fake_tensor(colvec_scale_dtype, (m,), leading_dim=0, divisibility=4) - mColVecReduce = None - n_tiles = cute.sym_int() - if colvec_reduce_ndim == 3: - mColVecReduce = fake_tensor( - colvec_reduce_dtype, - (l, m, n_tiles), - leading_dim=2, - divisibility=1, - ) - elif colvec_reduce_ndim == 2: - mColVecReduce = fake_tensor( - colvec_reduce_dtype, - (m, n_tiles), - leading_dim=1, - divisibility=1, - ) epi_args = GemmCls.EpilogueArguments( mAuxOut, act_fn, @@ -338,7 +389,12 @@ def _set_implicit_dtype(gemm_obj): post_init = _set_implicit_dtype else: act_fn = dact_fn_map[activation] - epi_args = GemmCls.EpilogueArguments(mAuxOut, act_fn) + epi_args = GemmCls.EpilogueArguments( + mAuxOut, + act_fn, + mColVecBroadcast=mColVec, + mColVecReduce=mColVecReduce, + ) post_init = None scheduler_args = make_fake_scheduler_args( @@ -384,7 +440,7 @@ def gemm_dact( persistent: bool = True, is_dynamic_persistent: bool = False, max_swizzle_size: int = 8, - colvec_scale: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m (dgated only) + colvec_scale: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m # (l, m, ceildiv(n, tile_n)), or (total_m, ceildiv(n, tile_n)) if varlen_m (dgated only) colvec_reduce: Optional[Tensor] = None, cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length @@ -394,8 +450,6 @@ def gemm_dact( is_dgated = activation in dgate_fn_map if not is_dgated: assert activation in dact_fn_map, f"Unsupported activation {activation}" - assert colvec_scale is None, "colvec_scale is only supported for gated activations" - assert colvec_reduce is None, "colvec_reduce is only supported for gated activations" gemm_cls_name = "dgated" if is_dgated else "dact" varlen_m = cu_seqlens_m is not None @@ -500,6 +554,8 @@ def gemm_dact( epi_args = GemmDActMixin.EpilogueArguments( PostAct_p, None, + mColVecBroadcast=colvec_scale, + mColVecReduce=colvec_reduce, rounding_mode=None, sr_seed=None, ) diff --git a/quack/gemm_interface.py b/quack/gemm_interface.py index 067cdaa..ac79f7d 100644 --- a/quack/gemm_interface.py +++ b/quack/gemm_interface.py @@ -393,10 +393,17 @@ def gemm_act_tuned( ) +def prune_invalid_gemm_dact_configs(configs, named_args: dict, **kwargs): + kwargs = named_args | kwargs + if kwargs.get("colvec_scale", None) is not None or kwargs.get("colvec_reduce", False): + configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab] + return prune_invalid_gemm_configs(configs, named_args, **kwargs) + + @autotune( configs=[AutotuneConfig(config=c) for c in get_all_configs()], - key=["activation", "dynamic_scheduler"], - prune_configs_by={"early_config_prune": prune_invalid_gemm_configs}, + key=["activation", "colvec_reduce", "dynamic_scheduler"], + prune_configs_by={"early_config_prune": prune_invalid_gemm_dact_configs}, ) def gemm_dact_tuned( # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m @@ -405,17 +412,20 @@ def gemm_dact_tuned( PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m postact_out: Tensor, # (M, N) or (L, N, N) or (total_M, N) if varlen_m + colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m activation: ActActivation = None, + colvec_reduce: bool = False, cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32 A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m dynamic_scheduler: bool = True, config: Optional[GemmConfig] = None, -) -> None: +) -> Optional[Tensor]: if config is None: config = default_config(A.device) varlen_m = cu_seqlens_m is not None if varlen_m: assert not config.swap_ab, "Variable-length sequences not supported with swap_ab" + og_ndim_2 = A.ndim == 2 and not varlen_m if A.ndim == 2 and not varlen_m: A = A.unsqueeze(0) # (1, M, K) B = B.mT # (N, K) or (L, N, K) @@ -431,6 +441,21 @@ def gemm_dact_tuned( PostAct = postact_out.unsqueeze(0) else: PostAct = postact_out + if colvec_scale is not None and colvec_scale.ndim == 1 and not varlen_m: + colvec_scale = colvec_scale.unsqueeze(0) + if colvec_scale is not None: + assert not config.swap_ab, "colvec_scale not supported with swap_ab" + if colvec_reduce: + tile_n = config.tile_n + shape_n = (B.shape[-2] + tile_n - 1) // tile_n + if varlen_m: + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + colvec_shape = (total_m, shape_n) + else: + colvec_shape = (A.shape[0], A.shape[-2], shape_n) + colvec_reduce_partial = torch.empty(colvec_shape, dtype=torch.float32, device=A.device) + else: + colvec_reduce_partial = None dynamic_scheduler = dynamic_scheduler or config.is_dynamic_persistent tile_count_semaphore = ( torch.zeros(1, dtype=torch.int32, device=A.device) @@ -454,10 +479,19 @@ def gemm_dact_tuned( persistent=True, is_dynamic_persistent=dynamic_scheduler, max_swizzle_size=config.max_swizzle_size, + colvec_scale=colvec_scale, + colvec_reduce=colvec_reduce_partial, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx, use_tma_gather=config.use_tma_gather, ) + if colvec_reduce: + colvec_reduce_final = colvec_reduce_partial.sum(dim=-1) + if og_ndim_2: + colvec_reduce_final = colvec_reduce_final.squeeze(0) + else: + colvec_reduce_final = None + return colvec_reduce_final def gemm( @@ -1124,8 +1158,8 @@ def gemm_dact( postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m out_dtype: Optional[torch.dtype] = None, postact_dtype: Optional[torch.dtype] = None, - colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m (dgated only) - colvec_reduce: bool = False, # dgated only + colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m + colvec_reduce: bool = False, cu_seqlens_m: Optional[Tensor] = None, A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m dynamic_scheduler: bool = True, @@ -1180,19 +1214,23 @@ def gemm_dact( results.append(colvec_reduce_final) return tuple(results) else: - gemm_dact_out( + colvec_reduce_final = gemm_dact_out( A, B, PreAct, dx_out, postact_out, + colvec_scale, activation, + colvec_reduce, cu_seqlens_m, A_idx, dynamic_scheduler, tuned, ) results = [dx_out, postact_out] + if colvec_reduce: + results.append(colvec_reduce_final) return tuple(results) @@ -1203,7 +1241,7 @@ def gemm_dact( "quack::gemm_dact_out", mutates_args=("dx_out", "postact_out"), device_types="cuda", - schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> ()", + schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, Tensor? colvec_scale=None, str? activation=None, bool colvec_reduce=False, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> Tensor", ) def gemm_dact_out( A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m @@ -1211,15 +1249,32 @@ def gemm_dact_out( PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m activation: ActActivation = None, + colvec_reduce: bool = False, cu_seqlens_m: Optional[Tensor] = None, A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m dynamic_scheduler: bool = True, tuned: bool = True, -) -> None: +) -> Tensor: """GEMM with activation gradient and pre-allocated output tensors.""" fn = gemm_dact_tuned if tuned else partial(gemm_dact_tuned.fn, config=None) - fn(A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler) + result = fn( + A, + B, + PreAct, + dx_out, + postact_out, + colvec_scale, + activation, + colvec_reduce, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + ) + if result is None: + return torch.empty(0, device=A.device, dtype=torch.float32) + return result def gemm_dact_ref( @@ -1829,10 +1884,50 @@ def _rewrite_merge_alpha_beta(kwargs): _register_precompile_fake(gemm_out, gemm_tuned, rewrite=_rewrite_merge_alpha) _register_precompile_fake(gemm_add_out, gemm_tuned, rewrite=_rewrite_merge_alpha_beta) _register_precompile_fake(gemm_act_out, gemm_act_tuned) -_register_precompile_fake(gemm_dact_out, gemm_dact_tuned) _register_precompile_fake(gemm_gated_out, gemm_gated_tuned) +@gemm_dact_out.register_fake +def gemm_dact_out_fake( + A: Tensor, + B: Tensor, + PreAct: Tensor, + dx_out: Tensor, + postact_out: Tensor, + colvec_scale: Optional[Tensor] = None, + activation: str = None, + colvec_reduce: bool = False, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, + dynamic_scheduler: bool = True, + tuned: bool = True, +) -> Tensor: + _precompile_default_config( + gemm_dact_tuned, + A, + B, + PreAct, + dx_out, + postact_out, + colvec_scale=colvec_scale, + activation=activation, + colvec_reduce=colvec_reduce, + cu_seqlens_m=cu_seqlens_m, + A_idx=A_idx, + dynamic_scheduler=dynamic_scheduler, + ) + if not colvec_reduce: + return torch.empty(0, dtype=torch.float32, device=A.device) + if cu_seqlens_m is not None: + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + out_shape = (total_m,) + elif A.ndim == 2: + out_shape = (A.shape[0],) + else: + out_shape = (A.shape[0], A.shape[-2]) + return torch.empty(out_shape, dtype=torch.float32, device=A.device) + + @gemm_symmetric_out.register_fake def gemm_symmetric_out_fake( A: Tensor, diff --git a/tests/test_linear.py b/tests/test_linear.py index 0b1beb4..f4a9b89 100644 --- a/tests/test_linear.py +++ b/tests/test_linear.py @@ -194,9 +194,11 @@ def test_linear_act(in_features, out_features, has_bias, input_dtype, activation @pytest.mark.parametrize("activation", ["relu", "relu_sq", "gelu_tanh_approx"]) @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("colvec_reduce", [False, True]) +@pytest.mark.parametrize("has_colvec_scale", [False, True]) @pytest.mark.parametrize("k", [736, 1024]) @pytest.mark.parametrize("n", [1504, 2048]) -def test_gemm_dact(n, k, input_dtype, activation): +def test_gemm_dact(n, k, has_colvec_scale, colvec_reduce, input_dtype, activation): """Test GEMM with activation gradient computation.""" device = "cuda" torch.random.manual_seed(0) @@ -204,14 +206,39 @@ def test_gemm_dact(n, k, input_dtype, activation): dout_input = torch.randn((m, k), device=device, dtype=input_dtype) weight = torch.randn((n, k), device=device, dtype=input_dtype) / math.sqrt(k) preact = torch.randn((m, n), device=device, dtype=input_dtype, requires_grad=True) + colvec_scale = torch.randn(m, device=device) if has_colvec_scale else None # Disable tuning for faster test - dx, postact = gemm_dact(dout_input, weight.T, preact, activation=activation, tuned=False) + dx, postact, *rest = gemm_dact( + dout_input, + weight.T, + preact, + colvec_scale=colvec_scale, + activation=activation, + colvec_reduce=colvec_reduce, + tuned=False, + ) + if colvec_reduce: + colvec_reduce_out = rest[0] dx_ref, postact_ref = gemm_dact_ref( dout_input.float(), weight.float().T, preact.float(), activation=activation ) dx_pt, postact_pt = gemm_dact_ref(dout_input, weight.T, preact, activation=activation) + if colvec_reduce: + colvec_reduce_ref = (postact_ref * gemm_ref(dout_input.float(), weight.float().T)).sum( + dim=-1 + ) + colvec_reduce_pt = (postact_pt * gemm_ref(dout_input, weight.T)).sum(dim=-1) + if has_colvec_scale: + dx_ref *= colvec_scale.float()[:, None] + postact_ref *= colvec_scale.float()[:, None] + dx_pt *= colvec_scale[:, None] + postact_pt *= colvec_scale[:, None] assert (dx - dx_ref).abs().max() < 2 * (dx_pt - dx_ref).abs().max() + 1e-5 assert (postact - postact_ref).abs().max() < 2 * (postact_pt - postact_ref).abs().max() + 1e-5 + if colvec_reduce: + assert (colvec_reduce_out - colvec_reduce_ref).abs().max() < 2 * ( + colvec_reduce_pt - colvec_reduce_ref + ).abs().max() + 1e-5 @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) From 5ddd283f806a5f858adc1e53fba7b12009021028 Mon Sep 17 00:00:00 2001 From: nenomigami Date: Tue, 19 May 2026 22:22:53 +0900 Subject: [PATCH 2/2] Add gemm_dact correctness and benchmark coverage --- benchmarks/benchmark_gemm_autotuned.py | 155 ++++++++++++++++++++++++- tests/test_linear_varlen_m.py | 30 ++++- 2 files changed, 180 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmark_gemm_autotuned.py b/benchmarks/benchmark_gemm_autotuned.py index d7933bd..84de3fd 100644 --- a/benchmarks/benchmark_gemm_autotuned.py +++ b/benchmarks/benchmark_gemm_autotuned.py @@ -40,6 +40,8 @@ gemm, gemm_act, gemm_act_tuned, + gemm_dact, + gemm_dact_tuned, gemm_dgated, gemm_dgated_tuned, gemm_tuned, @@ -130,6 +132,43 @@ def _torch_dgated_act(dact_fn, x, w, preact): return dx, postact +def _drelu(x): + return (x > 0).to(x.dtype), F.relu(x) + + +def _drelu_sq(x): + relu_x = F.relu(x) + return 2.0 * relu_x, relu_x.square() + + +def _dgelu_tanh_approx(x): + c1 = math.sqrt(2.0 / math.pi) + c2 = 0.044715 + z = c1 * (x + c2 * x * x * x) + tanh_z = torch.tanh(z) + sech2_z = 1.0 - tanh_z * tanh_z + dz_dx = c1 * (1.0 + 3.0 * c2 * x * x) + postact = 0.5 * x * (1.0 + tanh_z) + dact = 0.5 * (1.0 + tanh_z) + 0.5 * x * sech2_z * dz_dx + return dact, postact + + +def _torch_dact_act(dact_fn, x, w, preact): + """Reference: GEMM + activation backward over preact.""" + dout = F.linear(x, w) + dact, postact = dact_fn(preact) + return dout * dact, postact + + +_dact_act_fns = { + "silu": _dsilu_exp, + "silu-tanh": _dsilu_tanh, + "relu": _drelu, + "relu_sq": _drelu_sq, + "gelu_tanh_approx": _dgelu_tanh_approx, +} + + _dgated_act_fns = { "swiglu": _dsilu_exp, "swiglu-tanh": _dsilu_tanh, @@ -191,6 +230,60 @@ def benchmark_gemm_act( return ms, tf +def benchmark_gemm_dact( + m, + n, + k, + activation="gelu_tanh_approx", + dtype=torch.bfloat16, + repeats=30, + tuned=True, + config=None, +): + """Benchmark fused GEMM + activation backward.""" + a = torch.randn(m, k, device="cuda", dtype=dtype) + b = torch.randn(k, n, device="cuda", dtype=dtype) / math.sqrt(k) + preact = torch.randn(m, n, device="cuda", dtype=dtype) + nflops = 2 * m * n * k + + if config is None: + fn = lambda: gemm_dact(a, b, preact, activation=activation, out_dtype=dtype, tuned=tuned) + else: + dx_out = torch.empty(m, n, device="cuda", dtype=dtype) + postact_out = torch.empty(m, n, device="cuda", dtype=preact.dtype) + fn = lambda: gemm_dact_tuned.fn( + a, + b, + preact, + dx_out, + postact_out, + None, + activation, + False, + None, + None, + True, + config=config, + ) + fn() # warmup / autotune + time.sleep(0.5) + ms = do_bench(fn, warmup=5, rep=repeats) + tf = tflops(nflops, ms) + + w = b.T.contiguous() + ref_fn = torch.compile(lambda: _torch_dact_act(_dact_act_fns[activation], a, w, preact)) + ref_fn() + ref_fn() + time.sleep(0.5) + ms_pt = do_bench(ref_fn, warmup=5, rep=repeats) + tf_pt = tflops(nflops, ms_pt) + + print(f" quack: {ms:.3f}ms {tf:.1f} TFLOPS") + print(f" cuBLAS + torch.compile: {ms_pt:.3f}ms {tf_pt:.1f} TFLOPS") + print(f" speedup: {ms_pt / ms:.2f}x") + return ms, tf + + def benchmark_gemm_dgated( m, n, @@ -287,6 +380,17 @@ def main(): action="store_true", help="Only run the transformer FFN gated backward GEMM benchmark", ) + parser.add_argument( + "--only-dact", + action="store_true", + help="Only run the GEMM activation backward benchmark", + ) + parser.add_argument( + "--dact-activation", + choices=sorted(_dact_act_fns), + default=None, + help="Restrict the activation backward benchmark to one activation", + ) parser.add_argument( "--dgated-activation", choices=sorted(_dgated_act_fns), @@ -332,11 +436,19 @@ def main(): "swiglu-tanh", ] ) + dact_activations = ( + [args.dact_activation] + if args.dact_activation + else [ + "gelu_tanh_approx", + "relu", + ] + ) forced_config = forced_config_from_args(args) ffn = int(args.dim * 3.5) # Llama-3 ratio - if args.only_gated and args.only_dgated: - raise ValueError("--only-gated and --only-dgated are mutually exclusive") + if sum([args.only_gated, args.only_dgated, args.only_dact]) > 1: + raise ValueError("--only-gated, --only-dgated, and --only-dact are mutually exclusive") if args.only_gated: print( @@ -361,6 +473,27 @@ def main(): ) return + if args.only_dact: + print( + f"GEMM activation backward benchmark (workers={os.environ.get('QUACK_COMPILE_WORKERS', '4')})" + ) + print(f" M={M}, N={N}, K={K}, dtype={args.dtype}") + if forced_config is not None: + print(f" forced config: {forced_config}") + for activation in dact_activations: + print(f"\n d{activation}: ({M}, {K}) x ({K}, {N})") + benchmark_gemm_dact( + M, + N, + K, + activation, + dtype, + repeats=args.repeats, + tuned=not args.untuned and forced_config is None, + config=forced_config, + ) + return + if args.only_dgated: print( f"GEMM gated backward benchmark (workers={os.environ.get('QUACK_COMPILE_WORKERS', '4')})" @@ -413,7 +546,23 @@ def main(): ) print() - # --- 3. Transformer-relevant shapes --- + # --- 3. GEMM + activation backward --- + print("=" * 60) + print(f"GEMM + dGeLU: ({M}, {K}) x ({K}, {N})") + print("=" * 60) + benchmark_gemm_dact( + M, + N, + K, + "gelu_tanh_approx", + dtype, + repeats=args.repeats, + tuned=not args.untuned and forced_config is None, + config=forced_config, + ) + print() + + # --- 4. Transformer-relevant shapes --- batch = args.batch dim = args.dim head_dim = 128 diff --git a/tests/test_linear_varlen_m.py b/tests/test_linear_varlen_m.py index 402cf16..bbae2a7 100644 --- a/tests/test_linear_varlen_m.py +++ b/tests/test_linear_varlen_m.py @@ -466,6 +466,8 @@ def test_gemm_act_varlen_m( # @pytest.mark.parametrize("dynamic_scheduler", [False]) @pytest.mark.parametrize("B_major", ["k", "n"]) @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("colvec_reduce", [False, True]) +@pytest.mark.parametrize("has_colvec_scale", [False, True]) @pytest.mark.parametrize("n", [1024, 1504]) @pytest.mark.parametrize("k", [512, 768]) @pytest.mark.parametrize("num_groups", [2, 4]) @@ -473,6 +475,8 @@ def test_gemm_dact_varlen_m( num_groups, k, n, + has_colvec_scale, + colvec_reduce, input_dtype, B_major, dynamic_scheduler, @@ -493,17 +497,22 @@ def test_gemm_dact_varlen_m( PreAct = torch.randn((total_m, n), device=device, dtype=input_dtype) * 0.1 if B_major == "k": B = B.permute(0, 2, 1).contiguous().permute(0, 2, 1) + colvec_scale = torch.randn(total_m, device=device) if has_colvec_scale else None # Test with kernel - dx, postact = gemm_dact( + dx, postact, *rest = gemm_dact( A, B, PreAct, + colvec_scale=colvec_scale, activation=activation, + colvec_reduce=colvec_reduce, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx, dynamic_scheduler=dynamic_scheduler, tuned=False, ) + if colvec_reduce: + colvec_reduce_out = rest[0] assert dx.shape == (total_m, n) assert postact.shape == (total_m, n) # Compare with reference @@ -516,12 +525,29 @@ def test_gemm_dact_varlen_m( cu_seqlens_m=cu_seqlens_m, A_idx=A_idx, ) - del A_f, B_f, P_f + del P_f dx_pt, postact_pt = gemm_dact_ref( A, B, PreAct, activation=activation, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx ) + if colvec_reduce: + colvec_reduce_ref = ( + postact_ref * gemm_ref(A_f, B_f, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx) + ).sum(dim=-1) + colvec_reduce_pt = ( + postact_pt * gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx) + ).sum(dim=-1) + del A_f, B_f + if has_colvec_scale: + dx_ref *= colvec_scale.float()[:, None] + postact_ref *= colvec_scale.float()[:, None] + dx_pt *= colvec_scale[:, None] + postact_pt *= colvec_scale[:, None] assert (dx - dx_ref).abs().max() < 2 * (dx_pt - dx_ref).abs().max() + 1e-5 assert (postact - postact_ref).abs().max() < 2 * (postact_pt - postact_ref).abs().max() + 1e-5 + if colvec_reduce: + assert (colvec_reduce_out - colvec_reduce_ref).abs().max() < 2 * ( + colvec_reduce_pt - colvec_reduce_ref + ).abs().max() + 1e-5 @pytest.mark.parametrize("pre_allocate_out", [False, True])