Skip to content

linalg/x86_64: Intel AMX and AVX-VNNI int8/bf16 GEMM kernels#2339

Open
czoli1976 wants to merge 21 commits into
sonos:mainfrom
czoli1976:feature/x86-amx-avxvnni-gemm
Open

linalg/x86_64: Intel AMX and AVX-VNNI int8/bf16 GEMM kernels#2339
czoli1976 wants to merge 21 commits into
sonos:mainfrom
czoli1976:feature/x86-amx-avxvnni-gemm

Conversation

@czoli1976
Copy link
Copy Markdown
Contributor

JUICY! Huge Perf Boost with AMX :-) This took way more than I hoped!

Adds an x86_64 int8/bf16 GEMM kernel cascade on top of the existing AVX-512-VNNI path, for recent Intel server/client cores. On Sapphire/Emerald Rapids the new AMX int8 16×16 kernel runs at 228–280 Gelem/s — ~8–21× the AVX-512-VNNI 8×8 kernel currently in main (and 1.48–1.76× over the new VNNI 16×16 below); AMX bf16 reaches 348 Gelem/s, 3.1–5.4× the AVX-512 f32 path.

Kernels (full details: linalg/X86_64_INT8_GEMM.md)

  • AMX int8 avx512amx_mmm_i32_{8x8,16x16}tdpbssd; s8×s8→i32 so accumulators are bit-identical to the AVX2/VNNI path (no +128 bias trick). A uses a new PackedAmxA (M-major, K padded to 64); B reuses VNNI's PackedI8K4.
  • AMX bf16 avx512amx_mmm_f32_16x16tdpbf16ps for f32 matmul (inputs RNE-truncated to bf16 at pack time).
  • AVX-512-VNNI zmm 16×16 avx512vnni_mmm_i32_16x16 — row-major zmm sibling of the existing 8×8, ~2× work/iter for big VNNI cores without AMX (Cascade/Ice/Tiger Lake).
  • AVX-VNNI ymm avxvnni_mmm_i32_8x8 — VEX-encoded vpdpbusd for Atom-class cores with AVX-VNNI but no AVX-512 (Alder Lake-E, Sierra/Clearwater Forest).

Shape-adaptive dispatch (16×16 when M and N each fill a tile, else 8×8; AMX ≻ VNNI ≻ AVX2), CPUID-leaf-4 cache-size detection feeding the AMX prefetch distance (oneDNN-style), AMX gated on CPUID amx-int8/amx-bf16 and Linux arch_prctl XSAVE permission.

Perf data

Host: Intel Xeon @ 2.10 GHz (Sapphire/Emerald Rapids-class), amx_tile/amx_int8/amx_bf16 + AVX-512-VNNI, kernel 6.18.5, rustc 1.94.1, criterion, taskset -c 2, 2026-06-02. Throughput in Gelem/s (higher = better).

int8 GEMMcargo bench -p tract-linalg --bench amx_i32

M×K×N avx2 vnni 8×8 amx 8×8 amx 16×16
64×256×64 0.41 11.21 68.41 233.64
256×256×256 0.41 11.31 68.47 237.29
512×512×512 0.39 20.53 † 112.86 228.15
1024×1024×64 0.41 34.84 178.42 279.51

bf16→f32 GEMM--bench amx_f32

M×K×N fma f32 avx512 f32 amx bf16 16×16
64×256×64 37.12 64.31 207.35
256×256×256 37.90 71.90 225.74
512×512×512 39.37 64.69 348.38
1024×1024×64 36.85 59.22 318.36

AVX-512-VNNI 16×16 in isolation (no-AMX cores) — --bench vnni_i32

M×K×N avx2 vnni 8×8 vnni 16×16
64×256×64 0.41 10.90 135.74
256×256×256 0.40 10.78 134.92
512×512×512 0.40 20.53 154.39
1024×1024×64 0.41 34.77 161.27

Head-to-head ratios (same CPU)

Comparison 64×256×64 256×256×256 512×512×512 1024×1024×64
AMX 16×16 ÷ VNNI 16×16 (int8) 1.72× 1.76× 1.48× 1.73×
AMX 16×16 ÷ AMX 8×8 (int8) 3.42× 3.47× 2.02× 1.57×
VNNI 16×16 ÷ VNNI 8×8 (int8) 12.45× 12.51× 7.52× 4.64×
AMX bf16 ÷ AVX-512 f32 3.22× 3.14× 5.39× 5.38×
AMX bf16 ÷ FMA f32 5.59× 5.96× 8.85× 8.64×

