Metal backend: Add gated delta rule kernel for linear attention#18878
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18878
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@claude review this code |
|
Claude finished @metascroy's task in 2m 37s —— View job
Review: Metal gated delta rule kernelOverall this is a well-structured addition that follows existing Metal backend patterns. The kernel logic is correct and the in-place state mutation approach is sound. A few issues worth addressing: Bug: Missing error check on output tensor creation ( The // op_gather_qmv.mm pattern:
AOTITorchError create_result = aoti_torch_create_tensor_from_blob_v2(
out_contents_ptr, 2, output_sizes.data(), output_strides.data(),
0, dtype, 13, 0, &out_tensor_handle, 0, nullptr, 0);
if (create_result != Error::Ok || !out_tensor_handle) {
ET_LOG(Error, "...: Failed to create output tensor");
aoti_torch_mps_free(out_contents_ptr);
return Error::Internal;
}In Bug: Python reference impl silently assumes Hk == Hv ( The reference implementation broadcasts Currently all instantiated kernels use
Nit: Stale comment in instantiation block ( Line 124 says "Qwen 3.5 MoE dimensions: Dk=128, Dv=128, Hk=16, Hv=32" but line 143 instantiates Nit: Tolerance relaxation for gather_qmv bfloat16 ( The bfloat16 tolerances for Minor observations (non-blocking):
|
This PR needs a
|
…rch#18878) Adds Metal kernel for the gated delta rule recurrence used by Qwen 3.5 MoE's GatedDeltaNet linear attention layers. Ported from the MLX delegate PR (pytorch#18785) Metal shader. The kernel processes the full sequence sequentially within a single GPU dispatch, keeping recurrent state in per-thread registers. Grid: [32, Dv, B*Hv], Threadgroup: [32, 4, 1]. Each simdgroup of 32 threads handles Dk/32 elements of the key dimension with SIMD reduction for dot products. The op mutates the recurrent state buffer in-place (mutates_args). Instantiated for both real model (Dk=128, Dv=128, Hk=32, Hv=32) and tiny test (Dk=64, Dv=64, Hk=4, Hv=4) dimensions. Includes: Metal shader + C++ host dispatch, Python custom op definition (metal::gated_delta_rule) with reference CPU impl and Meta impl, C shim dict, fallback kernel registration, CMakeLists entry, and test module.
Adds Metal kernel for the gated delta rule recurrence used by Qwen 3.5
MoE's GatedDeltaNet linear attention layers. Ported from the MLX delegate
PR (#18785) Metal shader. The kernel processes the full sequence
sequentially within a single GPU dispatch, keeping recurrent state in
per-thread registers.
Grid: [32, Dv, B*Hv], Threadgroup: [32, 4, 1]. Each simdgroup of 32
threads handles Dk/32 elements of the key dimension with SIMD reduction
for dot products.
The op mutates the recurrent state buffer in-place (mutates_args).
Instantiated for both real model (Dk=128, Dv=128, Hk=32, Hv=32) and
tiny test (Dk=64, Dv=64, Hk=4, Hv=4) dimensions.
Includes: Metal shader + C++ host dispatch, Python custom op definition
(metal::gated_delta_rule) with reference CPU impl and Meta impl, C shim
dict, fallback kernel registration, CMakeLists entry, and test module.
Authored with Claude.