fix(vulkan): correctness fixes for gather, scatter, matmul, and compiled ops#40
fix(vulkan): correctness fixes for gather, scatter, matmul, and compiled ops#40goniz wants to merge 10 commits intofeat/vulkanfrom
Conversation
Fix compiled strided output writes to honor output strides and track transfer staging endpoints so cross-queue waits follow real copy dependencies.
Reject bool/int8/uint8 compiled kernels until feature support is wired end-to-end, and fail fast for >2^32-element compiled launches instead of using an incomplete large-index key path.
Select compute devices using a deterministic score (device type, local memory, queue topology) with env overrides for forced index and preferred vendor, and remove dead compiled op mappings that were not actually supported.
Add chunked contiguous compiled execution for tensors above 2^32 elements and plumb Vulkan 8-bit/scalar-layout capability flags for future compiled dtype work (tracked in issue #40).
Tighten Vulkan buffer checks and non-contiguous matmul zero-init handling, add copy offset bounds and dynamic-copy push constants, and synchronize kernel cache/pipeline access to avoid racey map mutations.
Fix compiled dispatch to honor push-descriptor pipelines, add Vulkan Power execution for float dtypes, and avoid unsafe subgroup-first matvec selection on non-Nvidia GPUs so Gemma decode no longer hangs on the compiled GELU-like path.
Replace unsafe pow(x,3) with a safe helper that handles integer-like exponents on negative bases, and add a row-slice gather fallback for large low-precision embedding lookups that were returning zeros.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 09e4fb8686
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if (use_push_descriptor) { | ||
| for (auto& write : writes) { | ||
| write.dstSet = vk::DescriptorSet(); | ||
| } |
There was a problem hiding this comment.
Re-push descriptors after updating chunk offsets
When supports_push_descriptor is true, update_descriptor_set_for_chunk only mutates buffer_infos/writes but never calls push_descriptor_fn again, while the single push happens earlier before chunk-specific offsets are applied. This means offsetted or chunked compiled kernels still execute with stale descriptor offsets (typically zero), producing wrong reads/writes on drivers that use push descriptors.
Useful? React with 👍 / 👎.
| const uint32_t workgroups = std::max<uint32_t>( | ||
| (static_cast<uint32_t>(total_elements) + 255u) / 256u, 1u); | ||
| vulkan::dispatch_dynamic_compute( | ||
| constexpr uint32_t kPushConstantSize = sizeof(uint32_t) + sizeof(int64_t) * 4; |
There was a problem hiding this comment.
Size push constants from struct layout, not field sum
The push-constant block contains uint + 4*int64_t, which is 40 bytes with required alignment/padding, but kPushConstantSize is set to 36 bytes. This under-declares/pushes the range and can yield undefined dynamic_o_base values (or validation failures) when dynamic output offsets are used, causing incorrect dynamic copy addressing.
Useful? React with 👍 / 👎.
| os << " float lhs = " << power_input_expr(a_dtype, "a_buf") << ";\n"; | ||
| os << " float rhs = " << power_input_expr(b_dtype, "b_buf") << ";\n"; |
There was a problem hiding this comment.
Apply input offsets in dynamic Power shader loads
The shader computes a_idx and b_idx from push constants but the generated load expressions still read a_buf.data[idx]/b_buf.data[idx]. For views with non-zero element offsets, Power reads the wrong source elements and returns incorrect results.
Useful? React with 👍 / 👎.
|
|
||
| std::string build_power_shader(Dtype a_dtype, Dtype b_dtype, Dtype out_dtype) { | ||
| std::ostringstream os; | ||
| os << vulkan::emit_dynamic_shader_preamble(a_dtype, out_dtype, false); |
There was a problem hiding this comment.
Include exponent dtype when emitting Power shader preamble
The preamble is emitted from only a_dtype and out_dtype, but the shader also declares InputB using b_dtype. Mixed-type cases such as a=float32, b=float16/bfloat16, out=float32 will reference 16-bit types without enabling the required extensions, causing shader compilation failure instead of executing/falling back cleanly.
Useful? React with 👍 / 👎.
This reverts commit 62a8671.
Restore deterministic device selection, route large-offset shader copies through the dynamic Vulkan copy path, and make async event/fence waits observe stream completion so cached decode generation no longer hangs.
Benchmark ResultsBenchmarks after the latest fix ( bf168bit |
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 180f735483
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| push_fn( | ||
| cmd_buffer, | ||
| VK_PIPELINE_BIND_POINT_COMPUTE, | ||
| pipeline->layout, | ||
| 0, | ||
| static_cast<uint32_t>(writes.size()), | ||
| writes.data()); |
There was a problem hiding this comment.
Re-push descriptors after chunk offset updates
In the push-descriptor path, descriptors are pushed once before any call to update_descriptor_set_for_chunk, but that lambda mutates buffer_infos for each chunk/offset afterward. vkCmdPushDescriptorSetKHR copies descriptor state at record time, so later buffer_infos edits are not seen by subsequent dispatches. On devices using push descriptors, large compiled kernels will repeatedly read/write the initial chunk region instead of advancing through the tensor.
Useful? React with 👍 / 👎.
|
|
||
| vkCmdDispatch(cmd_buffer, workgroups, 1, 1); | ||
| if (!large) { | ||
| update_descriptor_set_for_chunk(0, num_elements); |
There was a problem hiding this comment.
Avoid updating bound descriptor sets without update-after-bind
This call updates descriptor bindings after the set has already been bound earlier in the same command buffer. The set layout is not created with VK_DESCRIPTOR_BINDING_UPDATE_AFTER_BIND_BIT, so vkUpdateDescriptorSets in this state invalidates the command buffer (per Vulkan descriptor update rules). The chunked/non-chunked dispatch path can therefore become invalid or driver-dependent when recording compiled kernels.
Useful? React with 👍 / 👎.
| uint64_t offset_bytes = static_cast<uint64_t>(input_offsets[i]) * item_size; | ||
| if (large && !is_scalar(arr)) { |
There was a problem hiding this comment.
Stop applying runtime element offsets twice
The generated GLSL already embeds input_offsets/output_offsets into index expressions, but this code also adds those same offsets into descriptor buffer offsets. For sliced/view tensors with non-zero element offsets, compiled kernels will read/write from doubly-shifted addresses and produce incorrect results.
Useful? React with 👍 / 👎.
| const uint32_t workgroups = std::max<uint32_t>( | ||
| (static_cast<uint32_t>(total_elements) + 255u) / 256u, 1u); | ||
| vulkan::dispatch_dynamic_compute( | ||
| constexpr uint32_t kPushConstantSize = sizeof(uint32_t) + sizeof(int64_t) * 4; |
There was a problem hiding this comment.
Use sizeof(PushConstants) for dynamic copy push constants
The push-constant size is hardcoded as 4 + 4*8 = 36 bytes, but the actual C++ struct has 8-byte alignment and is 40 bytes (dynamic_o_base starts at offset 32). Pushing only 36 bytes leaves the upper half of dynamic_o_base undefined in shader-visible state, which can corrupt dynamic offset indexing when this path is used.
Useful? React with 👍 / 👎.
Summary
This PR fixes several Vulkan backend correctness issues identified in goniz/mlx-vulkan#39:
Changes
Gather operations (
gather.cpp,gather_*.compshaders)Scatter operations (
scatter_*.compshaders)Matmul (
matmul.cpp)Compiled ops (
compiled.cpp)Copy operations (
copy.cpp)Device management (
device.cpp,device.h,vulkan.cpp,vulkan.h)Fixes
Testing
test_vulkan_ops.pywith targeted test casesBenchmarks
Benchmarks pending (model download taking longer than expected). Will update with results once available.
Note: This PR depends on changes across 20 files with ~1060 insertions and ~249 deletions focused on Vulkan backend stability and correctness.