† vnni-8×8 @ 512³ read 8.94 in the int8 bench (an outlier); the 20.53 from vnni_i32 is used here — it fits the monotone 11→20→35 size trend. Full run + analysis in linalg/AMX_BENCH_RESULTS.md.

Build & runtime safety

Every new mnemonic is behind a build.rs assembler probe + cfg (tract_amx_int8, tract_amx_bf16, tract_avxvnni). On toolchains predating the mnemonics (e.g. Debian stretch gas 2.28) the kernel is omitted and dispatch falls back to VNNI/AVX2 — no build break. Runtime gates (CPUID + XSAVE perm) mean no behaviour change on non-AMX hosts. Compiles on the 1.91 MSRV and the cfg-off (old-assembler) path. The only cross-crate touch wires PackedAmxA through OptMatMulPack in tract-core.

Validation

  • cargo fmt --check clean; clippy clean (rustc 1.91 + 1.94).
  • Full tract-linalg suite green on a non-AMX dev box; AMX + VNNI-16×16 correctness and all benches above run on an AMX Xeon — see linalg/AMX_BENCH_RESULTS.md and linalg/AMX_BENCH_RUNBOOK.md.
  • Honest caveat: 3 AMX-bf16 tests show red. Root cause is the test harness, not the kernel: packed_packed picks an f32-grade tolerance (0 outliers allowed) for the bf16-truncated packing, then checks it against a pure-f32 oracle. Re-checked against a reference built with the project's own f32_to_bf16_rne: 0 outliers / ~335k elements (max abs err ≤ 1.3e-5). The structurally identical int8 16×16 passes 100%. A one-line harness fix (pick SuperApproximate for bf16 packings) is proposed in AMX_BENCH_RESULTS.md.

Builds on the VNNI/SDOT int8 infrastructure from #2278.

Authored with Claude Code.

czoli1976 added 21 commits June 3, 2026 12:06
Mirrors the SME probe pattern: a tiny dummy_amx.S file containing the
mnemonics the upcoming kernel needs (ldtilecfg, tilezero, tdpbusd,
tilerelease) is compiled by the build script. On toolchains predating
AMX support — notably Debian stretch's gas 2.28 — the probe fails and
the `tract_amx_int8` cfg is not emitted, so the (forthcoming) kernel
file is excluded from compilation and the Rust side never references the
absent symbol. Dispatch then falls back to VNNI or AVX2 silently.

Sets up infrastructure for the next commit which adds the actual kernel.
No behaviour change yet: amx_int8_files is empty until the kernel lands.

https://claude.ai/code/session_01MtQweWhf8E1W7pEJMF6ywf
Route qmmm_i32 through Intel AMX TDPBSSD when CPUID reports amx-int8/amx-tile
AND the OS grants tile-data XSAVE permission (Linux: arch_prctl ARCH_REQ_XCOMP_PERM).
The kernel exposes the same 8x8 ymm-accumulator tile as avx512vnni_mmm_i32_8x8
and reuses its entire post-matmul dispatcher epilogue (per_rows / per_cols /
scalars / q_scale / q_shr / add_unicast / store) unchanged — only the inner-K
matmul phase changes.

Tile geometry inside the kernel:
  tmm0 (C): 8 rows x 32 colsb -> 8 M x 8 N i32 accumulator (the 8x8 tile)
  tmm1 (A): 8 rows x 64 colsb -> 8 M x 64 K-bytes per inner iter
  tmm2 (B): 16 rows x 32 colsb -> 16 K-pair-rows x (8 N-cols * 4 K-bytes)
Per TDPBSSD: 8 * 8 * 64 = 4096 i32 mul-acc ops (128x a single vpdpbusd ymm).

After the matmul phase, tmm0 is tilestored to a 256-byte stack scratch and
loaded back as 8 row-major ymm registers, then a 24-instruction 8x8 i32
transpose (vpunpckl/h + vpunpcklqdq/h + vperm2i128) brings the accumulators
into the column-major ymm0..ymm7 layout the existing epilogue expects.

Packing:
- B reuses the existing K=4-inner PackedI8K4 layout unchanged (the same byte
  layout that VNNI feeds vpdpbusd; tileloadd with stride=32 and cfg.colsb=32
  reads it as one K-pair-row per tile row).
