From 518ca9692d4e5adf11994637c4ad5618d00b57ce Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 23 May 2026 06:22:33 +0000 Subject: [PATCH 01/16] Use operand format signatures for kernel selection Replace KernelSpec.dtypes with format_signatures so registrations can describe role-specific operand tensor formats instead of a flat dtype set. Model each tensor role with TensorFormat and optional ScaleFormat, using storage_dtype consistently for payload and scale sidecar storage. Convert in-tree GEMM, attention, embedding, quantization, MoE, reference, and plugin registrations to the new FormatSignature API. Move quantized GEMM and MoE fused format choices out of selection traits, with MoE fused call sites passing weight_format explicitly. Update selection, numerics, benchmark, docs, and tests to resolve by full FormatSignature while preserving dtype-oriented convenience filters through primary_storage_dtype helpers. Signed-off-by: Lei Zhang --- .../moe/backends/fp8/flashinfer_cutlass.py | 2 +- .../layers/moe/backends/mxfp4/flashinfer.py | 2 +- .../moe/backends/mxfp4/triton_kernel.py | 2 - .../moe/backends/nvfp4/flashinfer_cutedsl.py | 2 +- .../moe/backends/nvfp4/flashinfer_cutlass.py | 2 +- .../moe/backends/nvfp4/flashinfer_trtllm.py | 2 +- .../unquantized/flashinfer_cutlass.py | 2 +- .../backends/unquantized/flashinfer_trtllm.py | 2 +- tokenspeed-kernel/README.md | 7 +- .../tokenspeed_kernel/benchmark/runner.py | 32 ++- .../python/tokenspeed_kernel/numerics/cli.py | 8 +- .../python/tokenspeed_kernel/numerics/gemm.py | 95 ++++++-- .../tokenspeed_kernel/numerics/inputs.py | 15 +- .../numerics/reference/gemm.py | 39 +++- .../numerics/reference/moe.py | 34 ++- .../numerics/reference/quantize.py | 6 +- .../tokenspeed_kernel/numerics/verify.py | 9 +- .../ops/attention/__init__.py | 23 +- .../ops/attention/cuda/__init__.py | 10 +- .../ops/attention/flash_attn/__init__.py | 31 ++- .../ops/attention/flashinfer/__init__.py | 24 ++- .../attention/gluon/mha_decode_fp16_gfx950.py | 10 +- .../gluon/mha_prefill_fp16_gfx950.py | 10 +- .../ops/attention/triton/__init__.py | 22 +- .../ops/embedding/__init__.py | 10 +- .../tokenspeed_kernel/ops/embedding/cuda.py | 4 +- .../tokenspeed_kernel/ops/embedding/triton.py | 10 +- .../tokenspeed_kernel/ops/gemm/__init__.py | 78 +++++-- .../tokenspeed_kernel/ops/gemm/deep_gemm.py | 21 +- .../tokenspeed_kernel/ops/gemm/flashinfer.py | 36 +++- .../tokenspeed_kernel/ops/gemm/triton.py | 39 +++- .../tokenspeed_kernel/ops/gemm/trtllm.py | 23 +- .../tokenspeed_kernel/ops/moe/__init__.py | 94 ++++++-- .../python/tokenspeed_kernel/ops/moe/cuda.py | 11 +- .../tokenspeed_kernel/ops/moe/deepep.py | 10 +- .../tokenspeed_kernel/ops/moe/flashinfer.py | 105 ++++++++- .../tokenspeed_kernel/ops/moe/triton.py | 18 +- .../ops/moe/triton_kernels.py | 14 +- .../tokenspeed_kernel/ops/moe/trtllm.py | 3 +- .../ops/quantization/__init__.py | 19 +- .../ops/quantization/flashinfer.py | 11 +- .../ops/quantization/triton.py | 8 +- .../ops/quantization/trtllm.py | 9 +- .../tokenspeed_kernel/plugins/README.md | 14 +- .../python/tokenspeed_kernel/registry.py | 46 +++- .../python/tokenspeed_kernel/selection.py | 103 +++++---- .../python/tokenspeed_kernel/signature.py | 202 ++++++++++++++++++ tokenspeed-kernel/test/conftest.py | 17 +- tokenspeed-kernel/test/test_benchmark.py | 10 +- .../test/test_kernel_api_selection.py | 34 ++- tokenspeed-kernel/test/test_numerics.py | 54 ++++- tokenspeed-kernel/test/test_plugins.py | 29 ++- tokenspeed-kernel/test/test_registry.py | 54 ++++- .../test/test_runtime_callsite_selection.py | 77 ++++--- tokenspeed-kernel/test/test_selection.py | 132 +++++++----- tokenspeed-kernel/test/utils.py | 60 ++++-- 56 files changed, 1382 insertions(+), 364 deletions(-) create mode 100644 tokenspeed-kernel/python/tokenspeed_kernel/signature.py diff --git a/python/tokenspeed/runtime/layers/moe/backends/fp8/flashinfer_cutlass.py b/python/tokenspeed/runtime/layers/moe/backends/fp8/flashinfer_cutlass.py index e4f9ee749..99c6c4df5 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/fp8/flashinfer_cutlass.py +++ b/python/tokenspeed/runtime/layers/moe/backends/fp8/flashinfer_cutlass.py @@ -128,8 +128,8 @@ def forward( activation_type=ActivationType.Swiglu, dtype=x.dtype, features={"pre_routed"}, + weight_format="fp8", traits={ - "weight_dtype": "fp8", "tp": self.spec.tp_size > 1, "ep": self.spec.ep_size > 1, "cuda_graph": False, diff --git a/python/tokenspeed/runtime/layers/moe/backends/mxfp4/flashinfer.py b/python/tokenspeed/runtime/layers/moe/backends/mxfp4/flashinfer.py index e82e4db8a..3f8bd2066 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/mxfp4/flashinfer.py +++ b/python/tokenspeed/runtime/layers/moe/backends/mxfp4/flashinfer.py @@ -409,7 +409,7 @@ def _call_kernel(self, router_logits, x_quant, x_scale, layer, top_k, output): output=output, dtype=torch.bfloat16, features={"self_routing"}, - traits={"weight_dtype": "mxfp4"}, + weight_format="mxfp4", expected_kernel_name="flashinfer_trtllm_fp4_fused_moe", )[0] diff --git a/python/tokenspeed/runtime/layers/moe/backends/mxfp4/triton_kernel.py b/python/tokenspeed/runtime/layers/moe/backends/mxfp4/triton_kernel.py index 1d308e0ff..859e9c5f1 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/mxfp4/triton_kernel.py +++ b/python/tokenspeed/runtime/layers/moe/backends/mxfp4/triton_kernel.py @@ -301,7 +301,6 @@ def forward( fused_activation=act, dtype=hidden_states.dtype, features={"ragged_metadata", "dispatch_gemm"}, - traits={"weight_dtype": "mxfp4"}, expected_kernel_name="triton_kernels_dispatch_gemm", ) @@ -328,7 +327,6 @@ def forward( n_expts_act=top_k, dtype=hidden_states.dtype, features={"ragged_metadata", "gemm_combine"}, - traits={"weight_dtype": "mxfp4"}, expected_kernel_name="triton_kernels_gemm_combine", ) diff --git a/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutedsl.py b/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutedsl.py index 32dcd352b..4c699edd4 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutedsl.py +++ b/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutedsl.py @@ -161,8 +161,8 @@ def _call_kernel( capacity=capacity, dtype=x_fp4.dtype, features={"pre_routed"}, + weight_format="nvfp4", traits={ - "weight_dtype": "nvfp4", "tp": False, "ep": True, "cuda_graph": True, diff --git a/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutlass.py b/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutlass.py index 6bd67d73d..c219636c8 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutlass.py +++ b/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_cutlass.py @@ -128,8 +128,8 @@ def forward( activation_type=ActivationType.Swiglu, dtype=x.dtype, features={"pre_routed"}, + weight_format="nvfp4", traits={ - "weight_dtype": "nvfp4", "tp": True, "ep": True, "cuda_graph": False, diff --git a/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_trtllm.py b/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_trtllm.py index c762a4b62..184bcd89a 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_trtllm.py +++ b/python/tokenspeed/runtime/layers/moe/backends/nvfp4/flashinfer_trtllm.py @@ -336,7 +336,7 @@ def forward( tune_max_num_tokens=next_power_of_2(num_tokens), dtype=x.dtype, features={"self_routing"}, - traits={"weight_dtype": "nvfp4"}, + weight_format="nvfp4", expected_kernel_name="flashinfer_trtllm_fp4_fused_moe", ) if do_finalize: diff --git a/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_cutlass.py b/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_cutlass.py index 144587e5f..f76b4d2f3 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_cutlass.py +++ b/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_cutlass.py @@ -105,8 +105,8 @@ def _call_cutlass_kernel(self, x, layer, topk_output): activation_type=ActivationType.Swiglu, dtype=x.dtype, features={"pre_routed"}, + weight_format="bf16", traits={ - "weight_dtype": "bf16", "tp": True, "ep": True, "cuda_graph": False, diff --git a/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_trtllm.py b/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_trtllm.py index 4370b4462..c4a31168d 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_trtllm.py +++ b/python/tokenspeed/runtime/layers/moe/backends/unquantized/flashinfer_trtllm.py @@ -203,7 +203,7 @@ def _call_trtllm_kernel(self, router_logits, x, layer, top_k, do_finalize): tune_max_num_tokens=next_power_of_2(x.shape[0]), dtype=x.dtype, features={"self_routing"}, - traits={"weight_dtype": "bf16"}, + weight_format="bf16", expected_kernel_name="flashinfer_trtllm_bf16_fused_moe", ) if do_finalize: diff --git a/tokenspeed-kernel/README.md b/tokenspeed-kernel/README.md index 8bac78f83..ead39fdc9 100644 --- a/tokenspeed-kernel/README.md +++ b/tokenspeed-kernel/README.md @@ -37,7 +37,7 @@ choices (still evolving; subject to change): public API (mha_prefill, mm, moe_fused, ...) │ ┌───────────┴───────────┐ - │ select_kernel │ (family, mode, dtype, traits, ...) + │ select_kernel │ (family, mode, format_signature, traits, ...) └───────────┬───────────┘ │ queries ┌──────────┴──────────┐ @@ -57,8 +57,8 @@ choices (still evolving; subject to change): ``` - **Registration** — backends register with `@register_kernel(family, mode, ...)`, - declaring supported dtypes, arch capability requirements, traits (head dim, - GQA factor, ...), and a priority band. + declaring supported `format_signatures`, arch capability requirements, + non-format traits (head dim, GQA factor, ...), and a priority band. - **Auto-selection** — `select_kernel` filters by capability and traits, ranks the survivors with an optional per-family `SelectionOracle` and priority, and returns a callable. Selection accepts an objective (latency, @@ -71,6 +71,7 @@ choices (still evolving; subject to change): tokenspeed_kernel/ __init__.py # Public API re-exports platform.py # PlatformInfo, capability detection + signature.py # TensorFormat, ScaleFormat, FormatSignature registry.py # KernelRegistry, register_kernel, Priority bands selection.py # select_kernel, oracles, overrides profiling.py # ShapeCapture, kernel_scope, Proton bootstrap diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/benchmark/runner.py b/tokenspeed-kernel/python/tokenspeed_kernel/benchmark/runner.py index 0f5a7b35e..508dfbc7f 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/benchmark/runner.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/benchmark/runner.py @@ -139,11 +139,16 @@ def _benchmark_one_shape( if not spec_matches_shape_traits(spec, shape): return None + signature = spec.format_signature_for_storage_dtype(dtype) + if signature is None: + return None + generator = get_input_generator( spec.family, spec.mode, dtype=dtype, traits=spec.traits, + format_signature=signature, device="cuda", seed=self.config.seed, ) @@ -224,10 +229,14 @@ def _verify_one_shape( return None, None, None registry = KernelRegistry.get() + signature = spec.format_signature_for_storage_dtype(dtype) + if signature is None: + return None, None, None + ref_specs = registry.get_for_operator( spec.family, spec.mode, - dtype=dtype, + format_signature=signature, solution="reference", ) if not ref_specs: @@ -285,8 +294,10 @@ def _benchmark_kernel_impl( if spec is None: raise ValueError(f"Kernel {kernel_name!r} is not registered") - if dtype not in spec.dtypes: - raise ValueError(f"Kernel {kernel_name!r} does not support dtype={dtype}") + if spec.format_signature_for_storage_dtype(dtype) is None: + raise ValueError( + f"Kernel {kernel_name!r} does not support primary storage dtype={dtype}" + ) platform = current_platform() if not spec.capability.satisfied_by(platform): @@ -337,12 +348,15 @@ def _benchmark_op_impl( """Benchmark all implementations of an op.""" registry = KernelRegistry.get() platform = current_platform() - specs = registry.get_for_operator( - op_family, - op_mode, - platform=platform, - dtype=dtype, - ) + specs = [ + spec + for spec in registry.get_for_operator( + op_family, + op_mode, + platform=platform, + ) + if spec.format_signature_for_storage_dtype(dtype) is not None + ] results: list[BenchmarkResult] = [] for spec in sorted(specs, key=lambda item: (item.solution, item.name)): diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/cli.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/cli.py index 14ddd136d..62e361e7f 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/cli.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/cli.py @@ -80,7 +80,11 @@ def _iter_candidate_specs( specs = [s for s in specs if s.family == family and s.mode == mode] if dtype_filter is not None: - specs = [s for s in specs if dtype_filter in s.dtypes] + specs = [ + s + for s in specs + if s.format_signature_for_storage_dtype(dtype_filter) is not None + ] specs.sort(key=lambda s: (s.family, s.mode, s.name)) return specs @@ -92,7 +96,7 @@ def _iter_dtypes( ) -> Iterable[torch.dtype]: if dtype_filter is not None: return (dtype_filter,) - return sorted(spec.dtypes, key=str) + return sorted(spec.primary_storage_dtypes(), key=str) def main(argv: list[str] | None = None) -> int: diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py index e3a98eae3..976ae20e0 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py @@ -31,6 +31,7 @@ set_standard_shapes, ) from tokenspeed_kernel.numerics.tolerance import Tolerance, set_family_tolerance +from tokenspeed_kernel.signature import TensorFormat # --------------------------------------------------------------------------- # Tolerance @@ -119,45 +120,93 @@ def _generate_scales(self, shape: tuple[int, ...], dtype) -> torch.Tensor: ) return scales.to(dtype) + def _format(self, role: str) -> TensorFormat | None: + if self.format_signature is None: + return None + return self.format_signature.format_for(role) + + def _block_size( + self, + *formats: TensorFormat | None, + ) -> list[int] | None: + for tensor_format in formats: + scale = tensor_format.scale if tensor_format is not None else None + if scale is not None and scale.block_shape is not None: + return list(scale.block_shape) + return None + + def _scale_for_format( + self, + tensor_format: TensorFormat | None, + role: str, + *, + M: int, + N: int, + K: int, + block_size: list[int] | None, + ) -> torch.Tensor | None: + scale = tensor_format.scale if tensor_format is not None else None + if scale is None: + return None + + if scale.granularity == "block" and scale.layout == "mxfp8": + block_n, block_k = block_size or [128, 128] + k_tiles = math.ceil(K / block_k) + if role == "a": + return self._generate_scales((M, k_tiles), scale.storage_dtype) + if role == "b": + n_tiles = math.ceil(N / block_n) + return self._generate_scales((n_tiles, k_tiles), scale.storage_dtype) + + if scale.granularity == "per_channel": + return self._generate_scales( + (M,) if role == "a" else (N,), + scale.storage_dtype, + ) + + return self._generate_scales((1,), scale.storage_dtype) + def generate( self, M: int, N: int, K: int, ) -> dict[str, Any]: - quant = self.traits.get("quant") - scale_type = self.traits.get("scale_type") a_layout = self.traits.get("a_layout") b_layout = self.traits.get("b_layout") + a_format = self._format("a") + b_format = self._format("b") + a_dtype = a_format.storage_dtype if a_format is not None else self.dtype + b_dtype = b_format.storage_dtype if b_format is not None else self.dtype A = ( - self._generate_value((K, M), self.dtype) + self._generate_value((K, M), a_dtype) if a_layout == {"KM"} - else self._generate_value((M, K), self.dtype) + else self._generate_value((M, K), a_dtype) ) B = ( - self._generate_value((K, N), self.dtype) + self._generate_value((K, N), b_dtype) if b_layout == {"KN"} - else self._generate_value((N, K), self.dtype) + else self._generate_value((N, K), b_dtype) ) - A_scales = None - B_scales = None - block_size = None - - if quant == {"mxfp8"}: - block_size = [128, 128] - k_tiles = math.ceil(K / block_size[0]) - n_tiles = math.ceil(N / block_size[1]) - A_scales = self._generate_scales((M, k_tiles), torch.float32) - B_scales = self._generate_scales((n_tiles, k_tiles), torch.float32) - else: - if scale_type == {"per_channel"}: - A_scales = self._generate_scales((M,), torch.float32) - B_scales = self._generate_scales((N,), torch.float32) - else: - A_scales = self._generate_scales((1,), torch.float32) - B_scales = self._generate_scales((1,), torch.float32) + block_size = self._block_size(a_format, b_format) + A_scales = self._scale_for_format( + a_format, + "a", + M=M, + N=N, + K=K, + block_size=block_size, + ) + B_scales = self._scale_for_format( + b_format, + "b", + M=M, + N=N, + K=K, + block_size=block_size, + ) out_dtype = torch.bfloat16 alpha = None diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/inputs.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/inputs.py index c8340a9a6..28f8815fe 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/inputs.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/inputs.py @@ -23,7 +23,7 @@ from typing import Any, Callable import torch -from tokenspeed_kernel.registry import KernelSpec +from tokenspeed_kernel.signature import FormatSignature __all__ = [ "InputGenerator", @@ -52,6 +52,7 @@ def __init__( dtype: torch.dtype, traits: dict, *, + format_signature: FormatSignature | None = None, device: str | None = None, seed: int = 42, ) -> None: @@ -59,6 +60,7 @@ def __init__( self.op_mode = op_mode self.dtype = dtype self.traits = traits + self.format_signature = format_signature self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") rng_device = "cuda" if self.device.startswith("cuda") else "cpu" @@ -98,6 +100,7 @@ def get_input_generator( dtype: torch.dtype, traits: dict, *, + format_signature: FormatSignature | None = None, device: str | None = None, seed: int = 42, ) -> InputGenerator: @@ -107,7 +110,15 @@ def get_input_generator( raise KeyError( f"No input generator registered for {op_family}.{op_mode}. Known: {known}" ) - return factory(op_family, op_mode, dtype, traits, device=device, seed=seed) + return factory( + op_family, + op_mode, + dtype, + traits, + format_signature=format_signature, + device=device, + seed=seed, + ) def get_standard_shapes(op_family: str, op_mode: str) -> list[dict[str, Any]]: diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py index 2b2c7f1d7..ee8709612 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py @@ -25,9 +25,36 @@ import torch import torch.nn.functional as F from tokenspeed_kernel.platform import Platform -from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + register_kernel, +) +from tokenspeed_kernel.signature import ( + ScaleFormat, + format_signatures, +) fp8_dtype = Platform.get().fp8e4m3fn.dtype +_FP8_BLOCK_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + block_shape=(128, 128), + layout="mxfp8", +) +_FP8_TENSOR_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="per_tensor", + layout="scaled", +) +_MXFP8_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "mxfp8", {fp8_dtype}, scale=_FP8_BLOCK_SCALE +) +_FP8_PER_TENSOR_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "fp8", {fp8_dtype}, scale=_FP8_TENSOR_SCALE +) +_DENSE_GEMM_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "dense", {torch.bfloat16, torch.float16, torch.float32} +) @register_kernel( @@ -35,8 +62,8 @@ "mm", name="torch_mm_fp8_blockscale", solution="reference", - dtypes={fp8_dtype}, - traits={"quant": frozenset({"mxfp8"})}, + signatures=_MXFP8_FORMAT_SIGNATURES, + traits={}, priority=Priority.PORTABLE + 2, tags={"portability"}, ) @@ -90,7 +117,7 @@ def torch_mm_fp8_blockscale( "mm", name="torch_mm_fp8_scaled_mnk", solution="reference", - dtypes={fp8_dtype}, + signatures=_FP8_PER_TENSOR_FORMAT_SIGNATURES, traits={ "b_layout": frozenset({"NK"}), }, @@ -132,7 +159,7 @@ def torch_mm_fp8_scaled_mnk( "mm", name="torch_mm_fp8_scaled_nkm", solution="reference", - dtypes={fp8_dtype}, + signatures=_FP8_PER_TENSOR_FORMAT_SIGNATURES, traits={ "b_layout": frozenset({"KN"}), }, @@ -172,7 +199,7 @@ def torch_mm_fp8_scaled_nkm( "mm", name="torch_mm", solution="reference", - dtypes={torch.bfloat16, torch.float16, torch.float32}, + signatures=_DENSE_GEMM_FORMAT_SIGNATURES, traits={}, priority=Priority.PORTABLE + 3, tags={"determinism", "portability"}, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py index 7905765d1..f630cb270 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py @@ -32,7 +32,15 @@ ExpertLocationDispatchInfo, topk_ids_logical_to_physical, ) -from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + register_kernel, +) +from tokenspeed_kernel.signature import ( + dense_format, + format_signature, + format_signatures, +) from tokenspeed_kernel.torch_compile import get_compiler_backend # --------------------------------------------------------------------------- @@ -46,10 +54,12 @@ name="reference_moe_fused", features={"pre_routed"}, solution="reference", - dtypes={torch.float16, torch.bfloat16, torch.float32}, + signatures=frozenset( + format_signature(x=dense_format(dtype), weight=dense_format(torch.bfloat16)) + for dtype in {torch.float16, torch.bfloat16, torch.float32} + ), priority=Priority.REFERENCE, traits={ - "weight_dtype": frozenset({"bf16", "fp16", "fp32"}), "tp": frozenset({False}), "ep": frozenset({False}), }, @@ -89,7 +99,9 @@ def fused_moe_forward_native( "route", name="torch_compile_fused_topk_bias", solution="reference", - dtypes={torch.float16, torch.bfloat16, torch.float32}, + signatures=format_signatures( + "logits", "dense", {torch.float16, torch.bfloat16, torch.float32} + ), traits={ "output_type": frozenset({"topk"}), "biased": frozenset({True}), @@ -131,7 +143,9 @@ def fused_topk_bias( "route", name="torch_native_fused_topk", solution="reference", - dtypes={torch.float16, torch.bfloat16, torch.float32}, + signatures=format_signatures( + "logits", "dense", {torch.float16, torch.bfloat16, torch.float32} + ), traits={ "output_type": frozenset({"topk"}), "biased": frozenset({True, False}), @@ -189,7 +203,9 @@ def _mask_topk_ids_padded_region( "route", name="torch_compile_grouped_topk", solution="reference", - dtypes={torch.float16, torch.bfloat16, torch.float32}, + signatures=format_signatures( + "logits", "dense", {torch.float16, torch.bfloat16, torch.float32} + ), traits={ "output_type": frozenset({"topk"}), "biased": frozenset({False}), @@ -270,7 +286,9 @@ def grouped_topk_gpu( "route", name="torch_compile_biased_grouped_topk", solution="reference", - dtypes={torch.float16, torch.bfloat16, torch.float32}, + signatures=format_signatures( + "logits", "dense", {torch.float16, torch.bfloat16, torch.float32} + ), traits={ "output_type": frozenset({"topk"}), "biased": frozenset({True}), @@ -366,7 +384,7 @@ def biased_grouped_topk_gpu( "align_block_size", name="torch_moe_align_block_size", solution="reference", - dtypes={torch.int32}, + signatures=format_signatures("indices", "dense", {torch.int32}), traits={}, priority=10, tags={"determinism", "portability"}, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/quantize.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/quantize.py index 9b3ddc453..9fcaa15ab 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/quantize.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/quantize.py @@ -49,7 +49,7 @@ def _quantize_fp8(x_fp32: torch.Tensor, max_abs: torch.Tensor) -> torch.Tensor: "fp8_token_group_128", name="torch_fp8_token_group_128", solution="reference", - dtypes={torch.bfloat16, torch.float16}, + signatures=format_signatures("x", "dense", {torch.bfloat16, torch.float16}), traits={}, priority=10, tags={"determinism", "portability"}, @@ -69,7 +69,7 @@ def torch_fp8_token_group_128(x: torch.Tensor) -> torch.Tensor: "fp8_token", name="torch_fp8_token", solution="reference", - dtypes={torch.bfloat16, torch.float16}, + signatures=format_signatures("x", "dense", {torch.bfloat16, torch.float16}), traits={}, priority=10, tags={"determinism", "portability"}, @@ -86,7 +86,7 @@ def torch_fp8_token(x: torch.Tensor) -> torch.Tensor: "fp8_tensor", name="torch_fp8_tensor", solution="reference", - dtypes={torch.bfloat16, torch.float16}, + signatures=format_signatures("x", "dense", {torch.bfloat16, torch.float16}), traits={}, priority=10, tags={"determinism", "portability"}, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py index c44c8749b..7636c051e 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py @@ -86,10 +86,16 @@ def verify_kernel( if kernel is None: raise ValueError(f"Kernel implementation for {kernel_name!r} is missing") + signature = spec.format_signature_for_storage_dtype(dtype) + if signature is None: + raise ValueError( + f"Kernel {kernel_name!r} does not support primary storage dtype={dtype}" + ) + ref_specs = registry.get_for_operator( spec.family, spec.mode, - dtype=dtype, + format_signature=signature, solution="reference", ) if not ref_specs: @@ -119,6 +125,7 @@ def verify_kernel( spec.mode, dtype=dtype, traits=spec.traits, + format_signature=signature, device=device, seed=seed, ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py index 5d21a41e9..1309f9494 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py @@ -32,9 +32,20 @@ from tokenspeed_kernel.ops.attention.flash_attn import mha_decode_scheduler_metadata from tokenspeed_kernel.profiling import ShapeCapture, kernel_scope from tokenspeed_kernel.selection import select_kernel +from tokenspeed_kernel.signature import ( + dense_format, + format_signature, +) AttentionResult = torch.Tensor | tuple[torch.Tensor, torch.Tensor | None] + +def _attention_format_signature(**roles: torch.Tensor): + return format_signature( + **{role: dense_format(tensor.dtype) for role, tensor in roles.items()} + ) + + __all__ = [ "mha_prefill", "mha_extend_with_kvcache", @@ -95,10 +106,11 @@ def mha_prefill( "support_sinks": sinks is not None, "return_lse": return_lse, } + signature = _attention_format_signature(q=q, k=k, v=v) kernel = select_kernel( "attention", "mha_prefill", - q.dtype, + signature, traits=traits, solution=solution, override=override, @@ -200,10 +212,11 @@ def mha_extend_with_kvcache( "support_sinks": sinks is not None, "return_lse": return_lse, } + signature = _attention_format_signature(q=q, k_cache=k_cache, v_cache=v_cache) kernel = select_kernel( "attention", "mha_extend_with_kvcache", - q.dtype, + signature, traits=traits, solution=solution, override=override, @@ -308,10 +321,11 @@ def mha_decode_with_kvcache( "support_sinks": sinks is not None, "return_lse": return_lse, } + signature = _attention_format_signature(q=q, k_cache=k_cache, v_cache=v_cache) kernel = select_kernel( "attention", "mha_decode_with_kvcache", - q.dtype, + signature, traits=traits, solution=solution, override=override, @@ -390,10 +404,11 @@ def mha_merge_state( traits = { "head_dim": out_a.shape[-1], } + signature = _attention_format_signature(out_a=out_a, out_b=out_b) kernel = select_kernel( "attention", "mha_merge_state", - out_a.dtype, + signature, traits=traits, solution=solution, override=override, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/cuda/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/cuda/__init__.py index 96e27b8ac..9f765d829 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/cuda/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/cuda/__init__.py @@ -6,7 +6,11 @@ CapabilityRequirement, current_platform, ) -from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures platform = current_platform() @@ -22,7 +26,9 @@ min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("out_a", "out_b"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED + 2, traits={}, tags={"throughput"}, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_attn/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_attn/__init__.py index dacc82004..b1305098e 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_attn/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_attn/__init__.py @@ -25,7 +25,12 @@ CapabilityRequirement, current_platform, ) -from tokenspeed_kernel.registry import Priority, error_fn, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + error_fn, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures __all__ = [ "flash_attn_func", @@ -79,7 +84,9 @@ min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED + 3, traits={ "head_dim": _FA4_BLACKWELL_PREFILL_HEAD_DIMS, @@ -130,7 +137,9 @@ def fa4_mha_prefill( min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED + 3, traits={ "head_dim": _FA4_BLACKWELL_DECODE_HEAD_DIMS, @@ -184,7 +193,9 @@ def fa4_mha_extend_with_kvcache( min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED + 3, traits={ "head_dim": _FA4_BLACKWELL_DECODE_HEAD_DIMS, @@ -245,7 +256,9 @@ def fa4_mha_decode_with_kvcache( min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED + 3, traits={ "sliding_window": frozenset({False, True}), @@ -294,7 +307,9 @@ def fa3_mha_prefill( min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED + 3, traits={ "sliding_window": frozenset({False, True}), @@ -350,7 +365,9 @@ def fa3_mha_extend_with_kvcache( min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED + 3, traits={ "sliding_window": frozenset({False, True}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flashinfer/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flashinfer/__init__.py index 61b13ceb4..1355d2f2d 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flashinfer/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flashinfer/__init__.py @@ -29,7 +29,13 @@ CapabilityRequirement, current_platform, ) -from tokenspeed_kernel.registry import ErrorClass, Priority, error_fn, register_kernel +from tokenspeed_kernel.registry import ( + ErrorClass, + Priority, + error_fn, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures platform = current_platform() @@ -170,7 +176,9 @@ def _get_paged_prefill_wrapper( min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED, traits={ "head_dim": frozenset({64, 128, 256}), @@ -233,7 +241,9 @@ def flashinfer_mha_prefill( min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED, traits={ "head_dim": frozenset({64, 128, 256}), @@ -340,7 +350,9 @@ def flashinfer_mha_extend_with_kvcache( min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED, traits={ "head_dim": frozenset({64, 128, 256}), @@ -414,7 +426,9 @@ def flashinfer_trtllm_mha_extend_with_kvcache( min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED, traits={ "sliding_window": frozenset({False, True}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_decode_fp16_gfx950.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_decode_fp16_gfx950.py index 8f8ba5963..49278cda7 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_decode_fp16_gfx950.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_decode_fp16_gfx950.py @@ -35,7 +35,11 @@ maximum, ) from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement -from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures cdna4 = gl.amd.cdna4 async_copy = cdna4.async_copy @@ -615,7 +619,9 @@ def get_config( max_arch_version=ArchVersion(9, 5), vendors=frozenset({"amd"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED, traits={ "head_dim": frozenset({64}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_prefill_fp16_gfx950.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_prefill_fp16_gfx950.py index 096728b6c..27b00da9a 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_prefill_fp16_gfx950.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_prefill_fp16_gfx950.py @@ -35,7 +35,11 @@ maximum, ) from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement -from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures cdna4 = gl.amd.cdna4 async_copy = cdna4.async_copy @@ -975,7 +979,9 @@ def get_config( max_arch_version=ArchVersion(9, 5), vendors=frozenset({"amd"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.SPECIALIZED, traits={ "head_dim": frozenset({64}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/__init__.py index 5e1eb2abe..179db22d4 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/__init__.py @@ -27,7 +27,11 @@ from tokenspeed_kernel.ops.attention.triton.mha_decode import decode_attention_fwd from tokenspeed_kernel.ops.attention.triton.mha_prefill import prefill_attention_fwd from tokenspeed_kernel.platform import CapabilityRequirement -from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures @triton.jit @@ -72,7 +76,9 @@ def mha_merge_state_kernel( name="triton_mha_prefill", solution="triton", capability=CapabilityRequirement(vendors=frozenset({"nvidia", "amd"})), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.PORTABLE, traits={ "sliding_window": frozenset({False, True}), @@ -138,7 +144,9 @@ def triton_mha_prefill( name="triton_mha_extend_with_kvcache", solution="triton", capability=CapabilityRequirement(vendors=frozenset({"nvidia", "amd"})), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.PORTABLE, traits={ "sliding_window": frozenset({False, True}), @@ -216,7 +224,9 @@ def triton_mha_extend_with_kvcache( name="triton_mha_decode_with_kvcache_cached", solution="triton", capability=CapabilityRequirement(vendors=frozenset({"nvidia", "amd"})), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.PORTABLE, traits={ "sliding_window": frozenset({False, True}), @@ -289,7 +299,9 @@ def triton_mha_decode_with_kvcache( name="triton_mha_merge_state", solution="triton", capability=CapabilityRequirement(vendors=frozenset({"nvidia", "amd"})), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("out_a", "out_b"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.PORTABLE, traits={}, tags={"portability"}, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py index b3096e91e..93d79f444 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py @@ -23,6 +23,10 @@ import torch from tokenspeed_kernel.profiling import ShapeCapture, kernel_scope from tokenspeed_kernel.selection import select_kernel +from tokenspeed_kernel.signature import ( + dense_format, + format_signature, +) @dataclass @@ -111,10 +115,14 @@ def apply_rope( "has_q_out": output_q_rope is not None, "has_k_out": output_k_rope is not None, } + signature = format_signature( + query=dense_format(query.dtype), + key=dense_format(key.dtype), + ) kernel = select_kernel( "embedding", "rope", - query.dtype, + signature, traits=traits, solution=solution, override=override, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/cuda.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/cuda.py index b75caf1dc..0f3ae6daa 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/cuda.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/cuda.py @@ -36,7 +36,9 @@ name="cuda_embedding_rope", solution="cuda", capability=CapabilityRequirement(vendors=frozenset({"nvidia"})), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("query", "key"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.PERFORMANT, traits={ "head_size": frozenset({64, 128, 256, 512}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/triton.py index d6d88533f..d895b319d 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/triton.py @@ -27,7 +27,11 @@ import torch from tokenspeed_kernel._triton import tl, triton from tokenspeed_kernel.platform import CapabilityRequirement -from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures def _next_power_of_2(n: int) -> int: @@ -361,7 +365,9 @@ def apply_rope_triton( name="triton_embedding_rope", solution="triton", capability=CapabilityRequirement(vendors=frozenset({"amd", "nvidia"})), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("query", "key"), "dense", {torch.float16, torch.bfloat16} + ), priority=Priority.PORTABLE, traits={ "partial_rotary": frozenset({True, False}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py index 938286678..6a7d58bf6 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py @@ -32,6 +32,12 @@ from tokenspeed_kernel.platform import Platform from tokenspeed_kernel.profiling import ShapeCapture, kernel_scope from tokenspeed_kernel.selection import select_kernel +from tokenspeed_kernel.signature import ( + ScaleFormat, + dense_format, + format_signature, + tensor_format, +) logger = logging.getLogger(__name__) @@ -70,6 +76,58 @@ def _infer_scale_type( return "per_channel" +def _scale_storage_dtype(*scales: torch.Tensor | None) -> torch.dtype: + for scale in scales: + if scale is not None: + return scale.dtype + return torch.float32 + + +def _gemm_format_signature( + A: torch.Tensor, + B: torch.Tensor, + A_scales: torch.Tensor | None, + B_scales: torch.Tensor | None, + out_dtype: torch.dtype, + quant: str | None, + block_size: list[int] | None, +): + _ = out_dtype + if quant == "mxfp8": + scale = ScaleFormat( + storage_dtype=_scale_storage_dtype(A_scales, B_scales), + granularity="block", + block_shape=tuple(block_size) if block_size is not None else None, + layout="mxfp8", + ) + a_storage_dtype = _fp8_dtype if A_scales is None else A.dtype + return format_signature( + a=tensor_format("mxfp8", a_storage_dtype, scale=scale), + b=tensor_format("mxfp8", B.dtype, scale=scale), + ) + if quant == "fp8": + scale = ScaleFormat( + storage_dtype=_scale_storage_dtype(A_scales, B_scales), + granularity=_infer_scale_type(A_scales, B_scales) or "unknown", + layout="scaled", + ) + return format_signature( + a=tensor_format("fp8", A.dtype, scale=scale), + b=tensor_format("fp8", B.dtype, scale=scale), + ) + if quant == "nvfp4": + scale = ScaleFormat( + storage_dtype=_scale_storage_dtype(A_scales, B_scales), + granularity="block", + layout="nvfp4", + ) + return format_signature( + a=tensor_format("nvfp4", A.dtype, scale=scale), + b=tensor_format("nvfp4", B.dtype, scale=scale), + ) + return format_signature(a=dense_format(A.dtype), b=dense_format(B.dtype)) + + def _online_quantize_mxfp8( A: torch.Tensor, block_size: list[int], @@ -184,13 +242,6 @@ def mm( M = A.shape[0] N = B.shape[-1] if B.shape[0] == K else B.shape[0] - if quant in ("mxfp8", "fp8"): - select_dtype = _fp8_dtype - elif quant == "nvfp4": - select_dtype = A.dtype - else: - select_dtype = A.dtype - traits: dict[str, object] = { "n_align_16": N % 16 == 0, "k_align_16": K % 16 == 0, @@ -199,18 +250,15 @@ def mm( "k_align_128": K % 128 == 0, } - if quant is not None: - traits["quant"] = quant - - if quant == "fp8": - scale_type = _infer_scale_type(A_scales, B_scales) - if scale_type is not None: - traits["scale_type"] = scale_type + signature = _gemm_format_signature( + A, B, A_scales, B_scales, out_dtype, quant, block_size + ) + select_dtype = signature.primary_storage_dtype() or A.dtype kernel = select_kernel( "gemm", "mm", - select_dtype, + signature, traits=traits, override=override, expected_kernel_name=expected_kernel_name, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py index 03a111989..18d353911 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py @@ -22,9 +22,25 @@ import torch from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement, Platform -from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + register_kernel, +) +from tokenspeed_kernel.signature import ( + ScaleFormat, + format_signatures, +) _fp8_dtype = Platform.get().fp8e4m3fn.dtype +_MXFP8_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + block_shape=(128, 128), + layout="mxfp8", +) +_MXFP8_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "mxfp8", {_fp8_dtype}, scale=_MXFP8_SCALE +) try: from tokenspeed_kernel.thirdparty.deep_gemm import ( @@ -54,9 +70,8 @@ min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={_fp8_dtype}, + signatures=_MXFP8_FORMAT_SIGNATURES, traits={ - "quant": frozenset({"mxfp8"}), "n_align_64": frozenset({True}), "k_align_128": frozenset({True}), }, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py index d68fe06a8..6cee58fb8 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py @@ -27,12 +27,37 @@ Platform, current_platform, ) -from tokenspeed_kernel.registry import Priority, error_fn, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + error_fn, + register_kernel, +) +from tokenspeed_kernel.signature import ( + ScaleFormat, + format_signatures, +) platform = current_platform() _fp8_dtype = Platform.get().fp8e4m3fn.dtype _fp4_dtypes: frozenset[torch.dtype] = frozenset({torch.uint8, torch.float4_e2m1fn_x2}) +_MXFP8_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + block_shape=(128, 128), + layout="mxfp8", +) +_NVFP4_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + layout="nvfp4", +) +_MXFP8_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "mxfp8", {_fp8_dtype}, scale=_MXFP8_SCALE +) +_NVFP4_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "nvfp4", _fp4_dtypes, scale=_NVFP4_SCALE +) # ---- FlashInfer block-scaled FP8 ---------------------------------------- @@ -59,9 +84,8 @@ min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes={_fp8_dtype}, + signatures=_MXFP8_FORMAT_SIGNATURES, traits={ - "quant": frozenset({"mxfp8"}), "n_align_128": frozenset({True}), "k_align_128": frozenset({True}), }, @@ -127,10 +151,8 @@ def flashinfer_mm_fp8_blockscale( min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes=_fp4_dtypes, - traits={ - "quant": frozenset({"nvfp4"}), - }, + signatures=_NVFP4_FORMAT_SIGNATURES, + traits={}, priority=Priority.SPECIALIZED + 2, ) def flashinfer_mm_nvfp4( diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py index d3f19ac5a..9d7348604 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py @@ -30,11 +30,40 @@ import torch from tokenspeed_kernel._triton import tl, triton from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement, Platform -from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + register_kernel, +) +from tokenspeed_kernel.signature import ( + ScaleFormat, + format_signatures, +) logger = logging.getLogger(__name__) _fp8_dtype = Platform.get().fp8e4m3fn.dtype +_MXFP8_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + block_shape=(128, 128), + layout="mxfp8", +) +_FP8_PER_TENSOR_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="per_tensor", + layout="scaled", +) +_FP8_PER_CHANNEL_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="per_channel", + layout="scaled", +) +_MXFP8_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "mxfp8", {_fp8_dtype}, scale=_MXFP8_SCALE +) +_FP8_SCALED_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "fp8", {_fp8_dtype}, scale=_FP8_PER_TENSOR_SCALE +) | format_signatures(("a", "b"), "fp8", {_fp8_dtype}, scale=_FP8_PER_CHANNEL_SCALE) def prepare_block_fp8_matmul_inputs( @@ -697,10 +726,8 @@ def triton_scaled_mm( min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes={_fp8_dtype}, - traits={ - "quant": frozenset({"mxfp8"}), - }, + signatures=_MXFP8_FORMAT_SIGNATURES, + traits={}, priority=Priority.PERFORMANT + 3, tags={"portability"}, ) @@ -740,7 +767,7 @@ def triton_mm_fp8_blockscale( min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes={_fp8_dtype}, + signatures=_FP8_SCALED_FORMAT_SIGNATURES, traits={ "b_layout": frozenset({"KN"}), }, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py index 195b1be12..9d57b51be 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py @@ -27,7 +27,14 @@ import torch from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement -from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + register_kernel, +) +from tokenspeed_kernel.signature import ( + ScaleFormat, + format_signatures, +) # Re-exported # dsv3_fused_a_gemm supports specific shapes only (see python/tokenspeed/runtime/models/deepseek_v3.py); @@ -35,6 +42,14 @@ from tokenspeed_kernel.thirdparty.trtllm import dsv3_fused_a_gemm # noqa: F401 _fp4_dtypes: frozenset[torch.dtype] = frozenset({torch.uint8, torch.float4_e2m1fn_x2}) +_NVFP4_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + layout="nvfp4", +) +_NVFP4_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "nvfp4", _fp4_dtypes, scale=_NVFP4_SCALE +) # One stateful torchbind instance per output dtype. Each holds its own # per-shape algo cache inside C++. @@ -59,10 +74,8 @@ def _get_runner(out_dtype: torch.dtype): min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes=_fp4_dtypes, - traits={ - "quant": frozenset({"nvfp4"}), - }, + signatures=_NVFP4_FORMAT_SIGNATURES, + traits={}, priority=Priority.SPECIALIZED + 3, ) def cublaslt_mm_nvfp4( diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py index d3eeccbcc..d5e8fd0b7 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py @@ -42,6 +42,12 @@ register_oracle, select_kernel, ) +from tokenspeed_kernel.signature import ( + ScaleFormat, + dense_format, + format_signature, + tensor_format, +) logger = logging.getLogger(__name__) @@ -88,7 +94,7 @@ def adjust(self, spec, platform, traits): "pre_routed" # caller provides topk_weights/topk_ids (marlin, cutlass, reference) ) -# Weight format trait values — used via traits={"weight_dtype": ...} +# Weight format values used by moe_fused(weight_format=...). WEIGHT_BF16 = "bf16" # dense bfloat16 weights WEIGHT_FP8 = "fp8" # FP8 block-scaled weights WEIGHT_MXFP4 = "mxfp4" # MXFP4 block-scaled weights @@ -116,6 +122,62 @@ def adjust(self, spec, platform, traits): ) +_FP8_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + layout="fp8", +) +_NVFP4_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + layout="nvfp4", +) +_MXFP4_SCALE = ScaleFormat( + storage_dtype=torch.uint8, + granularity="block", + block_shape=(32,), + layout="ue8m0", +) + + +def _single_dense_format_signature(role: str, storage_dtype: torch.dtype): + return format_signature(**{role: dense_format(storage_dtype)}) + + +def _moe_dispatch_format_signature(storage_dtype: torch.dtype, traits: Optional[dict]): + comm_strategy = (traits or {}).get("comm_strategy") + if storage_dtype == torch.int32 or comm_strategy == "local": + return _single_dense_format_signature("indices", storage_dtype) + return _single_dense_format_signature("x", storage_dtype) + + +def _moe_fused_format_signature( + storage_dtype: torch.dtype, + weight_format: str, +): + if weight_format == WEIGHT_FP8: + weight = tensor_format("fp8", torch.float8_e4m3fn, scale=_FP8_SCALE) + elif weight_format == WEIGHT_NVFP4: + weight = tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE) + elif weight_format == WEIGHT_MXFP4: + weight = tensor_format("mxfp4", torch.uint8, scale=_MXFP4_SCALE) + elif weight_format == WEIGHT_BF16: + weight = dense_format(torch.bfloat16) + else: + raise ValueError(f"Unsupported MoE fused weight_format={weight_format!r}") + + if storage_dtype == torch.uint8 and weight_format == WEIGHT_NVFP4: + x = tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE) + elif storage_dtype == torch.uint8 and weight_format == WEIGHT_MXFP4: + x = tensor_format("mxfp4", torch.uint8, scale=_MXFP4_SCALE) + elif storage_dtype == torch.float8_e4m3fn: + x = tensor_format("fp8", storage_dtype, scale=_FP8_SCALE) + else: + x = dense_format(storage_dtype) + + return format_signature(x=x, weight=weight) + + def moe_route( *args, dtype: torch.dtype = torch.bfloat16, @@ -134,10 +196,11 @@ def moe_route( * ``{"biased": True/False}``: whether correction_bias is applied. * ``{"grouped": True/False}``: whether grouped expert selection is used. """ + signature = _single_dense_format_signature("logits", dtype) kernel = select_kernel( "moe", "route", - dtype, + signature, features=frozenset(features) if features else None, traits=traits or {}, expected_kernel_name=expected_kernel_name, @@ -157,11 +220,13 @@ def moe_dispatch( Returns ``(sorted_token_ids, expert_ids, num_tokens_post_padded)``. """ + selection_traits = traits or {"comm_strategy": "local"} + signature = _moe_dispatch_format_signature(dtype, selection_traits) kernel = select_kernel( "moe", "dispatch", - dtype, - traits=traits or {"comm_strategy": "local"}, + signature, + traits=selection_traits, expected_kernel_name=expected_kernel_name, ) @@ -189,10 +254,11 @@ def moe_experts( * ``{"dispatch_gemm"}``: gather/dispatch tokens then GEMM (uses gather_indx). * ``{"gemm_combine"}``: GEMM then scatter/combine results (uses scatter_indx). """ + signature = _single_dense_format_signature("x", dtype) kernel = select_kernel( "moe", "experts", - dtype, + signature, features=frozenset(features) if features else None, traits=traits or {}, expected_kernel_name=expected_kernel_name, @@ -209,10 +275,11 @@ def moe_combine( **kwargs, ): """Combine expert outputs with weighted reduction.""" + signature = _single_dense_format_signature("x", dtype) kernel = select_kernel( "moe", "combine", - dtype, + signature, traits=traits or {}, expected_kernel_name=expected_kernel_name, ) @@ -224,6 +291,7 @@ def moe_fused( *args, dtype: torch.dtype = torch.bfloat16, features: Optional[Set[str]] = None, + weight_format: str = WEIGHT_BF16, traits: Optional[dict] = None, expected_kernel_name: Optional[str] = None, **kwargs, @@ -235,18 +303,18 @@ def moe_fused( * ``{"self_routing"}``: kernel does routing internally (trtllm). * ``{"pre_routed"}``: routing already done by caller (cutlass, reference). - Weight format traits (pass via ``traits``): - - * ``{"weight_dtype": "bf16"}``: dense bfloat16 weights. - * ``{"weight_dtype": "fp8"}``: FP8 block-scaled weights. - * ``{"weight_dtype": "mxfp4"}``: MXFP4 block-scaled weights. + Args: + weight_format: Weight tensor encoding. Supported values are ``"bf16"``, + ``"fp8"``, ``"mxfp4"``, and ``"nvfp4"``. """ + signature = _moe_fused_format_signature(dtype, weight_format) + selection_traits = dict(traits or {}) kernel = select_kernel( "moe", "fused", - dtype, + signature, features=frozenset(features) if features else None, - traits=traits or {}, + traits=selection_traits, expected_kernel_name=expected_kernel_name, ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/cuda.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/cuda.py index 462c4d849..6594d6b58 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/cuda.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/cuda.py @@ -28,7 +28,12 @@ from __future__ import annotations import torch -from tokenspeed_kernel.registry import Priority, error_fn, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + error_fn, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures try: from tokenspeed_kernel.thirdparty.cuda.activation import ( @@ -52,7 +57,9 @@ "route", name="cuda_routing_flash", solution="cuda", - dtypes={torch.float16, torch.bfloat16, torch.float32}, + signatures=format_signatures( + "logits", "dense", {torch.float16, torch.bfloat16, torch.float32} + ), traits={ "output_type": frozenset({"topk"}), "biased": frozenset({True}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/deepep.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/deepep.py index 746fbfba7..9d88fe12c 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/deepep.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/deepep.py @@ -24,7 +24,11 @@ import torch from tokenspeed_kernel._triton import tl, triton -from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures logger = logging.getLogger(__name__) @@ -277,7 +281,7 @@ def _fwd_kernel_ep_scatter_2( "dispatch", name="deepep_moe_scatter", solution="deep_ep", - dtypes={torch.float8_e4m3fn, torch.bfloat16}, + signatures=format_signatures("x", "dense", {torch.float8_e4m3fn, torch.bfloat16}), traits={ "comm_strategy": frozenset({"deep_ep"}), }, @@ -414,7 +418,7 @@ def _fwd_kernel_ep_gather( "combine", name="deepep_moe_gather", solution="deep_ep", - dtypes={torch.bfloat16}, + signatures=format_signatures("x", "dense", {torch.bfloat16}), traits={ "comm_strategy": frozenset({"deep_ep"}), }, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py index 48ecae3bd..b01c66479 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py @@ -22,10 +22,99 @@ import torch from tokenspeed_kernel.platform import current_platform -from tokenspeed_kernel.registry import Priority, error_fn, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + error_fn, + register_kernel, +) +from tokenspeed_kernel.signature import ( + ScaleFormat, + dense_format, + format_signature, + tensor_format, +) platform = current_platform() + +_FP8_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + layout="fp8", +) +_NVFP4_SCALE = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + layout="nvfp4", +) +_MXFP4_SCALE = ScaleFormat( + storage_dtype=torch.uint8, + granularity="block", + block_shape=(32,), + layout="ue8m0", +) +_BF16_FUSED_FORMAT_SIGNATURES = frozenset( + { + format_signature( + x=dense_format(torch.bfloat16), weight=dense_format(torch.bfloat16) + ) + } +) +_CUTLASS_FUSED_FORMAT_SIGNATURES = frozenset( + { + format_signature( + x=dense_format(torch.bfloat16), weight=dense_format(torch.bfloat16) + ), + format_signature( + x=dense_format(torch.float16), weight=dense_format(torch.bfloat16) + ), + format_signature( + x=dense_format(torch.bfloat16), + weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + ), + format_signature( + x=dense_format(torch.float16), + weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + ), + format_signature( + x=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + ), + format_signature( + x=tensor_format("fp8", torch.float8_e4m3fn, scale=_FP8_SCALE), + weight=tensor_format("fp8", torch.float8_e4m3fn, scale=_FP8_SCALE), + ), + } +) +_FP4_FUSED_FORMAT_SIGNATURES = frozenset( + { + format_signature( + x=dense_format(torch.bfloat16), + weight=tensor_format("mxfp4", torch.uint8, scale=_MXFP4_SCALE), + ), + format_signature( + x=dense_format(torch.bfloat16), + weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + ), + format_signature( + x=tensor_format("mxfp4", torch.uint8, scale=_MXFP4_SCALE), + weight=tensor_format("mxfp4", torch.uint8, scale=_MXFP4_SCALE), + ), + format_signature( + x=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + ), + } +) +_CUTEDSL_NVFP4_FORMAT_SIGNATURES = frozenset( + { + format_signature( + x=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), + ) + } +) + cutlass_fused_moe = error_fn fp4_quantize = error_fn mxfp8_quantize = error_fn @@ -101,8 +190,8 @@ name="flashinfer_trtllm_bf16_fused_moe", features={"self_routing"}, solution="trtllm", - dtypes={torch.bfloat16}, - traits={"weight_dtype": frozenset({"bf16"})}, + signatures=_BF16_FUSED_FORMAT_SIGNATURES, + traits={}, priority=Priority.SPECIALIZED, tags={"throughput"}, )(trtllm_bf16_moe) @@ -114,9 +203,8 @@ name="flashinfer_cutlass_fused_moe", features={"pre_routed"}, solution="flashinfer", - dtypes={torch.bfloat16, torch.float16, torch.uint8, torch.float8_e4m3fn}, + signatures=_CUTLASS_FUSED_FORMAT_SIGNATURES, traits={ - "weight_dtype": frozenset({"bf16", "fp8", "nvfp4"}), "tp": frozenset({True, False}), "ep": frozenset({True, False}), "cuda_graph": frozenset({False}), @@ -132,8 +220,8 @@ name="flashinfer_trtllm_fp4_fused_moe", features={"self_routing"}, solution="trtllm", - dtypes={torch.uint8, torch.bfloat16}, - traits={"weight_dtype": frozenset({"mxfp4", "nvfp4"})}, + signatures=_FP4_FUSED_FORMAT_SIGNATURES, + traits={}, priority=Priority.SPECIALIZED, tags={"throughput"}, )(trtllm_fp4_block_scale_moe) @@ -166,9 +254,8 @@ def _get_or_create_cutedsl_wrapper(wrapper_kwargs: dict) -> CuteDslMoEWrapper: name="flashinfer_cutedsl_nvfp4_fused_moe", features={"pre_routed"}, solution="cutedsl", - dtypes={torch.uint8}, + signatures=_CUTEDSL_NVFP4_FORMAT_SIGNATURES, traits={ - "weight_dtype": frozenset({"nvfp4"}), "tp": frozenset({False}), "ep": frozenset({True, False}), "cuda_graph": frozenset({True, False}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py index b54ecfe39..70078e8ab 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py @@ -32,7 +32,11 @@ from tokenspeed_kernel.ops.moe.expert_location_dispatch import ( ExpertLocationDispatchInfo, ) -from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures from tokenspeed_kernel.thirdparty.trtllm import ( moe_align_block_size as _moe_align_block_size, ) @@ -328,7 +332,7 @@ def _biased_grouped_topk_reference( "route", name="triton_minimax_biased_grouped_topk", solution="triton", - dtypes={torch.float32}, + signatures=format_signatures("logits", "dense", {torch.float32}), traits={ "output_type": frozenset({"topk"}), "biased": frozenset({True}), @@ -776,7 +780,9 @@ def _normalize_fp8_group_scale_layout( name="triton_moe_fused_experts", features={"dispatch_sorted"}, solution="triton", - dtypes={torch.float16, torch.bfloat16, torch.float8_e4m3fn}, + signatures=format_signatures( + "x", "dense", {torch.float16, torch.bfloat16, torch.float8_e4m3fn} + ), priority=Priority.PERFORMANT + 2, tags={"portability"}, ) @@ -978,7 +984,7 @@ def _moe_sum_reduce_kernel( "combine", name="triton_moe_sum_reduce", solution="triton", - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures("x", "dense", {torch.float16, torch.bfloat16}), traits={"comm_strategy": frozenset({None})}, priority=Priority.PERFORMANT + 2, tags={"portability"}, @@ -1023,7 +1029,7 @@ def moe_sum_reduce_triton( "combine", name="torch_compile_moe_sum_reduce", solution="reference", - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures("x", "dense", {torch.float16, torch.bfloat16}), traits={"comm_strategy": frozenset({None})}, priority=Priority.PORTABLE + 1, tags={"portability"}, @@ -1044,7 +1050,7 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor): "dispatch", name="triton_moe_align_block_size", solution="triton", - dtypes={torch.int32}, + signatures=format_signatures("indices", "dense", {torch.int32}), traits={ "comm_strategy": frozenset({"local"}), }, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton_kernels.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton_kernels.py index 35fea0a23..73a74b311 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton_kernels.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton_kernels.py @@ -28,7 +28,11 @@ import tokenspeed_kernel.thirdparty.triton_kernels # noqa: F401 import torch from tokenspeed_kernel.platform import current_platform -from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures try: import triton_kernels.matmul_details.opt_flags as opt_flags @@ -92,7 +96,9 @@ def _triton_kernels_routing(logits, n_expts_act, sm_first=False, dtype=None): "route", name="triton_kernels_routing", solution="triton", - dtypes={torch.float16, torch.bfloat16, torch.float32}, + signatures=format_signatures( + "logits", "dense", {torch.float16, torch.bfloat16, torch.float32} + ), traits={"output_type": frozenset({"ragged_metadata"})}, priority=Priority.PERFORMANT + 2, tags={"portability"}, @@ -175,7 +181,9 @@ def _matmul( _matmul_common = dict( solution="triton", - dtypes={torch.float16, torch.bfloat16, torch.uint8}, + signatures=format_signatures( + "x", "dense", {torch.float16, torch.bfloat16, torch.uint8} + ), priority=Priority.PERFORMANT + 2, tags={"portability"}, ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/trtllm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/trtllm.py index d0fad73b1..b25b85291 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/trtllm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/trtllm.py @@ -29,6 +29,7 @@ ) from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement from tokenspeed_kernel.registry import register_kernel +from tokenspeed_kernel.signature import format_signatures from tokenspeed_kernel.thirdparty.trtllm import ( moe_align_block_size as _trtllm_moe_align_block_size, ) @@ -43,7 +44,7 @@ min_arch_version=ArchVersion(9, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.int32}, + signatures=format_signatures("indices", "dense", {torch.int32}), traits={}, priority=12, tags={"latency"}, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py index 96325c328..1338ed55d 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py @@ -21,6 +21,10 @@ import torch from tokenspeed_kernel.profiling import ShapeCapture, kernel_scope from tokenspeed_kernel.selection import select_kernel +from tokenspeed_kernel.signature import ( + dense_format, + format_signature, +) __all__ = [ "quantize_fp8", @@ -67,10 +71,11 @@ def quantize_fp8( traits = { "has_scale": scale is not None, } + signature = format_signature(x=dense_format(x.dtype)) kernel = select_kernel( "quantization", "fp8", - x.dtype, + signature, traits=traits, solution=solution, override=override, @@ -146,10 +151,11 @@ def quantize_fp8_with_scale( "granularity": granularity_trait, "scale_encoding": scale_encoding, } + signature = format_signature(x=dense_format(x.dtype)) kernel = select_kernel( "quantization", "fp8_with_scale", - x.dtype, + signature, traits=traits, solution=solution, override=override, @@ -204,10 +210,11 @@ def quantize_mxfp8( """ traits = {} + signature = format_signature(x=dense_format(x.dtype)) kernel = select_kernel( "quantization", "mxfp8", - x.dtype, + signature, traits=traits, solution=solution, override=override, @@ -264,10 +271,11 @@ def quantize_nvfp4( "scale_layout": scale_layout, "has_scale": scale is not None, } + signature = format_signature(x=dense_format(x.dtype)) kernel = select_kernel( "quantization", "nvfp4", - x.dtype, + signature, traits=traits, solution=solution, override=override, @@ -335,10 +343,11 @@ def quantize_mxfp4( "has_global_scale": global_scale is not None, "scale_encoding": "ue8m0", } + signature = format_signature(x=dense_format(x.dtype)) kernel = select_kernel( "quantization", "mxfp4", - x.dtype, + signature, traits=traits, solution=solution, override=override, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/flashinfer.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/flashinfer.py index 6e8005365..99d4fe331 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/flashinfer.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/flashinfer.py @@ -22,7 +22,12 @@ CapabilityRequirement, current_platform, ) -from tokenspeed_kernel.registry import Priority, error_fn, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + error_fn, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures platform = current_platform() @@ -46,7 +51,7 @@ "mxfp8", name="flashinfer_quantize_mxfp8", solution="flashinfer", - dtypes={torch.bfloat16, torch.float16}, + signatures=format_signatures("x", "dense", {torch.bfloat16, torch.float16}), traits={}, priority=Priority.PERFORMANT, ) @@ -72,7 +77,7 @@ def flashinfer_quantize_mxfp8( min_arch_version=ArchVersion(10, 0), vendors=frozenset({"nvidia"}), ), - dtypes={torch.bfloat16, torch.float16}, + signatures=format_signatures("x", "dense", {torch.bfloat16, torch.float16}), traits={ "has_scale": frozenset({True}), }, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/triton.py index cbe4c3ce2..30ded5e34 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/triton.py @@ -18,7 +18,11 @@ import torch from tokenspeed_kernel._triton import tl, triton -from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures @triton.jit @@ -191,7 +195,7 @@ def fp8_quantize( "fp8", name="triton_quantize_fp8", solution="triton", - dtypes={torch.bfloat16, torch.float16}, + signatures=format_signatures("x", "dense", {torch.bfloat16, torch.float16}), traits={"has_scale": frozenset({True, False})}, priority=Priority.PORTABLE, ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/trtllm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/trtllm.py index b07865f11..bd6e9be08 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/trtllm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/trtllm.py @@ -21,7 +21,12 @@ import torch from tokenspeed_kernel.platform import current_platform -from tokenspeed_kernel.registry import Priority, error_fn, register_kernel +from tokenspeed_kernel.registry import ( + Priority, + error_fn, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures platform = current_platform() @@ -63,7 +68,7 @@ def trtllm_fp8_tensor(x: torch.Tensor) -> torch.Tensor: "fp8_with_scale", name="trtllm_quantize_fp8_with_scale", solution="trtllm", - dtypes={torch.bfloat16, torch.float16}, + signatures=format_signatures("x", "dense", {torch.bfloat16, torch.float16}), traits={ "granularity": frozenset({"tensor", "token", "token_group_128"}), "scale_encoding": frozenset({"float32", "ue8m0"}), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/plugins/README.md b/tokenspeed-kernel/python/tokenspeed_kernel/plugins/README.md index 76685a1c5..e931e5e6a 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/plugins/README.md +++ b/tokenspeed-kernel/python/tokenspeed_kernel/plugins/README.md @@ -43,6 +43,7 @@ my-kernels-plugin/ import torch from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement +from tokenspeed_kernel.signature import format_signatures from tokenspeed_kernel.registry import register_kernel @@ -53,7 +54,9 @@ def register() -> None: "attention", "decode", solution="my_custom", - dtypes={torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.bfloat16} + ), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(9, 0), @@ -110,10 +113,17 @@ points at all — call `register_kernel(...)` directly: ```python import torch +from tokenspeed_kernel.signature import format_signatures from tokenspeed_kernel.registry import register_kernel -@register_kernel("gemm", "mm", solution="experiment", dtypes={torch.bfloat16}, priority=15) +@register_kernel( + "gemm", + "mm", + solution="experiment", + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), + priority=15, +) def my_experimental_gemm(a, b, **kwargs): ... ``` diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/registry.py b/tokenspeed-kernel/python/tokenspeed_kernel/registry.py index 8b27d77e7..2dc84dfb6 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/registry.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/registry.py @@ -32,16 +32,19 @@ from tokenspeed_kernel.selection import SelectedKernel from tokenspeed_kernel.platform import CapabilityRequirement, PlatformInfo +from tokenspeed_kernel.signature import FormatSignature logger = logging.getLogger(__name__) __all__ = [ + "ErrorClass", "KernelSpec", "KernelRegistry", "Priority", "load_builtin_kernels", "register_kernel", "describe_kernel", + "error_fn", ] @@ -145,7 +148,7 @@ class KernelSpec: # Capabilities capability: CapabilityRequirement = field(default_factory=CapabilityRequirement) - dtypes: frozenset[torch.dtype] = frozenset() # Supported data types + format_signatures: frozenset[FormatSignature] = frozenset() # Op-specific traits, e.g. {"head_dim": frozenset({64, 128, 256}), "persistent": frozenset({True})} traits: dict[str, frozenset[Any]] = field(default_factory=dict) @@ -158,6 +161,28 @@ class KernelSpec: frozenset() ) # Standard tags: "throughput", "latency", "determinism", "portability" + def supports_format_signature(self, format_signature: FormatSignature) -> bool: + return format_signature in self.format_signatures + + def format_signature_for_storage_dtype( + self, + storage_dtype: torch.dtype, + ) -> FormatSignature | None: + for signature in sorted(self.format_signatures, key=str): + if signature.primary_storage_dtype() == storage_dtype: + return signature + return None + + def primary_storage_dtypes(self) -> frozenset[torch.dtype]: + return frozenset( + dtype + for dtype in ( + signature.primary_storage_dtype() + for signature in self.format_signatures + ) + if dtype is not None + ) + class KernelRegistry: """Central registry for all kernel implementations.""" @@ -223,7 +248,7 @@ def get_for_operator( *, features: frozenset[str] | None = None, platform: PlatformInfo | None = None, - dtype: torch.dtype | None = None, + format_signature: FormatSignature | None = None, tags: set[str] | None = None, solution: str | None = None, ) -> list[KernelSpec]: @@ -234,8 +259,8 @@ def get_for_operator( specs = [s for s in specs if features.issubset(s.features)] if platform: specs = [s for s in specs if s.capability.satisfied_by(platform)] - if dtype: - specs = [s for s in specs if dtype in s.dtypes] + if format_signature: + specs = [s for s in specs if s.supports_format_signature(format_signature)] if tags: specs = [s for s in specs if tags.issubset(s.tags)] if solution: @@ -293,7 +318,7 @@ def register_kernel( features: set[str] | None = None, solution: str, capability: CapabilityRequirement | None = None, - dtypes: set[torch.dtype], + signatures: set[FormatSignature] | frozenset[FormatSignature], traits: dict[str, frozenset[Any]] | None = None, priority: Priority | int = Priority.PERFORMANT + 2, tags: set[str] | None = None, @@ -307,6 +332,8 @@ def register_kernel( Example:: + from tokenspeed_kernel.signature import format_signatures + @register_kernel( "attention", "decode", features={"paged"}, @@ -315,7 +342,9 @@ def register_kernel( min_arch_version=ArchVersion(10, 0), required_features=frozenset({"tensor_core:f8"}), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} + ), # Narrowly gated on SM100 + tcgen05 → SPECIALIZED band. priority=Priority.SPECIALIZED + 1, tags={"latency", "determinism"}, @@ -334,7 +363,7 @@ def decorator(fn: Callable) -> Callable: mode=mode, solution=solution, features=frozenset(features or set()), - dtypes=frozenset(dtypes), + format_signatures=frozenset(signatures), capability=capability or CapabilityRequirement(), traits=traits or {}, priority=priority_int, @@ -362,7 +391,8 @@ def describe_kernel(name: str) -> str: f" Operator: {spec.family}.{spec.mode}", f" Solution: {spec.solution}", f" Priority: {spec.priority} ({band_str})", - f" Dtypes: {', '.join(str(d) for d in spec.dtypes)}", + " Format signatures: " + + ("; ".join(str(p) for p in spec.format_signatures) or "none"), f" Platform: {spec.capability}", f" Tags: {', '.join(spec.tags) or 'none'}", ] diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/selection.py b/tokenspeed-kernel/python/tokenspeed_kernel/selection.py index 01a6e1ba3..28e390172 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/selection.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/selection.py @@ -28,9 +28,12 @@ from pathlib import Path from typing import Any, Callable, Generator -import torch from tokenspeed_kernel.platform import PlatformInfo, current_platform -from tokenspeed_kernel.registry import KernelRegistry, KernelSpec +from tokenspeed_kernel.registry import ( + KernelRegistry, + KernelSpec, +) +from tokenspeed_kernel.signature import FormatSignature logger = logging.getLogger(__name__) @@ -313,7 +316,7 @@ def _get_config_override(family: str, mode: str) -> _ConfigOverrideEntry | None: def _make_cache_key( family: str, mode: str, - dtype: object, + format_signature: FormatSignature, arch: str, objective: SelectionObjective, features: frozenset[str] | None, @@ -323,7 +326,16 @@ def _make_cache_key( """Build a hashable cache key including selection-relevant traits.""" traits_key = tuple(sorted(traits.items())) if traits else () mods_key = frozenset(features) if features else frozenset() - return (family, mode, dtype, arch, objective, mods_key, traits_key, solution) + return ( + family, + mode, + format_signature, + arch, + objective, + mods_key, + traits_key, + solution, + ) def _score_priority(spec: KernelSpec) -> int: @@ -478,7 +490,7 @@ def _resolve_override( registry: KernelRegistry, family: str, mode: str, - dtype: object, + format_signature: object, override: str, platform: PlatformInfo, ) -> SelectedKernel: @@ -494,14 +506,14 @@ def _resolve_override( return SelectedKernel(name=kernel_name, impl=impl) raise NoKernelFoundError( - f"Override '{override}' not found for {family}.{mode} ({dtype})" + f"Override '{override}' not found for {family}.{mode} ({format_signature})" ) def _log_selection( family: str, mode: str, - dtype: object, + format_signature: object, winner: KernelSpec, scored: list[tuple[KernelSpec, ScoreBreakdown]], platform: PlatformInfo, @@ -517,7 +529,7 @@ def _log_selection( "[tokenspeed_kernel] %s.%s(%s) -> %s (%s, %s)", family, mode, - dtype, + format_signature, winner.name, breakdown, platform.arch, @@ -527,7 +539,7 @@ def _log_selection( "[tokenspeed_kernel] %s.%s(%s) -> %s (%s)", family, mode, - dtype, + format_signature, winner.name, platform.arch, ) @@ -536,7 +548,7 @@ def _log_selection( def select_kernel( family: str, mode: str, - dtype: torch.dtype, + format_signature: FormatSignature, *, features: frozenset[str] | None = None, platform: PlatformInfo | None = None, @@ -548,21 +560,21 @@ def select_kernel( ) -> SelectedKernel: """Select the best kernel for an operation. - On first call for a given (family, mode, dtype, platform, objective, traits, + On first call for a given (family, mode, format_signature, platform, objective, traits, solution) combination, runs the full selection pipeline. Subsequent calls with the same arguments return the cached result — a single dict lookup. Args: family: Operator family (e.g., "attention") mode: Operator mode (e.g., "decode") - dtype: Primary data type + format_signature: Role-indexed tensor format signature features: Required operator features (e.g., {"paged"}) platform: Hardware to match (auto-detected if None) objective: Selection objective (see SelectionObjective enum) traits: Op-specific trait values that affect kernel applicability (e.g., {"head_dim": 128, "num_kv_heads": 8}) solution: Restrict selection to a registered solution while preserving - normal platform, dtype, and trait filtering. + normal platform, format signature, and trait filtering. override: Force a specific kernel name or solution string expected_kernel_name: Debug-only hint. When set, a warning is logged if the selected kernel differs from this name. The @@ -609,7 +621,14 @@ def select_kernel( # Fast path: check cache (skipped when override is active) cache_key = _make_cache_key( - family, mode, dtype, platform.arch, objective, features, traits, solution + family, + mode, + format_signature, + platform.arch, + objective, + features, + traits, + solution, ) if override is None: cached = registry.cache_get(cache_key) @@ -617,7 +636,9 @@ def select_kernel( return cached if override: - return _resolve_override(registry, family, mode, dtype, override, platform) + return _resolve_override( + registry, family, mode, format_signature, override, platform + ) # Get candidates (same filtering for both strategies) candidates = registry.get_for_operator( @@ -625,14 +646,14 @@ def select_kernel( mode, features=features, platform=platform, - dtype=dtype, + format_signature=format_signature, solution=solution, ) solution_clause = f" with solution {solution!r}" if solution else "" if not candidates: raise NoKernelFoundError( - f"No kernel found for {family}.{mode} ({dtype})" + f"No kernel found for {family}.{mode} ({format_signature})" f"{solution_clause} on {platform.device_name}" ) @@ -641,7 +662,7 @@ def select_kernel( if not candidates: raise NoKernelFoundError( - f"No kernel found for {family}.{mode} ({dtype})" + f"No kernel found for {family}.{mode} ({format_signature})" f"{solution_clause} with traits {traits} on {platform.device_name}" ) @@ -653,7 +674,7 @@ def select_kernel( candidates, family, mode, - dtype, + format_signature, platform, traits, _policy.autotune_params, @@ -662,7 +683,7 @@ def select_kernel( scored = _rank_by_objective(candidates, objective, platform, traits) winner = scored[0][0] - _log_selection(family, mode, dtype, winner, scored, platform, objective) + _log_selection(family, mode, format_signature, winner, scored, platform, objective) if expected_kernel_name and winner.name != expected_kernel_name: logger.warning( @@ -670,7 +691,7 @@ def select_kernel( "expected '%s'. Score breakdown — selected: %s", family, mode, - dtype, + format_signature, winner.name, expected_kernel_name, next((s for sp, s in scored if sp.name == winner.name), "N/A"), @@ -686,7 +707,7 @@ def _autotune_select( candidates: list[KernelSpec], family: str, mode: str, - dtype: object, + format_signature: object, platform: PlatformInfo, traits: dict[str, Any] | None, params: AutotuneParams, @@ -704,7 +725,7 @@ def _autotune_select( "[tokenspeed_kernel:autotune] falling back to heuristic for %s.%s(%s)", family, mode, - dtype, + format_signature, ) return winner, scored @@ -729,7 +750,7 @@ def kernel_override( def explain_selection( family: str, mode: str, - dtype: torch.dtype, + format_signature: FormatSignature, *, features: frozenset[str] | None = None, platform: PlatformInfo | None = None, @@ -764,7 +785,7 @@ def explain_selection( mode, features=features, platform=platform, - dtype=dtype, + format_signature=format_signature, solution=solution, ) @@ -777,7 +798,7 @@ def explain_selection( filtered_out = [s for s in all_specs if s.name not in filtered_names] lines = [ - f"Op: {family}.{mode} ({dtype})", + f"Op: {family}.{mode} ({format_signature})", f"Platform: {platform.device_name} ({platform.arch})", f"Solution: {solution or 'any'}", f"Objective: {objective.value}", @@ -813,10 +834,12 @@ def explain_selection( f"arch mismatch (requires " f"{spec.capability.min_arch_version})" ) - if dtype and dtype not in spec.dtypes: + if format_signature and not spec.supports_format_signature( + format_signature + ): reasons.append( - f"dtype mismatch (supports " - f"{', '.join(str(d) for d in spec.dtypes)})" + f"format signature mismatch (supports " + f"{', '.join(str(d) for d in spec.format_signatures)})" ) if solution and spec.solution != solution: reasons.append(f"solution mismatch (is {spec.solution!r})") @@ -827,7 +850,7 @@ def explain_selection( def warmup_selection( - ops: list[tuple[str, str, torch.dtype, dict | None]] | None = None, + ops: list[tuple[str, str, FormatSignature, dict | None]] | None = None, ) -> None: """Pre-resolve kernel selection for the given ops (or all registered ops). @@ -836,18 +859,22 @@ def warmup_selection( """ if ops is None: - ops = [ - (f, m, torch.bfloat16, None) - for f, m in KernelRegistry.get().list_operators() - ] - - for family, mode, dtype, traits in ops: + ops = [] + registry = KernelRegistry.get() + for family, mode in registry.list_operators(): + specs = registry.get_for_operator(family, mode) + if not specs or not specs[0].format_signatures: + continue + format_signature = sorted(specs[0].format_signatures, key=str)[0] + ops.append((family, mode, format_signature, None)) + + for family, mode, format_signature, traits in ops: try: - select_kernel(family, mode, dtype, traits=traits) + select_kernel(family, mode, format_signature, traits=traits) except NoKernelFoundError: logger.debug( "[tokenspeed_kernel] warmup: no kernel for %s.%s(%s)", family, mode, - dtype, + format_signature, ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/signature.py b/tokenspeed-kernel/python/tokenspeed_kernel/signature.py new file mode 100644 index 000000000..a8cb30b03 --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/signature.py @@ -0,0 +1,202 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Iterable + +if TYPE_CHECKING: + import torch + +__all__ = [ + "FormatSignature", + "ScaleFormat", + "TensorFormat", + "dense_format", + "format_signature", + "tensor_format", + "format_signatures", +] + + +@dataclass(frozen=True) +class ScaleFormat: + """Representation metadata for a tensor scale sidecar. + + Args: + storage_dtype: Physical dtype used by the scale tensor. + granularity: Scale granularity, such as "per_tensor", + "per_channel", or "block". + block_shape: Logical block shape covered by each scale value when + granularity is block-based. + layout: Optional backend or format layout name for the scale tensor. + """ + + storage_dtype: torch.dtype + granularity: str + block_shape: tuple[int, ...] | None = None + layout: str | None = None + + def __str__(self) -> str: + parts = [self.granularity, f"storage={self.storage_dtype}"] + if self.block_shape is not None: + parts.append(f"block={self.block_shape}") + if self.layout is not None: + parts.append(f"layout={self.layout}") + return "scale(" + ", ".join(parts) + ")" + + +@dataclass(frozen=True) +class TensorFormat: + """Storage representation for one logical tensor role. + + Args: + storage_dtype: Physical dtype used by the main tensor payload. + format: Logical representation format, such as "dense", + "fp8", "mxfp4", or "nvfp4". + scale: Optional scale sidecar metadata bundled with this tensor role. + """ + + storage_dtype: torch.dtype + format: str = "dense" + scale: ScaleFormat | None = None + + def __str__(self) -> str: + if self.scale is None: + return f"{self.format}[storage={self.storage_dtype}]" + return f"{self.format}[storage={self.storage_dtype}, {self.scale}]" + + +@dataclass(frozen=True) +class FormatSignature: + """Role-indexed tensor formats for one supported operand-format combination.""" + + roles: tuple[tuple[str, TensorFormat], ...] + + def __post_init__(self) -> None: + seen: set[str] = set() + normalized: list[tuple[str, TensorFormat]] = [] + for role, tensor_format in sorted(self.roles, key=lambda item: item[0]): + if role in seen: + raise ValueError(f"duplicate format role {role!r}") + seen.add(role) + normalized.append((role, tensor_format)) + object.__setattr__(self, "roles", tuple(normalized)) + + def format_for(self, role: str) -> TensorFormat | None: + """Return the format for role, or None if it is absent.""" + for role_name, tensor_format in self.roles: + if role_name == role: + return tensor_format + return None + + def primary_storage_dtype(self) -> torch.dtype | None: + """Return the dtype used by dtype-oriented tooling and filters. + + The registry and selection path use the full signature. This helper is + only for compatibility with user-facing commands that still ask for a + single dtype filter, such as numerics and benchmark CLIs. + """ + preferred_roles = ( + "q", + "a", + "x", + "input", + "logits", + "out_a", + "hidden", + ) + for role in preferred_roles: + tensor_format = self.format_for(role) + if tensor_format is not None: + return tensor_format.storage_dtype + return self.roles[0][1].storage_dtype if self.roles else None + + def __str__(self) -> str: + return ( + ", ".join(f"{role}={tensor_format}" for role, tensor_format in self.roles) + or "none" + ) + + +def tensor_format( + format: str, + storage_dtype: torch.dtype, + *, + scale: ScaleFormat | None = None, +) -> TensorFormat: + """Construct a format for one tensor role. + + Args: + format: Logical representation format, such as "dense", "fp8", + "mxfp4", or "nvfp4". + storage_dtype: Physical dtype used by the main tensor payload. + scale: Optional scale sidecar metadata bundled with this tensor role. + """ + return TensorFormat(storage_dtype=storage_dtype, format=format, scale=scale) + + +def dense_format(storage_dtype: torch.dtype) -> TensorFormat: + """Construct a dense, unscaled tensor format for storage_dtype.""" + return tensor_format("dense", storage_dtype) + + +def format_signature(**roles: TensorFormat) -> FormatSignature: + """Construct a role-indexed format signature. + + Keyword names are logical tensor roles for the operator, for example + a/b for GEMM or q/k_cache/v_cache for attention. + Values are the exact formats required for those roles. + """ + return FormatSignature(tuple(roles.items())) + + +def format_signatures( + roles: str | Iterable[str], + format: str, + storage_dtypes: Iterable[torch.dtype], + *, + scale: ScaleFormat | None = None, +) -> frozenset[FormatSignature]: + """Construct same-format signatures for each dtype. + + Args: + roles: Logical tensor roles. Pass a string for one role or an iterable + for multiple roles. + format: Logical representation format assigned to every role. + storage_dtypes: Physical dtypes used by every role, one signature per + dtype. + scale: Optional scale sidecar metadata assigned to every role. + + Use ``format="dense"`` for dense same-format signatures. Use + ``format_signature`` directly for mixed-role combinations such as dense + activations with quantized weights. + """ + normalized_roles = (roles,) if isinstance(roles, str) else tuple(roles) + return frozenset( + format_signature( + **{ + role: tensor_format(format, storage_dtype, scale=scale) + for role in normalized_roles + } + ) + for storage_dtype in storage_dtypes + ) diff --git a/tokenspeed-kernel/test/conftest.py b/tokenspeed-kernel/test/conftest.py index 3796ef632..fc7491e99 100644 --- a/tokenspeed-kernel/test/conftest.py +++ b/tokenspeed-kernel/test/conftest.py @@ -54,13 +54,16 @@ def _require( ) -> None: from tokenspeed_kernel.platform import current_platform - specs = KernelRegistry.get().get_for_operator( - family, - mode, - platform=current_platform(), - dtype=dtype, - solution=solution, - ) + specs = [ + spec + for spec in KernelRegistry.get().get_for_operator( + family, + mode, + platform=current_platform(), + solution=solution, + ) + if spec.format_signature_for_storage_dtype(dtype) is not None + ] if not specs: pytest.skip(f"{family}.{mode} solution {solution!r} is not registered") diff --git a/tokenspeed-kernel/test/test_benchmark.py b/tokenspeed-kernel/test/test_benchmark.py index 19eccf70f..82d97649e 100644 --- a/tokenspeed-kernel/test/test_benchmark.py +++ b/tokenspeed-kernel/test/test_benchmark.py @@ -32,7 +32,11 @@ from tokenspeed_kernel.benchmark.throughput import ThroughputCalculator from tokenspeed_kernel.platform import Platform from tokenspeed_kernel.profiling import ProfilingConfig -from tokenspeed_kernel.registry import KernelRegistry, KernelSpec +from tokenspeed_kernel.registry import ( + KernelRegistry, + KernelSpec, +) +from tokenspeed_kernel.signature import format_signatures pytestmark = [ pytest.mark.usefixtures("fresh_registry"), @@ -66,7 +70,7 @@ def _register_test_gemm_kernels() -> None: family="gemm", mode="mm", solution="reference", - dtypes=frozenset({dtype}), + format_signatures=format_signatures(("a", "b"), "dense", {dtype}), traits=_TEST_GEMM_TRAITS, priority=0, ) @@ -75,7 +79,7 @@ def _register_test_gemm_kernels() -> None: family="gemm", mode="mm", solution="triton", - dtypes=frozenset({dtype}), + format_signatures=format_signatures(("a", "b"), "dense", {dtype}), traits=_TEST_GEMM_TRAITS, priority=10, ) diff --git a/tokenspeed-kernel/test/test_kernel_api_selection.py b/tokenspeed-kernel/test/test_kernel_api_selection.py index f1d5488c8..6e306145e 100644 --- a/tokenspeed-kernel/test/test_kernel_api_selection.py +++ b/tokenspeed-kernel/test/test_kernel_api_selection.py @@ -158,6 +158,29 @@ def _mm_mxfp8() -> torch.Tensor: ) +def test_gemm_mxfp8_online_activation_signature_uses_quantized_storage() -> None: + a = torch.empty((4, 128), dtype=torch.bfloat16) + b = torch.empty((128, 128), dtype=_fp8_dtype()) + b_scales = torch.empty((1, 1), dtype=torch.float32) + + signature = _gemm_pkg._gemm_format_signature( + a, + b, + None, + b_scales, + torch.bfloat16, + "mxfp8", + [128, 128], + ) + + a_format = signature.format_for("a") + b_format = signature.format_for("b") + assert a_format is not None + assert b_format is not None + assert a_format.storage_dtype == _fp8_dtype() + assert b_format.storage_dtype == _fp8_dtype() + + def _mm_nvfp4() -> torch.Tensor: a = torch.empty((4, 64), dtype=torch.uint8) b = torch.empty((128, 64), dtype=torch.uint8) @@ -324,7 +347,7 @@ def _moe_fused_self_routing_bf16() -> object: return tokenspeed_kernel.moe_fused( dtype=torch.bfloat16, features={"self_routing"}, - traits={"weight_dtype": "bf16"}, + weight_format="bf16", ) @@ -332,7 +355,7 @@ def _moe_fused_self_routing_mxfp4() -> object: return tokenspeed_kernel.moe_fused( dtype=torch.bfloat16, features={"self_routing"}, - traits={"weight_dtype": "mxfp4"}, + weight_format="mxfp4", ) @@ -340,8 +363,8 @@ def _moe_fused_prerouted_nvfp4_cutlass() -> object: return tokenspeed_kernel.moe_fused( dtype=torch.bfloat16, features={"pre_routed"}, + weight_format="nvfp4", traits={ - "weight_dtype": "nvfp4", "tp": True, "ep": True, "cuda_graph": False, @@ -353,8 +376,8 @@ def _moe_fused_prerouted_nvfp4_cutedsl() -> object: return tokenspeed_kernel.moe_fused( dtype=torch.uint8, features={"pre_routed"}, + weight_format="nvfp4", traits={ - "weight_dtype": "nvfp4", "tp": False, "ep": True, "cuda_graph": True, @@ -366,7 +389,8 @@ def _moe_fused_prerouted_bf16_reference() -> object: return tokenspeed_kernel.moe_fused( dtype=torch.bfloat16, features={"pre_routed"}, - traits={"weight_dtype": "bf16", "tp": False, "ep": False}, + weight_format="bf16", + traits={"tp": False, "ep": False}, ) diff --git a/tokenspeed-kernel/test/test_numerics.py b/tokenspeed-kernel/test/test_numerics.py index 4e4f9b2d4..f256c7cac 100644 --- a/tokenspeed-kernel/test/test_numerics.py +++ b/tokenspeed-kernel/test/test_numerics.py @@ -23,10 +23,19 @@ import pytest import torch from tokenspeed_kernel.numerics.comparison import compare_outputs, format_comparison +from tokenspeed_kernel.numerics.inputs import get_input_generator from tokenspeed_kernel.numerics.tolerance import Tolerance from tokenspeed_kernel.numerics.verify import verify_kernel from tokenspeed_kernel.platform import Platform -from tokenspeed_kernel.registry import KernelRegistry, KernelSpec, load_builtin_kernels +from tokenspeed_kernel.registry import ( + KernelRegistry, + KernelSpec, + load_builtin_kernels, +) +from tokenspeed_kernel.signature import ( + ScaleFormat, + format_signatures, +) _fp8_dtype = Platform.get().fp8e4m3fn.dtype @@ -59,6 +68,36 @@ def test_treats_inf_as_mismatch(self) -> None: assert result.num_mismatches == 1 +def test_gemm_input_generator_uses_signature_scale_metadata() -> None: + scale = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + block_shape=(128, 128), + layout="mxfp8", + ) + signature = next( + iter(format_signatures(("a", "b"), "mxfp8", {_fp8_dtype}, scale=scale)) + ) + generator = get_input_generator( + "gemm", + "mm", + dtype=_fp8_dtype, + traits={}, + format_signature=signature, + device="cpu", + ) + + inputs = generator.generate(M=4, N=256, K=128) + + assert inputs["A"].dtype == _fp8_dtype + assert inputs["B"].dtype == _fp8_dtype + assert inputs["A_scales"].shape == (4, 1) + assert inputs["B_scales"].shape == (2, 1) + assert inputs["A_scales"].dtype == torch.float32 + assert inputs["B_scales"].dtype == torch.float32 + assert inputs["block_size"] == [128, 128] + + class TestNumericsVerification: def _get_verifiable_specs( dtype: torch.dtype, family: str | None = None @@ -72,13 +111,16 @@ def _get_verifiable_specs( continue # Only run kernels that have a paired reference for this dtype; # otherwise verify_kernel raises ValueError and the test errors. - has_reference = any( - s.solution == "reference" - for s in registry.get_for_operator(family_name, mode, dtype=dtype) - ) + op_specs = registry.get_for_operator(family_name, mode) + dtype_specs = [ + s + for s in op_specs + if s.format_signature_for_storage_dtype(dtype) is not None + ] + has_reference = any(s.solution == "reference" for s in dtype_specs) if not has_reference: continue - for spec in registry.get_for_operator(family_name, mode, dtype=dtype): + for spec in dtype_specs: if spec.solution == "reference": continue if spec.solution == "deep_gemm": diff --git a/tokenspeed-kernel/test/test_plugins.py b/tokenspeed-kernel/test/test_plugins.py index 867aca3cc..8b18ced64 100644 --- a/tokenspeed-kernel/test/test_plugins.py +++ b/tokenspeed-kernel/test/test_plugins.py @@ -37,7 +37,11 @@ reset_plugins, ) from tokenspeed_kernel.plugins.cli import main as cli_main -from tokenspeed_kernel.registry import KernelRegistry, register_kernel +from tokenspeed_kernel.registry import ( + KernelRegistry, + register_kernel, +) +from tokenspeed_kernel.signature import format_signatures # --------------------------------------------------------------------------- # Fixtures and helpers @@ -104,7 +108,7 @@ def _make_register( mode: str = "mm", solution: str | None = None, priority: int = 12, - dtypes=None, + storage_dtypes=None, capability: CapabilityRequirement | None = None, ) -> Callable[[], None]: sol = solution or name @@ -115,7 +119,9 @@ def register(): mode, name=name, solution=sol, - dtypes=dtypes or {torch.bfloat16}, + signatures=format_signatures( + ("a", "b"), "dense", storage_dtypes or {torch.bfloat16} + ), priority=priority, capability=capability, ) @@ -147,7 +153,12 @@ class TestExplicitRegistration: """The decorator path is what plugins use; test it works standalone.""" def test_register_kernel_decorator(self, fresh_plugins): - @register_kernel("gemm", "mm", solution="manual", dtypes={torch.bfloat16}) + @register_kernel( + "gemm", + "mm", + solution="manual", + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), + ) def impl(a, b): return a @ b @@ -202,7 +213,7 @@ def register(): "mm", name=f"k_{name}", solution=name, - dtypes={torch.bfloat16}, + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), ) def impl(): return name @@ -232,7 +243,7 @@ def register(): "mm", name="once_kernel", solution="once", - dtypes={torch.bfloat16}, + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), ) def impl(): return None @@ -253,7 +264,7 @@ def register(): "mm", name="force_kernel", solution="force", - dtypes={torch.bfloat16}, + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), ) def impl(): return None @@ -274,7 +285,7 @@ def test_higher_priority_plugin_wins(self, fresh_plugins, patch_entry_points): "mm", name="builtin_gemm", solution="builtin", - dtypes={torch.bfloat16}, + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), priority=10, ) def builtin(a, b): @@ -363,7 +374,7 @@ def test_equal_priority_emits_warning(self, fresh_plugins, patch_entry_points): "mm", name="builtin_eq", solution="builtin", - dtypes={torch.bfloat16}, + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), priority=14, ) def b(a, x): diff --git a/tokenspeed-kernel/test/test_registry.py b/tokenspeed-kernel/test/test_registry.py index 7bf0c0567..ce242dcc1 100644 --- a/tokenspeed-kernel/test/test_registry.py +++ b/tokenspeed-kernel/test/test_registry.py @@ -29,6 +29,13 @@ describe_kernel, register_kernel, ) +from tokenspeed_kernel.signature import ( + ScaleFormat, + dense_format, + format_signature, + format_signatures, + tensor_format, +) from utils import dummy_impl, register_all_samples pytestmark = pytest.mark.usefixtures("fresh_registry") @@ -46,13 +53,32 @@ def test_default_values(self): assert spec.solution == "" assert spec.priority == 10 assert spec.tags == frozenset() - assert spec.dtypes == frozenset() + assert spec.format_signatures == frozenset() def test_hashable_without_dict_traits(self): spec = KernelSpec(name="k1", family="attention", mode="decode", traits={}) with pytest.raises(TypeError): hash(spec) + def test_format_signature_bundles_scale_metadata(self): + scale = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + block_shape=(32,), + layout="ue8m0", + ) + mixed = format_signature( + a=dense_format(torch.bfloat16), + b=tensor_format("mxfp4", torch.uint8, scale=scale), + ) + dense = format_signature( + a=dense_format(torch.bfloat16), + b=dense_format(torch.uint8), + ) + + assert mixed != dense + assert mixed.format_for("b").scale == scale + def test_equality(self): spec1 = KernelSpec(name="k1", family="attention", mode="decode") spec2 = KernelSpec(name="k1", family="attention", mode="decode") @@ -167,11 +193,16 @@ def test_filter_by_platform( assert "aiter_decode" in mi350_names assert "triton_decode" in mi350_names - def test_filter_by_dtype(self, sample_specs): + def test_filter_by_signature(self, sample_specs): reg = KernelRegistry.get() register_all_samples(reg, sample_specs) - fp32 = reg.get_for_operator("attention", "decode", dtype=torch.float32) + signature = next( + iter( + format_signatures(("q", "k_cache", "v_cache"), "dense", {torch.float32}) + ) + ) + fp32 = reg.get_for_operator("attention", "decode", format_signature=signature) names = {s.name for s in fp32} assert "reference_decode" in names assert "flashinfer_decode" not in names @@ -287,7 +318,7 @@ def test_basic_decorator(self): "gemm", "mm", solution="reference", - dtypes={torch.bfloat16}, + signatures=format_signatures(("a", "b"), "dense", {torch.bfloat16}), priority=12, ) def my_torch_gemm(a, b): @@ -298,7 +329,10 @@ def my_torch_gemm(a, b): assert spec is not None assert spec.solution == "reference" assert spec.priority == 12 - assert torch.bfloat16 in spec.dtypes + assert ( + next(iter(format_signatures(("a", "b"), "dense", {torch.bfloat16}))) + in spec.format_signatures + ) impl = reg.get_impl("reference_gemm_mm") assert impl is my_torch_gemm @@ -309,7 +343,9 @@ def test_custom_name(self): "decode", name="my_custom_kernel", solution="custom", - dtypes={torch.float16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16} + ), ) def some_func(): pass @@ -326,7 +362,9 @@ def test_decorator_with_features_and_tags(self): capability=CapabilityRequirement( min_arch_version=ArchVersion(8, 0), ), - dtypes={torch.float16, torch.bfloat16}, + signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), tags={"determinism", "latency"}, ) def decorated_kernel(): @@ -344,7 +382,7 @@ def test_decorator_returns_original_function(self): "gemm", "mm", solution="test", - dtypes={torch.float16}, + signatures=format_signatures(("a", "b"), "dense", {torch.float16}), ) def original(x): return x * 2 diff --git a/tokenspeed-kernel/test/test_runtime_callsite_selection.py b/tokenspeed-kernel/test/test_runtime_callsite_selection.py index 1adcbd6e7..7a2be75d6 100644 --- a/tokenspeed-kernel/test/test_runtime_callsite_selection.py +++ b/tokenspeed-kernel/test/test_runtime_callsite_selection.py @@ -47,6 +47,7 @@ import torch from tokenspeed_kernel.registry import KernelRegistry from tokenspeed_kernel.selection import select_kernel +from tokenspeed_kernel.signature import FormatSignature # -- Pre-import so they can be reloaded into the fresh registry. -- @@ -143,8 +144,10 @@ def _extract_features(node: ast.AST) -> Optional[set[str]]: return None -# A ``CallSite`` tuple: (family, mode, dtype|None, features|None, traits, expected_name, location) -CallSite = tuple[str, str, Optional[torch.dtype], Optional[set], dict, str, str] +# A ``CallSite`` tuple: (family, mode, dtype|None, features|None, traits, weight_format, expected_name, location) +CallSite = tuple[ + str, str, Optional[torch.dtype], Optional[set], dict, Optional[str], str, str +] def _collect_call_sites(search_dir: Path) -> list[CallSite]: @@ -204,9 +207,6 @@ def _collect_call_sites(search_dir: Path) -> list[CallSite]: features = _extract_features(feat_node) # -- traits -- - # For moe_* the traits come from the ``traits=`` kwarg. - # For mm() there is no ``traits=`` kwarg; instead ``quant=`` - # is the key selection-relevant parameter. traits: dict[str, Any] = {} traits_node = kwargs.get("traits") if traits_node is not None: @@ -214,18 +214,29 @@ def _collect_call_sites(search_dir: Path) -> list[CallSite]: if parsed is not None: traits = parsed - if family == "gemm": - quant_node = kwargs.get("quant") - if quant_node is not None: - qval, qok = _try_literal(quant_node) - if qok and isinstance(qval, str): - traits["quant"] = qval + weight_format: Optional[str] = None + weight_format_node = kwargs.get("weight_format") + if weight_format_node is not None: + value, ok = _try_literal(weight_format_node) + if ok and isinstance(value, str): + weight_format = value lineno = getattr(node, "lineno", 0) rel = py_path.relative_to(search_dir.parent.parent.parent) location = f"{rel}:{lineno}" - sites.append((family, mode, dtype, features, traits, expected, location)) + sites.append( + ( + family, + mode, + dtype, + features, + traits, + weight_format, + expected, + location, + ) + ) return sites @@ -251,6 +262,7 @@ def _collect_call_sites(search_dir: Path) -> list[CallSite]: torch.bfloat16, {"dispatch_sorted"}, {}, + None, "triton_moe_fused_experts", "manual:triton_common/experts", ), @@ -261,6 +273,7 @@ def _collect_call_sites(search_dir: Path) -> list[CallSite]: torch.bfloat16, None, {"num_tokens": 128, "comm_strategy": None}, + None, "triton_moe_sum_reduce", "manual:triton_common/combine_large", ), @@ -270,6 +283,7 @@ def _collect_call_sites(search_dir: Path) -> list[CallSite]: torch.bfloat16, None, {"num_tokens": 8, "comm_strategy": None}, + None, "torch_compile_moe_sum_reduce", "manual:triton_common/combine_small", ), @@ -289,24 +303,36 @@ def _collect_call_sites(search_dir: Path) -> list[CallSite]: def _infer_dtype(expected_name: str) -> torch.dtype: - """Pick a compatible dtype from the kernel's registered spec. - - Prefers common dtypes over exotic ones (e.g. ``float4_e2m1fn_x2``) - to avoid hitting platform capability gaps in tests. - """ spec = KernelRegistry.get().get_by_name(expected_name) - if spec is not None and spec.dtypes: + if spec is not None: + storage_dtypes = spec.primary_storage_dtypes() for dt in _DTYPE_PREFERENCE: - if dt in spec.dtypes: + if dt in storage_dtypes: return dt - return next(iter(spec.dtypes)) + if storage_dtypes: + return next(iter(storage_dtypes)) return torch.bfloat16 +def _infer_format_signature( + family: str, + mode: str, + dtype: torch.dtype, + weight_format: Optional[str], + spec, +) -> FormatSignature: + if family == "moe" and mode == "fused": + return _moe_pkg._moe_fused_format_signature(dtype, weight_format or "bf16") + signature = spec.format_signature_for_storage_dtype(dtype) + if signature is None: + pytest.skip(f"Kernel {spec.name!r} has no signature for {dtype}") + return signature + + def _site_id(site: CallSite) -> str: """Generate a readable test-id from a call-site tuple.""" - expected = site[5] - location = site[6] + expected = site[6] + location = site[7] return f"{expected}@{location}" @@ -330,7 +356,7 @@ def _site_id(site: CallSite) -> str: ) def test_kernel_selection(site, platform_name, request): platform = request.getfixturevalue(platform_name) - family, mode, raw_dtype, features, traits, expected, location = site + family, mode, raw_dtype, features, traits, weight_format, expected, location = site reg = KernelRegistry.get() spec = reg.get_by_name(expected) @@ -343,11 +369,12 @@ def test_kernel_selection(site, platform_name, request): ) dtype = raw_dtype or _infer_dtype(expected) + signature = _infer_format_signature(family, mode, dtype, weight_format, spec) result = select_kernel( family, mode, - dtype, + signature, features=frozenset(features) if features else None, traits=traits, platform=platform, @@ -355,5 +382,5 @@ def test_kernel_selection(site, platform_name, request): assert result.name == expected, ( f"Expected '{expected}' but got '{result.name}' " f"at {location} on {platform.device_name} " - f"for {family}.{mode}(dtype={dtype}, features={features}, traits={traits})" + f"for {family}.{mode}(signature={signature}, features={features}, traits={traits})" ) diff --git a/tokenspeed-kernel/test/test_selection.py b/tokenspeed-kernel/test/test_selection.py index 217112174..e5314fe5a 100644 --- a/tokenspeed-kernel/test/test_selection.py +++ b/tokenspeed-kernel/test/test_selection.py @@ -27,7 +27,10 @@ import tokenspeed_kernel.ops.gemm as gemm import torch from tokenspeed_kernel.platform import PlatformInfo -from tokenspeed_kernel.registry import KernelRegistry, KernelSpec +from tokenspeed_kernel.registry import ( + KernelRegistry, + KernelSpec, +) from tokenspeed_kernel.selection import ( AutotuneParams, NoKernelFoundError, @@ -55,10 +58,23 @@ spec_matches_traits, warmup_selection, ) +from tokenspeed_kernel.signature import ( + format_signatures, +) from utils import register_all_samples pytestmark = pytest.mark.usefixtures("fresh_registry") +ATTN_DECODE_BF16 = next( + iter(format_signatures(("q", "k_cache", "v_cache"), "dense", {torch.bfloat16})) +) +ATTN_PREFILL_BF16 = next( + iter(format_signatures(("q", "k", "v"), "dense", {torch.bfloat16})) +) +GEMM_BF16 = next(iter(format_signatures(("a", "b"), "dense", {torch.bfloat16}))) +GEMM_FP16 = next(iter(format_signatures(("a", "b"), "dense", {torch.float16}))) +INPUT_BF16 = next(iter(format_signatures("input", "dense", {torch.bfloat16}))) + class TestSelectionObjective: def test_all_enum_values(self): @@ -192,7 +208,7 @@ def test_rank_orders_lexicographically(self, sample_specs, h100_platform): "attention", "decode", platform=h100_platform, - dtype=torch.bfloat16, + format_signature=ATTN_DECODE_BF16, ) scored = _rank_by_objective( candidates, @@ -390,7 +406,7 @@ def test_non_alignment_traits_do_not_affect_shape_matching(self): name="k", family="f", mode="m", - traits={"quant": frozenset({"mxfp8"})}, + traits={"persistent": frozenset({True})}, ) assert spec_matches_shape_traits(spec, {"N": 30, "K": 70}) @@ -401,7 +417,7 @@ def test_deterministic(self): k1 = _make_cache_key( "attn", "dec", - torch.bfloat16, + INPUT_BF16, "sm_90", SelectionObjective.DEFAULT, None, @@ -410,7 +426,7 @@ def test_deterministic(self): k2 = _make_cache_key( "attn", "dec", - torch.bfloat16, + INPUT_BF16, "sm_90", SelectionObjective.DEFAULT, None, @@ -422,7 +438,7 @@ def test_different_objective(self): k1 = _make_cache_key( "attn", "dec", - torch.bfloat16, + INPUT_BF16, "sm_90", SelectionObjective.DEFAULT, None, @@ -431,7 +447,7 @@ def test_different_objective(self): k2 = _make_cache_key( "attn", "dec", - torch.bfloat16, + INPUT_BF16, "sm_90", SelectionObjective.LATENCY, None, @@ -443,7 +459,7 @@ def test_traits_order_independent(self): k1 = _make_cache_key( "a", "d", - torch.float16, + GEMM_FP16, "sm_90", SelectionObjective.DEFAULT, None, @@ -452,7 +468,7 @@ def test_traits_order_independent(self): k2 = _make_cache_key( "a", "d", - torch.float16, + GEMM_FP16, "sm_90", SelectionObjective.DEFAULT, None, @@ -464,10 +480,10 @@ def test_features_order_independent(self): f1 = frozenset({"paged", "mla"}) f2 = frozenset({"mla", "paged"}) k1 = _make_cache_key( - "a", "d", torch.float16, "sm_90", SelectionObjective.DEFAULT, f1, None + "a", "d", GEMM_FP16, "sm_90", SelectionObjective.DEFAULT, f1, None ) k2 = _make_cache_key( - "a", "d", torch.float16, "sm_90", SelectionObjective.DEFAULT, f2, None + "a", "d", GEMM_FP16, "sm_90", SelectionObjective.DEFAULT, f2, None ) assert k1 == k2 @@ -475,7 +491,7 @@ def test_solution_is_selection_relevant(self): k1 = _make_cache_key( "a", "d", - torch.float16, + GEMM_FP16, "sm_90", SelectionObjective.DEFAULT, None, @@ -485,7 +501,7 @@ def test_solution_is_selection_relevant(self): k2 = _make_cache_key( "a", "d", - torch.float16, + GEMM_FP16, "sm_90", SelectionObjective.DEFAULT, None, @@ -503,7 +519,7 @@ def test_basic_selection(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert callable(impl) @@ -513,16 +529,16 @@ def test_cached_on_second_call(self, sample_specs, h100_platform): register_all_samples(reg, sample_specs) impl1 = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) impl2 = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl1 is impl2 def test_no_kernel_raises(self, h100_platform): with pytest.raises(NoKernelFoundError): - select_kernel("nonexistent", "op", torch.bfloat16, platform=h100_platform) + select_kernel("nonexistent", "op", INPUT_BF16, platform=h100_platform) def test_no_kernel_after_trait_filter(self, h100_platform): reg = KernelRegistry.get() @@ -532,7 +548,7 @@ def test_no_kernel_after_trait_filter(self, h100_platform): mode="m", solution="triton", priority=10, - dtypes=frozenset({torch.bfloat16}), + format_signatures=frozenset({INPUT_BF16}), traits={"head_dim": frozenset({64})}, ) reg.register(spec, lambda: None) @@ -541,7 +557,7 @@ def test_no_kernel_after_trait_filter(self, h100_platform): select_kernel( "trait_op", "m", - torch.bfloat16, + INPUT_BF16, platform=h100_platform, traits={"head_dim": 128}, ) @@ -553,7 +569,7 @@ def test_override_by_name(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, override="reference_decode", ) @@ -566,7 +582,7 @@ def test_override_by_solution(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, override="triton", ) @@ -580,7 +596,7 @@ def test_solution_filter_preserves_trait_filtering(self, h100_platform): family="attention", mode="prefill", solution="fa4", - dtypes=frozenset({torch.bfloat16}), + format_signatures=frozenset({ATTN_PREFILL_BF16}), traits={"head_dim": frozenset({128})}, priority=15, ), @@ -592,7 +608,7 @@ def test_solution_filter_preserves_trait_filtering(self, h100_platform): family="attention", mode="prefill", solution="triton", - dtypes=frozenset({torch.bfloat16}), + format_signatures=frozenset({ATTN_PREFILL_BF16}), traits={"head_dim": frozenset({256})}, priority=10, ), @@ -602,7 +618,7 @@ def test_solution_filter_preserves_trait_filtering(self, h100_platform): impl = select_kernel( "attention", "prefill", - torch.bfloat16, + ATTN_PREFILL_BF16, platform=h100_platform, solution="fa4", traits={"head_dim": 128}, @@ -613,7 +629,7 @@ def test_solution_filter_preserves_trait_filtering(self, h100_platform): select_kernel( "attention", "prefill", - torch.bfloat16, + ATTN_PREFILL_BF16, platform=h100_platform, solution="fa4", traits={"head_dim": 256}, @@ -627,7 +643,7 @@ def test_override_not_found_raises(self, sample_specs, h100_platform): select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, override="nonexistent_kernel", ) @@ -643,7 +659,7 @@ def test_env_override(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert impl() == "reference_decode" @@ -655,7 +671,7 @@ def test_portability_objective_prefers_triton(self, sample_specs, h100_platform) impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, objective=SelectionObjective.PORTABILITY, ) @@ -669,7 +685,7 @@ def test_debug_objective_prefers_reference(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, objective=SelectionObjective.DEBUG, ) @@ -682,7 +698,7 @@ def test_amd_platform_selects_aiter(self, sample_specs, mi300_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=mi300_platform, ) assert impl() == "aiter_decode" @@ -694,7 +710,7 @@ def test_amd_mi350_platform_selects_aiter(self, sample_specs, mi350_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=mi350_platform, ) assert impl() == "aiter_decode" @@ -721,7 +737,7 @@ def adjust(self, spec, platform, traits): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert impl() == "triton_decode" @@ -736,7 +752,7 @@ def test_context_manager_overrides(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert impl() == "reference_decode" @@ -751,7 +767,7 @@ def test_context_manager_restores(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert impl() != "reference_decode" or True @@ -762,18 +778,18 @@ def test_nested_override(self, sample_specs, h100_platform): with kernel_override("attention", "decode", "reference_decode"): impl1 = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl1() == "reference_decode" with kernel_override("attention", "decode", "triton_decode"): impl2 = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl2() == "triton_decode" impl3 = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl3() == "reference_decode" @@ -783,7 +799,7 @@ def test_set_policy_clears_cache(self, sample_specs, h100_platform): reg = KernelRegistry.get() register_all_samples(reg, sample_specs) - select_kernel("attention", "decode", torch.bfloat16, platform=h100_platform) + select_kernel("attention", "decode", ATTN_DECODE_BF16, platform=h100_platform) set_selection_policy( SelectionPolicy(default_strategy=SelectionStrategy.AUTOTUNE) ) @@ -798,7 +814,7 @@ def test_output_contains_expected_sections(self, sample_specs, h100_platform): explanation = explain_selection( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert "attention.decode" in explanation @@ -813,7 +829,7 @@ def test_filtered_out_section(self, sample_specs, h100_platform): explanation = explain_selection( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert "Filtered out" in explanation @@ -823,7 +839,7 @@ def test_empty_candidates(self, h100_platform): explanation = explain_selection( "nonexistent", "op", - torch.bfloat16, + INPUT_BF16, platform=h100_platform, ) assert "0 matched" in explanation @@ -853,8 +869,8 @@ def test_warmup_explicit_ops(self, sample_specs, h100_platform): try: warmup_selection( ops=[ - ("attention", "decode", torch.bfloat16, None), - ("gemm", "mm", torch.bfloat16, None), + ("attention", "decode", ATTN_DECODE_BF16, None), + ("gemm", "mm", GEMM_BF16, None), ] ) assert len(reg._selection_cache) >= 2 @@ -867,7 +883,7 @@ def test_warmup_skips_missing_ops(self, h100_platform): Platform.override(h100_platform) try: - warmup_selection(ops=[("nonexistent", "op", torch.bfloat16, None)]) + warmup_selection(ops=[("nonexistent", "op", INPUT_BF16, None)]) finally: Platform.reset() @@ -885,7 +901,7 @@ def test_autotune_falls_back_to_heuristic(self, sample_specs, h100_platform): impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, ) assert callable(impl) @@ -1054,7 +1070,7 @@ def test_config_override_by_name(self, sample_specs, h100_platform, tmp_path): ) impl = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl() == "reference_decode" @@ -1068,7 +1084,7 @@ def test_config_override_by_solution(self, sample_specs, h100_platform, tmp_path ) impl = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl() == "triton_decode" @@ -1084,7 +1100,7 @@ def test_config_override_objective(self, sample_specs, h100_platform, tmp_path): ) impl = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl() == "reference_decode" @@ -1103,7 +1119,7 @@ def test_api_override_takes_priority_over_config( impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, override="triton_decode", ) @@ -1126,7 +1142,7 @@ def test_env_var_override_takes_priority_over_config( {"TOKENSPEED_KERNEL_OVERRIDE_ATTENTION_DECODE": "triton_decode"}, ): impl = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl() == "triton_decode" @@ -1144,7 +1160,7 @@ def test_context_manager_override_takes_priority_over_config( with kernel_override("attention", "decode", "triton_decode"): impl = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert impl() == "triton_decode" @@ -1163,7 +1179,7 @@ def test_explicit_objective_takes_priority_over_config( impl = select_kernel( "attention", "decode", - torch.bfloat16, + ATTN_DECODE_BF16, platform=h100_platform, objective=SelectionObjective.PORTABILITY, ) @@ -1182,7 +1198,9 @@ def test_config_override_not_found_raises( ) with pytest.raises(NoKernelFoundError, match="Override"): - select_kernel("attention", "decode", torch.bfloat16, platform=h100_platform) + select_kernel( + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform + ) def test_config_with_invalid_objective_falls_back( self, sample_specs, h100_platform, tmp_path @@ -1199,7 +1217,7 @@ def test_config_with_invalid_objective_falls_back( ) impl = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert callable(impl) @@ -1214,11 +1232,11 @@ def test_unrelated_ops_unaffected(self, sample_specs, h100_platform, tmp_path): ) attn_impl = select_kernel( - "attention", "decode", torch.bfloat16, platform=h100_platform + "attention", "decode", ATTN_DECODE_BF16, platform=h100_platform ) assert attn_impl() == "reference_decode" - gemm_impl = select_kernel("gemm", "mm", torch.bfloat16, platform=h100_platform) + gemm_impl = select_kernel("gemm", "mm", GEMM_BF16, platform=h100_platform) assert gemm_impl() != "reference_decode" @@ -1248,7 +1266,7 @@ def _register_kernel(name: str, solution: str, impl) -> None: family="gemm", mode="mm", solution=solution, - dtypes=frozenset({torch.float16}), + format_signatures=frozenset({GEMM_FP16}), priority=50, ) KernelRegistry.get().register(spec, impl) diff --git a/tokenspeed-kernel/test/utils.py b/tokenspeed-kernel/test/utils.py index ec2bd0ae6..95bd0a668 100644 --- a/tokenspeed-kernel/test/utils.py +++ b/tokenspeed-kernel/test/utils.py @@ -24,7 +24,11 @@ import torch from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement -from tokenspeed_kernel.registry import KernelRegistry, KernelSpec +from tokenspeed_kernel.registry import ( + KernelRegistry, + KernelSpec, +) +from tokenspeed_kernel.signature import format_signatures def dummy_impl(name: str) -> Callable: @@ -45,7 +49,9 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: mode="decode", solution="flashinfer", features=frozenset({"paged"}), - dtypes=frozenset({torch.float16, torch.bfloat16}), + format_signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(8, 0), @@ -63,7 +69,9 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: mode="decode", solution="triton", features=frozenset({"paged"}), - dtypes=frozenset({torch.float16, torch.bfloat16}), + format_signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), priority=10, tags=frozenset({"portability"}), ), @@ -76,7 +84,9 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: family="attention", mode="prefill", solution="cutlass", - dtypes=frozenset({torch.float16, torch.bfloat16}), + format_signatures=format_signatures( + ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} + ), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(9, 0), @@ -94,7 +104,11 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: mode="decode", solution="reference", features=frozenset({"paged"}), - dtypes=frozenset({torch.float16, torch.bfloat16, torch.float32}), + format_signatures=format_signatures( + ("q", "k_cache", "v_cache"), + "dense", + {torch.float16, torch.bfloat16, torch.float32}, + ), capability=CapabilityRequirement(), priority=10, tags=frozenset({"determinism", "portability"}), @@ -109,7 +123,9 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: mode="decode", solution="aiter", features=frozenset({"paged"}), - dtypes=frozenset({torch.float16, torch.bfloat16}), + format_signatures=format_signatures( + ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} + ), capability=CapabilityRequirement( vendors=frozenset({"amd"}), ), @@ -125,7 +141,9 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: family="gemm", mode="mm", solution="cutlass", - dtypes=frozenset({torch.float16, torch.bfloat16}), + format_signatures=format_signatures( + ("a", "b"), "dense", {torch.float16, torch.bfloat16} + ), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(8, 0), @@ -142,7 +160,9 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: family="gemm", mode="mm", solution="triton", - dtypes=frozenset({torch.float16, torch.bfloat16}), + format_signatures=format_signatures( + ("a", "b"), "dense", {torch.float16, torch.bfloat16} + ), priority=10, tags=frozenset({"portability"}), ), @@ -155,7 +175,9 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: family="gemm", mode="grouped_mm", solution="cutlass", - dtypes=frozenset({torch.float16, torch.bfloat16}), + format_signatures=format_signatures( + ("a", "b"), "dense", {torch.float16, torch.bfloat16} + ), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(9, 0), @@ -172,7 +194,9 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: family="gemm", mode="grouped_mm", solution="triton", - dtypes=frozenset({torch.float16, torch.bfloat16}), + format_signatures=format_signatures( + ("a", "b"), "dense", {torch.float16, torch.bfloat16} + ), priority=10, tags=frozenset({"portability"}), ), @@ -185,7 +209,9 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: family="moe", mode="fused", solution="triton", - dtypes=frozenset({torch.float16, torch.bfloat16}), + format_signatures=format_signatures( + ("x", "weight"), "dense", {torch.float16, torch.bfloat16} + ), priority=12, tags=frozenset({"throughput", "portability"}), ), @@ -198,7 +224,9 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: family="moe", mode="fused", solution="cutlass", - dtypes=frozenset({torch.float16, torch.bfloat16}), + format_signatures=format_signatures( + ("x", "weight"), "dense", {torch.float16, torch.bfloat16} + ), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(9, 0), @@ -215,7 +243,9 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: family="moe", mode="modular", solution="triton", - dtypes=frozenset({torch.float16, torch.bfloat16}), + format_signatures=format_signatures( + "x", "dense", {torch.float16, torch.bfloat16} + ), priority=10, tags=frozenset({"determinism", "portability"}), ), @@ -228,7 +258,9 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: family="moe", mode="modular", solution="cutlass", - dtypes=frozenset({torch.float16, torch.bfloat16}), + format_signatures=format_signatures( + "x", "dense", {torch.float16, torch.bfloat16} + ), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(8, 0), From 1c5b9672b8bcba88e5f0691441ba2b8bfb7b8262 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 23 May 2026 20:25:57 +0000 Subject: [PATCH 02/16] Restore registry import formatting Keep registry imports aligned with main where the imported symbol set did not change. This leaves only the functional KernelSpec import removal in numerics/inputs.py as a semantic change. Signed-off-by: Lei Zhang --- .../python/tokenspeed_kernel/numerics/reference/gemm.py | 5 +---- .../python/tokenspeed_kernel/numerics/reference/moe.py | 5 +---- .../tokenspeed_kernel/ops/attention/cuda/__init__.py | 5 +---- .../tokenspeed_kernel/ops/attention/flash_attn/__init__.py | 6 +----- .../tokenspeed_kernel/ops/attention/flashinfer/__init__.py | 7 +------ .../ops/attention/gluon/mha_decode_fp16_gfx950.py | 5 +---- .../ops/attention/gluon/mha_prefill_fp16_gfx950.py | 5 +---- .../tokenspeed_kernel/ops/attention/triton/__init__.py | 5 +---- .../python/tokenspeed_kernel/ops/embedding/triton.py | 5 +---- .../python/tokenspeed_kernel/ops/gemm/deep_gemm.py | 5 +---- .../python/tokenspeed_kernel/ops/gemm/flashinfer.py | 6 +----- .../python/tokenspeed_kernel/ops/gemm/triton.py | 5 +---- .../python/tokenspeed_kernel/ops/gemm/trtllm.py | 5 +---- tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/cuda.py | 6 +----- .../python/tokenspeed_kernel/ops/moe/deepep.py | 5 +---- .../python/tokenspeed_kernel/ops/moe/flashinfer.py | 6 +----- .../python/tokenspeed_kernel/ops/moe/triton.py | 5 +---- .../python/tokenspeed_kernel/ops/moe/triton_kernels.py | 5 +---- .../tokenspeed_kernel/ops/quantization/flashinfer.py | 6 +----- .../python/tokenspeed_kernel/ops/quantization/triton.py | 5 +---- .../python/tokenspeed_kernel/ops/quantization/trtllm.py | 6 +----- tokenspeed-kernel/python/tokenspeed_kernel/selection.py | 5 +---- tokenspeed-kernel/test/test_benchmark.py | 5 +---- tokenspeed-kernel/test/test_numerics.py | 6 +----- tokenspeed-kernel/test/test_plugins.py | 5 +---- tokenspeed-kernel/test/test_selection.py | 5 +---- tokenspeed-kernel/test/utils.py | 5 +---- 27 files changed, 27 insertions(+), 117 deletions(-) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py index ee8709612..583645672 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py @@ -25,10 +25,7 @@ import torch import torch.nn.functional as F from tokenspeed_kernel.platform import Platform -from tokenspeed_kernel.registry import ( - Priority, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, register_kernel from tokenspeed_kernel.signature import ( ScaleFormat, format_signatures, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py index f630cb270..0e116a575 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py @@ -32,10 +32,7 @@ ExpertLocationDispatchInfo, topk_ids_logical_to_physical, ) -from tokenspeed_kernel.registry import ( - Priority, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, register_kernel from tokenspeed_kernel.signature import ( dense_format, format_signature, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/cuda/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/cuda/__init__.py index 9f765d829..7e151da1d 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/cuda/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/cuda/__init__.py @@ -6,10 +6,7 @@ CapabilityRequirement, current_platform, ) -from tokenspeed_kernel.registry import ( - Priority, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, register_kernel from tokenspeed_kernel.signature import format_signatures platform = current_platform() diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_attn/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_attn/__init__.py index b1305098e..ae355b86e 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_attn/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flash_attn/__init__.py @@ -25,11 +25,7 @@ CapabilityRequirement, current_platform, ) -from tokenspeed_kernel.registry import ( - Priority, - error_fn, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, error_fn, register_kernel from tokenspeed_kernel.signature import format_signatures __all__ = [ diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flashinfer/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flashinfer/__init__.py index 1355d2f2d..39e1a076c 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flashinfer/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/flashinfer/__init__.py @@ -29,12 +29,7 @@ CapabilityRequirement, current_platform, ) -from tokenspeed_kernel.registry import ( - ErrorClass, - Priority, - error_fn, - register_kernel, -) +from tokenspeed_kernel.registry import ErrorClass, Priority, error_fn, register_kernel from tokenspeed_kernel.signature import format_signatures platform = current_platform() diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_decode_fp16_gfx950.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_decode_fp16_gfx950.py index 49278cda7..307421087 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_decode_fp16_gfx950.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_decode_fp16_gfx950.py @@ -35,10 +35,7 @@ maximum, ) from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement -from tokenspeed_kernel.registry import ( - Priority, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, register_kernel from tokenspeed_kernel.signature import format_signatures cdna4 = gl.amd.cdna4 diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_prefill_fp16_gfx950.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_prefill_fp16_gfx950.py index 27b00da9a..96526f7a1 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_prefill_fp16_gfx950.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/gluon/mha_prefill_fp16_gfx950.py @@ -35,10 +35,7 @@ maximum, ) from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement -from tokenspeed_kernel.registry import ( - Priority, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, register_kernel from tokenspeed_kernel.signature import format_signatures cdna4 = gl.amd.cdna4 diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/__init__.py index 179db22d4..506cbefa9 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/triton/__init__.py @@ -27,10 +27,7 @@ from tokenspeed_kernel.ops.attention.triton.mha_decode import decode_attention_fwd from tokenspeed_kernel.ops.attention.triton.mha_prefill import prefill_attention_fwd from tokenspeed_kernel.platform import CapabilityRequirement -from tokenspeed_kernel.registry import ( - Priority, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, register_kernel from tokenspeed_kernel.signature import format_signatures diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/triton.py index d895b319d..417c125e7 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/triton.py @@ -27,10 +27,7 @@ import torch from tokenspeed_kernel._triton import tl, triton from tokenspeed_kernel.platform import CapabilityRequirement -from tokenspeed_kernel.registry import ( - Priority, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, register_kernel from tokenspeed_kernel.signature import format_signatures diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py index 18d353911..658198dac 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py @@ -22,10 +22,7 @@ import torch from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement, Platform -from tokenspeed_kernel.registry import ( - Priority, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, register_kernel from tokenspeed_kernel.signature import ( ScaleFormat, format_signatures, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py index 6cee58fb8..00f2a3ccf 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py @@ -27,11 +27,7 @@ Platform, current_platform, ) -from tokenspeed_kernel.registry import ( - Priority, - error_fn, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, error_fn, register_kernel from tokenspeed_kernel.signature import ( ScaleFormat, format_signatures, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py index 9d7348604..b8cd4d38f 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py @@ -30,10 +30,7 @@ import torch from tokenspeed_kernel._triton import tl, triton from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement, Platform -from tokenspeed_kernel.registry import ( - Priority, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, register_kernel from tokenspeed_kernel.signature import ( ScaleFormat, format_signatures, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py index 9d57b51be..b0a534442 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py @@ -27,10 +27,7 @@ import torch from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement -from tokenspeed_kernel.registry import ( - Priority, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, register_kernel from tokenspeed_kernel.signature import ( ScaleFormat, format_signatures, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/cuda.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/cuda.py index 6594d6b58..6a33f1dd5 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/cuda.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/cuda.py @@ -28,11 +28,7 @@ from __future__ import annotations import torch -from tokenspeed_kernel.registry import ( - Priority, - error_fn, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, error_fn, register_kernel from tokenspeed_kernel.signature import format_signatures try: diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/deepep.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/deepep.py index 9d88fe12c..cb6faa2fa 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/deepep.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/deepep.py @@ -24,10 +24,7 @@ import torch from tokenspeed_kernel._triton import tl, triton -from tokenspeed_kernel.registry import ( - Priority, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, register_kernel from tokenspeed_kernel.signature import format_signatures logger = logging.getLogger(__name__) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py index b01c66479..dde96f63e 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py @@ -22,11 +22,7 @@ import torch from tokenspeed_kernel.platform import current_platform -from tokenspeed_kernel.registry import ( - Priority, - error_fn, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, error_fn, register_kernel from tokenspeed_kernel.signature import ( ScaleFormat, dense_format, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py index 70078e8ab..73d9afa65 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton.py @@ -32,10 +32,7 @@ from tokenspeed_kernel.ops.moe.expert_location_dispatch import ( ExpertLocationDispatchInfo, ) -from tokenspeed_kernel.registry import ( - Priority, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, register_kernel from tokenspeed_kernel.signature import format_signatures from tokenspeed_kernel.thirdparty.trtllm import ( moe_align_block_size as _moe_align_block_size, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton_kernels.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton_kernels.py index 73a74b311..3d0d1db67 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton_kernels.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/triton_kernels.py @@ -28,10 +28,7 @@ import tokenspeed_kernel.thirdparty.triton_kernels # noqa: F401 import torch from tokenspeed_kernel.platform import current_platform -from tokenspeed_kernel.registry import ( - Priority, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, register_kernel from tokenspeed_kernel.signature import format_signatures try: diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/flashinfer.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/flashinfer.py index 99d4fe331..e4124b52f 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/flashinfer.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/flashinfer.py @@ -22,11 +22,7 @@ CapabilityRequirement, current_platform, ) -from tokenspeed_kernel.registry import ( - Priority, - error_fn, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, error_fn, register_kernel from tokenspeed_kernel.signature import format_signatures platform = current_platform() diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/triton.py index 30ded5e34..47dfc7682 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/triton.py @@ -18,10 +18,7 @@ import torch from tokenspeed_kernel._triton import tl, triton -from tokenspeed_kernel.registry import ( - Priority, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, register_kernel from tokenspeed_kernel.signature import format_signatures diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/trtllm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/trtllm.py index bd6e9be08..7b57cc8d4 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/trtllm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/trtllm.py @@ -21,11 +21,7 @@ import torch from tokenspeed_kernel.platform import current_platform -from tokenspeed_kernel.registry import ( - Priority, - error_fn, - register_kernel, -) +from tokenspeed_kernel.registry import Priority, error_fn, register_kernel from tokenspeed_kernel.signature import format_signatures platform = current_platform() diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/selection.py b/tokenspeed-kernel/python/tokenspeed_kernel/selection.py index 28e390172..37d7614d7 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/selection.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/selection.py @@ -29,10 +29,7 @@ from typing import Any, Callable, Generator from tokenspeed_kernel.platform import PlatformInfo, current_platform -from tokenspeed_kernel.registry import ( - KernelRegistry, - KernelSpec, -) +from tokenspeed_kernel.registry import KernelRegistry, KernelSpec from tokenspeed_kernel.signature import FormatSignature logger = logging.getLogger(__name__) diff --git a/tokenspeed-kernel/test/test_benchmark.py b/tokenspeed-kernel/test/test_benchmark.py index 82d97649e..0b014e12b 100644 --- a/tokenspeed-kernel/test/test_benchmark.py +++ b/tokenspeed-kernel/test/test_benchmark.py @@ -32,10 +32,7 @@ from tokenspeed_kernel.benchmark.throughput import ThroughputCalculator from tokenspeed_kernel.platform import Platform from tokenspeed_kernel.profiling import ProfilingConfig -from tokenspeed_kernel.registry import ( - KernelRegistry, - KernelSpec, -) +from tokenspeed_kernel.registry import KernelRegistry, KernelSpec from tokenspeed_kernel.signature import format_signatures pytestmark = [ diff --git a/tokenspeed-kernel/test/test_numerics.py b/tokenspeed-kernel/test/test_numerics.py index f256c7cac..7784044aa 100644 --- a/tokenspeed-kernel/test/test_numerics.py +++ b/tokenspeed-kernel/test/test_numerics.py @@ -27,11 +27,7 @@ from tokenspeed_kernel.numerics.tolerance import Tolerance from tokenspeed_kernel.numerics.verify import verify_kernel from tokenspeed_kernel.platform import Platform -from tokenspeed_kernel.registry import ( - KernelRegistry, - KernelSpec, - load_builtin_kernels, -) +from tokenspeed_kernel.registry import KernelRegistry, KernelSpec, load_builtin_kernels from tokenspeed_kernel.signature import ( ScaleFormat, format_signatures, diff --git a/tokenspeed-kernel/test/test_plugins.py b/tokenspeed-kernel/test/test_plugins.py index 8b18ced64..07730990e 100644 --- a/tokenspeed-kernel/test/test_plugins.py +++ b/tokenspeed-kernel/test/test_plugins.py @@ -37,10 +37,7 @@ reset_plugins, ) from tokenspeed_kernel.plugins.cli import main as cli_main -from tokenspeed_kernel.registry import ( - KernelRegistry, - register_kernel, -) +from tokenspeed_kernel.registry import KernelRegistry, register_kernel from tokenspeed_kernel.signature import format_signatures # --------------------------------------------------------------------------- diff --git a/tokenspeed-kernel/test/test_selection.py b/tokenspeed-kernel/test/test_selection.py index e5314fe5a..c3a0357e3 100644 --- a/tokenspeed-kernel/test/test_selection.py +++ b/tokenspeed-kernel/test/test_selection.py @@ -27,10 +27,7 @@ import tokenspeed_kernel.ops.gemm as gemm import torch from tokenspeed_kernel.platform import PlatformInfo -from tokenspeed_kernel.registry import ( - KernelRegistry, - KernelSpec, -) +from tokenspeed_kernel.registry import KernelRegistry, KernelSpec from tokenspeed_kernel.selection import ( AutotuneParams, NoKernelFoundError, diff --git a/tokenspeed-kernel/test/utils.py b/tokenspeed-kernel/test/utils.py index 95bd0a668..60c823ab0 100644 --- a/tokenspeed-kernel/test/utils.py +++ b/tokenspeed-kernel/test/utils.py @@ -24,10 +24,7 @@ import torch from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement -from tokenspeed_kernel.registry import ( - KernelRegistry, - KernelSpec, -) +from tokenspeed_kernel.registry import KernelRegistry, KernelSpec from tokenspeed_kernel.signature import format_signatures From ed45c706e242c57245d440f4720a5fbbdf43047f Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 23 May 2026 20:30:13 +0000 Subject: [PATCH 03/16] Use single-line small signature imports Collapse parenthesized tokenspeed_kernel.signature imports when they import three or fewer symbols. Keep larger imports wrapped, and split the one long three-symbol import so formatter hooks do not rewrap it. Signed-off-by: Lei Zhang --- .../python/tokenspeed_kernel/numerics/reference/gemm.py | 5 +---- .../python/tokenspeed_kernel/numerics/reference/moe.py | 9 ++++----- .../python/tokenspeed_kernel/ops/attention/__init__.py | 5 +---- .../python/tokenspeed_kernel/ops/embedding/__init__.py | 5 +---- .../python/tokenspeed_kernel/ops/gemm/deep_gemm.py | 5 +---- .../python/tokenspeed_kernel/ops/gemm/flashinfer.py | 5 +---- .../python/tokenspeed_kernel/ops/gemm/triton.py | 5 +---- .../python/tokenspeed_kernel/ops/gemm/trtllm.py | 5 +---- .../tokenspeed_kernel/ops/quantization/__init__.py | 5 +---- tokenspeed-kernel/test/test_numerics.py | 5 +---- tokenspeed-kernel/test/test_selection.py | 4 +--- 11 files changed, 14 insertions(+), 44 deletions(-) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py index 583645672..c782fd1c5 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py @@ -26,10 +26,7 @@ import torch.nn.functional as F from tokenspeed_kernel.platform import Platform from tokenspeed_kernel.registry import Priority, register_kernel -from tokenspeed_kernel.signature import ( - ScaleFormat, - format_signatures, -) +from tokenspeed_kernel.signature import ScaleFormat, format_signatures fp8_dtype = Platform.get().fp8e4m3fn.dtype _FP8_BLOCK_SCALE = ScaleFormat( diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py index 0e116a575..ef1129a6e 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py @@ -33,11 +33,10 @@ topk_ids_logical_to_physical, ) from tokenspeed_kernel.registry import Priority, register_kernel -from tokenspeed_kernel.signature import ( - dense_format, - format_signature, - format_signatures, -) +from tokenspeed_kernel.signature import dense_format, format_signature + +# isort: split +from tokenspeed_kernel.signature import format_signatures from tokenspeed_kernel.torch_compile import get_compiler_backend # --------------------------------------------------------------------------- diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py index 1309f9494..cf2a0208e 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py @@ -32,10 +32,7 @@ from tokenspeed_kernel.ops.attention.flash_attn import mha_decode_scheduler_metadata from tokenspeed_kernel.profiling import ShapeCapture, kernel_scope from tokenspeed_kernel.selection import select_kernel -from tokenspeed_kernel.signature import ( - dense_format, - format_signature, -) +from tokenspeed_kernel.signature import dense_format, format_signature AttentionResult = torch.Tensor | tuple[torch.Tensor, torch.Tensor | None] diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py index 93d79f444..cb85a3150 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py @@ -23,10 +23,7 @@ import torch from tokenspeed_kernel.profiling import ShapeCapture, kernel_scope from tokenspeed_kernel.selection import select_kernel -from tokenspeed_kernel.signature import ( - dense_format, - format_signature, -) +from tokenspeed_kernel.signature import dense_format, format_signature @dataclass diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py index 658198dac..12286d699 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py @@ -23,10 +23,7 @@ import torch from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement, Platform from tokenspeed_kernel.registry import Priority, register_kernel -from tokenspeed_kernel.signature import ( - ScaleFormat, - format_signatures, -) +from tokenspeed_kernel.signature import ScaleFormat, format_signatures _fp8_dtype = Platform.get().fp8e4m3fn.dtype _MXFP8_SCALE = ScaleFormat( diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py index 00f2a3ccf..7c98f608e 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py @@ -28,10 +28,7 @@ current_platform, ) from tokenspeed_kernel.registry import Priority, error_fn, register_kernel -from tokenspeed_kernel.signature import ( - ScaleFormat, - format_signatures, -) +from tokenspeed_kernel.signature import ScaleFormat, format_signatures platform = current_platform() _fp8_dtype = Platform.get().fp8e4m3fn.dtype diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py index b8cd4d38f..67f183ced 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py @@ -31,10 +31,7 @@ from tokenspeed_kernel._triton import tl, triton from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement, Platform from tokenspeed_kernel.registry import Priority, register_kernel -from tokenspeed_kernel.signature import ( - ScaleFormat, - format_signatures, -) +from tokenspeed_kernel.signature import ScaleFormat, format_signatures logger = logging.getLogger(__name__) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py index b0a534442..63efd84c6 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py @@ -28,10 +28,7 @@ import torch from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement from tokenspeed_kernel.registry import Priority, register_kernel -from tokenspeed_kernel.signature import ( - ScaleFormat, - format_signatures, -) +from tokenspeed_kernel.signature import ScaleFormat, format_signatures # Re-exported # dsv3_fused_a_gemm supports specific shapes only (see python/tokenspeed/runtime/models/deepseek_v3.py); diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py index 1338ed55d..f21da1f49 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py @@ -21,10 +21,7 @@ import torch from tokenspeed_kernel.profiling import ShapeCapture, kernel_scope from tokenspeed_kernel.selection import select_kernel -from tokenspeed_kernel.signature import ( - dense_format, - format_signature, -) +from tokenspeed_kernel.signature import dense_format, format_signature __all__ = [ "quantize_fp8", diff --git a/tokenspeed-kernel/test/test_numerics.py b/tokenspeed-kernel/test/test_numerics.py index 7784044aa..ba2f1cfd5 100644 --- a/tokenspeed-kernel/test/test_numerics.py +++ b/tokenspeed-kernel/test/test_numerics.py @@ -28,10 +28,7 @@ from tokenspeed_kernel.numerics.verify import verify_kernel from tokenspeed_kernel.platform import Platform from tokenspeed_kernel.registry import KernelRegistry, KernelSpec, load_builtin_kernels -from tokenspeed_kernel.signature import ( - ScaleFormat, - format_signatures, -) +from tokenspeed_kernel.signature import ScaleFormat, format_signatures _fp8_dtype = Platform.get().fp8e4m3fn.dtype diff --git a/tokenspeed-kernel/test/test_selection.py b/tokenspeed-kernel/test/test_selection.py index c3a0357e3..b92a8858c 100644 --- a/tokenspeed-kernel/test/test_selection.py +++ b/tokenspeed-kernel/test/test_selection.py @@ -55,9 +55,7 @@ spec_matches_traits, warmup_selection, ) -from tokenspeed_kernel.signature import ( - format_signatures, -) +from tokenspeed_kernel.signature import format_signatures from utils import register_all_samples pytestmark = pytest.mark.usefixtures("fresh_registry") From 9f3505e6c272d6159568c02a4a91808d2294a4d3 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 23 May 2026 20:39:36 +0000 Subject: [PATCH 04/16] Simplify operand scale signatures Remove ScaleFormat.layout and let TensorFormat.format carry representation-specific names such as mxfp8, mxfp4, and nvfp4. Keep ScaleFormat focused on scale storage, granularity, and optional block shape. Use dense TensorFormat entries for scaled FP8 payloads, with ScaleFormat recording tensor or channel granularity. Update GEMM numerics generation to identify MXFP8 block scales from the tensor format instead of duplicated scale layout metadata. Signed-off-by: Lei Zhang --- .../python/tokenspeed_kernel/numerics/gemm.py | 4 +- .../numerics/reference/gemm.py | 12 ++--- .../tokenspeed_kernel/ops/gemm/__init__.py | 13 ++--- .../tokenspeed_kernel/ops/gemm/deep_gemm.py | 1 - .../tokenspeed_kernel/ops/gemm/flashinfer.py | 2 - .../tokenspeed_kernel/ops/gemm/triton.py | 15 +++--- .../tokenspeed_kernel/ops/gemm/trtllm.py | 1 - .../tokenspeed_kernel/ops/moe/__init__.py | 7 +-- .../tokenspeed_kernel/ops/moe/flashinfer.py | 7 +-- .../python/tokenspeed_kernel/signature.py | 16 +++--- .../test/test_kernel_api_selection.py | 49 +++++++++++++++++++ tokenspeed-kernel/test/test_numerics.py | 1 - tokenspeed-kernel/test/test_registry.py | 1 - 13 files changed, 78 insertions(+), 51 deletions(-) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py index 976ae20e0..51d12663d 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py @@ -149,7 +149,7 @@ def _scale_for_format( if scale is None: return None - if scale.granularity == "block" and scale.layout == "mxfp8": + if scale.granularity == "block" and tensor_format.format == "mxfp8": block_n, block_k = block_size or [128, 128] k_tiles = math.ceil(K / block_k) if role == "a": @@ -158,7 +158,7 @@ def _scale_for_format( n_tiles = math.ceil(N / block_n) return self._generate_scales((n_tiles, k_tiles), scale.storage_dtype) - if scale.granularity == "per_channel": + if scale.granularity == "channel": return self._generate_scales( (M,) if role == "a" else (N,), scale.storage_dtype, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py index c782fd1c5..e46102a31 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py @@ -33,18 +33,16 @@ storage_dtype=torch.float32, granularity="block", block_shape=(128, 128), - layout="mxfp8", ) _FP8_TENSOR_SCALE = ScaleFormat( storage_dtype=torch.float32, - granularity="per_tensor", - layout="scaled", + granularity="tensor", ) _MXFP8_FORMAT_SIGNATURES = format_signatures( ("a", "b"), "mxfp8", {fp8_dtype}, scale=_FP8_BLOCK_SCALE ) -_FP8_PER_TENSOR_FORMAT_SIGNATURES = format_signatures( - ("a", "b"), "fp8", {fp8_dtype}, scale=_FP8_TENSOR_SCALE +_FP8_TENSOR_FORMAT_SIGNATURES = format_signatures( + ("a", "b"), "dense", {fp8_dtype}, scale=_FP8_TENSOR_SCALE ) _DENSE_GEMM_FORMAT_SIGNATURES = format_signatures( ("a", "b"), "dense", {torch.bfloat16, torch.float16, torch.float32} @@ -111,7 +109,7 @@ def torch_mm_fp8_blockscale( "mm", name="torch_mm_fp8_scaled_mnk", solution="reference", - signatures=_FP8_PER_TENSOR_FORMAT_SIGNATURES, + signatures=_FP8_TENSOR_FORMAT_SIGNATURES, traits={ "b_layout": frozenset({"NK"}), }, @@ -153,7 +151,7 @@ def torch_mm_fp8_scaled_mnk( "mm", name="torch_mm_fp8_scaled_nkm", solution="reference", - signatures=_FP8_PER_TENSOR_FORMAT_SIGNATURES, + signatures=_FP8_TENSOR_FORMAT_SIGNATURES, traits={ "b_layout": frozenset({"KN"}), }, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py index 6a7d58bf6..013150417 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py @@ -68,12 +68,12 @@ def _infer_scale_type( A_scales: torch.Tensor | None, B_scales: torch.Tensor | None, ) -> str | None: - """For fp8, distinguish per-tensor from per-channel scaling.""" + """For fp8, distinguish tensor-scalar from channel/vector scaling.""" if A_scales is None or B_scales is None: return None if A_scales.numel() == 1 and B_scales.numel() == 1: - return "per_tensor" - return "per_channel" + return "tensor" + return "channel" def _scale_storage_dtype(*scales: torch.Tensor | None) -> torch.dtype: @@ -98,7 +98,6 @@ def _gemm_format_signature( storage_dtype=_scale_storage_dtype(A_scales, B_scales), granularity="block", block_shape=tuple(block_size) if block_size is not None else None, - layout="mxfp8", ) a_storage_dtype = _fp8_dtype if A_scales is None else A.dtype return format_signature( @@ -109,17 +108,15 @@ def _gemm_format_signature( scale = ScaleFormat( storage_dtype=_scale_storage_dtype(A_scales, B_scales), granularity=_infer_scale_type(A_scales, B_scales) or "unknown", - layout="scaled", ) return format_signature( - a=tensor_format("fp8", A.dtype, scale=scale), - b=tensor_format("fp8", B.dtype, scale=scale), + a=tensor_format("dense", A.dtype, scale=scale), + b=tensor_format("dense", B.dtype, scale=scale), ) if quant == "nvfp4": scale = ScaleFormat( storage_dtype=_scale_storage_dtype(A_scales, B_scales), granularity="block", - layout="nvfp4", ) return format_signature( a=tensor_format("nvfp4", A.dtype, scale=scale), diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py index 12286d699..2f621e30a 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/deep_gemm.py @@ -30,7 +30,6 @@ storage_dtype=torch.float32, granularity="block", block_shape=(128, 128), - layout="mxfp8", ) _MXFP8_FORMAT_SIGNATURES = format_signatures( ("a", "b"), "mxfp8", {_fp8_dtype}, scale=_MXFP8_SCALE diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py index 7c98f608e..2cc0ab42d 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/flashinfer.py @@ -38,12 +38,10 @@ storage_dtype=torch.float32, granularity="block", block_shape=(128, 128), - layout="mxfp8", ) _NVFP4_SCALE = ScaleFormat( storage_dtype=torch.float32, granularity="block", - layout="nvfp4", ) _MXFP8_FORMAT_SIGNATURES = format_signatures( ("a", "b"), "mxfp8", {_fp8_dtype}, scale=_MXFP8_SCALE diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py index 67f183ced..80dd1ae2d 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py @@ -40,24 +40,21 @@ storage_dtype=torch.float32, granularity="block", block_shape=(128, 128), - layout="mxfp8", ) -_FP8_PER_TENSOR_SCALE = ScaleFormat( +_FP8_TENSOR_SCALE = ScaleFormat( storage_dtype=torch.float32, - granularity="per_tensor", - layout="scaled", + granularity="tensor", ) -_FP8_PER_CHANNEL_SCALE = ScaleFormat( +_FP8_CHANNEL_SCALE = ScaleFormat( storage_dtype=torch.float32, - granularity="per_channel", - layout="scaled", + granularity="channel", ) _MXFP8_FORMAT_SIGNATURES = format_signatures( ("a", "b"), "mxfp8", {_fp8_dtype}, scale=_MXFP8_SCALE ) _FP8_SCALED_FORMAT_SIGNATURES = format_signatures( - ("a", "b"), "fp8", {_fp8_dtype}, scale=_FP8_PER_TENSOR_SCALE -) | format_signatures(("a", "b"), "fp8", {_fp8_dtype}, scale=_FP8_PER_CHANNEL_SCALE) + ("a", "b"), "dense", {_fp8_dtype}, scale=_FP8_TENSOR_SCALE +) | format_signatures(("a", "b"), "dense", {_fp8_dtype}, scale=_FP8_CHANNEL_SCALE) def prepare_block_fp8_matmul_inputs( diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py index 63efd84c6..d323d01b2 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/trtllm.py @@ -39,7 +39,6 @@ _NVFP4_SCALE = ScaleFormat( storage_dtype=torch.float32, granularity="block", - layout="nvfp4", ) _NVFP4_FORMAT_SIGNATURES = format_signatures( ("a", "b"), "nvfp4", _fp4_dtypes, scale=_NVFP4_SCALE diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py index d5e8fd0b7..de5f73c37 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py @@ -125,18 +125,15 @@ def adjust(self, spec, platform, traits): _FP8_SCALE = ScaleFormat( storage_dtype=torch.float32, granularity="block", - layout="fp8", ) _NVFP4_SCALE = ScaleFormat( storage_dtype=torch.float32, granularity="block", - layout="nvfp4", ) _MXFP4_SCALE = ScaleFormat( storage_dtype=torch.uint8, granularity="block", block_shape=(32,), - layout="ue8m0", ) @@ -156,7 +153,7 @@ def _moe_fused_format_signature( weight_format: str, ): if weight_format == WEIGHT_FP8: - weight = tensor_format("fp8", torch.float8_e4m3fn, scale=_FP8_SCALE) + weight = tensor_format("dense", torch.float8_e4m3fn, scale=_FP8_SCALE) elif weight_format == WEIGHT_NVFP4: weight = tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE) elif weight_format == WEIGHT_MXFP4: @@ -171,7 +168,7 @@ def _moe_fused_format_signature( elif storage_dtype == torch.uint8 and weight_format == WEIGHT_MXFP4: x = tensor_format("mxfp4", torch.uint8, scale=_MXFP4_SCALE) elif storage_dtype == torch.float8_e4m3fn: - x = tensor_format("fp8", storage_dtype, scale=_FP8_SCALE) + x = tensor_format("dense", storage_dtype, scale=_FP8_SCALE) else: x = dense_format(storage_dtype) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py index dde96f63e..5bb68e581 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py @@ -36,18 +36,15 @@ _FP8_SCALE = ScaleFormat( storage_dtype=torch.float32, granularity="block", - layout="fp8", ) _NVFP4_SCALE = ScaleFormat( storage_dtype=torch.float32, granularity="block", - layout="nvfp4", ) _MXFP4_SCALE = ScaleFormat( storage_dtype=torch.uint8, granularity="block", block_shape=(32,), - layout="ue8m0", ) _BF16_FUSED_FORMAT_SIGNATURES = frozenset( { @@ -77,8 +74,8 @@ weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), ), format_signature( - x=tensor_format("fp8", torch.float8_e4m3fn, scale=_FP8_SCALE), - weight=tensor_format("fp8", torch.float8_e4m3fn, scale=_FP8_SCALE), + x=tensor_format("dense", torch.float8_e4m3fn, scale=_FP8_SCALE), + weight=tensor_format("dense", torch.float8_e4m3fn, scale=_FP8_SCALE), ), } ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/signature.py b/tokenspeed-kernel/python/tokenspeed_kernel/signature.py index a8cb30b03..6eebf6b47 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/signature.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/signature.py @@ -43,24 +43,20 @@ class ScaleFormat: Args: storage_dtype: Physical dtype used by the scale tensor. - granularity: Scale granularity, such as "per_tensor", - "per_channel", or "block". + granularity: Scale granularity, such as "tensor", "channel", + or "block". block_shape: Logical block shape covered by each scale value when granularity is block-based. - layout: Optional backend or format layout name for the scale tensor. """ storage_dtype: torch.dtype granularity: str block_shape: tuple[int, ...] | None = None - layout: str | None = None def __str__(self) -> str: parts = [self.granularity, f"storage={self.storage_dtype}"] if self.block_shape is not None: parts.append(f"block={self.block_shape}") - if self.layout is not None: - parts.append(f"layout={self.layout}") return "scale(" + ", ".join(parts) + ")" @@ -71,7 +67,8 @@ class TensorFormat: Args: storage_dtype: Physical dtype used by the main tensor payload. format: Logical representation format, such as "dense", - "fp8", "mxfp4", or "nvfp4". + "mxfp8", "mxfp4", or "nvfp4". Use "dense" for ordinary + dense tensors, including dense FP8 payloads with optional scales. scale: Optional scale sidecar metadata bundled with this tensor role. """ @@ -146,8 +143,9 @@ def tensor_format( """Construct a format for one tensor role. Args: - format: Logical representation format, such as "dense", "fp8", - "mxfp4", or "nvfp4". + format: Logical representation format, such as "dense", "mxfp8", + "mxfp4", or "nvfp4". Use "dense" for ordinary dense tensors, + including dense FP8 payloads with optional scales. storage_dtype: Physical dtype used by the main tensor payload. scale: Optional scale sidecar metadata bundled with this tensor role. """ diff --git a/tokenspeed-kernel/test/test_kernel_api_selection.py b/tokenspeed-kernel/test/test_kernel_api_selection.py index 6e306145e..7b89acd65 100644 --- a/tokenspeed-kernel/test/test_kernel_api_selection.py +++ b/tokenspeed-kernel/test/test_kernel_api_selection.py @@ -181,6 +181,55 @@ def test_gemm_mxfp8_online_activation_signature_uses_quantized_storage() -> None assert b_format.storage_dtype == _fp8_dtype() +def test_gemm_fp8_scaled_signature_uses_dense_format_with_scale() -> None: + a = torch.empty((4, 128), dtype=_fp8_dtype()) + b = torch.empty((128, 128), dtype=_fp8_dtype()) + a_scales = torch.empty((1,), dtype=torch.float32) + b_scales = torch.empty((1,), dtype=torch.float32) + + signature = _gemm_pkg._gemm_format_signature( + a, + b, + a_scales, + b_scales, + torch.bfloat16, + "fp8", + None, + ) + + for role in ("a", "b"): + tensor_format = signature.format_for(role) + assert tensor_format is not None + assert tensor_format.format == "dense" + assert tensor_format.storage_dtype == _fp8_dtype() + assert tensor_format.scale is not None + assert tensor_format.scale.granularity == "tensor" + assert tensor_format.scale.storage_dtype == torch.float32 + + +def test_gemm_fp8_scaled_signature_uses_channel_granularity() -> None: + a = torch.empty((4, 128), dtype=_fp8_dtype()) + b = torch.empty((128, 128), dtype=_fp8_dtype()) + a_scales = torch.empty((4,), dtype=torch.float32) + b_scales = torch.empty((128,), dtype=torch.float32) + + signature = _gemm_pkg._gemm_format_signature( + a, + b, + a_scales, + b_scales, + torch.bfloat16, + "fp8", + None, + ) + + for role in ("a", "b"): + tensor_format = signature.format_for(role) + assert tensor_format is not None + assert tensor_format.scale is not None + assert tensor_format.scale.granularity == "channel" + + def _mm_nvfp4() -> torch.Tensor: a = torch.empty((4, 64), dtype=torch.uint8) b = torch.empty((128, 64), dtype=torch.uint8) diff --git a/tokenspeed-kernel/test/test_numerics.py b/tokenspeed-kernel/test/test_numerics.py index ba2f1cfd5..a5dc76f7c 100644 --- a/tokenspeed-kernel/test/test_numerics.py +++ b/tokenspeed-kernel/test/test_numerics.py @@ -66,7 +66,6 @@ def test_gemm_input_generator_uses_signature_scale_metadata() -> None: storage_dtype=torch.float32, granularity="block", block_shape=(128, 128), - layout="mxfp8", ) signature = next( iter(format_signatures(("a", "b"), "mxfp8", {_fp8_dtype}, scale=scale)) diff --git a/tokenspeed-kernel/test/test_registry.py b/tokenspeed-kernel/test/test_registry.py index ce242dcc1..bf1f23b3f 100644 --- a/tokenspeed-kernel/test/test_registry.py +++ b/tokenspeed-kernel/test/test_registry.py @@ -65,7 +65,6 @@ def test_format_signature_bundles_scale_metadata(self): storage_dtype=torch.float32, granularity="block", block_shape=(32,), - layout="ue8m0", ) mixed = format_signature( a=dense_format(torch.bfloat16), From 36240d1700956e87794cf6ee1b273c47ff858db2 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 23 May 2026 20:48:17 +0000 Subject: [PATCH 05/16] Keep FP8 as an explicit tensor format Restore fp8 as a distinct TensorFormat.format value for scaled FP8 GEMM and MoE signatures, while keeping ScaleFormat.layout removed and scale granularity normalized to tensor, channel, and block. Update test sample registrations to use register_kernel(..., signatures=...) instead of constructing KernelSpec(format_signatures=...) directly. Signed-off-by: Lei Zhang --- .../numerics/reference/gemm.py | 2 +- .../tokenspeed_kernel/ops/gemm/__init__.py | 4 +- .../tokenspeed_kernel/ops/gemm/triton.py | 4 +- .../tokenspeed_kernel/ops/moe/__init__.py | 4 +- .../tokenspeed_kernel/ops/moe/flashinfer.py | 4 +- .../python/tokenspeed_kernel/signature.py | 8 +- .../test/test_kernel_api_selection.py | 4 +- tokenspeed-kernel/test/utils.py | 281 ++++++++---------- 8 files changed, 138 insertions(+), 173 deletions(-) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py index e46102a31..014ae4f99 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/gemm.py @@ -42,7 +42,7 @@ ("a", "b"), "mxfp8", {fp8_dtype}, scale=_FP8_BLOCK_SCALE ) _FP8_TENSOR_FORMAT_SIGNATURES = format_signatures( - ("a", "b"), "dense", {fp8_dtype}, scale=_FP8_TENSOR_SCALE + ("a", "b"), "fp8", {fp8_dtype}, scale=_FP8_TENSOR_SCALE ) _DENSE_GEMM_FORMAT_SIGNATURES = format_signatures( ("a", "b"), "dense", {torch.bfloat16, torch.float16, torch.float32} diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py index 013150417..b777a4ba4 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py @@ -110,8 +110,8 @@ def _gemm_format_signature( granularity=_infer_scale_type(A_scales, B_scales) or "unknown", ) return format_signature( - a=tensor_format("dense", A.dtype, scale=scale), - b=tensor_format("dense", B.dtype, scale=scale), + a=tensor_format("fp8", A.dtype, scale=scale), + b=tensor_format("fp8", B.dtype, scale=scale), ) if quant == "nvfp4": scale = ScaleFormat( diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py index 80dd1ae2d..06aff7789 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py @@ -53,8 +53,8 @@ ("a", "b"), "mxfp8", {_fp8_dtype}, scale=_MXFP8_SCALE ) _FP8_SCALED_FORMAT_SIGNATURES = format_signatures( - ("a", "b"), "dense", {_fp8_dtype}, scale=_FP8_TENSOR_SCALE -) | format_signatures(("a", "b"), "dense", {_fp8_dtype}, scale=_FP8_CHANNEL_SCALE) + ("a", "b"), "fp8", {_fp8_dtype}, scale=_FP8_TENSOR_SCALE +) | format_signatures(("a", "b"), "fp8", {_fp8_dtype}, scale=_FP8_CHANNEL_SCALE) def prepare_block_fp8_matmul_inputs( diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py index de5f73c37..1644997f3 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py @@ -153,7 +153,7 @@ def _moe_fused_format_signature( weight_format: str, ): if weight_format == WEIGHT_FP8: - weight = tensor_format("dense", torch.float8_e4m3fn, scale=_FP8_SCALE) + weight = tensor_format("fp8", torch.float8_e4m3fn, scale=_FP8_SCALE) elif weight_format == WEIGHT_NVFP4: weight = tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE) elif weight_format == WEIGHT_MXFP4: @@ -168,7 +168,7 @@ def _moe_fused_format_signature( elif storage_dtype == torch.uint8 and weight_format == WEIGHT_MXFP4: x = tensor_format("mxfp4", torch.uint8, scale=_MXFP4_SCALE) elif storage_dtype == torch.float8_e4m3fn: - x = tensor_format("dense", storage_dtype, scale=_FP8_SCALE) + x = tensor_format("fp8", storage_dtype, scale=_FP8_SCALE) else: x = dense_format(storage_dtype) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py index 5bb68e581..7dd549a89 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py @@ -74,8 +74,8 @@ weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), ), format_signature( - x=tensor_format("dense", torch.float8_e4m3fn, scale=_FP8_SCALE), - weight=tensor_format("dense", torch.float8_e4m3fn, scale=_FP8_SCALE), + x=tensor_format("fp8", torch.float8_e4m3fn, scale=_FP8_SCALE), + weight=tensor_format("fp8", torch.float8_e4m3fn, scale=_FP8_SCALE), ), } ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/signature.py b/tokenspeed-kernel/python/tokenspeed_kernel/signature.py index 6eebf6b47..4372d13a9 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/signature.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/signature.py @@ -67,8 +67,7 @@ class TensorFormat: Args: storage_dtype: Physical dtype used by the main tensor payload. format: Logical representation format, such as "dense", - "mxfp8", "mxfp4", or "nvfp4". Use "dense" for ordinary - dense tensors, including dense FP8 payloads with optional scales. + "fp8", "mxfp8", "mxfp4", or "nvfp4". scale: Optional scale sidecar metadata bundled with this tensor role. """ @@ -143,9 +142,8 @@ def tensor_format( """Construct a format for one tensor role. Args: - format: Logical representation format, such as "dense", "mxfp8", - "mxfp4", or "nvfp4". Use "dense" for ordinary dense tensors, - including dense FP8 payloads with optional scales. + format: Logical representation format, such as "dense", "fp8", + "mxfp8", "mxfp4", or "nvfp4". storage_dtype: Physical dtype used by the main tensor payload. scale: Optional scale sidecar metadata bundled with this tensor role. """ diff --git a/tokenspeed-kernel/test/test_kernel_api_selection.py b/tokenspeed-kernel/test/test_kernel_api_selection.py index 7b89acd65..fc84e23b5 100644 --- a/tokenspeed-kernel/test/test_kernel_api_selection.py +++ b/tokenspeed-kernel/test/test_kernel_api_selection.py @@ -181,7 +181,7 @@ def test_gemm_mxfp8_online_activation_signature_uses_quantized_storage() -> None assert b_format.storage_dtype == _fp8_dtype() -def test_gemm_fp8_scaled_signature_uses_dense_format_with_scale() -> None: +def test_gemm_fp8_scaled_signature_uses_fp8_format_with_scale() -> None: a = torch.empty((4, 128), dtype=_fp8_dtype()) b = torch.empty((128, 128), dtype=_fp8_dtype()) a_scales = torch.empty((1,), dtype=torch.float32) @@ -200,7 +200,7 @@ def test_gemm_fp8_scaled_signature_uses_dense_format_with_scale() -> None: for role in ("a", "b"): tensor_format = signature.format_for(role) assert tensor_format is not None - assert tensor_format.format == "dense" + assert tensor_format.format == "fp8" assert tensor_format.storage_dtype == _fp8_dtype() assert tensor_format.scale is not None assert tensor_format.scale.granularity == "tensor" diff --git a/tokenspeed-kernel/test/utils.py b/tokenspeed-kernel/test/utils.py index 60c823ab0..414be3d04 100644 --- a/tokenspeed-kernel/test/utils.py +++ b/tokenspeed-kernel/test/utils.py @@ -24,8 +24,10 @@ import torch from tokenspeed_kernel.platform import ArchVersion, CapabilityRequirement -from tokenspeed_kernel.registry import KernelRegistry, KernelSpec -from tokenspeed_kernel.signature import format_signatures +from tokenspeed_kernel.registry import KernelRegistry, register_kernel +from tokenspeed_kernel.signature import FormatSignature, format_signatures + +SampleRegistration = tuple[dict, Callable] def dummy_impl(name: str) -> Callable: @@ -36,19 +38,45 @@ def impl(*args, **kwargs): return impl -def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: - specs: dict[str, tuple[KernelSpec, Callable]] = {} +def _sample_registration( + name: str, + family: str, + mode: str, + solution: str, + signatures: frozenset[FormatSignature], + *, + features: frozenset[str] | None = None, + capability: CapabilityRequirement | None = None, + priority: int = 10, + tags: frozenset[str] | None = None, +) -> SampleRegistration: + return ( + { + "family": family, + "mode": mode, + "name": name, + "solution": solution, + "features": features, + "capability": capability, + "signatures": signatures, + "priority": priority, + "tags": tags, + }, + dummy_impl(name), + ) - specs["flashinfer_decode"] = ( - KernelSpec( - name="flashinfer_decode", - family="attention", - mode="decode", - solution="flashinfer", - features=frozenset({"paged"}), - format_signatures=format_signatures( + +def make_sample_specs() -> dict[str, SampleRegistration]: + return { + "flashinfer_decode": _sample_registration( + "flashinfer_decode", + "attention", + "decode", + "flashinfer", + format_signatures( ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} ), + features=frozenset({"paged"}), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(8, 0), @@ -56,32 +84,24 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: priority=18, tags=frozenset({"latency"}), ), - dummy_impl("flashinfer_decode"), - ) - - specs["triton_decode"] = ( - KernelSpec( - name="triton_decode", - family="attention", - mode="decode", - solution="triton", - features=frozenset({"paged"}), - format_signatures=format_signatures( + "triton_decode": _sample_registration( + "triton_decode", + "attention", + "decode", + "triton", + format_signatures( ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} ), + features=frozenset({"paged"}), priority=10, tags=frozenset({"portability"}), ), - dummy_impl("triton_decode"), - ) - - specs["cutlass_prefill"] = ( - KernelSpec( - name="cutlass_prefill", - family="attention", - mode="prefill", - solution="cutlass", - format_signatures=format_signatures( + "cutlass_prefill": _sample_registration( + "cutlass_prefill", + "attention", + "prefill", + "cutlass", + format_signatures( ("q", "k", "v"), "dense", {torch.float16, torch.bfloat16} ), capability=CapabilityRequirement( @@ -91,56 +111,40 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: priority=16, tags=frozenset({"throughput"}), ), - dummy_impl("cutlass_prefill"), - ) - - specs["reference_decode"] = ( - KernelSpec( - name="reference_decode", - family="attention", - mode="decode", - solution="reference", - features=frozenset({"paged"}), - format_signatures=format_signatures( + "reference_decode": _sample_registration( + "reference_decode", + "attention", + "decode", + "reference", + format_signatures( ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16, torch.float32}, ), + features=frozenset({"paged"}), capability=CapabilityRequirement(), priority=10, tags=frozenset({"determinism", "portability"}), ), - dummy_impl("reference_decode"), - ) - - specs["aiter_decode"] = ( - KernelSpec( - name="aiter_decode", - family="attention", - mode="decode", - solution="aiter", - features=frozenset({"paged"}), - format_signatures=format_signatures( + "aiter_decode": _sample_registration( + "aiter_decode", + "attention", + "decode", + "aiter", + format_signatures( ("q", "k_cache", "v_cache"), "dense", {torch.float16, torch.bfloat16} ), - capability=CapabilityRequirement( - vendors=frozenset({"amd"}), - ), + features=frozenset({"paged"}), + capability=CapabilityRequirement(vendors=frozenset({"amd"})), priority=16, tags=frozenset({"latency", "portability"}), ), - dummy_impl("aiter_decode"), - ) - - specs["cutlass_gemm"] = ( - KernelSpec( - name="cutlass_gemm", - family="gemm", - mode="mm", - solution="cutlass", - format_signatures=format_signatures( - ("a", "b"), "dense", {torch.float16, torch.bfloat16} - ), + "cutlass_gemm": _sample_registration( + "cutlass_gemm", + "gemm", + "mm", + "cutlass", + format_signatures(("a", "b"), "dense", {torch.float16, torch.bfloat16}), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(8, 0), @@ -148,33 +152,21 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: priority=15, tags=frozenset({"throughput", "latency"}), ), - dummy_impl("cutlass_gemm"), - ) - - specs["triton_gemm"] = ( - KernelSpec( - name="triton_gemm", - family="gemm", - mode="mm", - solution="triton", - format_signatures=format_signatures( - ("a", "b"), "dense", {torch.float16, torch.bfloat16} - ), + "triton_gemm": _sample_registration( + "triton_gemm", + "gemm", + "mm", + "triton", + format_signatures(("a", "b"), "dense", {torch.float16, torch.bfloat16}), priority=10, tags=frozenset({"portability"}), ), - dummy_impl("triton_gemm"), - ) - - specs["cutlass_grouped_gemm"] = ( - KernelSpec( - name="cutlass_grouped_gemm", - family="gemm", - mode="grouped_mm", - solution="cutlass", - format_signatures=format_signatures( - ("a", "b"), "dense", {torch.float16, torch.bfloat16} - ), + "cutlass_grouped_gemm": _sample_registration( + "cutlass_grouped_gemm", + "gemm", + "grouped_mm", + "cutlass", + format_signatures(("a", "b"), "dense", {torch.float16, torch.bfloat16}), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(9, 0), @@ -182,46 +174,32 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: priority=16, tags=frozenset({"throughput"}), ), - dummy_impl("cutlass_grouped_gemm"), - ) - - specs["triton_grouped_gemm"] = ( - KernelSpec( - name="triton_grouped_gemm", - family="gemm", - mode="grouped_mm", - solution="triton", - format_signatures=format_signatures( - ("a", "b"), "dense", {torch.float16, torch.bfloat16} - ), + "triton_grouped_gemm": _sample_registration( + "triton_grouped_gemm", + "gemm", + "grouped_mm", + "triton", + format_signatures(("a", "b"), "dense", {torch.float16, torch.bfloat16}), priority=10, tags=frozenset({"portability"}), ), - dummy_impl("triton_grouped_gemm"), - ) - - specs["triton_fused_moe"] = ( - KernelSpec( - name="triton_fused_moe", - family="moe", - mode="fused", - solution="triton", - format_signatures=format_signatures( + "triton_fused_moe": _sample_registration( + "triton_fused_moe", + "moe", + "fused", + "triton", + format_signatures( ("x", "weight"), "dense", {torch.float16, torch.bfloat16} ), priority=12, tags=frozenset({"throughput", "portability"}), ), - dummy_impl("triton_fused_moe"), - ) - - specs["cutlass_fused_moe"] = ( - KernelSpec( - name="cutlass_fused_moe", - family="moe", - mode="fused", - solution="cutlass", - format_signatures=format_signatures( + "cutlass_fused_moe": _sample_registration( + "cutlass_fused_moe", + "moe", + "fused", + "cutlass", + format_signatures( ("x", "weight"), "dense", {torch.float16, torch.bfloat16} ), capability=CapabilityRequirement( @@ -231,33 +209,21 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: priority=15, tags=frozenset({"latency", "throughput"}), ), - dummy_impl("cutlass_fused_moe"), - ) - - specs["triton_modular_moe"] = ( - KernelSpec( - name="triton_modular_moe", - family="moe", - mode="modular", - solution="triton", - format_signatures=format_signatures( - "x", "dense", {torch.float16, torch.bfloat16} - ), + "triton_modular_moe": _sample_registration( + "triton_modular_moe", + "moe", + "modular", + "triton", + format_signatures("x", "dense", {torch.float16, torch.bfloat16}), priority=10, tags=frozenset({"determinism", "portability"}), ), - dummy_impl("triton_modular_moe"), - ) - - specs["cutlass_modular_moe"] = ( - KernelSpec( - name="cutlass_modular_moe", - family="moe", - mode="modular", - solution="cutlass", - format_signatures=format_signatures( - "x", "dense", {torch.float16, torch.bfloat16} - ), + "cutlass_modular_moe": _sample_registration( + "cutlass_modular_moe", + "moe", + "modular", + "cutlass", + format_signatures("x", "dense", {torch.float16, torch.bfloat16}), capability=CapabilityRequirement( vendors=frozenset({"nvidia"}), min_arch_version=ArchVersion(8, 0), @@ -265,12 +231,13 @@ def make_sample_specs() -> dict[str, tuple[KernelSpec, Callable]]: priority=14, tags=frozenset({"throughput"}), ), - dummy_impl("cutlass_modular_moe"), - ) - - return specs + } -def register_all_samples(registry: KernelRegistry, specs: dict) -> None: - for spec, impl in specs.values(): - registry.register(spec, impl) +def register_all_samples( + registry: KernelRegistry, samples: dict[str, SampleRegistration] +) -> None: + if registry is not KernelRegistry.get(): + raise ValueError("sample registrations must target the active KernelRegistry") + for options, impl in samples.values(): + register_kernel(**options)(impl) From f1e1be9c3cb729b62827c47950581851d753432b Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 23 May 2026 21:06:02 +0000 Subject: [PATCH 06/16] Clarify primary dtype signature lookup Rename KernelSpec.format_signature_for_storage_dtype to format_signature_for_primary_storage_dtype so callers see that the helper matches the signature primary storage dtype used by dtype-oriented filters. Signed-off-by: Lei Zhang --- .../python/tokenspeed_kernel/benchmark/runner.py | 8 ++++---- .../python/tokenspeed_kernel/numerics/cli.py | 2 +- .../python/tokenspeed_kernel/numerics/verify.py | 2 +- .../python/tokenspeed_kernel/ops/gemm/__init__.py | 2 +- tokenspeed-kernel/python/tokenspeed_kernel/registry.py | 7 ++++--- tokenspeed-kernel/test/conftest.py | 2 +- tokenspeed-kernel/test/test_numerics.py | 2 +- tokenspeed-kernel/test/test_runtime_callsite_selection.py | 2 +- 8 files changed, 14 insertions(+), 13 deletions(-) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/benchmark/runner.py b/tokenspeed-kernel/python/tokenspeed_kernel/benchmark/runner.py index 508dfbc7f..b9377797c 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/benchmark/runner.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/benchmark/runner.py @@ -139,7 +139,7 @@ def _benchmark_one_shape( if not spec_matches_shape_traits(spec, shape): return None - signature = spec.format_signature_for_storage_dtype(dtype) + signature = spec.format_signature_for_primary_storage_dtype(dtype) if signature is None: return None @@ -229,7 +229,7 @@ def _verify_one_shape( return None, None, None registry = KernelRegistry.get() - signature = spec.format_signature_for_storage_dtype(dtype) + signature = spec.format_signature_for_primary_storage_dtype(dtype) if signature is None: return None, None, None @@ -294,7 +294,7 @@ def _benchmark_kernel_impl( if spec is None: raise ValueError(f"Kernel {kernel_name!r} is not registered") - if spec.format_signature_for_storage_dtype(dtype) is None: + if spec.format_signature_for_primary_storage_dtype(dtype) is None: raise ValueError( f"Kernel {kernel_name!r} does not support primary storage dtype={dtype}" ) @@ -355,7 +355,7 @@ def _benchmark_op_impl( op_mode, platform=platform, ) - if spec.format_signature_for_storage_dtype(dtype) is not None + if spec.format_signature_for_primary_storage_dtype(dtype) is not None ] results: list[BenchmarkResult] = [] diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/cli.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/cli.py index 62e361e7f..293f648e2 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/cli.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/cli.py @@ -83,7 +83,7 @@ def _iter_candidate_specs( specs = [ s for s in specs - if s.format_signature_for_storage_dtype(dtype_filter) is not None + if s.format_signature_for_primary_storage_dtype(dtype_filter) is not None ] specs.sort(key=lambda s: (s.family, s.mode, s.name)) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py index 7636c051e..314835462 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py @@ -86,7 +86,7 @@ def verify_kernel( if kernel is None: raise ValueError(f"Kernel implementation for {kernel_name!r} is missing") - signature = spec.format_signature_for_storage_dtype(dtype) + signature = spec.format_signature_for_primary_storage_dtype(dtype) if signature is None: raise ValueError( f"Kernel {kernel_name!r} does not support primary storage dtype={dtype}" diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py index b777a4ba4..b5c20d444 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py @@ -68,7 +68,7 @@ def _infer_scale_type( A_scales: torch.Tensor | None, B_scales: torch.Tensor | None, ) -> str | None: - """For fp8, distinguish tensor-scalar from channel/vector scaling.""" + """For fp8, distinguish tensor/channel/scalar scaling.""" if A_scales is None or B_scales is None: return None if A_scales.numel() == 1 and B_scales.numel() == 1: diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/registry.py b/tokenspeed-kernel/python/tokenspeed_kernel/registry.py index 2dc84dfb6..9df5f5ccd 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/registry.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/registry.py @@ -164,12 +164,13 @@ class KernelSpec: def supports_format_signature(self, format_signature: FormatSignature) -> bool: return format_signature in self.format_signatures - def format_signature_for_storage_dtype( + def format_signature_for_primary_storage_dtype( self, - storage_dtype: torch.dtype, + primary_storage_dtype: torch.dtype, ) -> FormatSignature | None: + """Return the first signature matching a dtype-oriented filter.""" for signature in sorted(self.format_signatures, key=str): - if signature.primary_storage_dtype() == storage_dtype: + if signature.primary_storage_dtype() == primary_storage_dtype: return signature return None diff --git a/tokenspeed-kernel/test/conftest.py b/tokenspeed-kernel/test/conftest.py index fc7491e99..22d0a9f8c 100644 --- a/tokenspeed-kernel/test/conftest.py +++ b/tokenspeed-kernel/test/conftest.py @@ -62,7 +62,7 @@ def _require( platform=current_platform(), solution=solution, ) - if spec.format_signature_for_storage_dtype(dtype) is not None + if spec.format_signature_for_primary_storage_dtype(dtype) is not None ] if not specs: pytest.skip(f"{family}.{mode} solution {solution!r} is not registered") diff --git a/tokenspeed-kernel/test/test_numerics.py b/tokenspeed-kernel/test/test_numerics.py index a5dc76f7c..0660f3a8c 100644 --- a/tokenspeed-kernel/test/test_numerics.py +++ b/tokenspeed-kernel/test/test_numerics.py @@ -107,7 +107,7 @@ def _get_verifiable_specs( dtype_specs = [ s for s in op_specs - if s.format_signature_for_storage_dtype(dtype) is not None + if s.format_signature_for_primary_storage_dtype(dtype) is not None ] has_reference = any(s.solution == "reference" for s in dtype_specs) if not has_reference: diff --git a/tokenspeed-kernel/test/test_runtime_callsite_selection.py b/tokenspeed-kernel/test/test_runtime_callsite_selection.py index 7a2be75d6..4e34560ae 100644 --- a/tokenspeed-kernel/test/test_runtime_callsite_selection.py +++ b/tokenspeed-kernel/test/test_runtime_callsite_selection.py @@ -323,7 +323,7 @@ def _infer_format_signature( ) -> FormatSignature: if family == "moe" and mode == "fused": return _moe_pkg._moe_fused_format_signature(dtype, weight_format or "bf16") - signature = spec.format_signature_for_storage_dtype(dtype) + signature = spec.format_signature_for_primary_storage_dtype(dtype) if signature is None: pytest.skip(f"Kernel {spec.name!r} has no signature for {dtype}") return signature From fe90cf3fb939f90133ac9c08a8d7a7114c429128 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 23 May 2026 21:12:05 +0000 Subject: [PATCH 07/16] Rename dense tensor format helper Rename dense_format to dense_tensor_format so the helper name matches TensorFormat and the tensor_format helper. Update call sites and add examples to format_signature and format_signatures docstrings. Signed-off-by: Lei Zhang --- .../numerics/reference/moe.py | 6 ++++-- .../ops/attention/__init__.py | 4 ++-- .../ops/embedding/__init__.py | 6 +++--- .../tokenspeed_kernel/ops/gemm/__init__.py | 6 ++++-- .../tokenspeed_kernel/ops/moe/__init__.py | 20 +++++++++---------- .../tokenspeed_kernel/ops/moe/flashinfer.py | 19 ++++++++++-------- .../ops/quantization/__init__.py | 12 +++++------ .../python/tokenspeed_kernel/signature.py | 19 ++++++++++++++++-- tokenspeed-kernel/test/test_registry.py | 8 ++++---- 9 files changed, 61 insertions(+), 39 deletions(-) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py index ef1129a6e..fa56fc022 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py @@ -33,7 +33,7 @@ topk_ids_logical_to_physical, ) from tokenspeed_kernel.registry import Priority, register_kernel -from tokenspeed_kernel.signature import dense_format, format_signature +from tokenspeed_kernel.signature import dense_tensor_format, format_signature # isort: split from tokenspeed_kernel.signature import format_signatures @@ -51,7 +51,9 @@ features={"pre_routed"}, solution="reference", signatures=frozenset( - format_signature(x=dense_format(dtype), weight=dense_format(torch.bfloat16)) + format_signature( + x=dense_tensor_format(dtype), weight=dense_tensor_format(torch.bfloat16) + ) for dtype in {torch.float16, torch.bfloat16, torch.float32} ), priority=Priority.REFERENCE, diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py index cf2a0208e..fa1226795 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/attention/__init__.py @@ -32,14 +32,14 @@ from tokenspeed_kernel.ops.attention.flash_attn import mha_decode_scheduler_metadata from tokenspeed_kernel.profiling import ShapeCapture, kernel_scope from tokenspeed_kernel.selection import select_kernel -from tokenspeed_kernel.signature import dense_format, format_signature +from tokenspeed_kernel.signature import dense_tensor_format, format_signature AttentionResult = torch.Tensor | tuple[torch.Tensor, torch.Tensor | None] def _attention_format_signature(**roles: torch.Tensor): return format_signature( - **{role: dense_format(tensor.dtype) for role, tensor in roles.items()} + **{role: dense_tensor_format(tensor.dtype) for role, tensor in roles.items()} ) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py index cb85a3150..fdde20194 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/__init__.py @@ -23,7 +23,7 @@ import torch from tokenspeed_kernel.profiling import ShapeCapture, kernel_scope from tokenspeed_kernel.selection import select_kernel -from tokenspeed_kernel.signature import dense_format, format_signature +from tokenspeed_kernel.signature import dense_tensor_format, format_signature @dataclass @@ -113,8 +113,8 @@ def apply_rope( "has_k_out": output_k_rope is not None, } signature = format_signature( - query=dense_format(query.dtype), - key=dense_format(key.dtype), + query=dense_tensor_format(query.dtype), + key=dense_tensor_format(key.dtype), ) kernel = select_kernel( "embedding", diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py index b5c20d444..cbab8dc26 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/__init__.py @@ -34,7 +34,7 @@ from tokenspeed_kernel.selection import select_kernel from tokenspeed_kernel.signature import ( ScaleFormat, - dense_format, + dense_tensor_format, format_signature, tensor_format, ) @@ -122,7 +122,9 @@ def _gemm_format_signature( a=tensor_format("nvfp4", A.dtype, scale=scale), b=tensor_format("nvfp4", B.dtype, scale=scale), ) - return format_signature(a=dense_format(A.dtype), b=dense_format(B.dtype)) + return format_signature( + a=dense_tensor_format(A.dtype), b=dense_tensor_format(B.dtype) + ) def _online_quantize_mxfp8( diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py index 1644997f3..bcf0a30f8 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py @@ -44,7 +44,7 @@ ) from tokenspeed_kernel.signature import ( ScaleFormat, - dense_format, + dense_tensor_format, format_signature, tensor_format, ) @@ -137,15 +137,15 @@ def adjust(self, spec, platform, traits): ) -def _single_dense_format_signature(role: str, storage_dtype: torch.dtype): - return format_signature(**{role: dense_format(storage_dtype)}) +def _single_dense_tensor_format_signature(role: str, storage_dtype: torch.dtype): + return format_signature(**{role: dense_tensor_format(storage_dtype)}) def _moe_dispatch_format_signature(storage_dtype: torch.dtype, traits: Optional[dict]): comm_strategy = (traits or {}).get("comm_strategy") if storage_dtype == torch.int32 or comm_strategy == "local": - return _single_dense_format_signature("indices", storage_dtype) - return _single_dense_format_signature("x", storage_dtype) + return _single_dense_tensor_format_signature("indices", storage_dtype) + return _single_dense_tensor_format_signature("x", storage_dtype) def _moe_fused_format_signature( @@ -159,7 +159,7 @@ def _moe_fused_format_signature( elif weight_format == WEIGHT_MXFP4: weight = tensor_format("mxfp4", torch.uint8, scale=_MXFP4_SCALE) elif weight_format == WEIGHT_BF16: - weight = dense_format(torch.bfloat16) + weight = dense_tensor_format(torch.bfloat16) else: raise ValueError(f"Unsupported MoE fused weight_format={weight_format!r}") @@ -170,7 +170,7 @@ def _moe_fused_format_signature( elif storage_dtype == torch.float8_e4m3fn: x = tensor_format("fp8", storage_dtype, scale=_FP8_SCALE) else: - x = dense_format(storage_dtype) + x = dense_tensor_format(storage_dtype) return format_signature(x=x, weight=weight) @@ -193,7 +193,7 @@ def moe_route( * ``{"biased": True/False}``: whether correction_bias is applied. * ``{"grouped": True/False}``: whether grouped expert selection is used. """ - signature = _single_dense_format_signature("logits", dtype) + signature = _single_dense_tensor_format_signature("logits", dtype) kernel = select_kernel( "moe", "route", @@ -251,7 +251,7 @@ def moe_experts( * ``{"dispatch_gemm"}``: gather/dispatch tokens then GEMM (uses gather_indx). * ``{"gemm_combine"}``: GEMM then scatter/combine results (uses scatter_indx). """ - signature = _single_dense_format_signature("x", dtype) + signature = _single_dense_tensor_format_signature("x", dtype) kernel = select_kernel( "moe", "experts", @@ -272,7 +272,7 @@ def moe_combine( **kwargs, ): """Combine expert outputs with weighted reduction.""" - signature = _single_dense_format_signature("x", dtype) + signature = _single_dense_tensor_format_signature("x", dtype) kernel = select_kernel( "moe", "combine", diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py index 7dd549a89..af80b9efd 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/flashinfer.py @@ -25,7 +25,7 @@ from tokenspeed_kernel.registry import Priority, error_fn, register_kernel from tokenspeed_kernel.signature import ( ScaleFormat, - dense_format, + dense_tensor_format, format_signature, tensor_format, ) @@ -49,24 +49,27 @@ _BF16_FUSED_FORMAT_SIGNATURES = frozenset( { format_signature( - x=dense_format(torch.bfloat16), weight=dense_format(torch.bfloat16) + x=dense_tensor_format(torch.bfloat16), + weight=dense_tensor_format(torch.bfloat16), ) } ) _CUTLASS_FUSED_FORMAT_SIGNATURES = frozenset( { format_signature( - x=dense_format(torch.bfloat16), weight=dense_format(torch.bfloat16) + x=dense_tensor_format(torch.bfloat16), + weight=dense_tensor_format(torch.bfloat16), ), format_signature( - x=dense_format(torch.float16), weight=dense_format(torch.bfloat16) + x=dense_tensor_format(torch.float16), + weight=dense_tensor_format(torch.bfloat16), ), format_signature( - x=dense_format(torch.bfloat16), + x=dense_tensor_format(torch.bfloat16), weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), ), format_signature( - x=dense_format(torch.float16), + x=dense_tensor_format(torch.float16), weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), ), format_signature( @@ -82,11 +85,11 @@ _FP4_FUSED_FORMAT_SIGNATURES = frozenset( { format_signature( - x=dense_format(torch.bfloat16), + x=dense_tensor_format(torch.bfloat16), weight=tensor_format("mxfp4", torch.uint8, scale=_MXFP4_SCALE), ), format_signature( - x=dense_format(torch.bfloat16), + x=dense_tensor_format(torch.bfloat16), weight=tensor_format("nvfp4", torch.uint8, scale=_NVFP4_SCALE), ), format_signature( diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py index f21da1f49..2b3819fd2 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/quantization/__init__.py @@ -21,7 +21,7 @@ import torch from tokenspeed_kernel.profiling import ShapeCapture, kernel_scope from tokenspeed_kernel.selection import select_kernel -from tokenspeed_kernel.signature import dense_format, format_signature +from tokenspeed_kernel.signature import dense_tensor_format, format_signature __all__ = [ "quantize_fp8", @@ -68,7 +68,7 @@ def quantize_fp8( traits = { "has_scale": scale is not None, } - signature = format_signature(x=dense_format(x.dtype)) + signature = format_signature(x=dense_tensor_format(x.dtype)) kernel = select_kernel( "quantization", "fp8", @@ -148,7 +148,7 @@ def quantize_fp8_with_scale( "granularity": granularity_trait, "scale_encoding": scale_encoding, } - signature = format_signature(x=dense_format(x.dtype)) + signature = format_signature(x=dense_tensor_format(x.dtype)) kernel = select_kernel( "quantization", "fp8_with_scale", @@ -207,7 +207,7 @@ def quantize_mxfp8( """ traits = {} - signature = format_signature(x=dense_format(x.dtype)) + signature = format_signature(x=dense_tensor_format(x.dtype)) kernel = select_kernel( "quantization", "mxfp8", @@ -268,7 +268,7 @@ def quantize_nvfp4( "scale_layout": scale_layout, "has_scale": scale is not None, } - signature = format_signature(x=dense_format(x.dtype)) + signature = format_signature(x=dense_tensor_format(x.dtype)) kernel = select_kernel( "quantization", "nvfp4", @@ -340,7 +340,7 @@ def quantize_mxfp4( "has_global_scale": global_scale is not None, "scale_encoding": "ue8m0", } - signature = format_signature(x=dense_format(x.dtype)) + signature = format_signature(x=dense_tensor_format(x.dtype)) kernel = select_kernel( "quantization", "mxfp4", diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/signature.py b/tokenspeed-kernel/python/tokenspeed_kernel/signature.py index 4372d13a9..ca60b8fdd 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/signature.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/signature.py @@ -30,7 +30,7 @@ "FormatSignature", "ScaleFormat", "TensorFormat", - "dense_format", + "dense_tensor_format", "format_signature", "tensor_format", "format_signatures", @@ -150,7 +150,7 @@ def tensor_format( return TensorFormat(storage_dtype=storage_dtype, format=format, scale=scale) -def dense_format(storage_dtype: torch.dtype) -> TensorFormat: +def dense_tensor_format(storage_dtype: torch.dtype) -> TensorFormat: """Construct a dense, unscaled tensor format for storage_dtype.""" return tensor_format("dense", storage_dtype) @@ -161,6 +161,13 @@ def format_signature(**roles: TensorFormat) -> FormatSignature: Keyword names are logical tensor roles for the operator, for example a/b for GEMM or q/k_cache/v_cache for attention. Values are the exact formats required for those roles. + + Examples: + >>> import torch + >>> format_signature( + ... a=dense_tensor_format(torch.bfloat16), + ... b=tensor_format("mxfp4", torch.uint8), + ... ) """ return FormatSignature(tuple(roles.items())) @@ -185,6 +192,14 @@ def format_signatures( Use ``format="dense"`` for dense same-format signatures. Use ``format_signature`` directly for mixed-role combinations such as dense activations with quantized weights. + + Examples: + >>> import torch + >>> format_signatures( + ... ("q", "k_cache", "v_cache"), + ... "dense", + ... {torch.float16, torch.bfloat16}, + ... ) """ normalized_roles = (roles,) if isinstance(roles, str) else tuple(roles) return frozenset( diff --git a/tokenspeed-kernel/test/test_registry.py b/tokenspeed-kernel/test/test_registry.py index bf1f23b3f..c9aecd8e7 100644 --- a/tokenspeed-kernel/test/test_registry.py +++ b/tokenspeed-kernel/test/test_registry.py @@ -31,7 +31,7 @@ ) from tokenspeed_kernel.signature import ( ScaleFormat, - dense_format, + dense_tensor_format, format_signature, format_signatures, tensor_format, @@ -67,12 +67,12 @@ def test_format_signature_bundles_scale_metadata(self): block_shape=(32,), ) mixed = format_signature( - a=dense_format(torch.bfloat16), + a=dense_tensor_format(torch.bfloat16), b=tensor_format("mxfp4", torch.uint8, scale=scale), ) dense = format_signature( - a=dense_format(torch.bfloat16), - b=dense_format(torch.uint8), + a=dense_tensor_format(torch.bfloat16), + b=dense_tensor_format(torch.uint8), ) assert mixed != dense From 986c90b8a9e47f6fbf8f02c9469b0cb9370d810b Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 23 May 2026 21:20:14 +0000 Subject: [PATCH 08/16] Clarify format signature expansion docs Document that each FormatSignature is one concrete operand-format combination with one TensorFormat per role. Expand the helper examples to show the concrete signatures produced by format_signature and format_signatures. Signed-off-by: Lei Zhang --- .../python/tokenspeed_kernel/signature.py | 46 +++++++++++++++++-- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/signature.py b/tokenspeed-kernel/python/tokenspeed_kernel/signature.py index ca60b8fdd..b26c9b815 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/signature.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/signature.py @@ -83,7 +83,13 @@ def __str__(self) -> str: @dataclass(frozen=True) class FormatSignature: - """Role-indexed tensor formats for one supported operand-format combination.""" + """Role-indexed tensor formats for one concrete operand-format combination. + + Each role appears at most once and maps to exactly one ``TensorFormat``. + A kernel that supports alternatives for a role represents them as multiple + ``FormatSignature`` values in ``KernelSpec.format_signatures`` rather than + as multiple formats inside one signature. + """ roles: tuple[tuple[str, TensorFormat], ...] @@ -156,11 +162,13 @@ def dense_tensor_format(storage_dtype: torch.dtype) -> TensorFormat: def format_signature(**roles: TensorFormat) -> FormatSignature: - """Construct a role-indexed format signature. + """Construct one concrete role-indexed format signature. Keyword names are logical tensor roles for the operator, for example a/b for GEMM or q/k_cache/v_cache for attention. - Values are the exact formats required for those roles. + Values are the exact formats required for those roles. Each role gets + exactly one ``TensorFormat``; represent alternatives by constructing + multiple ``FormatSignature`` values. Examples: >>> import torch @@ -168,6 +176,15 @@ def format_signature(**roles: TensorFormat) -> FormatSignature: ... a=dense_tensor_format(torch.bfloat16), ... b=tensor_format("mxfp4", torch.uint8), ... ) + + This creates one signature equivalent to: + + >>> FormatSignature( + ... ( + ... ("a", dense_tensor_format(torch.bfloat16)), + ... ("b", tensor_format("mxfp4", torch.uint8)), + ... ) + ... ) """ return FormatSignature(tuple(roles.items())) @@ -179,7 +196,7 @@ def format_signatures( *, scale: ScaleFormat | None = None, ) -> frozenset[FormatSignature]: - """Construct same-format signatures for each dtype. + """Construct same-format signatures for each storage dtype. Args: roles: Logical tensor roles. Pass a string for one role or an iterable @@ -191,7 +208,8 @@ def format_signatures( Use ``format="dense"`` for dense same-format signatures. Use ``format_signature`` directly for mixed-role combinations such as dense - activations with quantized weights. + activations with quantized weights. This helper expands dtype alternatives + into separate signatures; it does not put multiple formats on one role. Examples: >>> import torch @@ -200,6 +218,24 @@ def format_signatures( ... "dense", ... {torch.float16, torch.bfloat16}, ... ) + + This expands to a ``frozenset`` containing one signature per dtype, + equivalent to: + + >>> frozenset( + ... { + ... format_signature( + ... q=dense_tensor_format(torch.float16), + ... k_cache=dense_tensor_format(torch.float16), + ... v_cache=dense_tensor_format(torch.float16), + ... ), + ... format_signature( + ... q=dense_tensor_format(torch.bfloat16), + ... k_cache=dense_tensor_format(torch.bfloat16), + ... v_cache=dense_tensor_format(torch.bfloat16), + ... ), + ... } + ... ) """ normalized_roles = (roles,) if isinstance(roles, str) else tuple(roles) return frozenset( From 4faadecbdf7364875309836193ed74dda3f704d2 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 23 May 2026 21:27:19 +0000 Subject: [PATCH 09/16] Format --- .../python/tokenspeed_kernel/numerics/reference/moe.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py index fa56fc022..fd2f3bbca 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/reference/moe.py @@ -33,10 +33,11 @@ topk_ids_logical_to_physical, ) from tokenspeed_kernel.registry import Priority, register_kernel -from tokenspeed_kernel.signature import dense_tensor_format, format_signature - -# isort: split -from tokenspeed_kernel.signature import format_signatures +from tokenspeed_kernel.signature import ( + dense_tensor_format, + format_signature, + format_signatures, +) from tokenspeed_kernel.torch_compile import get_compiler_backend # --------------------------------------------------------------------------- From 4dc5350227257996ed6499fe34e8fb6c3f8abcd5 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sun, 24 May 2026 15:32:31 +0000 Subject: [PATCH 10/16] Refine some comments --- .../python/tokenspeed_kernel/signature.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/signature.py b/tokenspeed-kernel/python/tokenspeed_kernel/signature.py index b26c9b815..b6f7331da 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/signature.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/signature.py @@ -39,12 +39,11 @@ @dataclass(frozen=True) class ScaleFormat: - """Representation metadata for a tensor scale sidecar. + """Metadata representation for one tensor scale sidecar. Args: storage_dtype: Physical dtype used by the scale tensor. - granularity: Scale granularity, such as "tensor", "channel", - or "block". + granularity: Scale granularity, such as "tensor", "channel", "block". block_shape: Logical block shape covered by each scale value when granularity is block-based. """ @@ -62,12 +61,12 @@ def __str__(self) -> str: @dataclass(frozen=True) class TensorFormat: - """Storage representation for one logical tensor role. + """Metadata representation for one logical tensor. Args: storage_dtype: Physical dtype used by the main tensor payload. - format: Logical representation format, such as "dense", - "fp8", "mxfp8", "mxfp4", or "nvfp4". + format: Logical representation format, such as "dense", "fp8", + "mxfp8", "mxfp4", or "nvfp4". scale: Optional scale sidecar metadata bundled with this tensor role. """ @@ -83,7 +82,7 @@ def __str__(self) -> str: @dataclass(frozen=True) class FormatSignature: - """Role-indexed tensor formats for one concrete operand-format combination. + """One concrete set of role-indexed tensor formats for all tensor operands. Each role appears at most once and maps to exactly one ``TensorFormat``. A kernel that supports alternatives for a role represents them as multiple @@ -206,10 +205,9 @@ def format_signatures( dtype. scale: Optional scale sidecar metadata assigned to every role. - Use ``format="dense"`` for dense same-format signatures. Use - ``format_signature`` directly for mixed-role combinations such as dense - activations with quantized weights. This helper expands dtype alternatives - into separate signatures; it does not put multiple formats on one role. + This helper expands dtype alternatives into separate signatures; it does + not put multiple formats on one role. Use ``format_signature`` directly for + mixed-role combinations such as dense activations with quantized weights. Examples: >>> import torch From 1a6871914ad49e62a9397c4ffeb7f0e4589c4f00 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sun, 24 May 2026 15:46:57 +0000 Subject: [PATCH 11/16] Document warmup selection smoke path Clarify that warmup_selection without explicit ops picks one deterministic representative signature per registered operator and is not comprehensive. Direct model init code to pass exact format signatures and traits for hot-path warmup. Signed-off-by: Lei Zhang --- .../python/tokenspeed_kernel/selection.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/selection.py b/tokenspeed-kernel/python/tokenspeed_kernel/selection.py index 37d7614d7..615694593 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/selection.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/selection.py @@ -849,10 +849,17 @@ def explain_selection( def warmup_selection( ops: list[tuple[str, str, FormatSignature, dict | None]] | None = None, ) -> None: - """Pre-resolve kernel selection for the given ops (or all registered ops). + """Pre-resolve kernel selection for explicit op signatures. - Call this at model init to front-load both heuristic and autotune costs. - After warmup, every select_kernel() call on the hot path is a cache hit. + Pass ``ops`` from model initialization to front-load heuristic and autotune + costs for the actual hot-path call sites. Each entry must include the exact + ``FormatSignature`` and trait values used by runtime selection. + + When ``ops`` is ``None``, this performs only a deterministic smoke warmup: + one representative signature for each registered operator, with no traits. + That path verifies the registry and fills a small cache sample, but it does + not warm all supported signatures, trait combinations, or feature-specific + call paths. """ if ops is None: @@ -862,6 +869,9 @@ def warmup_selection( specs = registry.get_for_operator(family, mode) if not specs or not specs[0].format_signatures: continue + # No-arg warmup is intentionally a smoke path. Pick a stable + # representative from the highest-priority spec; callers that need + # comprehensive warmup should pass explicit op signatures. format_signature = sorted(specs[0].format_signatures, key=str)[0] ops.append((family, mode, format_signature, None)) From ba9a9f4e29623a96e918f5a43b50c0fe6123a123 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sun, 24 May 2026 16:08:15 +0000 Subject: [PATCH 12/16] Require explicit MXFP8 block shape Remove the implicit [128, 128] fallback from GEMM numerics input generation so mxfp8 block-scaled signatures must provide block_shape metadata. Add regression coverage for missing block_shape and rename the Triton GEMM MXFP8 scale constant to emphasize that it describes block scale metadata. Signed-off-by: Lei Zhang --- .../python/tokenspeed_kernel/numerics/gemm.py | 4 +++- .../tokenspeed_kernel/ops/gemm/triton.py | 4 ++-- tokenspeed-kernel/test/test_numerics.py | 21 +++++++++++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py index 51d12663d..07321a6ea 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py @@ -150,7 +150,9 @@ def _scale_for_format( return None if scale.granularity == "block" and tensor_format.format == "mxfp8": - block_n, block_k = block_size or [128, 128] + if block_size is None: + raise ValueError("mxfp8 block scale format requires block_shape") + block_n, block_k = block_size k_tiles = math.ceil(K / block_k) if role == "a": return self._generate_scales((M, k_tiles), scale.storage_dtype) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py index 06aff7789..c96e6e61a 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/gemm/triton.py @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) _fp8_dtype = Platform.get().fp8e4m3fn.dtype -_MXFP8_SCALE = ScaleFormat( +_MXFP8_BLOCK_SCALE = ScaleFormat( storage_dtype=torch.float32, granularity="block", block_shape=(128, 128), @@ -50,7 +50,7 @@ granularity="channel", ) _MXFP8_FORMAT_SIGNATURES = format_signatures( - ("a", "b"), "mxfp8", {_fp8_dtype}, scale=_MXFP8_SCALE + ("a", "b"), "mxfp8", {_fp8_dtype}, scale=_MXFP8_BLOCK_SCALE ) _FP8_SCALED_FORMAT_SIGNATURES = format_signatures( ("a", "b"), "fp8", {_fp8_dtype}, scale=_FP8_TENSOR_SCALE diff --git a/tokenspeed-kernel/test/test_numerics.py b/tokenspeed-kernel/test/test_numerics.py index 0660f3a8c..687498e46 100644 --- a/tokenspeed-kernel/test/test_numerics.py +++ b/tokenspeed-kernel/test/test_numerics.py @@ -90,6 +90,27 @@ def test_gemm_input_generator_uses_signature_scale_metadata() -> None: assert inputs["block_size"] == [128, 128] +def test_gemm_input_generator_requires_mxfp8_block_shape() -> None: + scale = ScaleFormat( + storage_dtype=torch.float32, + granularity="block", + ) + signature = next( + iter(format_signatures(("a", "b"), "mxfp8", {_fp8_dtype}, scale=scale)) + ) + generator = get_input_generator( + "gemm", + "mm", + dtype=_fp8_dtype, + traits={}, + format_signature=signature, + device="cpu", + ) + + with pytest.raises(ValueError, match="requires block_shape"): + generator.generate(M=4, N=256, K=128) + + class TestNumericsVerification: def _get_verifiable_specs( dtype: torch.dtype, family: str | None = None From 97c73b51eca29251d8c81ea7a44633377c5280a1 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sun, 24 May 2026 16:15:32 +0000 Subject: [PATCH 13/16] Clarify MoE fused weight formats Document what each moe_fused weight_format value means, including main tensor storage, scale storage, and how dtype disambiguates uint8 activations. Signed-off-by: Lei Zhang --- .../python/tokenspeed_kernel/ops/moe/__init__.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py index bcf0a30f8..dff8763a5 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe/__init__.py @@ -301,8 +301,19 @@ def moe_fused( * ``{"pre_routed"}``: routing already done by caller (cutlass, reference). Args: - weight_format: Weight tensor encoding. Supported values are ``"bf16"``, - ``"fp8"``, ``"mxfp4"``, and ``"nvfp4"``. + weight_format: Weight tensor encoding used for the expert weights. + Supported values are: + + * ``"bf16"``: dense bfloat16 weights with no scale tensor. + * ``"fp8"``: FP8 E4M3 weights with float32 block scales. + * ``"mxfp4"``: packed MXFP4 weights stored as uint8 with uint8 + block scales over 32-value blocks. + * ``"nvfp4"``: packed NVFP4 weights stored as uint8 with float32 + block scales. + + The activation/input format is selected by ``dtype``. When + ``dtype=torch.uint8``, ``weight_format`` disambiguates whether the + input is interpreted as MXFP4 or NVFP4. """ signature = _moe_fused_format_signature(dtype, weight_format) selection_traits = dict(traits or {}) From 8d37eeb2fa60403e2d7b2cb5f0701534a9a667e7 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sun, 24 May 2026 16:28:00 +0000 Subject: [PATCH 14/16] Add format signature selection coverage Add focused selection tests for exact mixed-operand signatures and kernels with multiple registered format signatures. Also keep optional backend placeholder helpers out of registry.__all__ while preserving explicit imports. Signed-off-by: Lei Zhang --- .../python/tokenspeed_kernel/registry.py | 2 - tokenspeed-kernel/test/test_selection.py | 98 ++++++++++++++++++- 2 files changed, 97 insertions(+), 3 deletions(-) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/registry.py b/tokenspeed-kernel/python/tokenspeed_kernel/registry.py index 9df5f5ccd..905187f0b 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/registry.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/registry.py @@ -37,14 +37,12 @@ logger = logging.getLogger(__name__) __all__ = [ - "ErrorClass", "KernelSpec", "KernelRegistry", "Priority", "load_builtin_kernels", "register_kernel", "describe_kernel", - "error_fn", ] diff --git a/tokenspeed-kernel/test/test_selection.py b/tokenspeed-kernel/test/test_selection.py index b92a8858c..c5b92171d 100644 --- a/tokenspeed-kernel/test/test_selection.py +++ b/tokenspeed-kernel/test/test_selection.py @@ -55,7 +55,13 @@ spec_matches_traits, warmup_selection, ) -from tokenspeed_kernel.signature import format_signatures +from tokenspeed_kernel.signature import ( + ScaleFormat, + dense_tensor_format, + format_signature, + format_signatures, + tensor_format, +) from utils import register_all_samples pytestmark = pytest.mark.usefixtures("fresh_registry") @@ -557,6 +563,96 @@ def test_no_kernel_after_trait_filter(self, h100_platform): traits={"head_dim": 128}, ) + def test_selects_exact_mixed_operand_signature(self, h100_platform): + reg = KernelRegistry.get() + scale = ScaleFormat( + storage_dtype=torch.uint8, + granularity="block", + block_shape=(32,), + ) + mixed_signature = format_signature( + a=dense_tensor_format(torch.bfloat16), + b=tensor_format("mxfp4", torch.uint8, scale=scale), + ) + dense_uint8_signature = format_signature( + a=dense_tensor_format(torch.bfloat16), + b=dense_tensor_format(torch.uint8), + ) + + reg.register( + KernelSpec( + name="dense_uint8", + family="gemm", + mode="mm", + solution="test", + format_signatures=frozenset({dense_uint8_signature}), + priority=19, + ), + lambda: "dense_uint8", + ) + reg.register( + KernelSpec( + name="mixed_mxfp4", + family="gemm", + mode="mm", + solution="test", + format_signatures=frozenset({mixed_signature}), + priority=5, + ), + lambda: "mixed_mxfp4", + ) + + impl = select_kernel( + "gemm", + "mm", + mixed_signature, + platform=h100_platform, + ) + + assert impl() == "mixed_mxfp4" + + def test_selects_each_registered_format_signature(self, h100_platform): + reg = KernelRegistry.get() + fp8_scale = ScaleFormat( + storage_dtype=torch.float32, + granularity="tensor", + ) + fp8_signature = format_signature( + a=tensor_format("fp8", torch.float8_e4m3fn, scale=fp8_scale), + b=tensor_format("fp8", torch.float8_e4m3fn, scale=fp8_scale), + ) + + reg.register( + KernelSpec( + name="dense_multi", + family="gemm", + mode="mm", + solution="test", + format_signatures=frozenset({GEMM_BF16, GEMM_FP16}), + priority=10, + ), + lambda: "dense_multi", + ) + reg.register( + KernelSpec( + name="fp8_scaled", + family="gemm", + mode="mm", + solution="test", + format_signatures=frozenset({fp8_signature}), + priority=10, + ), + lambda: "fp8_scaled", + ) + + bf16_impl = select_kernel("gemm", "mm", GEMM_BF16, platform=h100_platform) + fp16_impl = select_kernel("gemm", "mm", GEMM_FP16, platform=h100_platform) + fp8_impl = select_kernel("gemm", "mm", fp8_signature, platform=h100_platform) + + assert bf16_impl() == "dense_multi" + assert fp16_impl() == "dense_multi" + assert fp8_impl() == "fp8_scaled" + def test_override_by_name(self, sample_specs, h100_platform): reg = KernelRegistry.get() register_all_samples(reg, sample_specs) From e9ad17dad4ee0d469ba11a52446624649d86a380 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 25 May 2026 15:55:21 +0000 Subject: [PATCH 15/16] Fix embedding CUDA signature import Signed-off-by: Lei Zhang --- tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/cuda.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/cuda.py index 0f3ae6daa..cebc7ed8a 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/cuda.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/embedding/cuda.py @@ -22,6 +22,7 @@ import torch from tokenspeed_kernel.platform import CapabilityRequirement, current_platform from tokenspeed_kernel.registry import Priority, register_kernel +from tokenspeed_kernel.signature import format_signatures platform = current_platform() From 11cce5988dedc917d79c2f9d0a308ce0470c5d7f Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 25 May 2026 16:17:13 +0000 Subject: [PATCH 16/16] Fix FP8 numerics reference signature selection Try all primary-dtype format signatures during numerics verification and choose the one with a compatible reference kernel. This avoids selecting the FP8 channel-scale signature when only the tensor-scale reference is registered. Signed-off-by: Lei Zhang --- .../tokenspeed_kernel/numerics/verify.py | 67 +++++++++++++------ tokenspeed-kernel/test/test_numerics.py | 42 +++++++++++- 2 files changed, 87 insertions(+), 22 deletions(-) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py index 314835462..f69eca43c 100644 --- a/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py +++ b/tokenspeed-kernel/python/tokenspeed_kernel/numerics/verify.py @@ -38,11 +38,12 @@ ToleranceOverride, get_family_tolerance, ) -from tokenspeed_kernel.registry import KernelRegistry, load_builtin_kernels +from tokenspeed_kernel.registry import KernelRegistry, KernelSpec, load_builtin_kernels from tokenspeed_kernel.selection import ( ref_compatible_with_spec, spec_matches_shape_traits, ) +from tokenspeed_kernel.signature import FormatSignature # isort: split import tokenspeed_kernel.numerics.gemm # noqa: F401 @@ -65,6 +66,49 @@ def _as_tolerance_fn(override: ToleranceOverride | None) -> ToleranceFn | None: ) +def _format_signatures_for_primary_storage_dtype( + spec: KernelSpec, + dtype: torch.dtype, +) -> tuple[FormatSignature, ...]: + return tuple( + signature + for signature in sorted(spec.format_signatures, key=str) + if signature.primary_storage_dtype() == dtype + ) + + +def _compatible_reference_for_signature( + registry: KernelRegistry, + spec: KernelSpec, + signature: FormatSignature, +) -> KernelSpec | None: + ref_specs = registry.get_for_operator( + spec.family, + spec.mode, + format_signature=signature, + solution="reference", + ) + for ref in ref_specs: + if ref.name == spec.name: + continue + if ref_compatible_with_spec(ref, spec): + return ref + return None + + +def _verification_signature_and_reference( + registry: KernelRegistry, + spec: KernelSpec, + dtype: torch.dtype, +) -> tuple[FormatSignature | None, KernelSpec | None]: + signatures = _format_signatures_for_primary_storage_dtype(spec, dtype) + for signature in signatures: + ref_spec = _compatible_reference_for_signature(registry, spec, signature) + if ref_spec is not None: + return signature, ref_spec + return (signatures[0], None) if signatures else (None, None) + + def verify_kernel( kernel_name: str, *, @@ -86,30 +130,11 @@ def verify_kernel( if kernel is None: raise ValueError(f"Kernel implementation for {kernel_name!r} is missing") - signature = spec.format_signature_for_primary_storage_dtype(dtype) + signature, ref_spec = _verification_signature_and_reference(registry, spec, dtype) if signature is None: raise ValueError( f"Kernel {kernel_name!r} does not support primary storage dtype={dtype}" ) - - ref_specs = registry.get_for_operator( - spec.family, - spec.mode, - format_signature=signature, - solution="reference", - ) - if not ref_specs: - raise ValueError( - f"No reference kernel found for {spec.family}.{spec.mode} and dtype={dtype}" - ) - - ref_spec = None - for ref in ref_specs: - if ref.name == spec.name: - continue - if ref_compatible_with_spec(ref, spec): - ref_spec = ref - break if ref_spec is None: raise ValueError( "No compatible reference kernel found for " diff --git a/tokenspeed-kernel/test/test_numerics.py b/tokenspeed-kernel/test/test_numerics.py index 687498e46..381184124 100644 --- a/tokenspeed-kernel/test/test_numerics.py +++ b/tokenspeed-kernel/test/test_numerics.py @@ -25,7 +25,10 @@ from tokenspeed_kernel.numerics.comparison import compare_outputs, format_comparison from tokenspeed_kernel.numerics.inputs import get_input_generator from tokenspeed_kernel.numerics.tolerance import Tolerance -from tokenspeed_kernel.numerics.verify import verify_kernel +from tokenspeed_kernel.numerics.verify import ( + _verification_signature_and_reference, + verify_kernel, +) from tokenspeed_kernel.platform import Platform from tokenspeed_kernel.registry import KernelRegistry, KernelSpec, load_builtin_kernels from tokenspeed_kernel.signature import ScaleFormat, format_signatures @@ -111,6 +114,43 @@ def test_gemm_input_generator_requires_mxfp8_block_shape() -> None: generator.generate(M=4, N=256, K=128) +def test_verification_uses_signature_with_compatible_reference(fresh_registry) -> None: + tensor_scale = ScaleFormat(storage_dtype=torch.float32, granularity="tensor") + channel_scale = ScaleFormat(storage_dtype=torch.float32, granularity="channel") + tensor_signature = next( + iter(format_signatures(("a", "b"), "fp8", {_fp8_dtype}, scale=tensor_scale)) + ) + channel_signature = next( + iter(format_signatures(("a", "b"), "fp8", {_fp8_dtype}, scale=channel_scale)) + ) + ref_spec = KernelSpec( + name="test_tensor_scale_reference", + family="gemm", + mode="mm", + solution="reference", + format_signatures=frozenset({tensor_signature}), + traits={"b_layout": frozenset({"KN"})}, + ) + test_spec = KernelSpec( + name="test_fp8_scaled", + family="gemm", + mode="mm", + solution="triton", + format_signatures=frozenset({channel_signature, tensor_signature}), + traits={"b_layout": frozenset({"KN"})}, + ) + registry = KernelRegistry.get() + registry.register(ref_spec, lambda **_kwargs: None) + registry.register(test_spec, lambda **_kwargs: None) + + signature, reference = _verification_signature_and_reference( + registry, test_spec, _fp8_dtype + ) + + assert signature == tensor_signature + assert reference is ref_spec + + class TestNumericsVerification: def _get_verifiable_specs( dtype: torch.dtype, family: str | None = None