Skip to content

fix(vulkan): correctness fixes for gather, scatter, matmul, and compiled ops#40

Open
goniz wants to merge 10 commits intofeat/vulkanfrom
fix/issue-39-vulkan-correctness
Open

fix(vulkan): correctness fixes for gather, scatter, matmul, and compiled ops#40
goniz wants to merge 10 commits intofeat/vulkanfrom
fix/issue-39-vulkan-correctness

Conversation

@goniz
Copy link
Copy Markdown
Owner

@goniz goniz commented Apr 21, 2026

Summary

This PR fixes several Vulkan backend correctness issues identified in goniz/mlx-vulkan#39:

Changes

  1. Gather operations (gather.cpp, gather_*.comp shaders)

    • Fixed index bounds checking in gather kernels
    • Added proper handling for edge cases in axis/pair/take variants
  2. Scatter operations (scatter_*.comp shaders)

    • Fixed atomic operations in scatter kernels
    • Improved memory consistency for concurrent updates
  3. Matmul (matmul.cpp)

    • Hardened follow-up paths for various matrix configurations
    • Improved handling of edge cases in matrix dimensions
  4. Compiled ops (compiled.cpp)

    • Implemented chunking for large compiled dispatches to prevent timeout hangs
    • Added safe power codegen to avoid problematic shader generation
  5. Copy operations (copy.cpp)

    • Fixed data transfer paths for various tensor layouts
  6. Device management (device.cpp, device.h, vulkan.cpp, vulkan.h)

    • Added device ranking for context selection
    • Exposed 8-bit capability flags
    • Improved stream synchronization primitives

Fixes

Testing

  • Added test_vulkan_ops.py with targeted test cases
  • Verified with Qwen3-0.6B and Qwen3-2B models

Benchmarks

Benchmarks pending (model download taking longer than expected). Will update with results once available.

Note: This PR depends on changes across 20 files with ~1060 insertions and ~249 deletions focused on Vulkan backend stability and correctness.