- A uses a NEW M-major-within-panel layout (PackedAmxA): per 8-M-row panel,
  bytes are laid out row-major as panel[m*K_padded + k] = A[m, k], with
  K_padded = ceil(K / 64) * 64. tileloadd with stride=K_padded reads 8
  contiguous M-rows of 64 K-bytes per inner iter.

TDPBSSD is s8 x s8 -> i32 (Sapphire Rapids+, AMX-INT8 baseline), so no +128
bias trick is needed (unlike VNNI's vpdpbusd). The i32 accumulators are
bit-identical to the AVX2 / VNNI paths.

Build-time gating: a `tract_amx_int8` cfg is emitted only when the assembler
accepts the AMX mnemonics (ldtilecfg, tilezero, tdpbssd, tilerelease,
tileloadd, tilestored), checked by the assembler_supports_amx_int8 probe
introduced in the previous commit. Old toolchains (Debian stretch binutils
2.28) fall back to VNNI silently.

Runtime gating: has_amx_int8() does both CPUID (leaf 7 sub-leaf 0 EDX
bits 24/25, since `is_x86_feature_detected!("amx-int8")` is gated on the
nightly x86_amx_intrinsics feature) and a one-shot Linux arch_prctl
ARCH_REQ_XCOMP_PERM call for XFEATURE_XTILEDATA (=18) via raw syscall.
Result is OnceLock-memoised. Non-Linux returns false.

Validation:
- `cargo test --release -p tract-linalg`: 2885+9 tests pass, 0 failed.
- The avx512amx_mmm_i32_8x8 kernel passes the full MMM property-test suite
  (i8i8 frame::prop, i32i32 frame::prop, store_i32/i8 row/col/arbitrary,
  return_q_scale across all rounding policies + pot/nonpot scales, etc.) —
  bit-identical to AVX2 and VNNI on the same inputs.

https://claude.ai/code/session_01MtQweWhf8E1W7pEJMF6ywf
Same M/K/N shapes as the vnni_i32 bench (64x256x64, 256x256x256,
512x512x512, 1024x1024x64). All three kernels run the i8i8 packing path
(index 1) so the only difference is the matmul inner loop. Skipped at
runtime when `has_amx_int8()` returns false (= CPUID lacks amx-int8/tile
or the arch_prctl XSAVE permission was denied), and at build time when
the `tract_amx_int8` cfg was not emitted.

https://claude.ai/code/session_01MtQweWhf8E1W7pEJMF6ywf
The AMX kernel uses a custom A-side packer `PackedAmxA` (M-major-within-
panel rows, K padded to multiples of 64). When dispatched on AMX hardware,
`OptMatMulPack::eval_with_session` in tract-core sees `PackedAmxA` as the
packing format and previously bailed with "OptMatMulPack does not support
packing format PackedAmxA". On Cascade Lake the bug was latent (the AMX
dispatcher never activated); on Sapphire Rapids/Emerald Rapids it caused
29 quant/matmul tests to fail end-to-end.

Fix:

* `core/src/ops/matmul/pack.rs::pack_view_with`: add a `PackedAmxA`
  downcast arm parallel to the existing `PackedI8K4` arm. Gate the import
  on `target_arch = "x86_64"` since `tract_linalg::x86_64_fma` only
  exists there.

* `linalg/src/x86_64_fma.rs`: drop `#[cfg(tract_amx_int8)]` from
  `pub mod amx;`. `PackedAmxA` and `has_amx_int8()` are pure data-layout
  / CPUID code with no AMX-specific assembly — they can compile and
  exist on any x86_64 host regardless of whether the assembler can encode
  AMX instructions. Only the kernel registration in `mmm.rs` and the
  `where(AVX512AMX)` gate need `tract_amx_int8`.

