diff --git a/python/tokenspeed/runtime/layers/moe/backends/fp8/flashinfer_cutlass.py b/python/tokenspeed/runtime/layers/moe/backends/fp8/flashinfer_cutlass.py index e4f9ee74..99c6c4df 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/fp8/flashinfer_cutlass.py +++ b/python/tokenspeed/runtime/layers/moe/backends/fp8/flashinfer_cutlass.py @@ -128,8 +128,8 @@ def forward( activation_type=ActivationType.Swiglu, dtype=x.dtype, features={"pre_routed"}, + weight_format="fp8", traits={ - "weight_dtype": "fp8", "tp": self.spec.tp_size > 1, "ep": self.spec.ep_size > 1, "cuda_graph": False, diff --git a/python/tokenspeed/runtime/layers/moe/backends/mxfp4/flashinfer.py b/python/tokenspeed/runtime/layers/moe/backends/mxfp4/flashinfer.py index e82e4db8..3f8bd206 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/mxfp4/flashinfer.py +++ b/python/tokenspeed/runtime/layers/moe/backends/mxfp4/flashinfer.py @@ -409,7 +409,7 @@ def _call_kernel(self, router_logits, x_quant, x_scale, layer, top_k, output): output=output, dtype=torch.bfloat16, features={"self_routing"}, - traits={"weight_dtype": "mxfp4"}, + weight_format="mxfp4", expected_kernel_name="flashinfer_trtllm_fp4_fused_moe", )[0] diff --git a/python/tokenspeed/runtime/layers/moe/backends/mxfp4/triton_kernel.py b/python/tokenspeed/runtime/layers/moe/backends/mxfp4/triton_kernel.py index 1d308e0f..859e9c5f 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/mxfp4/triton_kernel.py +++ b/python/tokenspeed/runtime/layers/moe/backends/mxfp4/triton_kernel.py @@ -301,7 +301,6 @@ def forward( fused_activation=act, dtype=hidden_states.dtype, features={"ragged_metadata", "dispatch_gemm"}, - traits={"weight_dtype": "mxfp4"}, expected_kernel_name="triton_kernels_dispatch_gemm", ) @@ -328,7 +327,6 @@ def forward( n_expts_act=top_k, dtype=hidden_states.dtype, features={"ragged_metadata", "gemm_combine"}, - traits={"weight_dtype": "mxfp4"}, expected_kernel_name="triton_kernels_gemm_combine", ) diff --git a/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutedsl.py b/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutedsl.py index 32dcd352..4c699edd 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutedsl.py +++ b/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutedsl.py @@ -161,8 +161,8 @@ def _call_kernel( capacity=capacity, dtype=x_fp4.dtype, features={"pre_routed"}, + weight_format="nvfp4", traits={ - "weight_dtype": "nvfp4", "tp": False, "ep": True, "cuda_graph": True, diff --git a/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutlass.py b/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutlass.py index 6bd67d73..c219636c 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutlass.py +++ b/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutlass.py @@ -128,8 +128,8 @@ def forward( activation_type=ActivationType.Swiglu, dtype=x.dtype, features={"pre_routed"}, + weight_format="nvfp4", traits={ - "weight_dtype": "nvfp4", "tp": True, "ep": True, "cuda_graph": False, diff --git a/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_trtllm.py b/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_trtllm.py index c762a4b6..184bcd89 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_trtllm.py +++ b/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_trtllm.py @@ -336,7 +336,7 @@ def forward( tune_max_num_tokens=next_power_of_2(num_tokens), dtype=x.dtype, features={"self_routing"}, - traits={"weight_dtype": "nvfp4"}, + weight_format="nvfp4", expected_kernel_name="flashinfer_trtllm_fp4_fused_moe", ) if do_finalize: diff --git a/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_cutlass.py b/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_cutlass.py index 144587e5..f76b4d2f 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_cutlass.py +++ b/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_cutlass.py @@ -105,8 +105,8 @@ def _call_cutlass_kernel(self, x, layer, topk_output): activation_type=ActivationType.Swiglu, dtype=x.dtype, features={"pre_routed"}, + weight_format="bf16", traits={ - "weight_dtype": "bf16", "tp": True, "ep": True, "cuda_graph": False, diff --git a/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_trtllm.py b/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_trtllm.py index 4370b446..c4a31168 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_trtllm.py +++ b/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_trtllm.py @@ -203,7 +203,7 @@ def _call_trtllm_kernel(self, router_logits, x, layer, top_k, do_finalize): tune_max_num_tokens=next_power_of_2(x.shape[0]), dtype=x.dtype, features={"self_routing"}, - traits={"weight_dtype": "bf16"}, + weight_format="bf16", expected_kernel_name="flashinfer_trtllm_bf16_fused_moe", ) if do_finalize: diff --git a/tokenspeed-kernel/README.md b/tokenspeed-kernel/README.md index 8bac78f8..ead39fdc 100644 --- a/tokenspeed-kernel/README.md +++ b/tokenspeed-kernel/README.md @@ -37,7 +37,7 @@ choices (still evolving; subject to change): public API (mha_prefill, mm, moe_fused, ...) │ ┌───────────┴───────────┐ - │ select_kernel │ (family, mode, dtype, traits, ...) + │ select_kernel │ (family, mode, format_signature, traits, ...) └───────────┬───────────┘ │ queries ┌──────────┴──────────┐ @@ -57,8 +57,8 @@ choices (still evolving; subject to change): ``` - **Registration** — backends register with `@register_kernel(family, mode, ...)`, - declaring supported dtypes, arch capability requirements, traits (head dim, - GQA factor, ...), and a priority band. + declaring supported `format_signatures`, arch capability requirements, + non-format traits (head dim, GQA factor, ...), and a priority band. - **Auto-selection** — `select_kernel` filters by capability and traits, ranks the survivors with an optional per-family `SelectionOracle` and priority, and returns a callable. Selection accepts an objective (latency, @@ -71,6 +71,7 @@ choices (still evolving; subject to change): tokenspeed_kernel/ __init__.py # Public API re-exports platform.py # PlatformInfo, capability detection + signature.py # TensorFormat, ScaleFormat, FormatSignature registry.py # KernelRegistry, register_kernel, Priority bands selection.py # select_kernel, oracles, overrides profiling.py # ShapeCapture, kernel_scope, Proton bootstrap diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/benchmark/runner.py b/tokenspeed-kernel/python/tokenspeed_kernel/benchmark/runner.py index 0f5a7b35..b9377797 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/benchmark/runner.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/benchmark/runner.py @@ -139,11 +139,16 @@ def _benchmark_one_shape( if not spec_matches_shape_traits(spec, shape): return None + signature = spec.format_signature_for_primary_storage_dtype(dtype) + if signature is None: + return None + generator = get_input_generator( spec.family, spec.mode, dtype=dtype, traits=spec.traits, + format_signature=signature, device="cuda", seed=self.config.seed, ) @@ -224,10 +229,14 @@ def _verify_one_shape( return None, None, None registry = KernelRegistry.get() + signature = spec.format_signature_for_primary_storage_dtype(dtype) + if signature is None: + return None, None, None + ref_specs = registry.get_for_operator( spec.family, spec.mode, - dtype=dtype, + format_signature=signature, solution="reference", ) if not ref_specs: @@ -285,8 +294,10 @@ def _benchmark_kernel_impl( if spec is None: raise ValueError(f"Kernel {kernel_name!r} is not registered") - if dtype not in spec.dtypes: - raise ValueError(f"Kernel {kernel_name!r} does not support dtype={dtype}") + if spec.format_signature_for_primary_storage_dtype(dtype) is None: + raise ValueError( + f"Kernel {kernel_name!r} does not support primary storage dtype={dtype}" + ) platform = current_platform() if not spec.capability.satisfied_by(platform): @@ -337,12 +348,15 @@ def _benchmark_op_impl( """Benchmark all implementations of an op.""" registry = KernelRegistry.get() platform = current_platform() - specs = registry.get_for_operator( - op_family, - op_mode, - platform=platform, - dtype=dtype, - ) + specs = [ + spec + for spec in registry.get_for_operator( + op_family, + op_mode, + platform=platform, + ) + if spec.format_signature_for_primary_storage_dtype(dtype) is not None + ] results: list[BenchmarkResult] = [] for spec in sorted(specs, key=lambda item: (item.solution, item.name)): diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/cli.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/cli.py index 14ddd136..293f648e 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/cli.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/cli.py @@ -80,7 +80,11 @@ def _iter_candidate_specs( specs = [s for s in specs if s.family == family and s.mode == mode] if dtype_filter is not None: - specs = [s for s in specs if dtype_filter in s.dtypes] + specs = [ + s + for s in specs + if s.format_signature_for_primary_storage_dtype(dtype_filter) is not None + ] specs.sort(key=lambda s: (s.family, s.mode, s.name)) return specs @@ -92,7 +96,7 @@ def _iter_dtypes( ) -> Iterable[torch.dtype]: if dtype_filter is not None: return (dtype_filter,) - return sorted(spec.dtypes, key=str) + return sorted(spec.primary_storage_dtypes(), key=str) def main(argv: list[str] | None = None) -> int: diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py index e3a98eae..07321a6e 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py @@ -31,6 +31,7 @@ set_standard_shapes, ) from tokenspeed_kernel.numerics.tolerance import Tolerance, set_family_tolerance +from tokenspeed_kernel.signature import TensorFormat # --------------------------------------------------------------------------- # Tolerance @@ -119,45 +120,95 @@ def _generate_scales(self, shape: tuple[int, ...], dtype) -> torch.Tensor: ) return scales.to(dtype) + def _format(self, role: str) -> TensorFormat | None: + if self.format_signature is None: + return None + return self.format_signature.format_for(role) + + def _block_size( + self, + *formats: TensorFormat | None, + ) -> list[int] | None: + for tensor_format in formats: + scale = tensor_format.scale if tensor_format is not None else None + if scale is not None and scale.block_shape is not None: + return list(scale.block_shape) + return None + + def _scale_for_format( + self, + tensor_format: TensorFormat | None, + role: str, + *, + M: int, + N: int, + K: int, + block_size: list[int] | None, + ) -> torch.Tensor | None: + scale = tensor_format.scale if tensor_format is not None else None + if scale is None: + return None + + if scale.granularity == "block" and tensor_format.format == "mxfp8": + if block_size is None: + raise ValueError("mxfp8 block scale format requires block_shape") + block_n, block_k = block_size + k_tiles = math.ceil(K / block_k) + if role == "a": + return self._generate_scales((M, k_tiles), scale.storage_dtype) + if role == "b": + n_tiles = math.ceil(N / block_n) + return self._generate_scales((n_tiles, k_tiles), scale.storage_dtype) + + if scale.granularity == "channel": + return self._generate_scales( + (M,) if role == "a" else (N,), + scale.storage_dtype, + ) + + return self._generate_scales((1,), scale.storage_dtype) + def generate( self, M: int, N: int, K: int, ) -> dict[str, Any]: - quant = self.traits.get("quant") - scale_type = self.traits.get("scale_type") a_layout = self.traits.get("a_layout") b_layout = self.traits.get("b_layout") + a_format = self._format("a") + b_format = self._format("b") + a_dtype = a_format.storage_dtype if a_format is not None else self.dtype + b_dtype = b_format.storage_dtype if b_format is not None else self.dtype A = ( - self._generate_value((K, M), self.dtype) + self._generate_value((K, M), a_dtype) if a_layout == {"KM"} - else self._generate_value((M, K), self.dtype) + else self._generate_value((M, K), a_dtype) ) B = ( - self._generate_value((K, N), self.dtype) + self._generate_value((K, N), b_dtype) if b_layout == {"KN"} - else self._generate_value((N, K), self.dtype) + else self._generate_value((N, K), b_dtype) ) - A_scales = None - B_scales = None - block_size = None - - if quant == {"mxfp8"}: - block_size = [128, 128] - k_tiles = math.ceil(K / block_size[0]) - n_tiles = math.ceil(N / block_size[1]) - A_scales = self._generate_scales((M, k_tiles), torch.float32) - B_scales = self._generate_scales((n_tiles, k_tiles), torch.float32) - else: - if scale_type == {"per_channel"}: - A_scales = self._generate_scales((M,), torch.float32) - B_scales = self._generate_scales((N,), torch.float32) - else: - A_scales = self._generate_scales((1,), torch.float32) - B_scales = self._generate_scales((1,), torch.float32) + block_size = self._block_size(a_format, b_format) + A_scales = self._scale_for_format( + a_format, + "a", + M=M, + N=N, + K=K, + block_size=block_size, + ) + B_scales = self._scale_for_format( + b_format, + "b", + M=M, + N=N, + K=K, + block_size=block_size, + ) out_dtype = torch.bfloat16 alpha = None diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/inputs.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/inputs.py index c8340a9a..28f8815f 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/inputs.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/inputs.py @@ -23,7 +23,7 @@ from typing import Any, Callable import torch -from tokenspeed_kernel.registry import KernelSpec +from tokenspeed_kernel.signature import FormatSignature __all__ = [ "InputGenerator", @@ -52,6 +52,7 @@ def __init__( dtype: torch.dtype, traits: dict, *, + format_signature: FormatSignature | None = None, device: str | None = None, seed: int = 42, ) -> None: @@ -59,6 +60,7 @@ def __init__( self.op_mode = op_mode self.dtype = dtype self.traits = traits + self.format_signature = format_signature self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") rng_device = "cuda" if self.device.startswith("cuda") else "cpu" @@ -98,6 +100,7 @@ def get_input_generator( dtype: torch.dtype, traits: dict, *, + format_signature: FormatSignature | None = None, device: str | None = None, seed: int = 42, ) -> InputGenerator: @@ -107,7 +110,15 @@ def get_input_generator( raise KeyError( f"No input generator registered for {op_family}.{op_mode}. Known: {known}" ) - return factory(op_family, op_mode, dtype, traits, device=device, seed=seed) + return factory( + op_family, + op_mode, + dtype, + traits, + format_signature=format_signature, + device=device, + seed=seed, + ) def get_standard_shapes(op_family: str, op_mode: str) -> list[dict[str, Any]]: diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py index 2b2c7f1d..014ae4f9 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py @@ -26,8 +26,27 @@ import torch.nn.functional as F from tokenspeed_kernel.platform import Platform from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import ScaleFormat, format_signatures fp8_dtype = Platform.get().fp8e4m3fn.dtype +_FP8_BLOCK_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + block_shape=(128, 128), +) +_FP8_TENSOR_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="tensor", +) +_MXFP8_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "mxfp8", {fp8_dtype}, scale=_FP8_BLOCK_SCALE +) +_FP8_TENSOR_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "fp8", {fp8_dtype}, scale=_FP8_TENSOR_SCALE +) +_DENSE_GEMM_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "dense", {torch.bfloat16, torch.float16, torch.float32} +) @register_kernel( @@ -35,8 +54,8 @@ "mm", name="torch_mm_fp8_blockscale", solution="reference", - dtypes={fp8_dtype}, - traits={"quant": frozenset({"mxfp8"})}, + signatures=_MXFP8_FORMAT_SIGNATURES, + traits={}, priority=Priority.PORTABLE + 2, tags={"portability"}, ) @@ -90,7 +109,7 @@ def torch_mm_fp8_blockscale( "mm", name="torch_mm_fp8_scaled_mnk", solution="reference", - dtypes={fp8_dtype}, + signatures=_FP8_TENSOR_FORMAT_SIGNATURES, traits={ "b_layout": frozenset({"NK"}), }, @@ -132,7 +151,7 @@ def torch_mm_fp8_scaled_mnk( "mm", name="torch_mm_fp8_scaled_nkm", solution="reference", - dtypes={fp8_dtype}, + signatures=_FP8_TENSOR_FORMAT_SIGNATURES, traits={ "b_layout": frozenset({"KN"}), }, @@ -172,7 +191,7 @@ def torch_mm_fp8_scaled_nkm( "mm", name="torch_mm", solution="reference", - dtypes={torch.bfloat16, torch.float16, torch.float32}, + signatures=_DENSE_GEMM_FORMAT_SIGNATURES, traits={}, priority=Priority.PORTABLE + 3, tags={"determinism", "portability"}, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py index 7905765d..fd2f3bbc 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py @@ -33,6 +33,11 @@ topk_ids_logical_to_physical, ) from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import ( + dense_tensor_format, + format_signature, + format_signatures, +) from tokenspeed_kernel.torch_compile import get_compiler_backend # --------------------------------------------------------------------------- @@ -46,10 +51,14 @@ name="reference_moe_fused", features={"pre_routed"}, solution="reference", - dtypes={torch.float16, torch.bfloat16, torch.float32}, + signatures=frozenset( + format_signature( + x=dense_tensor_format(dtype), weight=dense_tensor_format(torch.bfloat16) + ) + for dtype in {torch.float16, torch.bfloat16, torch.float32} + ), priority=Priority.REFERENCE, traits={ - "weight_dtype": frozenset({"bf16", "fp16", "fp32"}), "tp": frozenset({False}), "ep": frozenset({False}), }, @@ -89,7 +98,9 @@ def fused_moe_forward_native( "route", name="torch_compile_fused_topk_bias", solution="reference", - dtypes={torch.float16, torch.bfloat16, torch.float32}, + signatures=format_signatures( + "logits", "dense", {torch.float16, torch.bfloat16, torch.float32} + ), traits={ "output_type": frozenset({"topk"}), "biased": frozenset({True}), @@ -131,7 +142,9 @@ def fused_topk_bias( "route", name="torch_native_fused_topk", solution="reference", - dtypes={torch.float16, torch.bfloat16, torch.float32}, + signatures=format_signatures( + "logits", "dense", {torch.float16, torch.bfloat16, torch.float32} + ), traits={ "output_type": frozenset({"topk"}), "biased": frozenset({True, False}), @@ -189,7 +202,9 @@ def _mask_topk_ids_padded_region( "route", name="torch_compile_grouped_topk", solution="reference", - dtypes={torch.float16, torch.bfloat16, torch.float32}, + signatures=format_signatures( + "logits", "dense", {torch.float16, torch.bfloat16, torch.float32} + ), traits={ "output_type": frozenset({"topk"}), "biased": frozenset({False}), @@ -270,7 +285,9 @@ def grouped_topk_gpu( "route", name="torch_compile_biased_grouped_topk", solution="reference", - dtypes={torch.float16, torch.bfloat16, torch.float32}, + signatures=format_signatures( + "logits", "dense", {torch.float16, torch.bfloat16, torch.float32} + ), traits={ "output_type": frozenset({"topk"}), "biased": frozenset({True}), @@ -366,7 +383,7 @@ def biased_grouped_topk_gpu( "align_block_size", name="torch_moe_align_block_size", solution="reference", - dtypes={torch.int32}, + signatures=format_signatures("indices", "dense", {torch.int32}), traits={}, priority=10, tags={"determinism", "portability"}, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/quantize.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/quantize.py index 9b3ddc45..9fcaa15a 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/quantize.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/quantize.py @@ -49,7 +49,7 @@ def _quantize_fp8(x_fp32: torch.Tensor, max_abs: torch.Tensor) -> torch.Tensor: "fp8_token_group_128", name="torch_fp8_token_group_128", solution="reference", - dtypes={torch.bfloat16, torch.float16}, + signatures=format_signatures("x", "dense", {torch.bfloat16, torch.float16}), traits={}, priority=10, tags={"determinism", "portability"}, @@ -69,7 +69,7 @@ def torch_fp8_token_group_128(x: torch.Tensor) -> torch.Tensor: "fp8_token", name="torch_fp8_token", solution="reference", - dtypes={torch.bfloat16, torch.float16}, + signatures=format_signatures("x", "dense", {torch.bfloat16, torch.float16}), traits={}, priority=10, tags={"determinism", "portability"}, @@ -86,7 +86,7 @@ def torch_fp8_token(x: torch.Tensor) -> torch.Tensor: "fp8_tensor", name="torch_fp8_tensor", solution="reference", - dtypes={torch.bfloat16, torch.float16}, + signatures=format_signatures("x", "dense", {torch.bfloat16, torch.float16}), traits={}, priority=10, tags={"determinism", "portability"}, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py index c44c8749..f69eca43 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py @@ -38,11 +38,12 @@ ToleranceOverride, get_family_tolerance, ) -from tokenspeed_kernel.registry import KernelRegistry, load_builtin_kernels +from tokenspeed_kernel.registry import KernelRegistry, KernelSpec, load_builtin_kernels from tokenspeed_kernel.selection import ( ref_compatible_with_spec, spec_matches_shape_traits, ) +from tokenspeed_kernel.signature import FormatSignature # isort: split import tokenspeed_kernel.numerics.gemm # noqa: F401 @@ -65,6 +66,49 @@ def _as_tolerance_fn(override: ToleranceOverride | None) -> ToleranceFn | None: ) +def _format_signatures_for_primary_storage_dtype( + spec: KernelSpec, + dtype: torch.dtype, +) -> tuple[FormatSignature, ...]: + return tuple( + signature + for signature in sorted(spec.format_signatures, key=str) + if signature.primary_storage_dtype() == dtype + ) + + +def _compatible_reference_for_signature( + registry: KernelRegistry, + spec: KernelSpec, + signature: FormatSignature, +) -> KernelSpec | None: + ref_specs = registry.get_for_operator( + spec.family, + spec.mode, + format_signature=signature, + solution="reference", + ) + for ref in ref_specs: + if ref.name == spec.name: + continue + if ref_compatible_with_spec(ref, spec): + return ref + return None + + +def _verification_signature_and_reference( + registry: KernelRegistry, + spec: KernelSpec, + dtype: torch.dtype, +) -> tuple[FormatSignature | None, KernelSpec | None]: + signatures = _format_signatures_for_primary_storage_dtype(spec, dtype) + for signature in signatures: + ref_spec = _compatible_reference_for_signature(registry, spec, signature) + if ref_spec is not None: + return signature, ref_spec + return (signatures[0], None) if signatures else (None, None) + + def verify_kernel( kernel_name: str, *, @@ -86,24 +130,11 @@ def verify_kernel( if kernel is None: raise ValueError(f"Kernel implementation for {kernel_name!r} is missing") - ref_specs = registry.get_for_operator( - spec.family, - spec.mode, - dtype=dtype, - solution="reference", - ) - if not ref_specs: + signature, ref_spec = _verification_signature_and_reference(registry, spec, dtype) + if signature is None: raise ValueError( - f"No reference kernel found for {spec.family}.{spec.mode} and dtype={dtype}" + f"Kernel {kernel_name!r} does not support primary storage dtype={dtype}" ) - - ref_spec = None - for ref in ref_specs: - if ref.name == spec.name: - continue - if ref_compatible_with_spec(ref, spec): - ref_spec = ref - break if ref_spec is None: raise ValueError( "No compatible reference kernel found for " @@ -119,6 +150,7 @@ def verify_kernel( spec.mode, dtype=dtype, traits=spec.traits, + format_signature=signature, device=device, seed=seed, ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py index 5d21a41e..fa122679 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py @@ -32,9 +32,17 @@ from tokenspeed_kernel.ops.attention.flash_attn import mha_decode_scheduler_metadata from tokenspeed_kernel.profiling import ShapeCapture, kernel_scope from tokenspeed_kernel.selection import select_kernel +from tokenspeed_kernel.signature import dense_tensor_format, format_signature AttentionResult = torch.Tensor | tuple[torch.Tensor, torch.Tensor | None] + +def _attention_format_signature(**roles: torch.Tensor): + return format_signature( + **{role: dense_tensor_format(tensor.dtype) for role, tensor in roles.items()} + ) + + __all__ = [ "mha_prefill", "mha_extend_with_kvcache", @@ -95,10 +103,11 @@ def mha_prefill( "support_sinks": sinks is not None, "return_lse": return_lse, } + signature = _attention_format_signature(q=q, k=k, v=v) kernel = select_kernel( "attention", "mha_prefill", - q.dtype, + signature, traits=traits, solution=solution, override=override, @@ -200,10 +209,11 @@ def mha_extend_with_kvcache( "support_sinks": sinks is not None, "return_lse": return_lse, } + signature = _attention_format_signature(q=q, k_cache=k_cache, v_cache=v_cache) kernel = select_kernel( "attention", "mha_extend_with_kvcache", - q.dtype, + signature, traits=traits, solution=solution, override=override, @@ -308,10 +318,11 @@ def mha_decode_with_kvcache( "support_sinks": sinks is not None, "return_lse": return_lse, } + signature = _attention_format_signature(q=q, k_cache=k_cache, v_cache=v_cache) kernel = select_kernel( "attention", "mha_decode_with_kvcache", - q.dtype, + signature, traits=traits, solution=solution, override=override, @@ -390,10 +401,11 @@ def mha_merge_state( traits = { "head_dim": out_a.shape[-1], } + signature = _attention_format_signature(out_a=out_a, out_b=out_b) kernel = select_kernel( "attention", "mha_merge_state", - out_a.dtype, + signature, traits=traits, solution=solution, override=override, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/cuda/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/cuda/__init__.py index 96e27b8a..7e151da1 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/cuda/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/cuda/__init__.py @@ -7,6 +7,7 @@ current_platform, ) from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import format_signatures platform = current_platform() @@ -22,7 +23,9 @@ min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("out_a", "out_b"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED + 2, traits={}, tags={"throughput"}, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_attn/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_attn/__init__.py index dacc8200..ae355b86 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_attn/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_attn/__init__.py @@ -26,6 +26,7 @@ current_platform, ) from tokenspeed_kernel.registry import Priority, error_fn, register_kernel +from tokenspeed_kernel.signature import format_signatures __all__ = [ "flash_attn_func", @@ -79,7 +80,9 @@ min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED + 3, traits={ "head_dim": _FA4_BLACKWELL_PREFILL_HEAD_DIMS, @@ -130,7 +133,9 @@ def fa4_mha_prefill( min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED + 3, traits={ "head_dim": _FA4_BLACKWELL_DECODE_HEAD_DIMS, @@ -184,7 +189,9 @@ def fa4_mha_extend_with_kvcache( min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED + 3, traits={ "head_dim": _FA4_BLACKWELL_DECODE_HEAD_DIMS, @@ -245,7 +252,9 @@ def fa4_mha_decode_with_kvcache( min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED + 3, traits={ "sliding_window": frozenset({False, True}), @@ -294,7 +303,9 @@ def fa3_mha_prefill( min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED + 3, traits={ "sliding_window": frozenset({False, True}), @@ -350,7 +361,9 @@ def fa3_mha_extend_with_kvcache( min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED + 3, traits={ "sliding_window": frozenset({False, True}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flashinfer/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flashinfer/__init__.py index eb2d0a1a..e11c6b61 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flashinfer/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flashinfer/__init__.py @@ -30,6 +30,7 @@ current_platform, ) from tokenspeed_kernel.registry import ErrorClass, Priority, error_fn, register_kernel +from tokenspeed_kernel.signature import format_signatures platform = current_platform() @@ -176,7 +177,9 @@ def _get_paged_prefill_wrapper( min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED, traits={ "head_dim": frozenset({64, 128, 256}), @@ -239,7 +242,9 @@ def flashinfer_mha_prefill( min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED, traits={ "head_dim": frozenset({64, 128, 256}), @@ -346,7 +351,9 @@ def flashinfer_mha_extend_with_kvcache( min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED, traits={ "head_dim": frozenset({64, 128, 256}), @@ -420,7 +427,9 @@ def flashinfer_trtllm_mha_extend_with_kvcache( min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED, traits={ "sliding_window": frozenset({False, True}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_decode_fp16_gfx950.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_decode_fp16_gfx950.py index 8f8ba596..30742108 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_decode_fp16_gfx950.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_decode_fp16_gfx950.py @@ -36,6 +36,7 @@ ) from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import format_signatures cdna4 = gl.amd.cdna4 async_copy = cdna4.async_copy @@ -615,7 +616,9 @@ def get_config( max_arch_version=ArchVersion(9, 5), vendors=frozenset({"amd"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED, traits={ "head_dim": frozenset({64}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_prefill_fp16_gfx950.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_prefill_fp16_gfx950.py index 096728b6..96526f7a 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_prefill_fp16_gfx950.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_prefill_fp16_gfx950.py @@ -36,6 +36,7 @@ ) from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import format_signatures cdna4 = gl.amd.cdna4 async_copy = cdna4.async_copy @@ -975,7 +976,9 @@ def get_config( max_arch_version=ArchVersion(9, 5), vendors=frozenset({"amd"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED, traits={ "head_dim": frozenset({64}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/__init__.py index 5e1eb2ab..506cbefa 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/__init__.py @@ -28,6 +28,7 @@ from tokenspeed_kernel.ops.attention.triton.mha_prefill import prefill_attention_fwd from tokenspeed_kernel.platform import CapabilityRequirement from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import format_signatures @triton.jit @@ -72,7 +73,9 @@ def mha_merge_state_kernel( name="triton_mha_prefill", solution="triton", capability=CapabilityRequirement(vendors=frozenset({"nvidia", "amd"})), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.PORTABLE, traits={ "sliding_window": frozenset({False, True}), @@ -138,7 +141,9 @@ def triton_mha_prefill( name="triton_mha_extend_with_kvcache", solution="triton", capability=CapabilityRequirement(vendors=frozenset({"nvidia", "amd"})), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.PORTABLE, traits={ "sliding_window": frozenset({False, True}), @@ -216,7 +221,9 @@ def triton_mha_extend_with_kvcache( name="triton_mha_decode_with_kvcache_cached", solution="triton", capability=CapabilityRequirement(vendors=frozenset({"nvidia", "amd"})), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.PORTABLE, traits={ "sliding_window": frozenset({False, True}), @@ -289,7 +296,9 @@ def triton_mha_decode_with_kvcache( name="triton_mha_merge_state", solution="triton", capability=CapabilityRequirement(vendors=frozenset({"nvidia", "amd"})), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("out_a", "out_b"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.PORTABLE, traits={}, tags={"portability"}, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py index b3096e91..fdde2019 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py @@ -23,6 +23,7 @@ import torch from tokenspeed_kernel.profiling import ShapeCapture, kernel_scope from tokenspeed_kernel.selection import select_kernel +from tokenspeed_kernel.signature import dense_tensor_format, format_signature @dataclass @@ -111,10 +112,14 @@ def apply_rope( "has_q_out": output_q_rope is not None, "has_k_out": output_k_rope is not None, } + signature = format_signature( + query=dense_tensor_format(query.dtype), + key=dense_tensor_format(key.dtype), + ) kernel = select_kernel( "embedding", "rope", - query.dtype, + signature, traits=traits, solution=solution, override=override, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/cuda.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/cuda.py index b75caf1d..cebc7ed8 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/cuda.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/cuda.py @@ -22,6 +22,7 @@ import torch from tokenspeed_kernel.platform import CapabilityRequirement, current_platform from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import format_signatures platform = current_platform() @@ -36,7 +37,9 @@ name="cuda_embedding_rope", solution="cuda", capability=CapabilityRequirement(vendors=frozenset({"nvidia"})), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("query", "key"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.PERFORMANT, traits={ "head_size": frozenset({64, 128, 256, 512}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/triton.py index d6d88533..417c125e 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/triton.py @@ -28,6 +28,7 @@ from tokenspeed_kernel._triton import tl, triton from tokenspeed_kernel.platform import CapabilityRequirement from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import format_signatures def _next_power_of_2(n: int) -> int: @@ -361,7 +362,9 @@ def apply_rope_triton( name="triton_embedding_rope", solution="triton", capability=CapabilityRequirement(vendors=frozenset({"amd", "nvidia"})), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("query", "key"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.PORTABLE, traits={ "partial_rotary": frozenset({True, False}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py index 93828667..cbab8dc2 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py @@ -32,6 +32,12 @@ from tokenspeed_kernel.platform import Platform from tokenspeed_kernel.profiling import ShapeCapture, kernel_scope from tokenspeed_kernel.selection import select_kernel +from tokenspeed_kernel.signature import ( + ScaleFormat, + dense_tensor_format, + format_signature, + tensor_format, +) logger = logging.getLogger(__name__) @@ -62,12 +68,63 @@ def _infer_scale_type( A_scales: torch.Tensor | None, B_scales: torch.Tensor | None, ) -> str | None: - """For fp8, distinguish per-tensor from per-channel scaling.""" + """For fp8, distinguish tensor/channel/scalar scaling.""" if A_scales is None or B_scales is None: return None if A_scales.numel() == 1 and B_scales.numel() == 1: - return "per_tensor" - return "per_channel" + return "tensor" + return "channel" + + +def _scale_storage_dtype(*scales: torch.Tensor | None) -> torch.dtype: + for scale in scales: + if scale is not None: + return scale.dtype + return torch.float32 + + +def _gemm_format_signature( + A: torch.Tensor, + B: torch.Tensor, + A_scales: torch.Tensor | None, + B_scales: torch.Tensor | None, + out_dtype: torch.dtype, + quant: str | None, + block_size: list[int] | None, +): + _ = out_dtype + if quant == "mxfp8": + scale = ScaleFormat( + storage_dtype=_scale_storage_dtype(A_scales, B_scales), + granularity="block", + block_shape=tuple(block_size) if block_size is not None else None, + ) + a_storage_dtype = _fp8_dtype if A_scales is None else A.dtype + return format_signature( + a=tensor_format("mxfp8", a_storage_dtype, scale=scale), + b=tensor_format("mxfp8", B.dtype, scale=scale), + ) + if quant == "fp8": + scale = ScaleFormat( + storage_dtype=_scale_storage_dtype(A_scales, B_scales), + granularity=_infer_scale_type(A_scales, B_scales) or "unknown", + ) + return format_signature( + a=tensor_format("fp8", A.dtype, scale=scale), + b=tensor_format("fp8", B.dtype, scale=scale), + ) + if quant == "nvfp4": + scale = ScaleFormat( + storage_dtype=_scale_storage_dtype(A_scales, B_scales), + granularity="block", + ) + return format_signature( + a=tensor_format("nvfp4", A.dtype, scale=scale), + b=tensor_format("nvfp4", B.dtype, scale=scale), + ) + return format_signature( + a=dense_tensor_format(A.dtype), b=dense_tensor_format(B.dtype) + ) def _online_quantize_mxfp8( @@ -184,13 +241,6 @@ def mm( M = A.shape[0] N = B.shape[-1] if B.shape[0] == K else B.shape[0] - if quant in ("mxfp8", "fp8"): - select_dtype = _fp8_dtype - elif quant == "nvfp4": - select_dtype = A.dtype - else: - select_dtype = A.dtype - traits: dict[str, object] = { "n_align_16": N % 16 == 0, "k_align_16": K % 16 == 0, @@ -199,18 +249,15 @@ def mm( "k_align_128": K % 128 == 0, } - if quant is not None: - traits["quant"] = quant - - if quant == "fp8": - scale_type = _infer_scale_type(A_scales, B_scales) - if scale_type is not None: - traits["scale_type"] = scale_type + signature = _gemm_format_signature( + A, B, A_scales, B_scales, out_dtype, quant, block_size + ) + select_dtype = signature.primary_storage_dtype() or A.dtype kernel = select_kernel( "gemm", "mm", - select_dtype, + signature, traits=traits, override=override, expected_kernel_name=expected_kernel_name, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py index 03a11198..2f621e30 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py @@ -23,8 +23,17 @@ import torch from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement, Platform from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import ScaleFormat, format_signatures _fp8_dtype = Platform.get().fp8e4m3fn.dtype +_MXFP8_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + block_shape=(128, 128), +) +_MXFP8_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "mxfp8", {_fp8_dtype}, scale=_MXFP8_SCALE +) try: from tokenspeed_kernel.thirdparty.deep_gemm import ( @@ -54,9 +63,8 @@ min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={_fp8_dtype}, + signatures=_MXFP8_FORMAT_SIGNATURES, traits={ - "quant": frozenset({"mxfp8"}), "n_align_64": frozenset({True}), "k_align_128": frozenset({True}), }, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py index d68fe06a..2cc0ab42 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py @@ -28,11 +28,27 @@ current_platform, ) from tokenspeed_kernel.registry import Priority, error_fn, register_kernel +from tokenspeed_kernel.signature import ScaleFormat, format_signatures platform = current_platform() _fp8_dtype = Platform.get().fp8e4m3fn.dtype _fp4_dtypes: frozenset[torch.dtype] = frozenset({torch.uint8, torch.float4_e2m1fn_x2}) +_MXFP8_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + block_shape=(128, 128), +) +_NVFP4_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", +) +_MXFP8_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "mxfp8", {_fp8_dtype}, scale=_MXFP8_SCALE +) +_NVFP4_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "nvfp4", _fp4_dtypes, scale=_NVFP4_SCALE +) # ---- FlashInfer block-scaled FP8 ---------------------------------------- @@ -59,9 +75,8 @@ min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes={_fp8_dtype}, + signatures=_MXFP8_FORMAT_SIGNATURES, traits={ - "quant": frozenset({"mxfp8"}), "n_align_128": frozenset({True}), "k_align_128": frozenset({True}), }, @@ -127,10 +142,8 @@ def flashinfer_mm_fp8_blockscale( min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes=_fp4_dtypes, - traits={ - "quant": frozenset({"nvfp4"}), - }, + signatures=_NVFP4_FORMAT_SIGNATURES, + traits={}, priority=Priority.SPECIALIZED + 2, ) def flashinfer_mm_nvfp4( diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py index d3f19ac5..c96e6e61 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py @@ -31,10 +31,30 @@ from tokenspeed_kernel._triton import tl, triton from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement, Platform from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import ScaleFormat, format_signatures logger = logging.getLogger(__name__) _fp8_dtype = Platform.get().fp8e4m3fn.dtype +_MXFP8_BLOCK_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + block_shape=(128, 128), +) +_FP8_TENSOR_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="tensor", +) +_FP8_CHANNEL_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="channel", +) +_MXFP8_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "mxfp8", {_fp8_dtype}, scale=_MXFP8_BLOCK_SCALE +) +_FP8_SCALED_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "fp8", {_fp8_dtype}, scale=_FP8_TENSOR_SCALE +) | format_signatures(("a", "b"), "fp8", {_fp8_dtype}, scale=_FP8_CHANNEL_SCALE) def prepare_block_fp8_matmul_inputs( @@ -697,10 +717,8 @@ def triton_scaled_mm( min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes={_fp8_dtype}, - traits={ - "quant": frozenset({"mxfp8"}), - }, + signatures=_MXFP8_FORMAT_SIGNATURES, + traits={}, priority=Priority.PERFORMANT + 3, tags={"portability"}, ) @@ -740,7 +758,7 @@ def triton_mm_fp8_blockscale( min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes={_fp8_dtype}, + signatures=_FP8_SCALED_FORMAT_SIGNATURES, traits={ "b_layout": frozenset({"KN"}), }, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py index 195b1be1..d323d01b 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py @@ -28,6 +28,7 @@ import torch from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import ScaleFormat, format_signatures # Re-exported # dsv3_fused_a_gemm supports specific shapes only (see python/tokenspeed/runtime/models/deepseek_v3.py); @@ -35,6 +36,13 @@ from tokenspeed_kernel.thirdparty.trtllm import dsv3_fused_a_gemm # noqa: F401 _fp4_dtypes: frozenset[torch.dtype] = frozenset({torch.uint8, torch.float4_e2m1fn_x2}) +_NVFP4_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", +) +_NVFP4_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "nvfp4", _fp4_dtypes, scale=_NVFP4_SCALE +) # One stateful torchbind instance per output dtype. Each holds its own # per-shape algo cache inside C++. @@ -59,10 +67,8 @@ def _get_runner(out_dtype: torch.dtype): min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes=_fp4_dtypes, - traits={ - "quant": frozenset({"nvfp4"}), - }, + signatures=_NVFP4_FORMAT_SIGNATURES, + traits={}, priority=Priority.SPECIALIZED + 3, ) def cublaslt_mm_nvfp4( diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py index d3eeccbc..dff8763a 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py @@ -42,6 +42,12 @@ register_oracle, select_kernel, ) +from tokenspeed_kernel.signature import ( + ScaleFormat, + dense_tensor_format, + format_signature, + tensor_format, +) logger = logging.getLogger(__name__) @@ -88,7 +94,7 @@ def adjust(self, spec, platform, traits): "pre_routed" # caller provides topk_weights/topk_ids (marlin, cutlass, reference) ) -# Weight format trait values — used via traits={"weight_dtype": ...} +# Weight format values used by moe_fused(weight_format=...). WEIGHT_BF16 = "bf16" # dense bfloat16 weights WEIGHT_FP8 = "fp8" # FP8 block-scaled weights WEIGHT_MXFP4 = "mxfp4" # MXFP4 block-scaled weights @@ -116,6 +122,59 @@ def adjust(self, spec, platform, traits): ) +_FP8_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", +) +_NVFP4_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", +) +_MXFP4_SCALE = ScaleFormat( + storage_dtype=torch.uint8, + granularity="block", + block_shape=(32,), +) + + +def _single_dense_tensor_format_signature(role: str, storage_dtype: torch.dtype): + return format_signature(**{role: dense_tensor_format(storage_dtype)}) + + +def _moe_dispatch_format_signature(storage_dtype: torch.dtype, traits: Optional[dict]): + comm_strategy = (traits or {}).get("comm_strategy") + if storage_dtype == torch.int32 or comm_strategy == "local": + return _single_dense_tensor_format_signature("indices", storage_dtype) + return _single_dense_tensor_format_signature("x", storage_dtype) + + +def _moe_fused_format_signature( + storage_dtype: torch.dtype, + weight_format: str, +): + if weight_format == WEIGHT_FP8: + weight = tensor_format("fp8", torch.float8_e4m3fn, scale=_FP8_SCALE) + elif weight_format == WEIGHT_NVFP4: + weight = tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE) + elif weight_format == WEIGHT_MXFP4: + weight = tensor_format("mxfp4", torch.uint8, scale=_MXFP4_SCALE) + elif weight_format == WEIGHT_BF16: + weight = dense_tensor_format(torch.bfloat16) + else: + raise ValueError(f"Unsupported MoE fused weight_format={weight_format!r}") + + if storage_dtype == torch.uint8 and weight_format == WEIGHT_NVFP4: + x = tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE) + elif storage_dtype == torch.uint8 and weight_format == WEIGHT_MXFP4: + x = tensor_format("mxfp4", torch.uint8, scale=_MXFP4_SCALE) + elif storage_dtype == torch.float8_e4m3fn: + x = tensor_format("fp8", storage_dtype, scale=_FP8_SCALE) + else: + x = dense_tensor_format(storage_dtype) + + return format_signature(x=x, weight=weight) + + def moe_route( *args, dtype: torch.dtype = torch.bfloat16, @@ -134,10 +193,11 @@ def moe_route( * ``{"biased": True/False}``: whether correction_bias is applied. * ``{"grouped": True/False}``: whether grouped expert selection is used. """ + signature = _single_dense_tensor_format_signature("logits", dtype) kernel = select_kernel( "moe", "route", - dtype, + signature, features=frozenset(features) if features else None, traits=traits or {}, expected_kernel_name=expected_kernel_name, @@ -157,11 +217,13 @@ def moe_dispatch( Returns ``(sorted_token_ids, expert_ids, num_tokens_post_padded)``. """ + selection_traits = traits or {"comm_strategy": "local"} + signature = _moe_dispatch_format_signature(dtype, selection_traits) kernel = select_kernel( "moe", "dispatch", - dtype, - traits=traits or {"comm_strategy": "local"}, + signature, + traits=selection_traits, expected_kernel_name=expected_kernel_name, ) @@ -189,10 +251,11 @@ def moe_experts( * ``{"dispatch_gemm"}``: gather/dispatch tokens then GEMM (uses gather_indx). * ``{"gemm_combine"}``: GEMM then scatter/combine results (uses scatter_indx). """ + signature = _single_dense_tensor_format_signature("x", dtype) kernel = select_kernel( "moe", "experts", - dtype, + signature, features=frozenset(features) if features else None, traits=traits or {}, expected_kernel_name=expected_kernel_name, @@ -209,10 +272,11 @@ def moe_combine( **kwargs, ): """Combine expert outputs with weighted reduction.""" + signature = _single_dense_tensor_format_signature("x", dtype) kernel = select_kernel( "moe", "combine", - dtype, + signature, traits=traits or {}, expected_kernel_name=expected_kernel_name, ) @@ -224,6 +288,7 @@ def moe_fused( *args, dtype: torch.dtype = torch.bfloat16, features: Optional[Set[str]] = None, + weight_format: str = WEIGHT_BF16, traits: Optional[dict] = None, expected_kernel_name: Optional[str] = None, **kwargs, @@ -235,18 +300,29 @@ def moe_fused( * ``{"self_routing"}``: kernel does routing internally (trtllm). * ``{"pre_routed"}``: routing already done by caller (cutlass, reference). - Weight format traits (pass via ``traits``): + Args: + weight_format: Weight tensor encoding used for the expert weights. + Supported values are: + + * ``"bf16"``: dense bfloat16 weights with no scale tensor. + * ``"fp8"``: FP8 E4M3 weights with float32 block scales. + * ``"mxfp4"``: packed MXFP4 weights stored as uint8 with uint8 + block scales over 32-value blocks. + * ``"nvfp4"``: packed NVFP4 weights stored as uint8 with float32 + block scales. - * ``{"weight_dtype": "bf16"}``: dense bfloat16 weights. - * ``{"weight_dtype": "fp8"}``: FP8 block-scaled weights. - * ``{"weight_dtype": "mxfp4"}``: MXFP4 block-scaled weights. + The activation/input format is selected by ``dtype``. When + ``dtype=torch.uint8``, ``weight_format`` disambiguates whether the + input is interpreted as MXFP4 or NVFP4. """ + signature = _moe_fused_format_signature(dtype, weight_format) + selection_traits = dict(traits or {}) kernel = select_kernel( "moe", "fused", - dtype, + signature, features=frozenset(features) if features else None, - traits=traits or {}, + traits=selection_traits, expected_kernel_name=expected_kernel_name, ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/cuda.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/cuda.py index 462c4d84..6a33f1dd 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/cuda.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/cuda.py @@ -29,6 +29,7 @@ import torch from tokenspeed_kernel.registry import Priority, error_fn, register_kernel +from tokenspeed_kernel.signature import format_signatures try: from tokenspeed_kernel.thirdparty.cuda.activation import ( @@ -52,7 +53,9 @@ "route", name="cuda_routing_flash", solution="cuda", - dtypes={torch.float16, torch.bfloat16, torch.float32}, + signatures=format_signatures( + "logits", "dense", {torch.float16, torch.bfloat16, torch.float32} + ), traits={ "output_type": frozenset({"topk"}), "biased": frozenset({True}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/deepep.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/deepep.py index 746fbfba..cb6faa2f 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/deepep.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/deepep.py @@ -25,6 +25,7 @@ import torch from tokenspeed_kernel._triton import tl, triton from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import format_signatures logger = logging.getLogger(__name__) @@ -277,7 +278,7 @@ def _fwd_kernel_ep_scatter_2( "dispatch", name="deepep_moe_scatter", solution="deep_ep", - dtypes={torch.float8_e4m3fn, torch.bfloat16}, + signatures=format_signatures("x", "dense", {torch.float8_e4m3fn, torch.bfloat16}), traits={ "comm_strategy": frozenset({"deep_ep"}), }, @@ -414,7 +415,7 @@ def _fwd_kernel_ep_gather( "combine", name="deepep_moe_gather", solution="deep_ep", - dtypes={torch.bfloat16}, + signatures=format_signatures("x", "dense", {torch.bfloat16}), traits={ "comm_strategy": frozenset({"deep_ep"}), }, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py index 48ecae3b..af80b9ef 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py @@ -23,9 +23,94 @@ import torch from tokenspeed_kernel.platform import current_platform from tokenspeed_kernel.registry import Priority, error_fn, register_kernel +from tokenspeed_kernel.signature import ( + ScaleFormat, + dense_tensor_format, + format_signature, + tensor_format, +) platform = current_platform() + +_FP8_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", +) +_NVFP4_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", +) +_MXFP4_SCALE = ScaleFormat( + storage_dtype=torch.uint8, + granularity="block", + block_shape=(32,), +) +_BF16_FUSED_FORMAT_SIGNATURES = frozenset( + { + format_signature( + x=dense_tensor_format(torch.bfloat16), + weight=dense_tensor_format(torch.bfloat16), + ) + } +) +_CUTLASS_FUSED_FORMAT_SIGNATURES = frozenset( + { + format_signature( + x=dense_tensor_format(torch.bfloat16), + weight=dense_tensor_format(torch.bfloat16), + ), + format_signature( + x=dense_tensor_format(torch.float16), + weight=dense_tensor_format(torch.bfloat16), + ), + format_signature( + x=dense_tensor_format(torch.bfloat16), + weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + ), + format_signature( + x=dense_tensor_format(torch.float16), + weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + ), + format_signature( + x=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + ), + format_signature( + x=tensor_format("fp8", torch.float8_e4m3fn, scale=_FP8_SCALE), + weight=tensor_format("fp8", torch.float8_e4m3fn, scale=_FP8_SCALE), + ), + } +) +_FP4_FUSED_FORMAT_SIGNATURES = frozenset( + { + format_signature( + x=dense_tensor_format(torch.bfloat16), + weight=tensor_format("mxfp4", torch.uint8, scale=_MXFP4_SCALE), + ), + format_signature( + x=dense_tensor_format(torch.bfloat16), + weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + ), + format_signature( + x=tensor_format("mxfp4", torch.uint8, scale=_MXFP4_SCALE), + weight=tensor_format("mxfp4", torch.uint8, scale=_MXFP4_SCALE), + ), + format_signature( + x=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + ), + } +) +_CUTEDSL_NVFP4_FORMAT_SIGNATURES = frozenset( + { + format_signature( + x=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + ) + } +) + cutlass_fused_moe = error_fn fp4_quantize = error_fn mxfp8_quantize = error_fn @@ -101,8 +186,8 @@ name="flashinfer_trtllm_bf16_fused_moe", features={"self_routing"}, solution="trtllm", - dtypes={torch.bfloat16}, - traits={"weight_dtype": frozenset({"bf16"})}, + signatures=_BF16_FUSED_FORMAT_SIGNATURES, + traits={}, priority=Priority.SPECIALIZED, tags={"throughput"}, )(trtllm_bf16_moe) @@ -114,9 +199,8 @@ name="flashinfer_cutlass_fused_moe", features={"pre_routed"}, solution="flashinfer", - dtypes={torch.bfloat16, torch.float16, torch.uint8, torch.float8_e4m3fn}, + signatures=_CUTLASS_FUSED_FORMAT_SIGNATURES, traits={ - "weight_dtype": frozenset({"bf16", "fp8", "nvfp4"}), "tp": frozenset({True, False}), "ep": frozenset({True, False}), "cuda_graph": frozenset({False}), @@ -132,8 +216,8 @@ name="flashinfer_trtllm_fp4_fused_moe", features={"self_routing"}, solution="trtllm", - dtypes={torch.uint8, torch.bfloat16}, - traits={"weight_dtype": frozenset({"mxfp4", "nvfp4"})}, + signatures=_FP4_FUSED_FORMAT_SIGNATURES, + traits={}, priority=Priority.SPECIALIZED, tags={"throughput"}, )(trtllm_fp4_block_scale_moe) @@ -166,9 +250,8 @@ def _get_or_create_cutedsl_wrapper(wrapper_kwargs: dict) -> CuteDslMoEWrapper: name="flashinfer_cutedsl_nvfp4_fused_moe", features={"pre_routed"}, solution="cutedsl", - dtypes={torch.uint8}, + signatures=_CUTEDSL_NVFP4_FORMAT_SIGNATURES, traits={ - "weight_dtype": frozenset({"nvfp4"}), "tp": frozenset({False}), "ep": frozenset({True, False}), "cuda_graph": frozenset({True, False}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py index b54ecfe3..73d9afa6 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py @@ -33,6 +33,7 @@ ExpertLocationDispatchInfo, ) from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import format_signatures from tokenspeed_kernel.thirdparty.trtllm import ( moe_align_block_size as _moe_align_block_size, ) @@ -328,7 +329,7 @@ def _biased_grouped_topk_reference( "route", name="triton_minimax_biased_grouped_topk", solution="triton", - dtypes={torch.float32}, + signatures=format_signatures("logits", "dense", {torch.float32}), traits={ "output_type": frozenset({"topk"}), "biased": frozenset({True}), @@ -776,7 +777,9 @@ def _normalize_fp8_group_scale_layout( name="triton_moe_fused_experts", features={"dispatch_sorted"}, solution="triton", - dtypes={torch.float16, torch.bfloat16, torch.float8_e4m3fn}, + signatures=format_signatures( + "x", "dense", {torch.float16, torch.bfloat16, torch.float8_e4m3fn} + ), priority=Priority.PERFORMANT + 2, tags={"portability"}, ) @@ -978,7 +981,7 @@ def _moe_sum_reduce_kernel( "combine", name="triton_moe_sum_reduce", solution="triton", - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures("x", "dense", {torch.float16, torch.bfloat16}), traits={"comm_strategy": frozenset({None})}, priority=Priority.PERFORMANT + 2, tags={"portability"}, @@ -1023,7 +1026,7 @@ def moe_sum_reduce_triton( "combine", name="torch_compile_moe_sum_reduce", solution="reference", - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures("x", "dense", {torch.float16, torch.bfloat16}), traits={"comm_strategy": frozenset({None})}, priority=Priority.PORTABLE + 1, tags={"portability"}, @@ -1044,7 +1047,7 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor): "dispatch", name="triton_moe_align_block_size", solution="triton", - dtypes={torch.int32}, + signatures=format_signatures("indices", "dense", {torch.int32}), traits={ "comm_strategy": frozenset({"local"}), }, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton_kernels.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton_kernels.py index 35fea0a2..3d0d1db6 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton_kernels.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton_kernels.py @@ -29,6 +29,7 @@ import torch from tokenspeed_kernel.platform import current_platform from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import format_signatures try: import triton_kernels.matmul_details.opt_flags as opt_flags @@ -92,7 +93,9 @@ def _triton_kernels_routing(logits, n_expts_act, sm_first=False, dtype=None): "route", name="triton_kernels_routing", solution="triton", - dtypes={torch.float16, torch.bfloat16, torch.float32}, + signatures=format_signatures( + "logits", "dense", {torch.float16, torch.bfloat16, torch.float32} + ), traits={"output_type": frozenset({"ragged_metadata"})}, priority=Priority.PERFORMANT + 2, tags={"portability"}, @@ -175,7 +178,9 @@ def _matmul( _matmul_common = dict( solution="triton", - dtypes={torch.float16, torch.bfloat16, torch.uint8}, + signatures=format_signatures( + "x", "dense", {torch.float16, torch.bfloat16, torch.uint8} + ), priority=Priority.PERFORMANT + 2, tags={"portability"}, ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/trtllm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/trtllm.py index d0fad73b..b25b8529 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/trtllm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/trtllm.py @@ -29,6 +29,7 @@ ) from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement from tokenspeed_kernel.registry import register_kernel +from tokenspeed_kernel.signature import format_signatures from tokenspeed_kernel.thirdparty.trtllm import ( moe_align_block_size as _trtllm_moe_align_block_size, ) @@ -43,7 +44,7 @@ min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.int32}, + signatures=format_signatures("indices", "dense", {torch.int32}), traits={}, priority=12, tags={"latency"}, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py index 96325c32..2b3819fd 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py @@ -21,6 +21,7 @@ import torch from tokenspeed_kernel.profiling import ShapeCapture, kernel_scope from tokenspeed_kernel.selection import select_kernel +from tokenspeed_kernel.signature import dense_tensor_format, format_signature __all__ = [ "quantize_fp8", @@ -67,10 +68,11 @@ def quantize_fp8( traits = { "has_scale": scale is not None, } + signature = format_signature(x=dense_tensor_format(x.dtype)) kernel = select_kernel( "quantization", "fp8", - x.dtype, + signature, traits=traits, solution=solution, override=override, @@ -146,10 +148,11 @@ def quantize_fp8_with_scale( "granularity": granularity_trait, "scale_encoding": scale_encoding, } + signature = format_signature(x=dense_tensor_format(x.dtype)) kernel = select_kernel( "quantization", "fp8_with_scale", - x.dtype, + signature, traits=traits, solution=solution, override=override, @@ -204,10 +207,11 @@ def quantize_mxfp8( """ traits = {} + signature = format_signature(x=dense_tensor_format(x.dtype)) kernel = select_kernel( "quantization", "mxfp8", - x.dtype, + signature, traits=traits, solution=solution, override=override, @@ -264,10 +268,11 @@ def quantize_nvfp4( "scale_layout": scale_layout, "has_scale": scale is not None, } + signature = format_signature(x=dense_tensor_format(x.dtype)) kernel = select_kernel( "quantization", "nvfp4", - x.dtype, + signature, traits=traits, solution=solution, override=override, @@ -335,10 +340,11 @@ def quantize_mxfp4( "has_global_scale": global_scale is not None, "scale_encoding": "ue8m0", } + signature = format_signature(x=dense_tensor_format(x.dtype)) kernel = select_kernel( "quantization", "mxfp4", - x.dtype, + signature, traits=traits, solution=solution, override=override, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/flashinfer.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/flashinfer.py index 6e800536..e4124b52 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/flashinfer.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/flashinfer.py @@ -23,6 +23,7 @@ current_platform, ) from tokenspeed_kernel.registry import Priority, error_fn, register_kernel +from tokenspeed_kernel.signature import format_signatures platform = current_platform() @@ -46,7 +47,7 @@ "mxfp8", name="flashinfer_quantize_mxfp8", solution="flashinfer", - dtypes={torch.bfloat16, torch.float16}, + signatures=format_signatures("x", "dense", {torch.bfloat16, torch.float16}), traits={}, priority=Priority.PERFORMANT, ) @@ -72,7 +73,7 @@ def flashinfer_quantize_mxfp8( min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.bfloat16, torch.float16}, + signatures=format_signatures("x", "dense", {torch.bfloat16, torch.float16}), traits={ "has_scale": frozenset({True}), }, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/triton.py index cbe4c3ce..47dfc768 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/triton.py @@ -19,6 +19,7 @@ import torch from tokenspeed_kernel._triton import tl, triton from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import format_signatures @triton.jit @@ -191,7 +192,7 @@ def fp8_quantize( "fp8", name="triton_quantize_fp8", solution="triton", - dtypes={torch.bfloat16, torch.float16}, + signatures=format_signatures("x", "dense", {torch.bfloat16, torch.float16}), traits={"has_scale": frozenset({True, False})}, priority=Priority.PORTABLE, ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/trtllm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/trtllm.py index b07865f1..7b57cc8d 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/trtllm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/trtllm.py @@ -22,6 +22,7 @@ import torch from tokenspeed_kernel.platform import current_platform from tokenspeed_kernel.registry import Priority, error_fn, register_kernel +from tokenspeed_kernel.signature import format_signatures platform = current_platform() @@ -63,7 +64,7 @@ def trtllm_fp8_tensor(x: torch.Tensor) -> torch.Tensor: "fp8_with_scale", name="trtllm_quantize_fp8_with_scale", solution="trtllm", - dtypes={torch.bfloat16, torch.float16}, + signatures=format_signatures("x", "dense", {torch.bfloat16, torch.float16}), traits={ "granularity": frozenset({"tensor", "token", "token_group_128"}), "scale_encoding": frozenset({"float32", "ue8m0"}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/plugins/README.md b/tokenspeed-kernel/python/tokenspeed_kernel/plugins/README.md index 76685a1c..e931e5e6 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/plugins/README.md +++ b/tokenspeed-kernel/python/tokenspeed_kernel/plugins/README.md @@ -43,6 +43,7 @@ my-kernels-plugin/ import torch from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement +from tokenspeed_kernel.signature import format_signatures from tokenspeed_kernel.registry import register_kernel @@ -53,7 +54,9 @@ def register() -> None: "attention", "decode", solution="my_custom", - dtypes={torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.bfloat16} + ), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(9, 0), @@ -110,10 +113,17 @@ points at all — call `register_kernel(...)` directly: ```python import torch +from tokenspeed_kernel.signature import format_signatures from tokenspeed_kernel.registry import register_kernel -@register_kernel("gemm", "mm", solution="experiment", dtypes={torch.bfloat16}, priority=15) +@register_kernel( + "gemm", + "mm", + solution="experiment", + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), + priority=15, +) def my_experimental_gemm(a, b, **kwargs): ... ``` diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/registry.py b/tokenspeed-kernel/python/tokenspeed_kernel/registry.py index 8b27d77e..905187f0 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/registry.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/registry.py @@ -32,6 +32,7 @@ from tokenspeed_kernel.selection import SelectedKernel from tokenspeed_kernel.platform import CapabilityRequirement, PlatformInfo +from tokenspeed_kernel.signature import FormatSignature logger = logging.getLogger(__name__) @@ -145,7 +146,7 @@ class KernelSpec: # Capabilities capability: CapabilityRequirement = field(default_factory=CapabilityRequirement) - dtypes: frozenset[torch.dtype] = frozenset() # Supported data types + format_signatures: frozenset[FormatSignature] = frozenset() # Op-specific traits, e.g. {"head_dim": frozenset({64, 128, 256}), "persistent": frozenset({True})} traits: dict[str, frozenset[Any]] = field(default_factory=dict) @@ -158,6 +159,29 @@ class KernelSpec: frozenset() ) # Standard tags: "throughput", "latency", "determinism", "portability" + def supports_format_signature(self, format_signature: FormatSignature) -> bool: + return format_signature in self.format_signatures + + def format_signature_for_primary_storage_dtype( + self, + primary_storage_dtype: torch.dtype, + ) -> FormatSignature | None: + """Return the first signature matching a dtype-oriented filter.""" + for signature in sorted(self.format_signatures, key=str): + if signature.primary_storage_dtype() == primary_storage_dtype: + return signature + return None + + def primary_storage_dtypes(self) -> frozenset[torch.dtype]: + return frozenset( + dtype + for dtype in ( + signature.primary_storage_dtype() + for signature in self.format_signatures + ) + if dtype is not None + ) + class KernelRegistry: """Central registry for all kernel implementations.""" @@ -223,7 +247,7 @@ def get_for_operator( *, features: frozenset[str] | None = None, platform: PlatformInfo | None = None, - dtype: torch.dtype | None = None, + format_signature: FormatSignature | None = None, tags: set[str] | None = None, solution: str | None = None, ) -> list[KernelSpec]: @@ -234,8 +258,8 @@ def get_for_operator( specs = [s for s in specs if features.issubset(s.features)] if platform: specs = [s for s in specs if s.capability.satisfied_by(platform)] - if dtype: - specs = [s for s in specs if dtype in s.dtypes] + if format_signature: + specs = [s for s in specs if s.supports_format_signature(format_signature)] if tags: specs = [s for s in specs if tags.issubset(s.tags)] if solution: @@ -293,7 +317,7 @@ def register_kernel( features: set[str] | None = None, solution: str, capability: CapabilityRequirement | None = None, - dtypes: set[torch.dtype], + signatures: set[FormatSignature] | frozenset[FormatSignature], traits: dict[str, frozenset[Any]] | None = None, priority: Priority | int = Priority.PERFORMANT + 2, tags: set[str] | None = None, @@ -307,6 +331,8 @@ def register_kernel( Example:: + from tokenspeed_kernel.signature import format_signatures + @register_kernel( "attention", "decode", features={"paged"}, @@ -315,7 +341,9 @@ def register_kernel( min_arch_version=ArchVersion(10, 0), required_features=frozenset({"tensor_core:f8"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} + ), # Narrowly gated on SM100 + tcgen05 → SPECIALIZED band. priority=Priority.SPECIALIZED + 1, tags={"latency", "determinism"}, @@ -334,7 +362,7 @@ def decorator(fn: Callable) -> Callable: mode=mode, solution=solution, features=frozenset(features or set()), - dtypes=frozenset(dtypes), + format_signatures=frozenset(signatures), capability=capability or CapabilityRequirement(), traits=traits or {}, priority=priority_int, @@ -362,7 +390,8 @@ def describe_kernel(name: str) -> str: f" Operator: {spec.family}.{spec.mode}", f" Solution: {spec.solution}", f" Priority: {spec.priority} ({band_str})", - f" Dtypes: {', '.join(str(d) for d in spec.dtypes)}", + " Format signatures: " + + ("; ".join(str(p) for p in spec.format_signatures) or "none"), f" Platform: {spec.capability}", f" Tags: {', '.join(spec.tags) or 'none'}", ] diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/selection.py b/tokenspeed-kernel/python/tokenspeed_kernel/selection.py index 01a6e1ba..61569459 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/selection.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/selection.py @@ -28,9 +28,9 @@ from pathlib import Path from typing import Any, Callable, Generator -import torch from tokenspeed_kernel.platform import PlatformInfo, current_platform from tokenspeed_kernel.registry import KernelRegistry, KernelSpec +from tokenspeed_kernel.signature import FormatSignature logger = logging.getLogger(__name__) @@ -313,7 +313,7 @@ def _get_config_override(family: str, mode: str) -> _ConfigOverrideEntry | None: def _make_cache_key( family: str, mode: str, - dtype: object, + format_signature: FormatSignature, arch: str, objective: SelectionObjective, features: frozenset[str] | None, @@ -323,7 +323,16 @@ def _make_cache_key( """Build a hashable cache key including selection-relevant traits.""" traits_key = tuple(sorted(traits.items())) if traits else () mods_key = frozenset(features) if features else frozenset() - return (family, mode, dtype, arch, objective, mods_key, traits_key, solution) + return ( + family, + mode, + format_signature, + arch, + objective, + mods_key, + traits_key, + solution, + ) def _score_priority(spec: KernelSpec) -> int: @@ -478,7 +487,7 @@ def _resolve_override( registry: KernelRegistry, family: str, mode: str, - dtype: object, + format_signature: object, override: str, platform: PlatformInfo, ) -> SelectedKernel: @@ -494,14 +503,14 @@ def _resolve_override( return SelectedKernel(name=kernel_name, impl=impl) raise NoKernelFoundError( - f"Override '{override}' not found for {family}.{mode} ({dtype})" + f"Override '{override}' not found for {family}.{mode} ({format_signature})" ) def _log_selection( family: str, mode: str, - dtype: object, + format_signature: object, winner: KernelSpec, scored: list[tuple[KernelSpec, ScoreBreakdown]], platform: PlatformInfo, @@ -517,7 +526,7 @@ def _log_selection( "[tokenspeed_kernel] %s.%s(%s) -> %s (%s, %s)", family, mode, - dtype, + format_signature, winner.name, breakdown, platform.arch, @@ -527,7 +536,7 @@ def _log_selection( "[tokenspeed_kernel] %s.%s(%s) -> %s (%s)", family, mode, - dtype, + format_signature, winner.name, platform.arch, ) @@ -536,7 +545,7 @@ def _log_selection( def select_kernel( family: str, mode: str, - dtype: torch.dtype, + format_signature: FormatSignature, *, features: frozenset[str] | None = None, platform: PlatformInfo | None = None, @@ -548,21 +557,21 @@ def select_kernel( ) -> SelectedKernel: """Select the best kernel for an operation. - On first call for a given (family, mode, dtype, platform, objective, traits, + On first call for a given (family, mode, format_signature, platform, objective, traits, solution) combination, runs the full selection pipeline. Subsequent calls with the same arguments return the cached result — a single dict lookup. Args: family: Operator family (e.g., "attention") mode: Operator mode (e.g., "decode") - dtype: Primary data type + format_signature: Role-indexed tensor format signature features: Required operator features (e.g., {"paged"}) platform: Hardware to match (auto-detected if None) objective: Selection objective (see SelectionObjective enum) traits: Op-specific trait values that affect kernel applicability (e.g., {"head_dim": 128, "num_kv_heads": 8}) solution: Restrict selection to a registered solution while preserving - normal platform, dtype, and trait filtering. + normal platform, format signature, and trait filtering. override: Force a specific kernel name or solution string expected_kernel_name: Debug-only hint. When set, a warning is logged if the selected kernel differs from this name. The @@ -609,7 +618,14 @@ def select_kernel( # Fast path: check cache (skipped when override is active) cache_key = _make_cache_key( - family, mode, dtype, platform.arch, objective, features, traits, solution + family, + mode, + format_signature, + platform.arch, + objective, + features, + traits, + solution, ) if override is None: cached = registry.cache_get(cache_key) @@ -617,7 +633,9 @@ def select_kernel( return cached if override: - return _resolve_override(registry, family, mode, dtype, override, platform) + return _resolve_override( + registry, family, mode, format_signature, override, platform + ) # Get candidates (same filtering for both strategies) candidates = registry.get_for_operator( @@ -625,14 +643,14 @@ def select_kernel( mode, features=features, platform=platform, - dtype=dtype, + format_signature=format_signature, solution=solution, ) solution_clause = f" with solution {solution!r}" if solution else "" if not candidates: raise NoKernelFoundError( - f"No kernel found for {family}.{mode} ({dtype})" + f"No kernel found for {family}.{mode} ({format_signature})" f"{solution_clause} on {platform.device_name}" ) @@ -641,7 +659,7 @@ def select_kernel( if not candidates: raise NoKernelFoundError( - f"No kernel found for {family}.{mode} ({dtype})" + f"No kernel found for {family}.{mode} ({format_signature})" f"{solution_clause} with traits {traits} on {platform.device_name}" ) @@ -653,7 +671,7 @@ def select_kernel( candidates, family, mode, - dtype, + format_signature, platform, traits, _policy.autotune_params, @@ -662,7 +680,7 @@ def select_kernel( scored = _rank_by_objective(candidates, objective, platform, traits) winner = scored[0][0] - _log_selection(family, mode, dtype, winner, scored, platform, objective) + _log_selection(family, mode, format_signature, winner, scored, platform, objective) if expected_kernel_name and winner.name != expected_kernel_name: logger.warning( @@ -670,7 +688,7 @@ def select_kernel( "expected '%s'. Score breakdown — selected: %s", family, mode, - dtype, + format_signature, winner.name, expected_kernel_name, next((s for sp, s in scored if sp.name == winner.name), "N/A"), @@ -686,7 +704,7 @@ def _autotune_select( candidates: list[KernelSpec], family: str, mode: str, - dtype: object, + format_signature: object, platform: PlatformInfo, traits: dict[str, Any] | None, params: AutotuneParams, @@ -704,7 +722,7 @@ def _autotune_select( "[tokenspeed_kernel:autotune] falling back to heuristic for %s.%s(%s)", family, mode, - dtype, + format_signature, ) return winner, scored @@ -729,7 +747,7 @@ def kernel_override( def explain_selection( family: str, mode: str, - dtype: torch.dtype, + format_signature: FormatSignature, *, features: frozenset[str] | None = None, platform: PlatformInfo | None = None, @@ -764,7 +782,7 @@ def explain_selection( mode, features=features, platform=platform, - dtype=dtype, + format_signature=format_signature, solution=solution, ) @@ -777,7 +795,7 @@ def explain_selection( filtered_out = [s for s in all_specs if s.name not in filtered_names] lines = [ - f"Op: {family}.{mode} ({dtype})", + f"Op: {family}.{mode} ({format_signature})", f"Platform: {platform.device_name} ({platform.arch})", f"Solution: {solution or 'any'}", f"Objective: {objective.value}", @@ -813,10 +831,12 @@ def explain_selection( f"arch mismatch (requires " f"{spec.capability.min_arch_version})" ) - if dtype and dtype not in spec.dtypes: + if format_signature and not spec.supports_format_signature( + format_signature + ): reasons.append( - f"dtype mismatch (supports " - f"{', '.join(str(d) for d in spec.dtypes)})" + f"format signature mismatch (supports " + f"{', '.join(str(d) for d in spec.format_signatures)})" ) if solution and spec.solution != solution: reasons.append(f"solution mismatch (is {spec.solution!r})") @@ -827,27 +847,41 @@ def explain_selection( def warmup_selection( - ops: list[tuple[str, str, torch.dtype, dict | None]] | None = None, + ops: list[tuple[str, str, FormatSignature, dict | None]] | None = None, ) -> None: - """Pre-resolve kernel selection for the given ops (or all registered ops). + """Pre-resolve kernel selection for explicit op signatures. - Call this at model init to front-load both heuristic and autotune costs. - After warmup, every select_kernel() call on the hot path is a cache hit. + Pass ``ops`` from model initialization to front-load heuristic and autotune + costs for the actual hot-path call sites. Each entry must include the exact + ``FormatSignature`` and trait values used by runtime selection. + + When ``ops`` is ``None``, this performs only a deterministic smoke warmup: + one representative signature for each registered operator, with no traits. + That path verifies the registry and fills a small cache sample, but it does + not warm all supported signatures, trait combinations, or feature-specific + call paths. """ if ops is None: - ops = [ - (f, m, torch.bfloat16, None) - for f, m in KernelRegistry.get().list_operators() - ] - - for family, mode, dtype, traits in ops: + ops = [] + registry = KernelRegistry.get() + for family, mode in registry.list_operators(): + specs = registry.get_for_operator(family, mode) + if not specs or not specs[0].format_signatures: + continue + # No-arg warmup is intentionally a smoke path. Pick a stable + # representative from the highest-priority spec; callers that need + # comprehensive warmup should pass explicit op signatures. + format_signature = sorted(specs[0].format_signatures, key=str)[0] + ops.append((family, mode, format_signature, None)) + + for family, mode, format_signature, traits in ops: try: - select_kernel(family, mode, dtype, traits=traits) + select_kernel(family, mode, format_signature, traits=traits) except NoKernelFoundError: logger.debug( "[tokenspeed_kernel] warmup: no kernel for %s.%s(%s)", family, mode, - dtype, + format_signature, ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/signature.py b/tokenspeed-kernel/python/tokenspeed_kernel/signature.py new file mode 100644 index 00000000..b6f7331d --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/signature.py @@ -0,0 +1,247 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Iterable + +if TYPE_CHECKING: + import torch + +__all__ = [ + "FormatSignature", + "ScaleFormat", + "TensorFormat", + "dense_tensor_format", + "format_signature", + "tensor_format", + "format_signatures", +] + + +@dataclass(frozen=True) +class ScaleFormat: + """Metadata representation for one tensor scale sidecar. + + Args: + storage_dtype: Physical dtype used by the scale tensor. + granularity: Scale granularity, such as "tensor", "channel", "block". + block_shape: Logical block shape covered by each scale value when + granularity is block-based. + """ + + storage_dtype: torch.dtype + granularity: str + block_shape: tuple[int, ...] | None = None + + def __str__(self) -> str: + parts = [self.granularity, f"storage={self.storage_dtype}"] + if self.block_shape is not None: + parts.append(f"block={self.block_shape}") + return "scale(" + ", ".join(parts) + ")" + + +@dataclass(frozen=True) +class TensorFormat: + """Metadata representation for one logical tensor. + + Args: + storage_dtype: Physical dtype used by the main tensor payload. + format: Logical representation format, such as "dense", "fp8", + "mxfp8", "mxfp4", or "nvfp4". + scale: Optional scale sidecar metadata bundled with this tensor role. + """ + + storage_dtype: torch.dtype + format: str = "dense" + scale: ScaleFormat | None = None + + def __str__(self) -> str: + if self.scale is None: + return f"{self.format}[storage={self.storage_dtype}]" + return f"{self.format}[storage={self.storage_dtype}, {self.scale}]" + + +@dataclass(frozen=True) +class FormatSignature: + """One concrete set of role-indexed tensor formats for all tensor operands. + + Each role appears at most once and maps to exactly one ``TensorFormat``. + A kernel that supports alternatives for a role represents them as multiple + ``FormatSignature`` values in ``KernelSpec.format_signatures`` rather than + as multiple formats inside one signature. + """ + + roles: tuple[tuple[str, TensorFormat], ...] + + def __post_init__(self) -> None: + seen: set[str] = set() + normalized: list[tuple[str, TensorFormat]] = [] + for role, tensor_format in sorted(self.roles, key=lambda item: item[0]): + if role in seen: + raise ValueError(f"duplicate format role {role!r}") + seen.add(role) + normalized.append((role, tensor_format)) + object.__setattr__(self, "roles", tuple(normalized)) + + def format_for(self, role: str) -> TensorFormat | None: + """Return the format for role, or None if it is absent.""" + for role_name, tensor_format in self.roles: + if role_name == role: + return tensor_format + return None + + def primary_storage_dtype(self) -> torch.dtype | None: + """Return the dtype used by dtype-oriented tooling and filters. + + The registry and selection path use the full signature. This helper is + only for compatibility with user-facing commands that still ask for a + single dtype filter, such as numerics and benchmark CLIs. + """ + preferred_roles = ( + "q", + "a", + "x", + "input", + "logits", + "out_a", + "hidden", + ) + for role in preferred_roles: + tensor_format = self.format_for(role) + if tensor_format is not None: + return tensor_format.storage_dtype + return self.roles[0][1].storage_dtype if self.roles else None + + def __str__(self) -> str: + return ( + ", ".join(f"{role}={tensor_format}" for role, tensor_format in self.roles) + or "none" + ) + + +def tensor_format( + format: str, + storage_dtype: torch.dtype, + *, + scale: ScaleFormat | None = None, +) -> TensorFormat: + """Construct a format for one tensor role. + + Args: + format: Logical representation format, such as "dense", "fp8", + "mxfp8", "mxfp4", or "nvfp4". + storage_dtype: Physical dtype used by the main tensor payload. + scale: Optional scale sidecar metadata bundled with this tensor role. + """ + return TensorFormat(storage_dtype=storage_dtype, format=format, scale=scale) + + +def dense_tensor_format(storage_dtype: torch.dtype) -> TensorFormat: + """Construct a dense, unscaled tensor format for storage_dtype.""" + return tensor_format("dense", storage_dtype) + + +def format_signature(**roles: TensorFormat) -> FormatSignature: + """Construct one concrete role-indexed format signature. + + Keyword names are logical tensor roles for the operator, for example + a/b for GEMM or q/k_cache/v_cache for attention. + Values are the exact formats required for those roles. Each role gets + exactly one ``TensorFormat``; represent alternatives by constructing + multiple ``FormatSignature`` values. + + Examples: + >>> import torch + >>> format_signature( + ... a=dense_tensor_format(torch.bfloat16), + ... b=tensor_format("mxfp4", torch.uint8), + ... ) + + This creates one signature equivalent to: + + >>> FormatSignature( + ... ( + ... ("a", dense_tensor_format(torch.bfloat16)), + ... ("b", tensor_format("mxfp4", torch.uint8)), + ... ) + ... ) + """ + return FormatSignature(tuple(roles.items())) + + +def format_signatures( + roles: str | Iterable[str], + format: str, + storage_dtypes: Iterable[torch.dtype], + *, + scale: ScaleFormat | None = None, +) -> frozenset[FormatSignature]: + """Construct same-format signatures for each storage dtype. + + Args: + roles: Logical tensor roles. Pass a string for one role or an iterable + for multiple roles. + format: Logical representation format assigned to every role. + storage_dtypes: Physical dtypes used by every role, one signature per + dtype. + scale: Optional scale sidecar metadata assigned to every role. + + This helper expands dtype alternatives into separate signatures; it does + not put multiple formats on one role. Use ``format_signature`` directly for + mixed-role combinations such as dense activations with quantized weights. + + Examples: + >>> import torch + >>> format_signatures( + ... ("q", "k_cache", "v_cache"), + ... "dense", + ... {torch.float16, torch.bfloat16}, + ... ) + + This expands to a ``frozenset`` containing one signature per dtype, + equivalent to: + + >>> frozenset( + ... { + ... format_signature( + ... q=dense_tensor_format(torch.float16), + ... k_cache=dense_tensor_format(torch.float16), + ... v_cache=dense_tensor_format(torch.float16), + ... ), + ... format_signature( + ... q=dense_tensor_format(torch.bfloat16), + ... k_cache=dense_tensor_format(torch.bfloat16), + ... v_cache=dense_tensor_format(torch.bfloat16), + ... ), + ... } + ... ) + """ + normalized_roles = (roles,) if isinstance(roles, str) else tuple(roles) + return frozenset( + format_signature( + **{ + role: tensor_format(format, storage_dtype, scale=scale) + for role in normalized_roles + } + ) + for storage_dtype in storage_dtypes + ) diff --git a/tokenspeed-kernel/test/conftest.py b/tokenspeed-kernel/test/conftest.py index 3796ef63..22d0a9f8 100644 --- a/tokenspeed-kernel/test/conftest.py +++ b/tokenspeed-kernel/test/conftest.py @@ -54,13 +54,16 @@ def _require( ) -> None: from tokenspeed_kernel.platform import current_platform - specs = KernelRegistry.get().get_for_operator( - family, - mode, - platform=current_platform(), - dtype=dtype, - solution=solution, - ) + specs = [ + spec + for spec in KernelRegistry.get().get_for_operator( + family, + mode, + platform=current_platform(), + solution=solution, + ) + if spec.format_signature_for_primary_storage_dtype(dtype) is not None + ] if not specs: pytest.skip(f"{family}.{mode} solution {solution!r} is not registered") diff --git a/tokenspeed-kernel/test/test_benchmark.py b/tokenspeed-kernel/test/test_benchmark.py index 19eccf70..0b014e12 100644 --- a/tokenspeed-kernel/test/test_benchmark.py +++ b/tokenspeed-kernel/test/test_benchmark.py @@ -33,6 +33,7 @@ from tokenspeed_kernel.platform import Platform from tokenspeed_kernel.profiling import ProfilingConfig from tokenspeed_kernel.registry import KernelRegistry, KernelSpec +from tokenspeed_kernel.signature import format_signatures pytestmark = [ pytest.mark.usefixtures("fresh_registry"), @@ -66,7 +67,7 @@ def _register_test_gemm_kernels() -> None: family="gemm", mode="mm", solution="reference", - dtypes=frozenset({dtype}), + format_signatures=format_signatures(("a", "b"), "dense", {dtype}), traits=_TEST_GEMM_TRAITS, priority=0, ) @@ -75,7 +76,7 @@ def _register_test_gemm_kernels() -> None: family="gemm", mode="mm", solution="triton", - dtypes=frozenset({dtype}), + format_signatures=format_signatures(("a", "b"), "dense", {dtype}), traits=_TEST_GEMM_TRAITS, priority=10, ) diff --git a/tokenspeed-kernel/test/test_kernel_api_selection.py b/tokenspeed-kernel/test/test_kernel_api_selection.py index f1d5488c..fc84e23b 100644 --- a/tokenspeed-kernel/test/test_kernel_api_selection.py +++ b/tokenspeed-kernel/test/test_kernel_api_selection.py @@ -158,6 +158,78 @@ def _mm_mxfp8() -> torch.Tensor: ) +def test_gemm_mxfp8_online_activation_signature_uses_quantized_storage() -> None: + a = torch.empty((4, 128), dtype=torch.bfloat16) + b = torch.empty((128, 128), dtype=_fp8_dtype()) + b_scales = torch.empty((1, 1), dtype=torch.float32) + + signature = _gemm_pkg._gemm_format_signature( + a, + b, + None, + b_scales, + torch.bfloat16, + "mxfp8", + [128, 128], + ) + + a_format = signature.format_for("a") + b_format = signature.format_for("b") + assert a_format is not None + assert b_format is not None + assert a_format.storage_dtype == _fp8_dtype() + assert b_format.storage_dtype == _fp8_dtype() + + +def test_gemm_fp8_scaled_signature_uses_fp8_format_with_scale() -> None: + a = torch.empty((4, 128), dtype=_fp8_dtype()) + b = torch.empty((128, 128), dtype=_fp8_dtype()) + a_scales = torch.empty((1,), dtype=torch.float32) + b_scales = torch.empty((1,), dtype=torch.float32) + + signature = _gemm_pkg._gemm_format_signature( + a, + b, + a_scales, + b_scales, + torch.bfloat16, + "fp8", + None, + ) + + for role in ("a", "b"): + tensor_format = signature.format_for(role) + assert tensor_format is not None + assert tensor_format.format == "fp8" + assert tensor_format.storage_dtype == _fp8_dtype() + assert tensor_format.scale is not None + assert tensor_format.scale.granularity == "tensor" + assert tensor_format.scale.storage_dtype == torch.float32 + + +def test_gemm_fp8_scaled_signature_uses_channel_granularity() -> None: + a = torch.empty((4, 128), dtype=_fp8_dtype()) + b = torch.empty((128, 128), dtype=_fp8_dtype()) + a_scales = torch.empty((4,), dtype=torch.float32) + b_scales = torch.empty((128,), dtype=torch.float32) + + signature = _gemm_pkg._gemm_format_signature( + a, + b, + a_scales, + b_scales, + torch.bfloat16, + "fp8", + None, + ) + + for role in ("a", "b"): + tensor_format = signature.format_for(role) + assert tensor_format is not None + assert tensor_format.scale is not None + assert tensor_format.scale.granularity == "channel" + + def _mm_nvfp4() -> torch.Tensor: a = torch.empty((4, 64), dtype=torch.uint8) b = torch.empty((128, 64), dtype=torch.uint8) @@ -324,7 +396,7 @@ def _moe_fused_self_routing_bf16() -> object: return tokenspeed_kernel.moe_fused( dtype=torch.bfloat16, features={"self_routing"}, - traits={"weight_dtype": "bf16"}, + weight_format="bf16", ) @@ -332,7 +404,7 @@ def _moe_fused_self_routing_mxfp4() -> object: return tokenspeed_kernel.moe_fused( dtype=torch.bfloat16, features={"self_routing"}, - traits={"weight_dtype": "mxfp4"}, + weight_format="mxfp4", ) @@ -340,8 +412,8 @@ def _moe_fused_prerouted_nvfp4_cutlass() -> object: return tokenspeed_kernel.moe_fused( dtype=torch.bfloat16, features={"pre_routed"}, + weight_format="nvfp4", traits={ - "weight_dtype": "nvfp4", "tp": True, "ep": True, "cuda_graph": False, @@ -353,8 +425,8 @@ def _moe_fused_prerouted_nvfp4_cutedsl() -> object: return tokenspeed_kernel.moe_fused( dtype=torch.uint8, features={"pre_routed"}, + weight_format="nvfp4", traits={ - "weight_dtype": "nvfp4", "tp": False, "ep": True, "cuda_graph": True, @@ -366,7 +438,8 @@ def _moe_fused_prerouted_bf16_reference() -> object: return tokenspeed_kernel.moe_fused( dtype=torch.bfloat16, features={"pre_routed"}, - traits={"weight_dtype": "bf16", "tp": False, "ep": False}, + weight_format="bf16", + traits={"tp": False, "ep": False}, ) diff --git a/tokenspeed-kernel/test/test_numerics.py b/tokenspeed-kernel/test/test_numerics.py index 4e4f9b2d..38118412 100644 --- a/tokenspeed-kernel/test/test_numerics.py +++ b/tokenspeed-kernel/test/test_numerics.py @@ -23,10 +23,15 @@ import pytest import torch from tokenspeed_kernel.numerics.comparison import compare_outputs, format_comparison +from tokenspeed_kernel.numerics.inputs import get_input_generator from tokenspeed_kernel.numerics.tolerance import Tolerance -from tokenspeed_kernel.numerics.verify import verify_kernel +from tokenspeed_kernel.numerics.verify import ( + _verification_signature_and_reference, + verify_kernel, +) from tokenspeed_kernel.platform import Platform from tokenspeed_kernel.registry import KernelRegistry, KernelSpec, load_builtin_kernels +from tokenspeed_kernel.signature import ScaleFormat, format_signatures _fp8_dtype = Platform.get().fp8e4m3fn.dtype @@ -59,6 +64,93 @@ def test_treats_inf_as_mismatch(self) -> None: assert result.num_mismatches == 1 +def test_gemm_input_generator_uses_signature_scale_metadata() -> None: + scale = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + block_shape=(128, 128), + ) + signature = next( + iter(format_signatures(("a", "b"), "mxfp8", {_fp8_dtype}, scale=scale)) + ) + generator = get_input_generator( + "gemm", + "mm", + dtype=_fp8_dtype, + traits={}, + format_signature=signature, + device="cpu", + ) + + inputs = generator.generate(M=4, N=256, K=128) + + assert inputs["A"].dtype == _fp8_dtype + assert inputs["B"].dtype == _fp8_dtype + assert inputs["A_scales"].shape == (4, 1) + assert inputs["B_scales"].shape == (2, 1) + assert inputs["A_scales"].dtype == torch.float32 + assert inputs["B_scales"].dtype == torch.float32 + assert inputs["block_size"] == [128, 128] + + +def test_gemm_input_generator_requires_mxfp8_block_shape() -> None: + scale = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + ) + signature = next( + iter(format_signatures(("a", "b"), "mxfp8", {_fp8_dtype}, scale=scale)) + ) + generator = get_input_generator( + "gemm", + "mm", + dtype=_fp8_dtype, + traits={}, + format_signature=signature, + device="cpu", + ) + + with pytest.raises(ValueError, match="requires block_shape"): + generator.generate(M=4, N=256, K=128) + + +def test_verification_uses_signature_with_compatible_reference(fresh_registry) -> None: + tensor_scale = ScaleFormat(storage_dtype=torch.float32, granularity="tensor") + channel_scale = ScaleFormat(storage_dtype=torch.float32, granularity="channel") + tensor_signature = next( + iter(format_signatures(("a", "b"), "fp8", {_fp8_dtype}, scale=tensor_scale)) + ) + channel_signature = next( + iter(format_signatures(("a", "b"), "fp8", {_fp8_dtype}, scale=channel_scale)) + ) + ref_spec = KernelSpec( + name="test_tensor_scale_reference", + family="gemm", + mode="mm", + solution="reference", + format_signatures=frozenset({tensor_signature}), + traits={"b_layout": frozenset({"KN"})}, + ) + test_spec = KernelSpec( + name="test_fp8_scaled", + family="gemm", + mode="mm", + solution="triton", + format_signatures=frozenset({channel_signature, tensor_signature}), + traits={"b_layout": frozenset({"KN"})}, + ) + registry = KernelRegistry.get() + registry.register(ref_spec, lambda **_kwargs: None) + registry.register(test_spec, lambda **_kwargs: None) + + signature, reference = _verification_signature_and_reference( + registry, test_spec, _fp8_dtype + ) + + assert signature == tensor_signature + assert reference is ref_spec + + class TestNumericsVerification: def _get_verifiable_specs( dtype: torch.dtype, family: str | None = None @@ -72,13 +164,16 @@ def _get_verifiable_specs( continue # Only run kernels that have a paired reference for this dtype; # otherwise verify_kernel raises ValueError and the test errors. - has_reference = any( - s.solution == "reference" - for s in registry.get_for_operator(family_name, mode, dtype=dtype) - ) + op_specs = registry.get_for_operator(family_name, mode) + dtype_specs = [ + s + for s in op_specs + if s.format_signature_for_primary_storage_dtype(dtype) is not None + ] + has_reference = any(s.solution == "reference" for s in dtype_specs) if not has_reference: continue - for spec in registry.get_for_operator(family_name, mode, dtype=dtype): + for spec in dtype_specs: if spec.solution == "reference": continue if spec.solution == "deep_gemm": diff --git a/tokenspeed-kernel/test/test_plugins.py b/tokenspeed-kernel/test/test_plugins.py index 867aca3c..07730990 100644 --- a/tokenspeed-kernel/test/test_plugins.py +++ b/tokenspeed-kernel/test/test_plugins.py @@ -38,6 +38,7 @@ ) from tokenspeed_kernel.plugins.cli import main as cli_main from tokenspeed_kernel.registry import KernelRegistry, register_kernel +from tokenspeed_kernel.signature import format_signatures # --------------------------------------------------------------------------- # Fixtures and helpers @@ -104,7 +105,7 @@ def _make_register( mode: str = "mm", solution: str | None = None, priority: int = 12, - dtypes=None, + storage_dtypes=None, capability: CapabilityRequirement | None = None, ) -> Callable[[], None]: sol = solution or name @@ -115,7 +116,9 @@ def register(): mode, name=name, solution=sol, - dtypes=dtypes or {torch.bfloat16}, + signatures=format_signatures( + ("a", "b"), "dense", storage_dtypes or {torch.bfloat16} + ), priority=priority, capability=capability, ) @@ -147,7 +150,12 @@ class TestExplicitRegistration: """The decorator path is what plugins use; test it works standalone.""" def test_register_kernel_decorator(self, fresh_plugins): - @register_kernel("gemm", "mm", solution="manual", dtypes={torch.bfloat16}) + @register_kernel( + "gemm", + "mm", + solution="manual", + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), + ) def impl(a, b): return a @ b @@ -202,7 +210,7 @@ def register(): "mm", name=f"k_{name}", solution=name, - dtypes={torch.bfloat16}, + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), ) def impl(): return name @@ -232,7 +240,7 @@ def register(): "mm", name="once_kernel", solution="once", - dtypes={torch.bfloat16}, + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), ) def impl(): return None @@ -253,7 +261,7 @@ def register(): "mm", name="force_kernel", solution="force", - dtypes={torch.bfloat16}, + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), ) def impl(): return None @@ -274,7 +282,7 @@ def test_higher_priority_plugin_wins(self, fresh_plugins, patch_entry_points): "mm", name="builtin_gemm", solution="builtin", - dtypes={torch.bfloat16}, + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), priority=10, ) def builtin(a, b): @@ -363,7 +371,7 @@ def test_equal_priority_emits_warning(self, fresh_plugins, patch_entry_points): "mm", name="builtin_eq", solution="builtin", - dtypes={torch.bfloat16}, + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), priority=14, ) def b(a, x): diff --git a/tokenspeed-kernel/test/test_registry.py b/tokenspeed-kernel/test/test_registry.py index 7bf0c056..c9aecd8e 100644 --- a/tokenspeed-kernel/test/test_registry.py +++ b/tokenspeed-kernel/test/test_registry.py @@ -29,6 +29,13 @@ describe_kernel, register_kernel, ) +from tokenspeed_kernel.signature import ( + ScaleFormat, + dense_tensor_format, + format_signature, + format_signatures, + tensor_format, +) from utils import dummy_impl, register_all_samples pytestmark = pytest.mark.usefixtures("fresh_registry") @@ -46,13 +53,31 @@ def test_default_values(self): assert spec.solution == "" assert spec.priority == 10 assert spec.tags == frozenset() - assert spec.dtypes == frozenset() + assert spec.format_signatures == frozenset() def test_hashable_without_dict_traits(self): spec = KernelSpec(name="k1", family="attention", mode="decode", traits={}) with pytest.raises(TypeError): hash(spec) + def test_format_signature_bundles_scale_metadata(self): + scale = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + block_shape=(32,), + ) + mixed = format_signature( + a=dense_tensor_format(torch.bfloat16), + b=tensor_format("mxfp4", torch.uint8, scale=scale), + ) + dense = format_signature( + a=dense_tensor_format(torch.bfloat16), + b=dense_tensor_format(torch.uint8), + ) + + assert mixed != dense + assert mixed.format_for("b").scale == scale + def test_equality(self): spec1 = KernelSpec(name="k1", family="attention", mode="decode") spec2 = KernelSpec(name="k1", family="attention", mode="decode") @@ -167,11 +192,16 @@ def test_filter_by_platform( assert "aiter_decode" in mi350_names assert "triton_decode" in mi350_names - def test_filter_by_dtype(self, sample_specs): + def test_filter_by_signature(self, sample_specs): reg = KernelRegistry.get() register_all_samples(reg, sample_specs) - fp32 = reg.get_for_operator("attention", "decode", dtype=torch.float32) + signature = next( + iter( + format_signatures(("q", "k_cache", "v_cache"), "dense", {torch.float32}) + ) + ) + fp32 = reg.get_for_operator("attention", "decode", format_signature=signature) names = {s.name for s in fp32} assert "reference_decode" in names assert "flashinfer_decode" not in names @@ -287,7 +317,7 @@ def test_basic_decorator(self): "gemm", "mm", solution="reference", - dtypes={torch.bfloat16}, + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), priority=12, ) def my_torch_gemm(a, b): @@ -298,7 +328,10 @@ def my_torch_gemm(a, b): assert spec is not None assert spec.solution == "reference" assert spec.priority == 12 - assert torch.bfloat16 in spec.dtypes + assert ( + next(iter(format_signatures(("a", "b"), "dense", {torch.bfloat16}))) + in spec.format_signatures + ) impl = reg.get_impl("reference_gemm_mm") assert impl is my_torch_gemm @@ -309,7 +342,9 @@ def test_custom_name(self): "decode", name="my_custom_kernel", solution="custom", - dtypes={torch.float16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16} + ), ) def some_func(): pass @@ -326,7 +361,9 @@ def test_decorator_with_features_and_tags(self): capability=CapabilityRequirement( min_arch_version=ArchVersion(8, 0), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), tags={"determinism", "latency"}, ) def decorated_kernel(): @@ -344,7 +381,7 @@ def test_decorator_returns_original_function(self): "gemm", "mm", solution="test", - dtypes={torch.float16}, + signatures=format_signatures(("a", "b"), "dense", {torch.float16}), ) def original(x): return x * 2 diff --git a/tokenspeed-kernel/test/test_runtime_callsite_selection.py b/tokenspeed-kernel/test/test_runtime_callsite_selection.py index 1adcbd6e..4e34560a 100644 --- a/tokenspeed-kernel/test/test_runtime_callsite_selection.py +++ b/tokenspeed-kernel/test/test_runtime_callsite_selection.py @@ -47,6 +47,7 @@ import torch from tokenspeed_kernel.registry import KernelRegistry from tokenspeed_kernel.selection import select_kernel +from tokenspeed_kernel.signature import FormatSignature # -- Pre-import so they can be reloaded into the fresh registry. -- @@ -143,8 +144,10 @@ def _extract_features(node: ast.AST) -> Optional[set[str]]: return None -# A ``CallSite`` tuple: (family, mode, dtype|None, features|None, traits, expected_name, location) -CallSite = tuple[str, str, Optional[torch.dtype], Optional[set], dict, str, str] +# A ``CallSite`` tuple: (family, mode, dtype|None, features|None, traits, weight_format, expected_name, location) +CallSite = tuple[ + str, str, Optional[torch.dtype], Optional[set], dict, Optional[str], str, str +] def _collect_call_sites(search_dir: Path) -> list[CallSite]: @@ -204,9 +207,6 @@ def _collect_call_sites(search_dir: Path) -> list[CallSite]: features = _extract_features(feat_node) # -- traits -- - # For moe_* the traits come from the ``traits=`` kwarg. - # For mm() there is no ``traits=`` kwarg; instead ``quant=`` - # is the key selection-relevant parameter. traits: dict[str, Any] = {} traits_node = kwargs.get("traits") if traits_node is not None: @@ -214,18 +214,29 @@ def _collect_call_sites(search_dir: Path) -> list[CallSite]: if parsed is not None: traits = parsed - if family == "gemm": - quant_node = kwargs.get("quant") - if quant_node is not None: - qval, qok = _try_literal(quant_node) - if qok and isinstance(qval, str): - traits["quant"] = qval + weight_format: Optional[str] = None + weight_format_node = kwargs.get("weight_format") + if weight_format_node is not None: + value, ok = _try_literal(weight_format_node) + if ok and isinstance(value, str): + weight_format = value lineno = getattr(node, "lineno", 0) rel = py_path.relative_to(search_dir.parent.parent.parent) location = f"{rel}:{lineno}" - sites.append((family, mode, dtype, features, traits, expected, location)) + sites.append( + ( + family, + mode, + dtype, + features, + traits, + weight_format, + expected, + location, + ) + ) return sites @@ -251,6 +262,7 @@ def _collect_call_sites(search_dir: Path) -> list[CallSite]: torch.bfloat16, {"dispatch_sorted"}, {}, + None, "triton_moe_fused_experts", "manual:triton_common/experts", ), @@ -261,6 +273,7 @@ def _collect_call_sites(search_dir: Path) -> list[CallSite]: torch.bfloat16, None, {"num_tokens": 128, "comm_strategy": None}, + None, "triton_moe_sum_reduce", "manual:triton_common/combine_large", ), @@ -270,6 +283,7 @@ def _collect_call_sites(search_dir: Path) -> list[CallSite]: torch.bfloat16, None, {"num_tokens": 8, "comm_strategy": None}, + None, "torch_compile_moe_sum_reduce", "manual:triton_common/combine_small", ), @@ -289,24 +303,36 @@ def _collect_call_sites(search_dir: Path) -> list[CallSite]: def _infer_dtype(expected_name: str) -> torch.dtype: - """Pick a compatible dtype from the kernel's registered spec. - - Prefers common dtypes over exotic ones (e.g. ``float4_e2m1fn_x2``) - to avoid hitting platform capability gaps in tests. - """ spec = KernelRegistry.get().get_by_name(expected_name) - if spec is not None and spec.dtypes: + if spec is not None: + storage_dtypes = spec.primary_storage_dtypes() for dt in _DTYPE_PREFERENCE: - if dt in spec.dtypes: + if dt in storage_dtypes: return dt - return next(iter(spec.dtypes)) + if storage_dtypes: + return next(iter(storage_dtypes)) return torch.bfloat16 +def _infer_format_signature( + family: str, + mode: str, + dtype: torch.dtype, + weight_format: Optional[str], + spec, +) -> FormatSignature: + if family == "moe" and mode == "fused": + return _moe_pkg._moe_fused_format_signature(dtype, weight_format or "bf16") + signature = spec.format_signature_for_primary_storage_dtype(dtype) + if signature is None: + pytest.skip(f"Kernel {spec.name!r} has no signature for {dtype}") + return signature + + def _site_id(site: CallSite) -> str: """Generate a readable test-id from a call-site tuple.""" - expected = site[5] - location = site[6] + expected = site[6] + location = site[7] return f"{expected}@{location}" @@ -330,7 +356,7 @@ def _site_id(site: CallSite) -> str: ) def test_kernel_selection(site, platform_name, request): platform = request.getfixturevalue(platform_name) - family, mode, raw_dtype, features, traits, expected, location = site + family, mode, raw_dtype, features, traits, weight_format, expected, location = site reg = KernelRegistry.get() spec = reg.get_by_name(expected) @@ -343,11 +369,12 @@ def test_kernel_selection(site, platform_name, request): ) dtype = raw_dtype or _infer_dtype(expected) + signature = _infer_format_signature(family, mode, dtype, weight_format, spec) result = select_kernel( family, mode, - dtype, + signature, features=frozenset(features) if features else None, traits=traits, platform=platform, @@ -355,5 +382,5 @@ def test_kernel_selection(site, platform_name, request): assert result.name == expected, ( f"Expected '{expected}' but got '{result.name}' " f"at {location} on {platform.device_name} " - f"for {family}.{mode}(dtype={dtype}, features={features}, traits={traits})" + f"for {family}.{mode}(signature={signature}, features={features}, traits={traits})" ) diff --git a/tokenspeed-kernel/test/test_selection.py b/tokenspeed-kernel/test/test_selection.py index 21711217..c5b92171 100644 --- a/tokenspeed-kernel/test/test_selection.py +++ b/tokenspeed-kernel/test/test_selection.py @@ -55,10 +55,27 @@ spec_matches_traits, warmup_selection, ) +from tokenspeed_kernel.signature import ( + ScaleFormat, + dense_tensor_format, + format_signature, + format_signatures, + tensor_format, +) from utils import register_all_samples pytestmark = pytest.mark.usefixtures("fresh_registry") +ATTN_DECODE_BF16 = next( + iter(format_signatures(("q", "k_cache", "v_cache"), "dense", {torch.bfloat16})) +) +ATTN_PREFILL_BF16 = next( + iter(format_signatures(("q", "k", "v"), "dense", {torch.bfloat16})) +) +GEMM_BF16 = next(iter(format_signatures(("a", "b"), "dense", {torch.bfloat16}))) +GEMM_FP16 = next(iter(format_signatures(("a", "b"), "dense", {torch.float16}))) +INPUT_BF16 = next(iter(format_signatures("input", "dense", {torch.bfloat16}))) + class TestSelectionObjective: def test_all_enum_values(self): @@ -192,7 +209,7 @@ def test_rank_orders_lexicographically(self, sample_specs, h100_platform): "attention", "decode", platform=h100_platform, - dtype=torch.bfloat16, + format_signature=ATTN_DECODE_BF16, ) scored = _rank_by_objective( candidates, @@ -390,7 +407,7 @@ def test_non_alignment_traits_do_not_affect_shape_matching(self): name="k", family="f", mode="m", - traits={"quant": frozenset({"mxfp8"})}, + traits={"persistent": frozenset({True})}, ) assert spec_matches_shape_traits(spec, {"N": 30, "K": 70}) @@ -401,7 +418,7 @@ def test_deterministic(self): k1 = _make_cache_key( "attn", "dec", - torch.bfloat16, + INPUT_BF16, "sm_90", SelectionObjective.DEFAULT, None, @@ -410,7 +427,7 @@ def test_deterministic(self): k2 = _make_cache_key( "attn", "dec", - torch.bfloat16, + INPUT_BF16, "sm_90", SelectionObjective.DEFAULT, None, @@ -422,7 +439,7 @@ def test_different_objective(self): k1 = _make_cache_key( "attn", "dec", - torch.bfloat16, + INPUT_BF16, "sm_90", SelectionObjective.DEFAULT, None, @@ -431,7 +448,7 @@ def test_different_objective(self): k2 = _make_cache_key( "attn", "dec", - torch.bfloat16, + INPUT_BF16, "sm_90", SelectionObjective.LATENCY, None, @@ -443,7 +460,7 @@ def test_traits_order_independent(self): k1 = _make_cache_key( "a", "d", - torch.float16, + GEMM_FP16, "sm_90", SelectionObjective.DEFAULT, None, @@ -452,7 +469,7 @@ def test_traits_order_independent(self): k2 = _make_cache_key( "a", "d", - torch.float16, + GEMM_FP16, "sm_90", SelectionObjective.DEFAULT, None, @@ -464,10 +481,10 @@ def test_features_order_independent(self): f1 = frozenset({"paged", "mla"}) f2 = frozenset({"mla", "paged"}) k1 = _make_cache_key( - "a", "d", torch.float16, "sm_90", SelectionObjective.DEFAULT, f1, None + "a", "d", GEMM_FP16, "sm_90", SelectionObjective.DEFAULT, f1, None ) k2 = _make_cache_key( - "a", "d", torch.float16, "sm_90", SelectionObjective.DEFAULT, f2, None + "a", "d", GEMM_FP16, "sm_90", SelectionObjective.DEFAULT, f2, None ) assert k1 == k2 @@ -475,7 +492,7 @@ def test_solution_is_selection_relevant(self): k1 = _make_cache_key( "a", "d", - torch.float16, + GEMM_FP16, "sm_90", SelectionObjective.DEFAULT, None, @@ -485,7 +502,7 @@ def test_solution_is_selection_relevant(self): k2 = _make_cache_key( "a", "d", - torch.float16, + GEMM_FP16, "sm_90", SelectionObjective.DEFAULT, None, @@ -503,7 +520,7 @@ def test_basic_selection(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert callable(impl) @@ -513,16 +530,16 @@ def test_cached_on_second_call(self, sample_specs, h100_platform): register_all_samples(reg, sample_specs) impl1 = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) impl2 = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl1 is impl2 def test_no_kernel_raises(self, h100_platform): with pytest.raises(NoKernelFoundError): - select_kernel("nonexistent", "op", torch.bfloat16, platform=h100_platform) + select_kernel("nonexistent", "op", INPUT_BF16, platform=h100_platform) def test_no_kernel_after_trait_filter(self, h100_platform): reg = KernelRegistry.get() @@ -532,7 +549,7 @@ def test_no_kernel_after_trait_filter(self, h100_platform): mode="m", solution="triton", priority=10, - dtypes=frozenset({torch.bfloat16}), + format_signatures=frozenset({INPUT_BF16}), traits={"head_dim": frozenset({64})}, ) reg.register(spec, lambda: None) @@ -541,11 +558,101 @@ def test_no_kernel_after_trait_filter(self, h100_platform): select_kernel( "trait_op", "m", - torch.bfloat16, + INPUT_BF16, platform=h100_platform, traits={"head_dim": 128}, ) + def test_selects_exact_mixed_operand_signature(self, h100_platform): + reg = KernelRegistry.get() + scale = ScaleFormat( + storage_dtype=torch.uint8, + granularity="block", + block_shape=(32,), + ) + mixed_signature = format_signature( + a=dense_tensor_format(torch.bfloat16), + b=tensor_format("mxfp4", torch.uint8, scale=scale), + ) + dense_uint8_signature = format_signature( + a=dense_tensor_format(torch.bfloat16), + b=dense_tensor_format(torch.uint8), + ) + + reg.register( + KernelSpec( + name="dense_uint8", + family="gemm", + mode="mm", + solution="test", + format_signatures=frozenset({dense_uint8_signature}), + priority=19, + ), + lambda: "dense_uint8", + ) + reg.register( + KernelSpec( + name="mixed_mxfp4", + family="gemm", + mode="mm", + solution="test", + format_signatures=frozenset({mixed_signature}), + priority=5, + ), + lambda: "mixed_mxfp4", + ) + + impl = select_kernel( + "gemm", + "mm", + mixed_signature, + platform=h100_platform, + ) + + assert impl() == "mixed_mxfp4" + + def test_selects_each_registered_format_signature(self, h100_platform): + reg = KernelRegistry.get() + fp8_scale = ScaleFormat( + storage_dtype=torch.float32, + granularity="tensor", + ) + fp8_signature = format_signature( + a=tensor_format("fp8", torch.float8_e4m3fn, scale=fp8_scale), + b=tensor_format("fp8", torch.float8_e4m3fn, scale=fp8_scale), + ) + + reg.register( + KernelSpec( + name="dense_multi", + family="gemm", + mode="mm", + solution="test", + format_signatures=frozenset({GEMM_BF16, GEMM_FP16}), + priority=10, + ), + lambda: "dense_multi", + ) + reg.register( + KernelSpec( + name="fp8_scaled", + family="gemm", + mode="mm", + solution="test", + format_signatures=frozenset({fp8_signature}), + priority=10, + ), + lambda: "fp8_scaled", + ) + + bf16_impl = select_kernel("gemm", "mm", GEMM_BF16, platform=h100_platform) + fp16_impl = select_kernel("gemm", "mm", GEMM_FP16, platform=h100_platform) + fp8_impl = select_kernel("gemm", "mm", fp8_signature, platform=h100_platform) + + assert bf16_impl() == "dense_multi" + assert fp16_impl() == "dense_multi" + assert fp8_impl() == "fp8_scaled" + def test_override_by_name(self, sample_specs, h100_platform): reg = KernelRegistry.get() register_all_samples(reg, sample_specs) @@ -553,7 +660,7 @@ def test_override_by_name(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, override="reference_decode", ) @@ -566,7 +673,7 @@ def test_override_by_solution(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, override="triton", ) @@ -580,7 +687,7 @@ def test_solution_filter_preserves_trait_filtering(self, h100_platform): family="attention", mode="prefill", solution="fa4", - dtypes=frozenset({torch.bfloat16}), + format_signatures=frozenset({ATTN_PREFILL_BF16}), traits={"head_dim": frozenset({128})}, priority=15, ), @@ -592,7 +699,7 @@ def test_solution_filter_preserves_trait_filtering(self, h100_platform): family="attention", mode="prefill", solution="triton", - dtypes=frozenset({torch.bfloat16}), + format_signatures=frozenset({ATTN_PREFILL_BF16}), traits={"head_dim": frozenset({256})}, priority=10, ), @@ -602,7 +709,7 @@ def test_solution_filter_preserves_trait_filtering(self, h100_platform): impl = select_kernel( "attention", "prefill", - torch.bfloat16, + ATTN_PREFILL_BF16, platform=h100_platform, solution="fa4", traits={"head_dim": 128}, @@ -613,7 +720,7 @@ def test_solution_filter_preserves_trait_filtering(self, h100_platform): select_kernel( "attention", "prefill", - torch.bfloat16, + ATTN_PREFILL_BF16, platform=h100_platform, solution="fa4", traits={"head_dim": 256}, @@ -627,7 +734,7 @@ def test_override_not_found_raises(self, sample_specs, h100_platform): select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, override="nonexistent_kernel", ) @@ -643,7 +750,7 @@ def test_env_override(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert impl() == "reference_decode" @@ -655,7 +762,7 @@ def test_portability_objective_prefers_triton(self, sample_specs, h100_platform) impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, objective=SelectionObjective.PORTABILITY, ) @@ -669,7 +776,7 @@ def test_debug_objective_prefers_reference(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, objective=SelectionObjective.DEBUG, ) @@ -682,7 +789,7 @@ def test_amd_platform_selects_aiter(self, sample_specs, mi300_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=mi300_platform, ) assert impl() == "aiter_decode" @@ -694,7 +801,7 @@ def test_amd_mi350_platform_selects_aiter(self, sample_specs, mi350_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=mi350_platform, ) assert impl() == "aiter_decode" @@ -721,7 +828,7 @@ def adjust(self, spec, platform, traits): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert impl() == "triton_decode" @@ -736,7 +843,7 @@ def test_context_manager_overrides(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert impl() == "reference_decode" @@ -751,7 +858,7 @@ def test_context_manager_restores(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert impl() != "reference_decode" or True @@ -762,18 +869,18 @@ def test_nested_override(self, sample_specs, h100_platform): with kernel_override("attention", "decode", "reference_decode"): impl1 = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl1() == "reference_decode" with kernel_override("attention", "decode", "triton_decode"): impl2 = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl2() == "triton_decode" impl3 = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl3() == "reference_decode" @@ -783,7 +890,7 @@ def test_set_policy_clears_cache(self, sample_specs, h100_platform): reg = KernelRegistry.get() register_all_samples(reg, sample_specs) - select_kernel("attention", "decode", torch.bfloat16, platform=h100_platform) + select_kernel("attention", "decode", ATTN_DECODE_BF16, platform=h100_platform) set_selection_policy( SelectionPolicy(default_strategy=SelectionStrategy.AUTOTUNE) ) @@ -798,7 +905,7 @@ def test_output_contains_expected_sections(self, sample_specs, h100_platform): explanation = explain_selection( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert "attention.decode" in explanation @@ -813,7 +920,7 @@ def test_filtered_out_section(self, sample_specs, h100_platform): explanation = explain_selection( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert "Filtered out" in explanation @@ -823,7 +930,7 @@ def test_empty_candidates(self, h100_platform): explanation = explain_selection( "nonexistent", "op", - torch.bfloat16, + INPUT_BF16, platform=h100_platform, ) assert "0 matched" in explanation @@ -853,8 +960,8 @@ def test_warmup_explicit_ops(self, sample_specs, h100_platform): try: warmup_selection( ops=[ - ("attention", "decode", torch.bfloat16, None), - ("gemm", "mm", torch.bfloat16, None), + ("attention", "decode", ATTN_DECODE_BF16, None), + ("gemm", "mm", GEMM_BF16, None), ] ) assert len(reg._selection_cache) >= 2 @@ -867,7 +974,7 @@ def test_warmup_skips_missing_ops(self, h100_platform): Platform.override(h100_platform) try: - warmup_selection(ops=[("nonexistent", "op", torch.bfloat16, None)]) + warmup_selection(ops=[("nonexistent", "op", INPUT_BF16, None)]) finally: Platform.reset() @@ -885,7 +992,7 @@ def test_autotune_falls_back_to_heuristic(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert callable(impl) @@ -1054,7 +1161,7 @@ def test_config_override_by_name(self, sample_specs, h100_platform, tmp_path): ) impl = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl() == "reference_decode" @@ -1068,7 +1175,7 @@ def test_config_override_by_solution(self, sample_specs, h100_platform, tmp_path ) impl = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl() == "triton_decode" @@ -1084,7 +1191,7 @@ def test_config_override_objective(self, sample_specs, h100_platform, tmp_path): ) impl = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl() == "reference_decode" @@ -1103,7 +1210,7 @@ def test_api_override_takes_priority_over_config( impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, override="triton_decode", ) @@ -1126,7 +1233,7 @@ def test_env_var_override_takes_priority_over_config( {"TOKENSPEED_KERNEL_OVERRIDE_ATTENTION_DECODE": "triton_decode"}, ): impl = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl() == "triton_decode" @@ -1144,7 +1251,7 @@ def test_context_manager_override_takes_priority_over_config( with kernel_override("attention", "decode", "triton_decode"): impl = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl() == "triton_decode" @@ -1163,7 +1270,7 @@ def test_explicit_objective_takes_priority_over_config( impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, objective=SelectionObjective.PORTABILITY, ) @@ -1182,7 +1289,9 @@ def test_config_override_not_found_raises( ) with pytest.raises(NoKernelFoundError, match="Override"): - select_kernel("attention", "decode", torch.bfloat16, platform=h100_platform) + select_kernel( + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform + ) def test_config_with_invalid_objective_falls_back( self, sample_specs, h100_platform, tmp_path @@ -1199,7 +1308,7 @@ def test_config_with_invalid_objective_falls_back( ) impl = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert callable(impl) @@ -1214,11 +1323,11 @@ def test_unrelated_ops_unaffected(self, sample_specs, h100_platform, tmp_path): ) attn_impl = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert attn_impl() == "reference_decode" - gemm_impl = select_kernel("gemm", "mm", torch.bfloat16, platform=h100_platform) + gemm_impl = select_kernel("gemm", "mm", GEMM_BF16, platform=h100_platform) assert gemm_impl() != "reference_decode" @@ -1248,7 +1357,7 @@ def _register_kernel(name: str, solution: str, impl) -> None: family="gemm", mode="mm", solution=solution, - dtypes=frozenset({torch.float16}), + format_signatures=frozenset({GEMM_FP16}), priority=50, ) KernelRegistry.get().register(spec, impl) diff --git a/tokenspeed-kernel/test/utils.py b/tokenspeed-kernel/test/utils.py index ec2bd0ae..414be3d0 100644 --- a/tokenspeed-kernel/test/utils.py +++ b/tokenspeed-kernel/test/utils.py @@ -24,7 +24,10 @@ import torch from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement -from tokenspeed_kernel.registry import KernelRegistry, KernelSpec +from tokenspeed_kernel.registry import KernelRegistry, register_kernel +from tokenspeed_kernel.signature import FormatSignature, format_signatures + +SampleRegistration = tuple[dict, Callable] def dummy_impl(name: str) -> Callable: @@ -35,17 +38,45 @@ def impl(*args, **kwargs): return impl -def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: - specs: dict[str, tuple[KernelSpec, Callable]] = {} +def _sample_registration( + name: str, + family: str, + mode: str, + solution: str, + signatures: frozenset[FormatSignature], + *, + features: frozenset[str] | None = None, + capability: CapabilityRequirement | None = None, + priority: int = 10, + tags: frozenset[str] | None = None, +) -> SampleRegistration: + return ( + { + "family": family, + "mode": mode, + "name": name, + "solution": solution, + "features": features, + "capability": capability, + "signatures": signatures, + "priority": priority, + "tags": tags, + }, + dummy_impl(name), + ) + - specs["flashinfer_decode"] = ( - KernelSpec( - name="flashinfer_decode", - family="attention", - mode="decode", - solution="flashinfer", +def make_sample_specs() -> dict[str, SampleRegistration]: + return { + "flashinfer_decode": _sample_registration( + "flashinfer_decode", + "attention", + "decode", + "flashinfer", + format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), features=frozenset({"paged"}), - dtypes=frozenset({torch.float16, torch.bfloat16}), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(8, 0), @@ -53,30 +84,26 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: priority=18, tags=frozenset({"latency"}), ), - dummy_impl("flashinfer_decode"), - ) - - specs["triton_decode"] = ( - KernelSpec( - name="triton_decode", - family="attention", - mode="decode", - solution="triton", + "triton_decode": _sample_registration( + "triton_decode", + "attention", + "decode", + "triton", + format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), features=frozenset({"paged"}), - dtypes=frozenset({torch.float16, torch.bfloat16}), priority=10, tags=frozenset({"portability"}), ), - dummy_impl("triton_decode"), - ) - - specs["cutlass_prefill"] = ( - KernelSpec( - name="cutlass_prefill", - family="attention", - mode="prefill", - solution="cutlass", - dtypes=frozenset({torch.float16, torch.bfloat16}), + "cutlass_prefill": _sample_registration( + "cutlass_prefill", + "attention", + "prefill", + "cutlass", + format_signatures( + ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} + ), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(9, 0), @@ -84,48 +111,40 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: priority=16, tags=frozenset({"throughput"}), ), - dummy_impl("cutlass_prefill"), - ) - - specs["reference_decode"] = ( - KernelSpec( - name="reference_decode", - family="attention", - mode="decode", - solution="reference", + "reference_decode": _sample_registration( + "reference_decode", + "attention", + "decode", + "reference", + format_signatures( + ("q", "k_cache", "v_cache"), + "dense", + {torch.float16, torch.bfloat16, torch.float32}, + ), features=frozenset({"paged"}), - dtypes=frozenset({torch.float16, torch.bfloat16, torch.float32}), capability=CapabilityRequirement(), priority=10, tags=frozenset({"determinism", "portability"}), ), - dummy_impl("reference_decode"), - ) - - specs["aiter_decode"] = ( - KernelSpec( - name="aiter_decode", - family="attention", - mode="decode", - solution="aiter", - features=frozenset({"paged"}), - dtypes=frozenset({torch.float16, torch.bfloat16}), - capability=CapabilityRequirement( - vendors=frozenset({"amd"}), + "aiter_decode": _sample_registration( + "aiter_decode", + "attention", + "decode", + "aiter", + format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} ), + features=frozenset({"paged"}), + capability=CapabilityRequirement(vendors=frozenset({"amd"})), priority=16, tags=frozenset({"latency", "portability"}), ), - dummy_impl("aiter_decode"), - ) - - specs["cutlass_gemm"] = ( - KernelSpec( - name="cutlass_gemm", - family="gemm", - mode="mm", - solution="cutlass", - dtypes=frozenset({torch.float16, torch.bfloat16}), + "cutlass_gemm": _sample_registration( + "cutlass_gemm", + "gemm", + "mm", + "cutlass", + format_signatures(("a", "b"), "dense", {torch.float16, torch.bfloat16}), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(8, 0), @@ -133,29 +152,21 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: priority=15, tags=frozenset({"throughput", "latency"}), ), - dummy_impl("cutlass_gemm"), - ) - - specs["triton_gemm"] = ( - KernelSpec( - name="triton_gemm", - family="gemm", - mode="mm", - solution="triton", - dtypes=frozenset({torch.float16, torch.bfloat16}), + "triton_gemm": _sample_registration( + "triton_gemm", + "gemm", + "mm", + "triton", + format_signatures(("a", "b"), "dense", {torch.float16, torch.bfloat16}), priority=10, tags=frozenset({"portability"}), ), - dummy_impl("triton_gemm"), - ) - - specs["cutlass_grouped_gemm"] = ( - KernelSpec( - name="cutlass_grouped_gemm", - family="gemm", - mode="grouped_mm", - solution="cutlass", - dtypes=frozenset({torch.float16, torch.bfloat16}), + "cutlass_grouped_gemm": _sample_registration( + "cutlass_grouped_gemm", + "gemm", + "grouped_mm", + "cutlass", + format_signatures(("a", "b"), "dense", {torch.float16, torch.bfloat16}), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(9, 0), @@ -163,42 +174,34 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: priority=16, tags=frozenset({"throughput"}), ), - dummy_impl("cutlass_grouped_gemm"), - ) - - specs["triton_grouped_gemm"] = ( - KernelSpec( - name="triton_grouped_gemm", - family="gemm", - mode="grouped_mm", - solution="triton", - dtypes=frozenset({torch.float16, torch.bfloat16}), + "triton_grouped_gemm": _sample_registration( + "triton_grouped_gemm", + "gemm", + "grouped_mm", + "triton", + format_signatures(("a", "b"), "dense", {torch.float16, torch.bfloat16}), priority=10, tags=frozenset({"portability"}), ), - dummy_impl("triton_grouped_gemm"), - ) - - specs["triton_fused_moe"] = ( - KernelSpec( - name="triton_fused_moe", - family="moe", - mode="fused", - solution="triton", - dtypes=frozenset({torch.float16, torch.bfloat16}), + "triton_fused_moe": _sample_registration( + "triton_fused_moe", + "moe", + "fused", + "triton", + format_signatures( + ("x", "weight"), "dense", {torch.float16, torch.bfloat16} + ), priority=12, tags=frozenset({"throughput", "portability"}), ), - dummy_impl("triton_fused_moe"), - ) - - specs["cutlass_fused_moe"] = ( - KernelSpec( - name="cutlass_fused_moe", - family="moe", - mode="fused", - solution="cutlass", - dtypes=frozenset({torch.float16, torch.bfloat16}), + "cutlass_fused_moe": _sample_registration( + "cutlass_fused_moe", + "moe", + "fused", + "cutlass", + format_signatures( + ("x", "weight"), "dense", {torch.float16, torch.bfloat16} + ), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(9, 0), @@ -206,29 +209,21 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: priority=15, tags=frozenset({"latency", "throughput"}), ), - dummy_impl("cutlass_fused_moe"), - ) - - specs["triton_modular_moe"] = ( - KernelSpec( - name="triton_modular_moe", - family="moe", - mode="modular", - solution="triton", - dtypes=frozenset({torch.float16, torch.bfloat16}), + "triton_modular_moe": _sample_registration( + "triton_modular_moe", + "moe", + "modular", + "triton", + format_signatures("x", "dense", {torch.float16, torch.bfloat16}), priority=10, tags=frozenset({"determinism", "portability"}), ), - dummy_impl("triton_modular_moe"), - ) - - specs["cutlass_modular_moe"] = ( - KernelSpec( - name="cutlass_modular_moe", - family="moe", - mode="modular", - solution="cutlass", - dtypes=frozenset({torch.float16, torch.bfloat16}), + "cutlass_modular_moe": _sample_registration( + "cutlass_modular_moe", + "moe", + "modular", + "cutlass", + format_signatures("x", "dense", {torch.float16, torch.bfloat16}), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(8, 0), @@ -236,12 +231,13 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: priority=14, tags=frozenset({"throughput"}), ), - dummy_impl("cutlass_modular_moe"), - ) - - return specs + } -def register_all_samples(registry: KernelRegistry, specs: dict) -> None: - for spec, impl in specs.values(): - registry.register(spec, impl) +def register_all_samples( + registry: KernelRegistry, samples: dict[str, SampleRegistration] +) -> None: + if registry is not KernelRegistry.get(): + raise ValueError("sample registrations must target the active KernelRegistry") + for options, impl in samples.values(): + register_kernel(**options)(impl)