goniz added 7 commits April 21, 2026 13:21
Fix compiled strided output writes to honor output strides and track transfer staging endpoints so cross-queue waits follow real copy dependencies.
Reject bool/int8/uint8 compiled kernels until feature support is wired end-to-end, and fail fast for >2^32-element compiled launches instead of using an incomplete large-index key path.
Select compute devices using a deterministic score (device type, local memory, queue topology) with env overrides for forced index and preferred vendor, and remove dead compiled op mappings that were not actually supported.
Add chunked contiguous compiled execution for tensors above 2^32 elements and plumb Vulkan 8-bit/scalar-layout capability flags for future compiled dtype work (tracked in issue #40).
Tighten Vulkan buffer checks and non-contiguous matmul zero-init handling, add copy offset bounds and dynamic-copy push constants, and synchronize kernel cache/pipeline access to avoid racey map mutations.
Fix compiled dispatch to honor push-descriptor pipelines, add Vulkan Power execution for float dtypes, and avoid unsafe subgroup-first matvec selection on non-Nvidia GPUs so Gemma decode no longer hangs on the compiled GELU-like path.
Replace unsafe pow(x,3) with a safe helper that handles integer-like exponents on negative bases, and add a row-slice gather fallback for large low-precision embedding lookups that were returning zeros.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 09e4fb8686

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +1224 to +1227
if (use_push_descriptor) {
for (auto& write : writes) {
write.dstSet = vk::DescriptorSet();
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Re-push descriptors after updating chunk offsets

When supports_push_descriptor is true, update_descriptor_set_for_chunk only mutates buffer_infos/writes but never calls push_descriptor_fn again, while the single push happens earlier before chunk-specific offsets are applied. This means offsetted or chunked compiled kernels still execute with stale descriptor offsets (typically zero), producing wrong reads/writes on drivers that use push descriptors.

Useful? React with 👍 / 👎.

const uint32_t workgroups = std::max<uint32_t>(
(static_cast<uint32_t>(total_elements) + 255u) / 256u, 1u);
vulkan::dispatch_dynamic_compute(
constexpr uint32_t kPushConstantSize = sizeof(uint32_t) + sizeof(int64_t) * 4;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Size push constants from struct layout, not field sum

The push-constant block contains uint + 4*int64_t, which is 40 bytes with required alignment/padding, but kPushConstantSize is set to 36 bytes. This under-declares/pushes the range and can yield undefined dynamic_o_base values (or validation failures) when dynamic output offsets are used, causing incorrect dynamic copy addressing.

Useful? React with 👍 / 👎.

Comment on lines +150 to +151
os << " float lhs = " << power_input_expr(a_dtype, "a_buf") << ";\n";
os << " float rhs = " << power_input_expr(b_dtype, "b_buf") << ";\n";
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Apply input offsets in dynamic Power shader loads

The shader computes a_idx and b_idx from push constants but the generated load expressions still read a_buf.data[idx]/b_buf.data[idx]. For views with non-zero element offsets, Power reads the wrong source elements and returns incorrect results.

Useful? React with 👍 / 👎.


std::string build_power_shader(Dtype a_dtype, Dtype b_dtype, Dtype out_dtype) {
std::ostringstream os;
os << vulkan::emit_dynamic_shader_preamble(a_dtype, out_dtype, false);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Include exponent dtype when emitting Power shader preamble

The preamble is emitted from only a_dtype and out_dtype, but the shader also declares InputB using b_dtype. Mixed-type cases such as a=float32, b=float16/bfloat16, out=float32 will reference 16-bit types without enabling the required extensions, causing shader compilation failure instead of executing/falling back cleanly.

Useful? React with 👍 / 👎.

goniz added 2 commits April 21, 2026 17:41
Restore deterministic device selection, route large-offset shader copies through the dynamic Vulkan copy path, and make async event/fence waits observe stream completion so cached decode generation no longer hangs.
@goniz
Copy link
Copy Markdown
Owner Author

goniz commented Apr 21, 2026

Benchmark Results

Benchmarks after the latest fix (fix(vulkan): unblock generate with large-offset copies):

bf16

Running Qwen3 benchmark with quantization: bf16
Running warmup..
Timing with prompt_tokens=4096, generation_tokens=128, batch_size=1.
Trial 1:  prompt_tps=1385.750, generation_tps=11.148, peak_memory=2.061, total_time=14.567
Trial 2:  prompt_tps=1382.511, generation_tps=11.157, peak_memory=2.062, total_time=14.569
Trial 3:  prompt_tps=1371.342, generation_tps=11.075, peak_memory=2.062, total_time=14.682
Trial 4:  prompt_tps=1388.792, generation_tps=11.000, peak_memory=2.062, total_time=14.723
Trial 5:  prompt_tps=1382.877, generation_tps=11.056, peak_memory=2.063, total_time=14.678
Averages: prompt_tps=1382.255, generation_tps=11.087, peak_memory=2.062

8bit

Running Qwen3 benchmark with quantization: 8bit
Running warmup..
Timing with prompt_tokens=4096, generation_tokens=128, batch_size=1.
Trial 1:  prompt_tps=848.006, generation_tps=6.481, peak_memory=1.393, total_time=24.719
Trial 2:  prompt_tps=851.011, generation_tps=6.575, peak_memory=1.394, total_time=24.422
Trial 3:  prompt_tps=883.247, generation_tps=6.756, peak_memory=1.394, total_time=23.723
Trial 4:  prompt_tps=890.729, generation_tps=6.723, peak_memory=1.394, total_time=23.783
Trial 5:  prompt_tps=894.327, generation_tps=6.752, peak_memory=1.394, total_time=23.676
Averages: prompt_tps=873.464, generation_tps=6.657, peak_memory=1.394

@goniz
Copy link
Copy Markdown
Owner Author

goniz commented Apr 21, 2026

@codex review

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 180f735483

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +1246 to +1252
push_fn(
cmd_buffer,
VK_PIPELINE_BIND_POINT_COMPUTE,
pipeline->layout,
0,
static_cast<uint32_t>(writes.size()),
writes.data());
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Re-push descriptors after chunk offset updates

In the push-descriptor path, descriptors are pushed once before any call to update_descriptor_set_for_chunk, but that lambda mutates buffer_infos for each chunk/offset afterward. vkCmdPushDescriptorSetKHR copies descriptor state at record time, so later buffer_infos edits are not seen by subsequent dispatches. On devices using push descriptors, large compiled kernels will repeatedly read/write the initial chunk region instead of advancing through the tensor.

Useful? React with 👍 / 👎.


vkCmdDispatch(cmd_buffer, workgroups, 1, 1);
if (!large) {
update_descriptor_set_for_chunk(0, num_elements);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid updating bound descriptor sets without update-after-bind

This call updates descriptor bindings after the set has already been bound earlier in the same command buffer. The set layout is not created with VK_DESCRIPTOR_BINDING_UPDATE_AFTER_BIND_BIT, so vkUpdateDescriptorSets in this state invalidates the command buffer (per Vulkan descriptor update rules). The chunked/non-chunked dispatch path can therefore become invalid or driver-dependent when recording compiled kernels.

Useful? React with 👍 / 👎.

Comment on lines +1200 to +1201
uint64_t offset_bytes = static_cast<uint64_t>(input_offsets[i]) * item_size;
if (large && !is_scalar(arr)) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Stop applying runtime element offsets twice

The generated GLSL already embeds input_offsets/output_offsets into index expressions, but this code also adds those same offsets into descriptor buffer offsets. For sliced/view tensors with non-zero element offsets, compiled kernels will read/write from doubly-shifted addresses and produce incorrect results.

Useful? React with 👍 / 👎.

const uint32_t workgroups = std::max<uint32_t>(
(static_cast<uint32_t>(total_elements) + 255u) / 256u, 1u);
vulkan::dispatch_dynamic_compute(
constexpr uint32_t kPushConstantSize = sizeof(uint32_t) + sizeof(int64_t) * 4;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Use sizeof(PushConstants) for dynamic copy push constants

The push-constant size is hardcoded as 4 + 4*8 = 36 bytes, but the actual C++ struct has 8-byte alignment and is 40 bytes (dynamic_o_base starts at offset 32). Pushing only 36 bytes leaves the upper half of dynamic_o_base undefined in shader-visible state, which can corrupt dynamic offset indexing when this path is used.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GPT 5.4 Pro analysis

1 participant