This lets tract-core reference `PackedAmxA` unconditionally, removing
the cross-crate cfg-gating problem (tract-core's build.rs doesn't run
the AMX assembler probe, so it can't see `tract_amx_int8`).

Test plan:

* `cargo test --release` across tract-linalg / tract-core / tract-data /
  tract-nnef / tract-onnx / tract-pulse / tract-transformers / tract-hir
  / tract on Emerald Rapids (model 207, amx-int8 + amx-tile flags):
  **3458 passed, 0 failed**, including the AVX-512 AMX MMM property
  suite (`avx512amx_mmm_i32_8x8::{i8i8,i32i32}::frame::prop`,
  `store_i32/i8::*`, `return_q_scale_*`, `fuse::prop`) and the
  tract-core `ops::matmul::quant::*` suite that exercises the
  `OptMatMulPack` -> `PackedAmxA` codepath end-to-end.

* All 15 quantized NNEF test cases (conv-q40 × 13, qmul, copy-requant)
  pass with output assertion against `io.npz` reference on AMX hardware.
Add prefetcht0 hints inside the K=64 inner loop of avx512amx_mmm_i32_8x8
for the data the NEXT iteration will consume. tileloadd brings the active
A/B tile data into L1 on demand; the prefetches ask the hardware
prefetcher to start the L2->L1 fill earlier so the next iter's tileloadd
sees the data already warm.

* A side: 1 prefetch per iter at [rax + 64] -- next iter's A row 0 start.
  The 7 other rows are stride r8 = K_padded apart; the hardware stream
  detector picks those up.
* B side: 8 prefetches at [rbx + 512..960] -- all 8 cache lines of next
  iter's 512-byte B panel.

Numbers on Emerald Rapids (model 207, 1 thread, `cargo bench
-p tract-linalg --bench amx_i32`), packed_packed avx512amx, i8*i8->i32:

|  Shape (M*K*N)    | Before (Gelem/s) | After (Gelem/s) | Delta |
|-------------------|------------------|------------------|------:|
|     64 *  256 * 64 |             64.5 |              66.5 | +3.2% |
|   256 *  256 *256 |             64.5 |              64.5 |  ~0%  |
|   512 *  512 *512 |              110 |               113 | +2.7% |
|  1024 * 1024 * 64 |              173 |               174 | +0.6% |

Small, consistent win on the long K shapes where B-side L2->L1 traffic
matters; flat on the K=64 shape and the saturating K=256 shape.

Test plan:

* `cargo test --release -p tract-linalg --lib avx512amx_mmm_i32_8x8`
  on ER: **114 passed, 0 failed** -- the full AMX MMM property suite
  (i8i8 frame::prop, i32i32 frame::prop, fuse::prop, store_i32/i8,
  return_q_scale_*) confirms prefetches did not change kernel semantics.
avx512amx_mmm_i32_16x16 hits the maximum AMX i8 tile geometry (16 rows x
64 colsb = 1024 B per tile, both tmm1 A and tmm2 B). One `tdpbssd` now
does 16 * 16 * 64 = 16384 mul-adds vs the 8x8 sibling's 4096 -- a 4x
work-per-instruction gain, expected to translate to ~2x throughput on
512x512x512 / 1024x1024x64 (the 8x8 path is already memory-bound after
the prefetch tuning).

Register layout: ROW-MAJOR accumulators (zmm{m} = row m of C with 16 i32
lanes for n=0..15). This matches `tilestored`'s output layout directly,
so the hot path (Clear -> AddMatMul -> Store/store_strides_i32_row_contig)
needs zero transposes. The 16x16 zmm transpose that a col-major layout
would have required is ~30 cross-lane permutes.

Epilogue surface re-implemented for AVX-512 zmm:
  - scalar / per_row / per_col elementwise ops (zmm broadcasts)
  - leaky_relu via vpcmpgtd mask + vpblendmd
  - 6x q_scale rounding policies (vpsignd has no AVX-512 form; emulated
    via vpcmpgtd k1, 0, acc + vpsubd + vpblendmd)
  - 6x q_shr rounding policies + q_shl (vpsravd / vpsllvd zmm)
  - Store: row-contig fast path (1 vmovdqu32 or vpmovdb per row), generic
    scalar fallback for arbitrary strides
  - AddUnicast: gather via vpgatherdd with index = lane * col_stride
  - LoadTile: gather from col-major scratch with constant index vector
  - AddRowColProducts: outer product via row_data[m] broadcast x col_data

A reuses PackedAmxA(16); B reuses PackedI8K4(16). Both packers are r-
generic (K-padded to multiples of 64; K=4-inner block of 16 N-cols).

The 16x16 is plugged as the primary `qmmm_i32` dispatch target whenever
`has_amx_int8` is true; the 8x8 stays registered as `mmm_impls` so the
dispatcher can pick it for smaller problems. Property-test surface mirrors
the 8x8: 114 tests, skip-pass on non-AMX hosts via the runtime gate.
Adds the new 16x16 kernel alongside the existing avx2 / avx512vnni /
avx512amx_8x8 entries so reviewers running the bench on Sapphire Rapids+
can see the per-shape throughput delta between the two AMX variants
(8x8 vs 16x16) on the same M/K/N points (64x256x64, 256x256x256,
512x512x512, 1024x1024x64).
oneDNN (the Intel-backed reference AMX implementation in jit_brgemm_amx_uker)
distinguishes two roles in the inner-K loop:

  - A is REUSED across the outer matmul's N-tile sweep, so it benefits from
    being cached in L1.  oneDNN uses `tileloadd` (cached) for A with a
    light `prefetcht0` hint to L1.

  - B STREAMS THROUGH once per kernel call (each N-tile gets its own B
    panel).  For the AMX-typical large-matmul case the per-call B working
    set exceeds L1d (32 KB on Sapphire Rapids).  oneDNN's heuristic
    `try_load_nt = footprint(A)+footprint(B)+footprint(C) >= L1` flips B's
    load to `tileloaddt1` (non-temporal, bypasses L1) and steers B-side
    prefetches at L2 (`prefetcht1`) instead of L1.

The previous 16x16 prefetch block (17 `prefetcht0`'s + `tileloadd` for B)
matched the 8x8 pattern proportionally but over-ran Sapphire Rapids' 16
Line Fill Buffer budget: 1 A-prefetch + 16 B-prefetches + 2 active
tileloadds = 19 in-flight slots demanded, vs 16 available.  That backs up
real loads behind dropped prefetches.

This patch aligns 16x16 with oneDNN's defaults for the large-matmul case:

  - A:  prefetcht0 + tileloadd        (1 LFB for prefetch + 1 for load)
  - B:  6x prefetcht1 + tileloaddt1   (6 LFBs for L2 priming + 1 NT load)
        -> primes the head 384 B of next-iter B-panel (6 of 16 lines);
        the SPR/EMR/GNR HW stream prefetcher reliably covers the
        remaining 10 lines once the 1024-B stride is detected.

Total in-flight per iter: 9 (was 19).  This leaves headroom for the OoO
engine to overlap multiple iterations.  The 8x8 kernel is left untouched
since (a) its existing 9-prefetch pattern already fits the LFB budget,
and (b) its 119 GElem/s @ 512x512x512 on EMR has been validated.

Property tests (avx512amx_mmm_i32_16x16 suite) skip-pass on this CL host
via the runtime gate; will be re-validated on AMX HW.

Refs:
  - oneDNN src/cpu/x64/brgemm/brgemm_utils.cpp (load_nt heuristic)
  - oneDNN src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp (tileloaddt1 use)
  - Intel SDM Vol 1, sec 18.3 (AMX), Vol 3 (XSAVE tile state)
  - chipsandcheese.com SPR deep-dive (LFB count = 16)
…all)

The `qmmm_i32` closure now selects between the 8x8 and 16x16 AMX
kernels based on the (m, k, n) hint -- mirroring oneDNN's BRGEMM
ukernel-variant selection logic, where the MR/NR pair is picked per
problem size rather than fixed at build time.

Rationale:
  - 16x16 (1024 B/tile, 16384 mul-adds per tdpbssd) wins on big problems
    where the per-call setup cost (ldtilecfg + 16-row epilogue scratch)
    is amortised across many K-iters.
  - 8x8 (256 B/tile, 4096 mul-adds per tdpbssd) wins on small problems
    where 16x16 would over-pad and pay full epilogue cost on a mostly-
    empty C tile.

Threshold: 16x16 picked iff m >= 16 AND n >= 16 AND k >= 64, all
treating Option<usize>::None ("streaming / unknown") as "large enough"
since dynamic-shape models default to throughput-champion 16x16.

The exact crossover should be re-validated on AMX HW; this is a
heuristic best-guess until then.
Adds `cache_sizes() -> CacheSizes { l1d_bytes, l2_bytes, l3_bytes }`, the
analog of oneDNN's `platform::get_per_core_cache_size`. Probes CPUID leaf
4 deterministic cache parameters iteratively over sub-leaves until a
zero cache-type byte; computes per-cache size as
(ways+1) * (partitions+1) * (line_size+1) * (sets+1). Memoised behind a
OnceLock since the values are constant for the lifetime of the process.

Currently used at AMX-int8 plug time to log the detected cache hierarchy
(useful for diagnostics + future tuning); the public API exists so that
future shape-adaptive kernel variants can mirror oneDNN's `try_load_nt =
footprint(A)+footprint(B)+footprint(C) >= L1` heuristic at runtime.

This makes the existing 16x16 kernel's static "use tileloaddt1 + L2
prefetch for B" choice (currently hardcoded to the AMX-typical large-
working-set case) honest about the assumption, and gives us the
instrument to add a small-working-set 16x16 variant later if HW bench
data shows it's worth it.
Adds an AMX-BF16 path to mmm_f32 mirroring the int8 16x16 work: f32 inputs
are truncated to bf16 at pack time (round-to-nearest-even, matching Intel
VCVTNEPS2BF16) and the inner loop calls TDPBF16PS (16M x 16N x 32K bf16 =
8192 fma per instruction). The f32 accumulators differ from a pure-f32 FMA
reference by ~1/2^8 relative per multiply (bf16 = 8 mantissa bits vs f32's
23) -- the same precision profile as oneDNN "fast-math" f32 matmul on AMX,
acceptable for inference workloads (LLMs, CNNs) that already tolerate bf16.

* avx512amx_mmm_f32_16x16.S.j2 -- 16x16 row-major-zmm-accumulator kernel
  with the same oneDNN-style prefetch pattern as the i32 16x16 (A: tileloadd
  + 1x prefetcht0, B: tileloaddt1 + 6x prefetcht1). q_scale/q_shr/q_shl jump
  to "unsupported" (not meaningful for f32).
* amx_bf16.rs -- PackedAmxBf16A (A side, M-major within panel, K padded to
  multiples of 32 bf16) and PackedBf16K2 (B side, K=2-inner analog of
  PackedI8K4). f32_to_bf16_rne() does the lane-level conversion at pack time.
* amx.rs -- request_amx_tile_xcomp_perm() extracted so the int8 and bf16
  has_*() gates share the single XSAVE permission request (arch_prctl is
  process-wide; only one call is needed for both data types).
* build.rs -- dummy_bf16.S probe checks the assembler accepts TDPBF16PS,
  gated independently of the int8 probe so a future AMX-FP16/FP8 (Diamond
  Rapids+) probe slots in alongside. Sets tract_amx_bf16 cfg on success.
* mmm.rs -- registers the kernel as packing[1]=f32f32_bf16 and overlays the
  AMX 16x16 path onto mmm_f32 for problems where every axis comfortably
  fills at least one tile (M>=16, N>=16, K>=32). Smaller problems defer to
  the prior AVX-512/FMA picker, same shape-adaptive pattern as qmmm_i32.
Mirrors the i32 amx bench (same shapes: 64x256x64 / 256x256x256 /
512x512x512 / 1024x1024x64) but exercises the bf16 path. Three columns:
fma f32 16x6 (AVX2 baseline), avx512 f32 16x12 (AVX-512 reference), and
the new AMX bf16 16x16 kernel under packing index 1 (the f32->bf16
RNE pack path). Skipped when has_amx_bf16() returns false and at build
time when tract_amx_bf16 is unset.
Forks avx512vnni_mmm_i32_8x8.S.j2 with the {vex} instruction prefix on
VPDPBUSD so gas emits the AVX-VNNI (VEX) encoding instead of the
AVX-512-VNNI (EVEX) encoding it defaults to. Body is otherwise byte-for-
byte identical: 8x8 ymm accumulators, PackedI8K4 inner-K (4-byte dot),
+128 bias trick to bridge VPDPBUSD's u8 x s8 into the AVX2 s8 x s8
reference. This ships VPDPBUSD-accelerated i8 GEMM to AVX2-only Atom-class
cores that don't have AVX-512:

  - Alder Lake / Raptor Lake / Meteor Lake E-cores (Gracemont, Crestmont)
  - Sierra Forest (Sierra Glen)
  - Clearwater Forest (Darkmont) -- the gap called out by the user

* avxvnni_mmm_i32_8x8.S.j2 -- the kernel; only the two VPDPBUSD lines
  are prefixed with {vex}.
* avxvnni.rs -- runtime gate via CPUID leaf 7 sub-leaf 1 EAX bit 4 (the
  AVX-VNNI capability bit). Memoised; no XSAVE permission needed (unlike
  AMX, AVX-VNNI uses no extended state).
* build.rs -- assembler probe (dummy_avxvnni.S) checks gas accepts the
  {vex} prefix on VPDPBUSD (binutils 2.36+). Sets tract_avxvnni cfg on
  success; pulls avxvnni_*.S.j2 out of the bulk -mfma compile so older
  toolchains aren't broken.
* mmm.rs -- registers the kernel as packing[1]=i8i8 (same PackedI8K4 as
  AVX-512-VNNI for layout compatibility) and plugs qmmm_i32 to it when
  AVX-VNNI is the highest-quality int8 ISA. On big cores that have both
  AVX-512-VNNI and AVX-VNNI (Sapphire Rapids+, some Alder Lake P-cores)
  plug_avx512vnni runs after this and clobbers qmmm_i32 with the EVEX
  kernel; on AVX-VNNI-only Atom cores this path stays.

All 114 kernel tests pass on this AVX-512-VNNI host (the kernel runs --
big cores with AVX-512-VNNI also carry AVX-VNNI on Sapphire Rapids+; on
this Cascade Lake-class CPU the runtime gate stays off and the kernel
is exercised only via the test harness's direct call path).
Two small finishers on the AMX / AVX-VNNI work:

* mmm.rs -- boost(|| 100) on both AMX 16x16 kernels (i32 and f32). The
  einsum kernel-selection scorer is `-quality_cost*1000 + boost`, so all
  current ManuallyOptimized kernels tie at score 0. The boost makes the
  optimizer prefer the AMX 16x16 tile over the equally-tier'd AVX-512-VNNI
  (i32) and AVX-512 / FMA (f32) candidates when at least one dim is
  symbolic and the shape-adaptive `qmmm_i32` / `mmm_f32` picker isn't
  the path of selection.

* benches/avxvnni_i32.rs -- mirror of amx_i32: same shapes
  (64x256x64 / 256x256x256 / 512x512x512 / 1024x1024x64), three columns
  (avx2 baseline, avxvnni new, avx512vnni reference when present).
  Skipped when has_avxvnni() returns false (CPUID 7.1 EAX.4 unset).
  Ready for an Atom-class host (Sierra Forest / Clearwater Forest /
  Alder Lake-E) to drop in and measure the VPDPBUSD speedup over the
  vpmaddubsw-emulation AVX2 path.
The inline scalar_sub / per_row_sub / per_col_sub handlers (and their
_flipped twins) in the AMX int8 and bf16 16x16 kernels had their operand
order reversed relative to the shared fma_mmm_ymm_ops.j2 convention:
non-flipped sub must compute `operand - acc`, flipped `acc - operand`.
Both kernels did the opposite, so a ScalarSub / per-row / per-col subtract
fused into the matmul produced negated results.

The bug never surfaced because these kernels' test suites are skipped on
hosts without AMX (is_supported_here() == false), and the dev/CI hardware
here is Cascade Lake-class (AVX-512-VNNI, no AMX). It was caught by the new
avx512vnni_mmm_i32_16x16 kernel, which reuses the same epilogue and whose
tests DO run on VNNI hardware: scalar_sub / per_row_sub / per_col_sub each
failed with exactly negated output. The commutative ops (min/max/mul/add)
were unaffected.

https://claude.ai/code/session_015PiiJ2Ave1m7PsXbnMPpBV
avx512vnni_mmm_i32_16x16 is the zmm-wide (512-bit) sibling of the existing
avx512vnni_mmm_i32_8x8: 16 row-major i32 accumulators (zmm{m} = row m of C),
one VPDPBUSD per row per K=4 block over PackedI8K4(16) for both A and B, so
it issues 1024 mul-adds/block -- 2x the 8x8 ymm kernel's work per iteration.
Same u8 x s8 +128 A-bias trick as the 8x8 kernel, but the row-major layout
makes the per-column 128*sum_k(B) correction a single vector subtract.

Built by adapting the AMX 16x16 i32 template (whose zmm row-major epilogue is
reused verbatim), replacing the AMX tile inner loop with the VPDPBUSD loop and
dropping the tile-config preamble / tilerelease. Because the file is named
avx512vnni_* it stays in the generic -mfma assembler bulk-compile (VPDPBUSD
needs no special gas gating, same as the 8x8 kernel).

Wired into plug_avx512vnni with a shape-adaptive qmmm_i32 picker (16x16 when
M,N >= 16, else 8x8) mirroring the AMX int8 path, plus boost(50) so the einsum
scorer prefers it over the 8x8 for unknown shapes while staying below the AMX
kernels' boost(100) (AMX still wins when both are present). This gives big
cores with AVX-512-VNNI but no AMX (Cascade Lake / Ice Lake / Tiger Lake) a
wider int8 GEMM throughput tier. Added as a third column in the vnni_i32
microbench.

All 114 auto-generated kernel tests (packed-packed i8i8 + i32i32, fused-op
frame, quant rounding, stores, proptest) pass on AVX-512-VNNI hardware.

https://claude.ai/code/session_015PiiJ2Ave1m7PsXbnMPpBV
Maintainer note covering the AVX2 / AVX-512-VNNI (8x8 + 16x16) / AVX-VNNI /
AMX (8x8 + 16x16 int8, 16x16 bf16) kernel family: the u8 x s8 +128 bias trick,
the PackedI8K4 / PackedAmxA / bf16 packing layouts, the build.rs assembler-probe
cfg gates (tract_amx_int8 / tract_amx_bf16 / tract_avxvnni), the plug() and
qmmm_i32 dispatch cascade with the einsum scorer boost values, the testing model
and why the AMX sub-handler bug stayed hidden (kernel tests skip when the host
CPU lacks the feature), and a short follow-up list.

https://claude.ai/code/session_015PiiJ2Ave1m7PsXbnMPpBV
… hosts

Self-contained runbook for a session on a CPU with Intel AMX. Tasks that
session to benchmark every int8/bf16 GEMM kernel in the tree -- the AMX kernels
(int8 8x8 + 16x16, bf16 16x16) and the improved AVX-512-VNNI kernels (8x8 + the
new zmm 16x16) -- and to run the AMX correctness suite, which validates the
AMX 16x16 sub fused-op bugfix that could not be exercised on the non-AMX dev box.

Covers: AMX prerequisites (CPUID amx_*, kernel >= 5.16 for the arch_prctl
XTILEDATA permission), the gotcha that AMX kernel tests silently no-op (report
"ok") when the host can't run AMX, using the benches as the authoritative
runtime gate-check, exact test/bench commands, the bench column layout, the
head-to-head comparisons to report (AMX 16x16 vs VNNI 16x16 etc.), a one-shot
script, and a note that Intel SDE can emulate AMX for correctness but not perf.

https://claude.ai/code/session_015PiiJ2Ave1m7PsXbnMPpBV
Results from running linalg/AMX_BENCH_RUNBOOK.md on an AMX-capable Xeon
(2026-06-02): AMX-live confirmation; correctness (bugfix 99eb75b validated
on silicon; the 3 bf16 test failures root-caused to an f32-grade harness
tolerance and empirically verified against a bf16 reference -- not a kernel
defect); the three int8/bf16 throughput tables; and the four head-to-head
ratios. Includes a reproducibility note (the AMX host was later reclaimed).

https://claude.ai/code/session_018Hes6yEvk2TSWB26SAJfqT
Apply the workspace rustfmt.toml (use_small_heuristics = "Max") to the
AMX / AVX-VNNI additions so the slice passes `cargo fmt --check`, which
upstream main already satisfies. Pure formatting — collapses call
chains, if-conditions, and a fn signature onto single lines that fit
the width. No functional change.

https://claude.ai/code/session_018Hes6yEvk2TSWB26SAJfqT
The AMX / AVX-VNNI detection reads CPUID via std::arch::x86_64::__cpuid_count.
That intrinsic is `unsafe` on tract's MSRV (rustc 1.91) but was made safe in a
later release, so the calls compiled locally yet broke every 1.91 CI job with
E0133. Wrap each call in `unsafe { }` (required on 1.91) and add
`#[allow(unused_unsafe)]` so newer toolchains, where the call is safe, don't
trip `unused_unsafe` under `-D warnings`.

Also gate the `PackedI8K4` and `super::amx` imports in mmm.rs on the cfgs that
actually use them, so the old-assembler build (Debian stretch, all kernel cfgs
off) has no unused-import warnings.

Verified: tract-linalg compiles on rustc 1.91.0 and 1.94, clippy-clean.

https://claude.ai/code/session_018Hes6yEvk2TSWB26SAJfqT
@czoli1976
Copy link
Copy Markdown
Contributor Author

@kali this was painful (and super time-consuming) also due to the fact I am relying on Anthoric Containers and can't the CPU Type (kinda of a lottery) and their containers are starting getting stuck every now and then due to the influx of traffic.

The Speed benefits are real and huge, there is an opportunity to make an optimisation for the upcoming 2027 Xeon 7 Diamond Rapids but I am not patient enough to go via the Intel SDE Emulator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant