Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def forward(
activation_type=ActivationType.Swiglu,
dtype=x.dtype,
features={"pre_routed"},
weight_format="fp8",
traits={
"weight_dtype": "fp8",
"tp": self.spec.tp_size > 1,
"ep": self.spec.ep_size > 1,
"cuda_graph": False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _call_kernel(self, router_logits, x_quant, x_scale, layer, top_k, output):
output=output,
dtype=torch.bfloat16,
features={"self_routing"},
traits={"weight_dtype": "mxfp4"},
weight_format="mxfp4",
expected_kernel_name="flashinfer_trtllm_fp4_fused_moe",
)[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@ def forward(
fused_activation=act,
dtype=hidden_states.dtype,
features={"ragged_metadata", "dispatch_gemm"},
traits={"weight_dtype": "mxfp4"},
expected_kernel_name="triton_kernels_dispatch_gemm",
)

Expand All @@ -328,7 +327,6 @@ def forward(
n_expts_act=top_k,
dtype=hidden_states.dtype,
features={"ragged_metadata", "gemm_combine"},
traits={"weight_dtype": "mxfp4"},
expected_kernel_name="triton_kernels_gemm_combine",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def _call_kernel(
capacity=capacity,
dtype=x_fp4.dtype,
features={"pre_routed"},
weight_format="nvfp4",
traits={
"weight_dtype": "nvfp4",
"tp": False,
"ep": True,
"cuda_graph": True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def forward(
activation_type=ActivationType.Swiglu,
dtype=x.dtype,
features={"pre_routed"},
weight_format="nvfp4",
traits={
"weight_dtype": "nvfp4",
"tp": True,
"ep": True,
"cuda_graph": False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def forward(
tune_max_num_tokens=next_power_of_2(num_tokens),
dtype=x.dtype,
features={"self_routing"},
traits={"weight_dtype": "nvfp4"},
weight_format="nvfp4",
expected_kernel_name="flashinfer_trtllm_fp4_fused_moe",
)
if do_finalize:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def _call_cutlass_kernel(self, x, layer, topk_output):
activation_type=ActivationType.Swiglu,
dtype=x.dtype,
features={"pre_routed"},
weight_format="bf16",
traits={
"weight_dtype": "bf16",
"tp": True,
"ep": True,
"cuda_graph": False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _call_trtllm_kernel(self, router_logits, x, layer, top_k, do_finalize):
tune_max_num_tokens=next_power_of_2(x.shape[0]),
dtype=x.dtype,
features={"self_routing"},
traits={"weight_dtype": "bf16"},
weight_format="bf16",
expected_kernel_name="flashinfer_trtllm_bf16_fused_moe",
)
if do_finalize:
Expand Down
7 changes: 4 additions & 3 deletions tokenspeed-kernel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ choices (still evolving; subject to change):
public API (mha_prefill, mm, moe_fused, ...)
┌───────────┴───────────┐
│ select_kernel │ (family, mode, dtype, traits, ...)
│ select_kernel │ (family, mode, format_signature, traits, ...)
└───────────┬───────────┘
│ queries
┌──────────┴──────────┐
Expand All @@ -57,8 +57,8 @@ choices (still evolving; subject to change):
```

- **Registration** — backends register with `@register_kernel(family, mode, ...)`,
declaring supported dtypes, arch capability requirements, traits (head dim,
GQA factor, ...), and a priority band.
declaring supported `format_signatures`, arch capability requirements,
non-format traits (head dim, GQA factor, ...), and a priority band.
- **Auto-selection** — `select_kernel` filters by capability and traits,
ranks the survivors with an optional per-family `SelectionOracle` and
priority, and returns a callable. Selection accepts an objective (latency,
Expand All @@ -71,6 +71,7 @@ choices (still evolving; subject to change):
tokenspeed_kernel/
__init__.py # Public API re-exports
platform.py # PlatformInfo, capability detection
signature.py # TensorFormat, ScaleFormat, FormatSignature
registry.py # KernelRegistry, register_kernel, Priority bands
selection.py # select_kernel, oracles, overrides
profiling.py # ShapeCapture, kernel_scope, Proton bootstrap
Expand Down
32 changes: 23 additions & 9 deletions tokenspeed-kernel/python/tokenspeed_kernel/benchmark/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,16 @@ def _benchmark_one_shape(
if not spec_matches_shape_traits(spec, shape):
return None

signature = spec.format_signature_for_primary_storage_dtype(dtype)
if signature is None:
return None

generator = get_input_generator(
spec.family,
spec.mode,
dtype=dtype,
traits=spec.traits,
format_signature=signature,
device="cuda",
seed=self.config.seed,
)
Expand Down Expand Up @@ -224,10 +229,14 @@ def _verify_one_shape(
return None, None, None

registry = KernelRegistry.get()
signature = spec.format_signature_for_primary_storage_dtype(dtype)
if signature is None:
return None, None, None

ref_specs = registry.get_for_operator(
spec.family,
spec.mode,
dtype=dtype,
format_signature=signature,
solution="reference",
)
if not ref_specs:
Expand Down Expand Up @@ -285,8 +294,10 @@ def _benchmark_kernel_impl(
if spec is None:
raise ValueError(f"Kernel {kernel_name!r} is not registered")

if dtype not in spec.dtypes:
raise ValueError(f"Kernel {kernel_name!r} does not support dtype={dtype}")
if spec.format_signature_for_primary_storage_dtype(dtype) is None:
raise ValueError(
f"Kernel {kernel_name!r} does not support primary storage dtype={dtype}"
)

platform = current_platform()
if not spec.capability.satisfied_by(platform):
Expand Down Expand Up @@ -337,12 +348,15 @@ def _benchmark_op_impl(
"""Benchmark all implementations of an op."""
registry = KernelRegistry.get()
platform = current_platform()
specs = registry.get_for_operator(
op_family,
op_mode,
platform=platform,
dtype=dtype,
)
specs = [
spec
for spec in registry.get_for_operator(
op_family,
op_mode,
platform=platform,
)
if spec.format_signature_for_primary_storage_dtype(dtype) is not None
]

results: list[BenchmarkResult] = []
for spec in sorted(specs, key=lambda item: (item.solution, item.name)):
Expand Down
8 changes: 6 additions & 2 deletions tokenspeed-kernel/python/tokenspeed_kernel/numerics/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ def _iter_candidate_specs(
specs = [s for s in specs if s.family == family and s.mode == mode]

if dtype_filter is not None:
specs = [s for s in specs if dtype_filter in s.dtypes]
specs = [
s
for s in specs
if s.format_signature_for_primary_storage_dtype(dtype_filter) is not None
]

specs.sort(key=lambda s: (s.family, s.mode, s.name))
return specs
Expand All @@ -92,7 +96,7 @@ def _iter_dtypes(
) -> Iterable[torch.dtype]:
if dtype_filter is not None:
return (dtype_filter,)
return sorted(spec.dtypes, key=str)
return sorted(spec.primary_storage_dtypes(), key=str)


def main(argv: list[str] | None = None) -> int:
Expand Down
97 changes: 74 additions & 23 deletions tokenspeed-kernel/python/tokenspeed_kernel/numerics/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
set_standard_shapes,
)
from tokenspeed_kernel.numerics.tolerance import Tolerance, set_family_tolerance
from tokenspeed_kernel.signature import TensorFormat

# ---------------------------------------------------------------------------
# Tolerance
Expand Down Expand Up @@ -119,45 +120,95 @@ def _generate_scales(self, shape: tuple[int, ...], dtype) -> torch.Tensor:
)
return scales.to(dtype)

def _format(self, role: str) -> TensorFormat | None:
if self.format_signature is None:
return None
return self.format_signature.format_for(role)

def _block_size(
self,
*formats: TensorFormat | None,
) -> list[int] | None:
for tensor_format in formats:
scale = tensor_format.scale if tensor_format is not None else None
if scale is not None and scale.block_shape is not None:
return list(scale.block_shape)
return None

def _scale_for_format(
self,
tensor_format: TensorFormat | None,
role: str,
*,
M: int,
N: int,
K: int,
block_size: list[int] | None,
) -> torch.Tensor | None:
scale = tensor_format.scale if tensor_format is not None else None
if scale is None:
return None

if scale.granularity == "block" and tensor_format.format == "mxfp8":
if block_size is None:
raise ValueError("mxfp8 block scale format requires block_shape")
block_n, block_k = block_size
k_tiles = math.ceil(K / block_k)
if role == "a":
return self._generate_scales((M, k_tiles), scale.storage_dtype)
if role == "b":
n_tiles = math.ceil(N / block_n)
return self._generate_scales((n_tiles, k_tiles), scale.storage_dtype)

if scale.granularity == "channel":
return self._generate_scales(
(M,) if role == "a" else (N,),
scale.storage_dtype,
)

return self._generate_scales((1,), scale.storage_dtype)

def generate(
self,
M: int,
N: int,
K: int,
) -> dict[str, Any]:
quant = self.traits.get("quant")
scale_type = self.traits.get("scale_type")
a_layout = self.traits.get("a_layout")
b_layout = self.traits.get("b_layout")
a_format = self._format("a")
b_format = self._format("b")
a_dtype = a_format.storage_dtype if a_format is not None else self.dtype
b_dtype = b_format.storage_dtype if b_format is not None else self.dtype

A = (
self._generate_value((K, M), self.dtype)
self._generate_value((K, M), a_dtype)
if a_layout == {"KM"}
else self._generate_value((M, K), self.dtype)
else self._generate_value((M, K), a_dtype)
)
B = (
self._generate_value((K, N), self.dtype)
self._generate_value((K, N), b_dtype)
if b_layout == {"KN"}
else self._generate_value((N, K), self.dtype)
else self._generate_value((N, K), b_dtype)
)

A_scales = None
B_scales = None
block_size = None

if quant == {"mxfp8"}:
block_size = [128, 128]
k_tiles = math.ceil(K / block_size[0])
n_tiles = math.ceil(N / block_size[1])
A_scales = self._generate_scales((M, k_tiles), torch.float32)
B_scales = self._generate_scales((n_tiles, k_tiles), torch.float32)
else:
if scale_type == {"per_channel"}:
A_scales = self._generate_scales((M,), torch.float32)
B_scales = self._generate_scales((N,), torch.float32)
else:
A_scales = self._generate_scales((1,), torch.float32)
B_scales = self._generate_scales((1,), torch.float32)
block_size = self._block_size(a_format, b_format)
A_scales = self._scale_for_format(
a_format,
"a",
M=M,
N=N,
K=K,
block_size=block_size,
)
B_scales = self._scale_for_format(
b_format,
"b",
M=M,
N=N,
K=K,
block_size=block_size,
)

out_dtype = torch.bfloat16
alpha = None
Expand Down
15 changes: 13 additions & 2 deletions tokenspeed-kernel/python/tokenspeed_kernel/numerics/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import Any, Callable

import torch
from tokenspeed_kernel.registry import KernelSpec
from tokenspeed_kernel.signature import FormatSignature

__all__ = [
"InputGenerator",
Expand Down Expand Up @@ -52,13 +52,15 @@ def __init__(
dtype: torch.dtype,
traits: dict,
*,
format_signature: FormatSignature | None = None,
device: str | None = None,
seed: int = 42,
) -> None:
self.op_family = op_family
self.op_mode = op_mode
self.dtype = dtype
self.traits = traits
self.format_signature = format_signature
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

rng_device = "cuda" if self.device.startswith("cuda") else "cpu"
Expand Down Expand Up @@ -98,6 +100,7 @@ def get_input_generator(
dtype: torch.dtype,
traits: dict,
*,
format_signature: FormatSignature | None = None,
device: str | None = None,
seed: int = 42,
) -> InputGenerator:
Expand All @@ -107,7 +110,15 @@ def get_input_generator(
raise KeyError(
f"No input generator registered for {op_family}.{op_mode}. Known: {known}"
)
return factory(op_family, op_mode, dtype, traits, device=device, seed=seed)
return factory(
op_family,
op_mode,
dtype,
traits,
format_signature=format_signature,
device=device,
seed=seed,
)


def get_standard_shapes(op_family: str, op_mode: str) -> list[dict[str, Any]]:
Expand Down
Loading
Loading