diff --git a/mlx/backend/vulkan/arange.cpp b/mlx/backend/vulkan/arange.cpp index 36430d6016..e35b757c86 100644 --- a/mlx/backend/vulkan/arange.cpp +++ b/mlx/backend/vulkan/arange.cpp @@ -43,7 +43,8 @@ void fill_arange_like_cpu(array& out, Stream s, double start, double step) { host_values.data(), host_values.size() * sizeof(T), out_buf->buffer, - out.offset()); + out.offset(), + out.data_shared_ptr()); vulkan::retain_array_for_stream(s, out); } diff --git a/mlx/backend/vulkan/compiled.cpp b/mlx/backend/vulkan/compiled.cpp index 62ca72419e..00cb934c88 100644 --- a/mlx/backend/vulkan/compiled.cpp +++ b/mlx/backend/vulkan/compiled.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -114,6 +115,8 @@ std::string dtype_to_glsl_compute(Dtype d) { return "float"; case complex64: return "vec2"; + case bool_: + return "bool"; default: return dtype_to_glsl_storage(d); } @@ -203,8 +206,8 @@ bool supports_primitive_name(const std::string& prim_name) { static const std::unordered_set supported = { "Abs", "Add", "AsType", "Broadcast", "Ceil", "Conjugate", "Cos", "Divide", "Exp", "Floor", "Imag", "Log", - "LogAddExp", "Maximum", "Minimum", "Multiply", "Negative", "Real", - "Round", "Sigmoid", "Sin", "Sqrt", "Subtract", "Tan"}; + "LogAddExp", "Maximum", "Minimum", "Multiply", "Negative", "Power", "Real", + "Round", "Sigmoid", "Sin", "Sqrt", "Subtract", "Tan", "Tanh"}; return supported.contains(prim_name); } @@ -213,7 +216,8 @@ std::string emit_glsl_preamble( bool uses_complex64, bool uses_float16_types, bool uses_int16_types, - bool uses_int8_types) { + bool uses_int8_types, + bool uses_power) { std::ostringstream os; os << "#version 450\n"; os << "#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require\n"; @@ -227,6 +231,8 @@ std::string emit_glsl_preamble( } if (uses_int8_types) { os << "#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require\n"; + os << "#extension GL_EXT_shader_8bit_storage : require\n"; + os << "#extension GL_EXT_scalar_block_layout : require\n"; } os << "\n"; @@ -291,6 +297,38 @@ vec2 complex_conjugate(vec2 z) { return vec2(z.x, -z.y); } +)"; + } + + if (uses_power) { + os << R"( +float safe_real_pow(float x, float y) { + if (x < 0.0) { + float yi = round(y); + if (abs(y - yi) <= 0.00001 && abs(yi) <= 64.0) { + int n = int(yi); + float base = abs(x); + float result = 1.0; + int exp = abs(n); + while (exp > 0) { + if ((exp & 1) != 0) { + result *= base; + } + base *= base; + exp >>= 1; + } + if (n < 0) { + result = 1.0 / result; + } + if ((abs(n) & 1) != 0) { + result = -result; + } + return result; + } + } + return pow(x, y); +} + )"; } @@ -307,17 +345,6 @@ std::string get_glsl_operator(const std::string& primitive_name) { {"Divide", "/"}, {"Maximum", "max"}, {"Minimum", "min"}, - {"Equal", "=="}, - {"NotEqual", "!="}, - {"Greater", ">"}, - {"Less", "<"}, - {"GreaterEqual", ">="}, - {"LessEqual", "<="}, - {"LogicalAnd", "&&"}, - {"LogicalOr", "||"}, - {"BitwiseAnd", "&"}, - {"BitwiseOr", "|"}, - {"BitwiseXor", "^"}, // GLSL built-in functions (lowercase) {"Exp", "exp"}, {"Log", "log"}, @@ -329,7 +356,9 @@ std::string get_glsl_operator(const std::string& primitive_name) { {"Floor", "floor"}, {"Ceil", "ceil"}, {"Round", "round"}, + {"Power", "pow"}, {"Sigmoid", "sigmoid"}, + {"Tanh", "tanh"}, }; auto it = op_map.find(primitive_name); @@ -403,6 +432,10 @@ inline void build_glsl_kernel( bool uses_int8_types = has_any_dtype(inputs, {int8, uint8, bool_}) || has_any_dtype(outputs, {int8, uint8, bool_}) || has_any_dtype(tape, {int8, uint8, bool_}); + const bool uses_power = std::any_of( + tape.begin(), tape.end(), [](const array& x) { + return x.primitive().name() == "Power"; + }); // GLSL header os = emit_glsl_preamble( @@ -410,7 +443,8 @@ inline void build_glsl_kernel( uses_complex64, uses_float16_types, uses_int16_types, - uses_int8_types); + uses_int8_types, + uses_power); // Determine max work per thread based on output dtype size int max_itemsize = 1; @@ -420,6 +454,10 @@ inline void build_glsl_kernel( int wpt = std::min(work_per_thread, 16 / max_itemsize); wpt = std::max(wpt, 1); + const std::string buffer_layout_fmt = uses_int8_types + ? "layout(scalar, binding = {})" + : "layout(binding = {})"; + // Buffer bindings for non-constant inputs int binding = 0; std::vector> input_bindings; // (index, name) @@ -432,8 +470,8 @@ inline void build_glsl_kernel( const auto& xname = get_var_name(x); os += fmt::format( - "layout(binding = {}) readonly buffer Buf{} {{ {} {}[]; }};\n", - binding++, + "{} readonly buffer Buf{} {{ {} {}[]; }};\n", + fmt::format(fmt::runtime(buffer_layout_fmt), binding++), i, dtype_to_glsl_storage(x.dtype()), xname); @@ -446,8 +484,8 @@ inline void build_glsl_kernel( auto& x = outputs[i]; const auto& xname = get_var_name(x); os += fmt::format( - "layout(binding = {}) buffer OutBuf{} {{ {} {}[]; }};\n", - binding++, + "{} buffer OutBuf{} {{ {} {}[]; }};\n", + fmt::format(fmt::runtime(buffer_layout_fmt), binding++), i, dtype_to_glsl_storage(x.dtype()), xname); @@ -625,6 +663,14 @@ layout(push_constant) uniform PushConstants { if (prim_name == "Negative" && x.inputs().size() == 1) { os += fmt::format("(-{});\n", get_input_expr(x.inputs()[0])); + } else if (prim_name == "Power" && x.inputs().size() == 2 && !is_complex) { + const auto lhs = get_input_expr(x.inputs()[0]); + const auto rhs = get_input_expr(x.inputs()[1]); + os += fmt::format( + "{}(safe_real_pow(float({}), float({})));\n", + type_str, + lhs, + rhs); } else if (prim_name == "LogAddExp" && x.inputs().size() == 2) { auto lhs = get_input_expr(x.inputs()[0]); auto rhs = get_input_expr(x.inputs()[1]); @@ -721,6 +767,10 @@ layout(push_constant) uniform PushConstants { } } + const auto& runtime_shape = contiguous ? Shape{} : *strided_shape; + const auto& runtime_output_strides = + contiguous ? Strides{} : (*strided_strides)[0]; + // Write outputs for (size_t i = 0; i < outputs.size(); ++i) { auto& x = outputs[i]; @@ -745,21 +795,37 @@ layout(push_constant) uniform PushConstants { " {}[idx + {}u] = t_{};\n", xname, base_offset, xname); } } else { + os += fmt::format(" uint loc_out_{} = {}u;\n", xname, base_offset); + os += fmt::format(" uint rem_out_{} = idx;\n", xname); + for (int axis = ndim - 1; axis >= 0; --axis) { + os += fmt::format( + " uint coord_out_{0}_{1} = rem_out_{0} % uint({2});\n", + xname, + axis, + runtime_shape[axis]); + os += fmt::format( + " rem_out_{0} /= uint({1});\n", xname, runtime_shape[axis]); + os += fmt::format( + " loc_out_{0} += coord_out_{0}_{1} * uint({2});\n", + xname, + axis, + runtime_output_strides[axis]); + } if (x.dtype() == bool_) { os += fmt::format( - " {}[idx + {}u] = uint8_t(t_{} ? 1 : 0);\n", + " {}[loc_out_{}] = uint8_t(t_{} ? 1 : 0);\n", + xname, xname, - base_offset, xname); } else if (x.dtype() == bfloat16) { os += fmt::format( - " {}[idx + {}u] = uint16_t(fp32_to_bf16(t_{}));\n", + " {}[loc_out_{}] = uint16_t(fp32_to_bf16(t_{}));\n", + xname, xname, - base_offset, xname); } else { os += fmt::format( - " {}[idx + {}u] = t_{};\n", xname, base_offset, xname); + " {}[loc_out_{}] = t_{};\n", xname, xname, xname); } } } @@ -833,14 +899,14 @@ void Compiled::eval_gpu( return true; } - auto has_bool_dtype = [](const std::vector& arrays) { + auto has_unimplemented_dtype = [](const std::vector& arrays) { return std::any_of(arrays.begin(), arrays.end(), [](const array& x) { - return x.dtype() == bool_; + return x.dtype() == bool_ || x.dtype() == int8 || x.dtype() == uint8; }); }; - if (has_bool_dtype(inputs_) || has_bool_dtype(outputs_) || - has_bool_dtype(tape_)) { + if (has_unimplemented_dtype(inputs_) || + has_unimplemented_dtype(outputs_) || has_unimplemented_dtype(tape_)) { return true; } @@ -873,7 +939,12 @@ void Compiled::eval_gpu( } // Use large index if needed - bool large = compiled_use_large_index(dispatch_inputs, outputs, contiguous); + bool large = compiled_use_large_index(dispatch_inputs, outputs, contiguous) || + outputs[0].data_size() > std::numeric_limits::max(); + if (large && !contiguous) { + throw std::runtime_error( + "Compiled kernel failed on Vulkan (arrays >2^32 elements are only supported for contiguous layouts)."); + } const bool has_nonzero_runtime_offset = std::any_of( input_offsets.begin(), @@ -931,10 +1002,6 @@ void Compiled::eval_gpu( kernel_name += fmt::format("_layout_{}", std::hash{}(layout_key.str())); } - if (large) { - kernel_name += "_large"; - } - // Check if we already have this kernel compiled (simple cache check) auto& manager = vulkan::KernelManager::get(); auto* existing_shader = manager.get_shader(kernel_name); @@ -1027,9 +1094,11 @@ void Compiled::eval_gpu( auto cmd_buffer = vulkan::begin_command_recording(s.index); const uint64_t descriptor_epoch = vulkan::descriptor_epoch_for_stream(s); - // Allocate descriptor set - auto descriptor_set = - manager.allocate_descriptor_set(pipeline->descriptor_layout); + const bool use_push_descriptor = pipeline->supports_push_descriptor; + vk::DescriptorSet descriptor_set; + if (!use_push_descriptor) { + descriptor_set = manager.allocate_descriptor_set(pipeline->descriptor_layout); + } // Prepare descriptor writes std::vector writes; @@ -1105,58 +1174,150 @@ void Compiled::eval_gpu( ++write_idx; } - // Update descriptor sets - if (!writes.empty()) { - vkUpdateDescriptorSets( - vulkan::VulkanContext::get().device(), - static_cast(writes.size()), - writes.data(), - 0, - nullptr); + if (large) { + for (size_t i = 0; i < dispatch_inputs.size(); ++i) { + if (is_constant_(i) || is_scalar(dispatch_inputs[i])) { + continue; + } + if (dispatch_inputs[i].data_size() != outputs[0].data_size()) { + throw std::runtime_error( + "Compiled kernel failed on Vulkan (large contiguous compiled kernels require elementwise non-scalar inputs)." + ); + } + } } + auto update_descriptor_set_for_chunk = + [&](uint64_t chunk_base_elements, uint64_t chunk_elements) { + size_t descriptor_binding = 0; + + for (size_t i = 0; i < dispatch_inputs.size(); ++i) { + if (is_constant_(i)) { + continue; + } + const auto& arr = dispatch_inputs[i]; + const uint64_t item_size = static_cast(size_of(arr.dtype())); + uint64_t offset_bytes = static_cast(input_offsets[i]) * item_size; + if (large && !is_scalar(arr)) { + offset_bytes += chunk_base_elements * item_size; + } + + buffer_infos[descriptor_binding].offset = static_cast(offset_bytes); + buffer_infos[descriptor_binding].range = VK_WHOLE_SIZE; + descriptor_binding++; + } + + for (size_t i = 0; i < outputs.size(); ++i) { + const auto& arr = outputs[i]; + const uint64_t item_size = static_cast(size_of(arr.dtype())); + uint64_t offset_bytes = + static_cast(output_offsets[i]) * item_size; + if (large) { + offset_bytes += chunk_base_elements * item_size; + } + + buffer_infos[descriptor_binding].offset = static_cast(offset_bytes); + buffer_infos[descriptor_binding].range = VK_WHOLE_SIZE; + descriptor_binding++; + } + + if (use_push_descriptor) { + for (auto& write : writes) { + write.dstSet = vk::DescriptorSet(); + } + } else if (!writes.empty()) { + vkUpdateDescriptorSets( + vulkan::VulkanContext::get().device(), + static_cast(writes.size()), + writes.data(), + 0, + nullptr); + } + }; + // Bind pipeline and descriptor set vkCmdBindPipeline( cmd_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); - VkDescriptorSet vk_descriptor_set = - static_cast(descriptor_set); - vkCmdBindDescriptorSets( - cmd_buffer, - VK_PIPELINE_BIND_POINT_COMPUTE, - pipeline->layout, - 0, - 1, - &vk_descriptor_set, - 0, - nullptr); + if (use_push_descriptor) { + auto push_fn = vulkan::VulkanContext::get().push_descriptor_fn(); + if (push_fn == nullptr) { + throw std::runtime_error("Missing Vulkan push descriptor function for compiled kernel"); + } + push_fn( + cmd_buffer, + VK_PIPELINE_BIND_POINT_COMPUTE, + pipeline->layout, + 0, + static_cast(writes.size()), + writes.data()); + } else { + VkDescriptorSet vk_descriptor_set = static_cast(descriptor_set); + vkCmdBindDescriptorSets( + cmd_buffer, + VK_PIPELINE_BIND_POINT_COMPUTE, + pipeline->layout, + 0, + 1, + &vk_descriptor_set, + 0, + nullptr); + } // Set push constants struct PushConstants { uint32_t size; } pc; - // TODO: Handle arrays larger than 2^32 elements - pc.size = static_cast(outputs[0].data_size()); - - vkCmdPushConstants( - cmd_buffer, - pipeline->layout, - VK_SHADER_STAGE_COMPUTE_BIT, - 0, - sizeof(pc), - &pc); - // Dispatch uint64_t num_elements = outputs[0].data_size(); - uint32_t workgroups = static_cast( - (num_elements + 256ULL * static_cast(work_per_thread) - 1) / - (256ULL * static_cast(work_per_thread))); - workgroups = std::max(workgroups, 1u); - - vkCmdDispatch(cmd_buffer, workgroups, 1, 1); + if (!large) { + update_descriptor_set_for_chunk(0, num_elements); + + pc.size = static_cast(num_elements); + vkCmdPushConstants( + cmd_buffer, + pipeline->layout, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + sizeof(pc), + &pc); + + uint32_t workgroups = static_cast( + (num_elements + 256ULL * static_cast(work_per_thread) - 1) / + (256ULL * static_cast(work_per_thread))); + workgroups = std::max(workgroups, 1u); + vkCmdDispatch(cmd_buffer, workgroups, 1, 1); + } else { + constexpr uint64_t kMaxChunkElements = + static_cast(std::numeric_limits::max()); + for (uint64_t chunk_base = 0; chunk_base < num_elements;) { + const uint64_t chunk_elements = + std::min(kMaxChunkElements, num_elements - chunk_base); + update_descriptor_set_for_chunk(chunk_base, chunk_elements); + + pc.size = static_cast(chunk_elements); + vkCmdPushConstants( + cmd_buffer, + pipeline->layout, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + sizeof(pc), + &pc); + + uint32_t workgroups = static_cast( + (chunk_elements + 256ULL * static_cast(work_per_thread) - + 1) / + (256ULL * static_cast(work_per_thread))); + workgroups = std::max(workgroups, 1u); + vkCmdDispatch(cmd_buffer, workgroups, 1, 1); + chunk_base += chunk_elements; + } + } // Defer descriptor set cleanup - manager.defer_descriptor_set_free(s.index, descriptor_epoch, descriptor_set); + if (!use_push_descriptor) { + manager.defer_descriptor_set_free(s.index, descriptor_epoch, descriptor_set); + } } } // namespace mlx::core diff --git a/mlx/backend/vulkan/copy.cpp b/mlx/backend/vulkan/copy.cpp index 2cb4706483..0632fae56e 100644 --- a/mlx/backend/vulkan/copy.cpp +++ b/mlx/backend/vulkan/copy.cpp @@ -39,6 +39,9 @@ using vulkan::zero_literal_for_dtype; constexpr size_t kMinTransferQueueCopyBytes = 256 * 1024; std::string copy_dtype_suffix(Dtype dtype); +bool needs_bf16_helpers(Dtype in_dtype, Dtype out_dtype); +std::string emit_bf16_conversion_helpers(); +std::string bf16_cast_expr(const std::string& expr, Dtype in_dtype, Dtype out_dtype); bool has_row_contiguous_strides(const mlx::core::array& arr) { if (arr.ndim() == 0) { @@ -204,25 +207,29 @@ std::string build_dynamic_offset_shader( } std::string build_dynamic_general_copy_shader( - Dtype dtype, + Dtype in_dtype, + Dtype out_dtype, const Shape& shape, const Strides& i_strides, const Strides& o_strides, - int64_t in_base_offset, - int64_t out_base_offset, - int64_t i_offset, - int64_t o_offset, - int64_t dynamic_i_base_offset, - int64_t dynamic_o_base_offset, bool has_dynamic_i_offset, - bool has_dynamic_o_offset, - size_t total_elements) { + bool has_dynamic_o_offset) { std::ostringstream os; - os << emit_dynamic_shader_preamble(dtype, true); + os << emit_dynamic_shader_preamble(in_dtype, out_dtype, true); + if (needs_bf16_helpers(in_dtype, out_dtype)) { + os << emit_bf16_conversion_helpers(); + } + os << "layout(push_constant) uniform PushConstants {\n"; + os << " uint total_elements;\n"; + os << " int64_t input_base;\n"; + os << " int64_t output_base;\n"; + os << " int64_t dynamic_i_base;\n"; + os << " int64_t dynamic_o_base;\n"; + os << "} pc;\n"; os << "layout(set = 0, binding = 0) readonly buffer InputBuffer {" - << dtype_to_glsl_storage_type(dtype) << " data[];} input_buf;\n"; + << dtype_to_glsl_storage_type(in_dtype) << " data[];} input_buf;\n"; os << "layout(set = 0, binding = 1) buffer OutputBuffer {" - << dtype_to_glsl_storage_type(dtype) << " data[];} output_buf;\n"; + << dtype_to_glsl_storage_type(out_dtype) << " data[];} output_buf;\n"; if (has_dynamic_i_offset) { os << "layout(set = 0, binding = 2) readonly buffer DynamicInputOffset {int64_t data[];} dynamic_i_offset_buf;\n"; } @@ -231,20 +238,16 @@ std::string build_dynamic_general_copy_shader( } os << "\nvoid main() {\n"; os << " uint linear_idx = gl_GlobalInvocationID.x;\n"; - os << " if (linear_idx >= " << total_elements << "u) {\n"; + os << " if (linear_idx >= pc.total_elements) {\n"; os << " return;\n"; os << " }\n"; - os << " int64_t input_index = int64_t(" << (in_base_offset + i_offset) - << ");\n"; - os << " int64_t output_index = int64_t(" << (out_base_offset + o_offset) - << ");\n"; + os << " int64_t input_index = pc.input_base;\n"; + os << " int64_t output_index = pc.output_base;\n"; if (has_dynamic_i_offset) { - os << " input_index += dynamic_i_offset_buf.data[" << dynamic_i_base_offset - << "];\n"; + os << " input_index += dynamic_i_offset_buf.data[uint(pc.dynamic_i_base)];\n"; } if (has_dynamic_o_offset) { - os << " output_index += dynamic_o_offset_buf.data[" - << dynamic_o_base_offset << "];\n"; + os << " output_index += dynamic_o_offset_buf.data[uint(pc.dynamic_o_base)];\n"; } if (!shape.empty()) { os << " uint remaining = linear_idx;\n"; @@ -260,7 +263,15 @@ std::string build_dynamic_general_copy_shader( os << " }\n"; } } - os << " output_buf.data[uint(output_index)] = input_buf.data[uint(input_index)];\n"; + std::string cast_expr; + if (needs_bf16_helpers(in_dtype, out_dtype)) { + cast_expr = bf16_cast_expr( + "input_buf.data[uint(input_index)]", in_dtype, out_dtype); + } else { + cast_expr = cast_expr_for_dtype( + "input_buf.data[uint(input_index)]", in_dtype, out_dtype); + } + os << " output_buf.data[uint(output_index)] = " << cast_expr << ";\n"; os << "}\n"; return os.str(); } @@ -310,10 +321,6 @@ bool dispatch_dynamic_general_copy( validate_dynamic_offset_array(dynamic_i_offset); validate_dynamic_offset_array(dynamic_o_offset); - if (in.dtype() != out.dtype()) { - throw std::runtime_error( - "Dynamic Vulkan copy currently requires matching input/output dtypes."); - } if (!is_vulkan_storage_array(in) || !is_vulkan_storage_array(out)) { throw std::runtime_error( "Dynamic Vulkan copy requires Vulkan-backed input and output arrays."); @@ -339,10 +346,8 @@ bool dispatch_dynamic_general_copy( dynamic_o_offset.has_value() ? element_offset(*dynamic_o_offset) : 0; std::ostringstream layout_key; - layout_key << static_cast(in.dtype().val()) << ':' << in_base_offset - << ':' << out_base_offset << ':' << i_offset << ':' << o_offset - << ':' << dynamic_i_base_offset << ':' << dynamic_o_base_offset - << ':' << dynamic_i_offset.has_value() << ':' + layout_key << static_cast(in.dtype().val()) << ':' + << dynamic_i_offset.has_value() << ':' << dynamic_o_offset.has_value() << ':'; append_layout_key(layout_key, shape); layout_key << ':'; @@ -351,23 +356,17 @@ bool dispatch_dynamic_general_copy( append_layout_key(layout_key, o_strides); const std::string shader_name = "dynamic_general_copy_" + - copy_dtype_suffix(in.dtype()) + "_" + + copy_dtype_suffix(in.dtype()) + "_" + copy_dtype_suffix(out.dtype()) + "_" + std::to_string(std::hash{}(layout_key.str())); const std::string glsl_source = build_dynamic_general_copy_shader( in.dtype(), + out.dtype(), shape, i_strides, o_strides, - in_base_offset, - out_base_offset, - i_offset, - o_offset, - dynamic_i_base_offset, - dynamic_o_base_offset, dynamic_i_offset.has_value(), - dynamic_o_offset.has_value(), - total_elements); + dynamic_o_offset.has_value()); std::vector arrays; arrays.push_back({&in, 0}); @@ -379,17 +378,40 @@ bool dispatch_dynamic_general_copy( arrays.push_back({&*dynamic_o_offset, 3}); } - const uint32_t workgroups = std::max( - (static_cast(total_elements) + 255u) / 256u, 1u); - vulkan::dispatch_dynamic_compute( + constexpr uint32_t kPushConstantSize = sizeof(uint32_t) + sizeof(int64_t) * 4; + auto dispatch = vulkan::dispatch_dynamic_compute_begin( shader_name, glsl_source, static_cast(arrays.size()), arrays.data(), - workgroups, - 1, - 1, + kPushConstantSize, s); + + struct PushConstants { + uint32_t total_elements; + int64_t input_base; + int64_t output_base; + int64_t dynamic_i_base; + int64_t dynamic_o_base; + } pc{}; + pc.total_elements = static_cast(total_elements); + pc.input_base = in_base_offset + i_offset; + pc.output_base = out_base_offset + o_offset; + pc.dynamic_i_base = dynamic_i_base_offset; + pc.dynamic_o_base = dynamic_o_base_offset; + + vkCmdPushConstants( + dispatch.command_buffer, + dispatch.pipeline->layout, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + kPushConstantSize, + &pc); + + const uint32_t workgroups = std::max( + (static_cast(total_elements) + 255u) / 256u, 1u); + vkCmdDispatch(dispatch.command_buffer, workgroups, 1, 1); + vulkan::end_command_recording(s.index); return true; } @@ -501,6 +523,12 @@ bool dispatch_dynamic_vector_cast_copy( if (in_base_offset < 0 || out_base_offset < 0) { return false; } + if (static_cast(in_base_offset) > + static_cast(std::numeric_limits::max()) || + static_cast(out_base_offset) > + static_cast(std::numeric_limits::max())) { + return false; + } const std::string shader_name = "dynamic_vector_cast_copy_" + copy_dtype_suffix(in.dtype()) + "_" + copy_dtype_suffix(out.dtype()); @@ -737,7 +765,8 @@ bool try_host_vector_cast_copy( host_out.data(), host_out.size(), out_buf->buffer, - out_offset * size_of(out.dtype())); + out_offset * size_of(out.dtype()), + out.data_shared_ptr()); mlx::core::vulkan::retain_array_for_stream(s, in); mlx::core::vulkan::retain_array_for_stream(s, out); }; @@ -1240,7 +1269,8 @@ void copy_gpu_inplace( host_fill.data(), host_fill.size(), out_buf->buffer, - out.offset()); + out.offset(), + out.data_shared_ptr()); vulkan::retain_array_for_stream(s, *source); vulkan::retain_array_for_stream(s, out); } @@ -1250,22 +1280,18 @@ void copy_gpu_inplace( const bool is_slice_copy = shader_copy_type && shader_id.has_value() && in.size() != out.size(); - - const bool segmented_buffer_copy = same_dtype && + const bool large_shader_offset = (shader_copy || is_slice_copy) && (has_large_element_offset(in_view) || has_large_element_offset(out_view)); - const bool contiguous_large_rank_copy = same_dtype && - dispatch_shape.size() > 4 && in_view.flags().contiguous && - out_view.flags().contiguous && in_view.size() == out_view.size() && - !is_slice_copy; + const bool segmented_buffer_copy = same_dtype && large_shader_offset; const bool host_contiguous_copy = in_view.flags().row_contiguous && out_view.flags().row_contiguous && dispatch_elements == in_view.size() && dispatch_elements == out_view.size(); if (!raw_buffer_copy && !shader_copy && !is_slice_copy && - !contiguous_large_rank_copy && host_contiguous_copy) { + host_contiguous_copy) { if (try_host_vector_cast_copy( in_view, out_view, @@ -1296,8 +1322,7 @@ void copy_gpu_inplace( } } - if (!raw_buffer_copy && !shader_copy && !is_slice_copy && - !contiguous_large_rank_copy) { + if (!raw_buffer_copy && !shader_copy && !is_slice_copy) { std::ostringstream oss; oss << "Copy operation failed on Vulkan (unsupported dtype or layout): " << "ctype=" << copy_type_name(ctype) << " " @@ -1330,7 +1355,12 @@ void copy_gpu_inplace( std::memcpy(dst_ptr, src_ptr, in_view.nbytes()); } else { vulkan::enqueue_owned_staging_upload( - s, src_ptr, in_view.nbytes(), out_buf->buffer, out_view.offset()); + s, + src_ptr, + in_view.nbytes(), + out_buf->buffer, + out_view.offset(), + out.data_shared_ptr()); vulkan::retain_array_for_stream(s, *source); vulkan::retain_array_for_stream(s, out); } @@ -1362,8 +1392,7 @@ void copy_gpu_inplace( // Small copies, especially decode-time KV cache updates, are latency // sensitive and do better when they stay in the current compute submission. const bool use_transfer_queue = - (raw_buffer_copy || contiguous_large_rank_copy || - segmented_buffer_copy) && + (raw_buffer_copy || segmented_buffer_copy) && should_use_transfer_queue_for_copy(copy_bytes); vk::CommandBuffer cmd_buffer = use_transfer_queue ? vulkan::begin_transfer_command_recording(s.index) @@ -1387,16 +1416,6 @@ void copy_gpu_inplace( vulkan::retain_array_for_stream(s, *source); vulkan::retain_array_for_stream(s, out); - } else if (contiguous_large_rank_copy) { - VkBufferCopy copy_region{}; - copy_region.srcOffset = static_cast(in_view.offset()); - copy_region.dstOffset = static_cast(out_view.offset()); - copy_region.size = static_cast(in_view.nbytes()); - - cmd_buffer.copyBuffer(in_buf->buffer, out_buf->buffer, {copy_region}); - - vulkan::retain_array_for_stream(s, in_view); - vulkan::retain_array_for_stream(s, out_view); } else if (segmented_buffer_copy) { const auto copy_regions = make_strided_copy_regions(in_view, out_view); if (copy_regions.empty()) { @@ -1435,6 +1454,24 @@ void copy_gpu_inplace( "Copy operation failed on Vulkan: >4D non-contiguous arrays not supported"); } + if (large_shader_offset) { + if (!dispatch_dynamic_general_copy( + in_view, + out_view, + dispatch_shape, + dispatch_i_strides, + dispatch_o_strides, + 0, + 0, + s, + std::nullopt, + std::nullopt)) { + throw std::runtime_error( + "Large-offset shader copy does not support tensors with rank greater than 4."); + } + return; + } + vulkan::dispatch_unary_op(in_view, out_view, *shader_id, cmd_buffer, s); } catch (const std::runtime_error& e) { if (use_transfer_queue) { @@ -1480,9 +1517,21 @@ void fill_gpu(const array& val, array& out, const Stream& s) { std::memcpy(dst_ptr + i, val_ptr, val_size); } } else { - // For discrete GPUs, we need to use a compute shader or staging buffer - // TODO: Implement compute shader fill - throw std::runtime_error("fill_gpu not yet implemented for discrete GPUs"); + std::vector host_fill(out.nbytes()); + const char* val_ptr = static_cast(val.data()); + size_t val_size = size_of(val.dtype()); + for (size_t i = 0; i < host_fill.size(); i += val_size) { + std::memcpy(host_fill.data() + i, val_ptr, val_size); + } + vulkan::enqueue_owned_staging_upload( + s, + host_fill.data(), + host_fill.size(), + out_buf->buffer, + out.offset(), + out.data_shared_ptr()); + vulkan::retain_array_for_stream(s, val); + vulkan::retain_array_for_stream(s, out); } } diff --git a/mlx/backend/vulkan/device.cpp b/mlx/backend/vulkan/device.cpp index cfc6de976c..5eb753b8d5 100644 --- a/mlx/backend/vulkan/device.cpp +++ b/mlx/backend/vulkan/device.cpp @@ -432,7 +432,6 @@ struct SubmissionResources { vk::CommandBuffer compute_command_buffer; vk::CommandPool transfer_command_pool; vk::CommandBuffer transfer_command_buffer; - vk::Fence fence; }; struct SubmissionRecord { @@ -1577,8 +1576,7 @@ class VulkanDevice { continue; } - auto* storage = static_cast( - const_cast(static_cast(data->buffer.ptr()))); + auto* storage = referenced_vulkan_buffer(data); if (storage == nullptr || storage->buffer == VK_NULL_HANDLE) { continue; } @@ -1823,9 +1821,6 @@ class VulkanDevice { resources->transfer_command_buffer = resources->compute_command_buffer; } - vk::FenceCreateInfo fence_info; - resources->fence = vk_device.createFence(fence_info); - return resources; } @@ -1848,9 +1843,6 @@ class VulkanDevice { {resources->transfer_command_buffer}); device.destroyCommandPool(resources->transfer_command_pool); } - if (resources->fence) { - device.destroyFence(resources->fence); - } if (resources->compute_command_pool) { device.destroyCommandPool(resources->compute_command_pool); } @@ -2277,6 +2269,7 @@ class VulkanDevice { stream->recent_primitives.clear(); stream->in_flight_submissions.push_back(std::move(submission)); stream->recording_transfer = false; + stream->submission_count++; if (decode_batch_enabled() && submit_to_transfer_queue) { stream->decode_transfer_submit_count++; @@ -2398,7 +2391,8 @@ void enqueue_owned_staging_upload( const void* src, size_t size, vk::Buffer dst_buffer, - uint64_t dst_offset) { + uint64_t dst_offset, + std::shared_ptr tracked_dst_data) { if (size == 0) { return; } @@ -2432,6 +2426,7 @@ void enqueue_owned_staging_upload( command_buffer.copyBuffer(staging_buffer->buffer, dst_buffer, {copy_region}); VulkanDevice::get().retain_data(s.index, staging.owner); + VulkanDevice::get().retain_data(s.index, std::move(tracked_dst_data)); add_completion_callback_for_stream( s, [arena = std::move(staging.arena), @@ -2446,7 +2441,8 @@ void enqueue_owned_staging_readback( vk::Buffer src_buffer, uint64_t src_offset, size_t size, - std::function completion) { + std::function completion, + std::shared_ptr tracked_src_data) { if (size == 0) { completion(nullptr, 0); return; @@ -2477,6 +2473,7 @@ void enqueue_owned_staging_readback( command_buffer.copyBuffer(src_buffer, staging_buffer->buffer, {copy_region}); VulkanDevice::get().retain_data(s.index, staging.owner); + VulkanDevice::get().retain_data(s.index, std::move(tracked_src_data)); add_completion_callback_for_stream( s, [arena = std::move(staging.arena), diff --git a/mlx/backend/vulkan/device.h b/mlx/backend/vulkan/device.h index 4c7d07d4fc..9d407ea092 100644 --- a/mlx/backend/vulkan/device.h +++ b/mlx/backend/vulkan/device.h @@ -59,13 +59,15 @@ void enqueue_owned_staging_upload( const void* src, size_t size, vk::Buffer dst_buffer, - uint64_t dst_offset = 0); + uint64_t dst_offset = 0, + std::shared_ptr tracked_dst_data = nullptr); void enqueue_owned_staging_readback( const Stream& s, vk::Buffer src_buffer, uint64_t src_offset, size_t size, - std::function completion); + std::function completion, + std::shared_ptr tracked_src_data = nullptr); uint64_t descriptor_epoch_for_stream(const Stream& s); array acquire_scratch_array( const Stream& s, diff --git a/mlx/backend/vulkan/event.cpp b/mlx/backend/vulkan/event.cpp index baf38564d6..8991b40ee7 100644 --- a/mlx/backend/vulkan/event.cpp +++ b/mlx/backend/vulkan/event.cpp @@ -39,6 +39,13 @@ void Event::wait() { auto* counter = static_cast(event_.get()); if (stream_.device == Device::gpu && counter->value < value()) { vulkan::synchronize_stream(stream_); + { + std::lock_guard lock(counter->mutex); + if (counter->value < value()) { + counter->value = value(); + } + } + counter->cv.notify_all(); } std::unique_lock lock(counter->mutex); if (counter->value >= value()) { diff --git a/mlx/backend/vulkan/fence.cpp b/mlx/backend/vulkan/fence.cpp index 38052a69cb..3f53317716 100644 --- a/mlx/backend/vulkan/fence.cpp +++ b/mlx/backend/vulkan/fence.cpp @@ -64,6 +64,13 @@ void Fence::wait(Stream stream, const array&) { if (impl->stream.device == Device::gpu) { vulkan::synchronize_stream(impl->stream); + { + std::lock_guard lock(impl->mutex); + if (impl->value < target) { + impl->value = target; + } + } + impl->cv.notify_all(); } wait_fence_value(fence_, target); } diff --git a/mlx/backend/vulkan/gather.cpp b/mlx/backend/vulkan/gather.cpp index 91292707a4..1afdb296d2 100644 --- a/mlx/backend/vulkan/gather.cpp +++ b/mlx/backend/vulkan/gather.cpp @@ -1,7 +1,9 @@ // Copyright © 2024 Apple Inc. #include +#include +#include "mlx/backend/gpu/copy.h" #include "mlx/backend/vulkan/primitives_utils.h" #include "mlx/ops.h" @@ -9,10 +11,94 @@ namespace mlx::core { namespace { +uint32_t checked_shape_product( + const array& arr, + int begin, + int end, + const char* label); + bool needs_row_contiguous(const array& arr) { return !arr.flags().row_contiguous || arr.offset() != 0; } +bool is_low_precision_gather_value_dtype(Dtype dtype) { + return dtype == float16 || dtype == bfloat16; +} + +int64_t read_index_value(const array& idx, size_t flat_pos) { + switch (idx.dtype()) { + case int32: + return static_cast(idx.data()[flat_pos]); + case int64: + return idx.data()[flat_pos]; + case uint32: + return static_cast(idx.data()[flat_pos]); + case uint64: { + auto value = idx.data()[flat_pos]; + if (value > static_cast(std::numeric_limits::max())) { + throw std::out_of_range("Gather index exceeds int64 range."); + } + return static_cast(value); + } + default: + throw std::runtime_error("Unsupported gather index dtype."); + } +} + +bool try_eval_take_axis0_low_precision_copy_fallback( + const array& src, + array idx, + array& out, + Stream s) { + if (src.ndim() == 0 || idx.size() == 0) { + return false; + } + if (!is_low_precision_gather_value_dtype(src.dtype()) || !src.flags().row_contiguous || + src.offset() != 0 || idx.ndim() == 0) { + return false; + } + + idx.wait(); + const uint32_t size_axis = static_cast(src.shape(0)); + const uint32_t size_post = checked_shape_product( + src, 1, src.ndim(), "gather_take_fallback size_post"); + array out_work(out.shape(), out.dtype(), nullptr, {}); + out_work.set_data(allocator::malloc(out_work.nbytes())); + + for (size_t flat_pos = 0; flat_pos < idx.size(); ++flat_pos) { + int64_t raw_index = read_index_value(idx, flat_pos); + if (raw_index < 0) { + raw_index += size_axis; + } + if (raw_index < 0 || raw_index >= size_axis) { + throw std::out_of_range("Gather index out of bounds."); + } + + array src_slice({static_cast(size_post)}, src.dtype(), nullptr, {}); + src_slice.copy_shared_buffer( + src, + {1}, + {true, true, true}, + size_post, + static_cast(raw_index) * size_post); + src_slice.set_status(array::Status::available); + + array out_slice({static_cast(size_post)}, out.dtype(), nullptr, {}); + out_slice.copy_shared_buffer( + out_work, + {1}, + {true, true, true}, + size_post, + flat_pos * static_cast(size_post)); + out_slice.set_status(array::Status::available); + + copy_gpu_inplace(src_slice, out_slice, CopyType::Vector, s); + } + + copy_gpu(out_work, out, CopyType::GeneralGeneral, s); + return true; +} + array ensure_row_contiguous(array arr, Stream s) { if (needs_row_contiguous(arr)) { arr = contiguous_copy_gpu(arr, s); @@ -179,6 +265,11 @@ bool try_eval_gather_vulkan( array src = ensure_row_contiguous(src_input, s); idx = ensure_row_contiguous(idx, s); + if (axis == 0 && + try_eval_take_axis0_low_precision_copy_fallback(src, idx, out, s)) { + return true; + } + const uint32_t size_pre = checked_shape_product(src_input, 0, axis, "gather_take size_pre"); const uint32_t size_axis = diff --git a/mlx/backend/vulkan/kernels.cpp b/mlx/backend/vulkan/kernels.cpp index 1920ced1e1..440c7ae97d 100644 --- a/mlx/backend/vulkan/kernels.cpp +++ b/mlx/backend/vulkan/kernels.cpp @@ -998,6 +998,8 @@ void KernelManager::register_shader( const std::string& name, const void* data, size_t size_bytes) { + std::scoped_lock lock(shader_cache_mutex_, pipeline_cache_mutex_); + auto& shader = dynamic_shaders_[name]; if (!shader) { shader = std::make_unique(); @@ -1065,6 +1067,7 @@ vk::ShaderModule KernelManager::compile_shader( ShaderModule* KernelManager::get_shader(StaticShaderId id) { ensure_static_registry_initialized(); + std::lock_guard lock(shader_cache_mutex_); const size_t index = static_cast(id); if (index >= static_shaders_.size()) { return nullptr; @@ -1084,15 +1087,18 @@ ShaderModule* KernelManager::get_shader(StaticShaderId id) { mlx::core::vulkan::ShaderModule* mlx::core::vulkan::KernelManager::get_shader( const std::string& name) { - auto it = dynamic_shaders_.find(name); - if (it != dynamic_shaders_.end()) { - auto* shader = it->second.get(); - if (!shader->compiled) { - shader->module = compile_shader(shader->spirv_code); - shader->compiled = true; - } + { + std::lock_guard lock(shader_cache_mutex_); + auto it = dynamic_shaders_.find(name); + if (it != dynamic_shaders_.end()) { + auto* shader = it->second.get(); + if (!shader->compiled) { + shader->module = compile_shader(shader->spirv_code); + shader->compiled = true; + } - return shader; + return shader; + } } ensure_static_registry_initialized(); @@ -1145,9 +1151,12 @@ ComputePipeline* KernelManager::get_pipeline( pipeline_key.bindings.push_back(make_descriptor_binding_key(binding)); } - auto it = pipelines_.find(pipeline_key); - if (it != pipelines_.end()) { - return it->second.get(); + { + std::lock_guard lock(pipeline_cache_mutex_); + auto it = pipelines_.find(pipeline_key); + if (it != pipelines_.end()) { + return it->second.get(); + } } VkDevice device = VulkanContext::get().device(); @@ -1287,7 +1296,19 @@ ComputePipeline* KernelManager::get_pipeline( pipeline_ptr->supports_push_descriptor = use_push_descriptor; auto* result = pipeline_ptr.get(); - pipelines_.emplace(std::move(pipeline_key), std::move(pipeline_ptr)); + { + std::lock_guard lock(pipeline_cache_mutex_); + auto it = pipelines_.find(pipeline_key); + if (it != pipelines_.end()) { + vkDestroyPipeline(device, pipeline, nullptr); + vkDestroyPipelineLayout(device, pipeline_layout, nullptr); + if (descriptor_layout != VK_NULL_HANDLE) { + vkDestroyDescriptorSetLayout(device, descriptor_layout, nullptr); + } + return it->second.get(); + } + pipelines_.emplace(std::move(pipeline_key), std::move(pipeline_ptr)); + } return result; } @@ -1331,9 +1352,12 @@ ComputePipeline* KernelManager::get_pipeline( pipeline_key.bindings.push_back(make_descriptor_binding_key(binding)); } - auto it = pipelines_.find(pipeline_key); - if (it != pipelines_.end()) { - return it->second.get(); + { + std::lock_guard lock(pipeline_cache_mutex_); + auto it = pipelines_.find(pipeline_key); + if (it != pipelines_.end()) { + return it->second.get(); + } } VkDevice device = VulkanContext::get().device(); @@ -1477,12 +1501,25 @@ ComputePipeline* KernelManager::get_pipeline( pipeline_ptr->supports_push_descriptor = use_push_descriptor; auto* result = pipeline_ptr.get(); - pipelines_.emplace(std::move(pipeline_key), std::move(pipeline_ptr)); + { + std::lock_guard lock(pipeline_cache_mutex_); + auto it = pipelines_.find(pipeline_key); + if (it != pipelines_.end()) { + vkDestroyPipeline(device, pipeline, nullptr); + vkDestroyPipelineLayout(device, pipeline_layout, nullptr); + if (descriptor_layout != VK_NULL_HANDLE) { + vkDestroyDescriptorSetLayout(device, descriptor_layout, nullptr); + } + return it->second.get(); + } + pipelines_.emplace(std::move(pipeline_key), std::move(pipeline_ptr)); + } return result; } void KernelManager::init_descriptor_pool() { + std::lock_guard lock(descriptor_pool_mutex_); if (descriptor_pool_initialized_) { return; } @@ -1929,21 +1966,30 @@ void KernelManager::purge_descriptor_sets_for_layouts( void KernelManager::cleanup() { reclaim_all_descriptor_sets(); - pipelines_.clear(); - dynamic_shaders_.clear(); + { + std::lock_guard pipeline_lock(pipeline_cache_mutex_); + pipelines_.clear(); + } + { + std::lock_guard shader_lock(shader_cache_mutex_); + dynamic_shaders_.clear(); + for (auto& shader : static_shaders_) { + shader.reset(); + } + } { std::lock_guard lock(static_registry_mutex_); static_registry_initialized_ = false; } - for (auto& shader : static_shaders_) { - shader.reset(); - } - if (descriptor_pool_ != VK_NULL_HANDLE) { - VkDevice device = VulkanContext::get().device(); - vkDestroyDescriptorPool(device, descriptor_pool_, nullptr); - descriptor_pool_ = VK_NULL_HANDLE; - descriptor_pool_initialized_ = false; + { + std::lock_guard lock(descriptor_pool_mutex_); + if (descriptor_pool_ != VK_NULL_HANDLE) { + VkDevice device = VulkanContext::get().device(); + vkDestroyDescriptorPool(device, descriptor_pool_, nullptr); + descriptor_pool_ = VK_NULL_HANDLE; + descriptor_pool_initialized_ = false; + } } } @@ -2924,8 +2970,8 @@ void dispatch_mul_mat_vec_op( push_constants.batch_stride_d = nrows; push_constants.fusion_flags = 0; push_constants.ne02 = 1; - push_constants.ne12 = 1; - push_constants.broadcast2 = 1; + push_constants.ne12 = batch_rows; + push_constants.broadcast2 = batch_rows; push_constants.broadcast3 = 1; const std::array bound_arrays = {{ @@ -2940,11 +2986,15 @@ void dispatch_mul_mat_vec_op( const uint32_t groups_z = (nrows + kMaxWorkgroupsX - 1u) / kMaxWorkgroupsX; const uint32_t groups_x = (nrows + groups_z - 1u) / groups_z; + const bool force_single_column_dispatch = + VulkanContext::get().architecture() == GpuArchitecture::AmdRdna; + const uint32_t col_chunk = force_single_column_dispatch ? 1u : kMaxMulMatVecCols; + for (uint32_t base_work_group_y = 0; base_work_group_y < batch_rows; - base_work_group_y += kMaxMulMatVecCols) { + base_work_group_y += col_chunk) { push_constants.base_work_group_y = base_work_group_y; const uint32_t num_cols = - std::min(kMaxMulMatVecCols, batch_rows - base_work_group_y); + std::min(col_chunk, batch_rows - base_work_group_y); const std::array grid = {groups_x, 1u, groups_z}; const std::vector specialization_constants = { 32u, diff --git a/mlx/backend/vulkan/kernels.h b/mlx/backend/vulkan/kernels.h index aeca8ad1d0..4914c7794e 100644 --- a/mlx/backend/vulkan/kernels.h +++ b/mlx/backend/vulkan/kernels.h @@ -189,6 +189,8 @@ class KernelManager { pipelines_; bool static_registry_initialized_{false}; std::mutex static_registry_mutex_; + std::mutex shader_cache_mutex_; + std::mutex pipeline_cache_mutex_; struct DescriptorSetRecord { vk::DescriptorSet set; @@ -214,6 +216,7 @@ class KernelManager { VulkanHandleHash> descriptor_set_layouts_; std::mutex descriptor_sets_mutex_; + std::mutex descriptor_pool_mutex_; void init_descriptor_pool(); }; diff --git a/mlx/backend/vulkan/kernels/gather_axis.comp b/mlx/backend/vulkan/kernels/gather_axis.comp index 8e5f4cfaee..1651d4ef84 100644 --- a/mlx/backend/vulkan/kernels/gather_axis.comp +++ b/mlx/backend/vulkan/kernels/gather_axis.comp @@ -1,6 +1,7 @@ #version 450 #extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_scalar_block_layout : require #if defined(INDEX_IS_I64) #extension GL_EXT_shader_explicit_arithmetic_types_int64 : require @@ -25,9 +26,9 @@ layout(push_constant) uniform parameter { uint idx_axis_size; } p; -layout(binding = 0) readonly buffer A { VALUE_TYPE data_a[]; }; -layout(binding = 1) readonly buffer I { INDEX_TYPE data_i[]; }; -layout(binding = 2) writeonly buffer D { VALUE_TYPE data_d[]; }; +layout(scalar, binding = 0) readonly buffer A { VALUE_TYPE data_a[]; }; +layout(scalar, binding = 1) readonly buffer I { INDEX_TYPE data_i[]; }; +layout(scalar, binding = 2) writeonly buffer D { VALUE_TYPE data_d[]; }; uint get_idx() { return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; diff --git a/mlx/backend/vulkan/kernels/gather_pair.comp b/mlx/backend/vulkan/kernels/gather_pair.comp index 19c9b225ab..4837397f01 100644 --- a/mlx/backend/vulkan/kernels/gather_pair.comp +++ b/mlx/backend/vulkan/kernels/gather_pair.comp @@ -1,6 +1,7 @@ #version 450 #extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_scalar_block_layout : require #if defined(INDEX_IS_I64) #extension GL_EXT_shader_explicit_arithmetic_types_int64 : require @@ -29,10 +30,10 @@ layout(push_constant) uniform parameter { uint index_count; } p; -layout(binding = 0) readonly buffer A { VALUE_TYPE data_a[]; }; -layout(binding = 1) readonly buffer I0 { INDEX_TYPE data_i0[]; }; -layout(binding = 2) readonly buffer I1 { INDEX_TYPE data_i1[]; }; -layout(binding = 3) writeonly buffer D { VALUE_TYPE data_d[]; }; +layout(scalar, binding = 0) readonly buffer A { VALUE_TYPE data_a[]; }; +layout(scalar, binding = 1) readonly buffer I0 { INDEX_TYPE data_i0[]; }; +layout(scalar, binding = 2) readonly buffer I1 { INDEX_TYPE data_i1[]; }; +layout(scalar, binding = 3) writeonly buffer D { VALUE_TYPE data_d[]; }; uint get_idx() { return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; diff --git a/mlx/backend/vulkan/kernels/gather_take.comp b/mlx/backend/vulkan/kernels/gather_take.comp index 4cdb8def4a..38e88de34e 100644 --- a/mlx/backend/vulkan/kernels/gather_take.comp +++ b/mlx/backend/vulkan/kernels/gather_take.comp @@ -1,6 +1,7 @@ #version 450 #extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_scalar_block_layout : require #if defined(INDEX_IS_I64) #extension GL_EXT_shader_explicit_arithmetic_types_int64 : require @@ -25,9 +26,9 @@ layout(push_constant) uniform parameter { uint idx_axis_size; } p; -layout(binding = 0) readonly buffer A { VALUE_TYPE data_a[]; }; -layout(binding = 1) readonly buffer I { INDEX_TYPE data_i[]; }; -layout(binding = 2) writeonly buffer D { VALUE_TYPE data_d[]; }; +layout(scalar, binding = 0) readonly buffer A { VALUE_TYPE data_a[]; }; +layout(scalar, binding = 1) readonly buffer I { INDEX_TYPE data_i[]; }; +layout(scalar, binding = 2) writeonly buffer D { VALUE_TYPE data_d[]; }; uint get_idx() { return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; diff --git a/mlx/backend/vulkan/kernels/scatter_axis.comp b/mlx/backend/vulkan/kernels/scatter_axis.comp index 602dd47df2..74035e2cbe 100644 --- a/mlx/backend/vulkan/kernels/scatter_axis.comp +++ b/mlx/backend/vulkan/kernels/scatter_axis.comp @@ -1,6 +1,7 @@ #version 450 #extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_scalar_block_layout : require #if defined(INDEX_IS_I64) #extension GL_EXT_shader_explicit_arithmetic_types_int64 : require @@ -25,9 +26,9 @@ layout(push_constant) uniform parameter { uint idx_axis_size; } p; -layout(binding = 0) readonly buffer U { VALUE_TYPE data_u[]; }; -layout(binding = 1) readonly buffer I { INDEX_TYPE data_i[]; }; -layout(binding = 2) buffer D { VALUE_TYPE data_d[]; }; +layout(scalar, binding = 0) readonly buffer U { VALUE_TYPE data_u[]; }; +layout(scalar, binding = 1) readonly buffer I { INDEX_TYPE data_i[]; }; +layout(scalar, binding = 2) buffer D { VALUE_TYPE data_d[]; }; uint get_idx() { return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; diff --git a/mlx/backend/vulkan/kernels/scatter_pair.comp b/mlx/backend/vulkan/kernels/scatter_pair.comp index e6335bd1dd..fef6067826 100644 --- a/mlx/backend/vulkan/kernels/scatter_pair.comp +++ b/mlx/backend/vulkan/kernels/scatter_pair.comp @@ -1,6 +1,7 @@ #version 450 #extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_scalar_block_layout : require #if defined(INDEX_IS_I64) #extension GL_EXT_shader_explicit_arithmetic_types_int64 : require @@ -29,10 +30,10 @@ layout(push_constant) uniform parameter { uint index_count; } p; -layout(binding = 0) readonly buffer U { VALUE_TYPE data_u[]; }; -layout(binding = 1) readonly buffer I0 { INDEX_TYPE data_i0[]; }; -layout(binding = 2) readonly buffer I1 { INDEX_TYPE data_i1[]; }; -layout(binding = 3) buffer D { VALUE_TYPE data_d[]; }; +layout(scalar, binding = 0) readonly buffer U { VALUE_TYPE data_u[]; }; +layout(scalar, binding = 1) readonly buffer I0 { INDEX_TYPE data_i0[]; }; +layout(scalar, binding = 2) readonly buffer I1 { INDEX_TYPE data_i1[]; }; +layout(scalar, binding = 3) buffer D { VALUE_TYPE data_d[]; }; uint get_idx() { return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; diff --git a/mlx/backend/vulkan/kernels/scatter_take.comp b/mlx/backend/vulkan/kernels/scatter_take.comp index abb7dbf837..60024d76bf 100644 --- a/mlx/backend/vulkan/kernels/scatter_take.comp +++ b/mlx/backend/vulkan/kernels/scatter_take.comp @@ -1,6 +1,7 @@ #version 450 #extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_scalar_block_layout : require #if defined(INDEX_IS_I64) #extension GL_EXT_shader_explicit_arithmetic_types_int64 : require @@ -25,9 +26,9 @@ layout(push_constant) uniform parameter { uint idx_axis_size; } p; -layout(binding = 0) readonly buffer U { VALUE_TYPE data_u[]; }; -layout(binding = 1) readonly buffer I { INDEX_TYPE data_i[]; }; -layout(binding = 2) buffer D { VALUE_TYPE data_d[]; }; +layout(scalar, binding = 0) readonly buffer U { VALUE_TYPE data_u[]; }; +layout(scalar, binding = 1) readonly buffer I { INDEX_TYPE data_i[]; }; +layout(scalar, binding = 2) buffer D { VALUE_TYPE data_d[]; }; uint get_idx() { return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; diff --git a/mlx/backend/vulkan/matmul.cpp b/mlx/backend/vulkan/matmul.cpp index d4a4da0cb7..090b734495 100644 --- a/mlx/backend/vulkan/matmul.cpp +++ b/mlx/backend/vulkan/matmul.cpp @@ -25,11 +25,12 @@ namespace mlx::core { namespace { constexpr uint32_t kMaxGridZ = 65535; -constexpr uint32_t kMaxMulMatVecCols = 8; +constexpr uint32_t kMaxMulMatVecCols = 16; constexpr char kMatvecMatrixCastScratchLane[] = "matvec.matrix_f16"; constexpr char kMatvecVectorCastScratchLane[] = "matvec.vec_f16"; constexpr char kMatvecOutScratchLane[] = "matvec.out_work"; constexpr char kMatvecScoresVOutScratchLane[] = "matvec.scores_v.out_work"; +constexpr char kMatmulZeroScratchLane[] = "matmul.zero.out"; constexpr char kMulMmACastScratchLane[] = "mul_mm.a_f16"; constexpr char kMulMmBCastScratchLane[] = "mul_mm.b_f16"; constexpr char kMulMmOutScratchLane[] = "mul_mm.out_work"; @@ -80,6 +81,18 @@ bool matvec_enabled() { return enabled; } +bool prefer_subgroup_matvec() { + static const bool prefer = []() { + const char* env = std::getenv("MLX_VULKAN_PREFER_SUBGROUP_MATVEC"); + if (env != nullptr) { + return std::string(env) != "0"; + } + return vulkan::VulkanContext::get().architecture() == + vulkan::GpuArchitecture::Nvidia; + }(); + return prefer; +} + bool mul_mm_enabled() { static auto& runtime_disabled = []() -> std::atomic& { static std::atomic disabled{false}; @@ -164,43 +177,53 @@ std::vector matvec_shader_candidates( if (!base.has_value()) { return {}; } + const auto order = [&](vulkan::StaticShaderId standard, + vulkan::StaticShaderId subgroup, + vulkan::StaticShaderId subgroup_no_shmem) { + if (prefer_subgroup_matvec()) { + return std::vector{ + subgroup, + subgroup_no_shmem, + standard, + }; + } + return std::vector{ + standard, + subgroup, + subgroup_no_shmem, + }; + }; switch (*base) { case vulkan::StaticShaderId::mul_mat_vec_f32_f32_f32: - return { + return order( vulkan::StaticShaderId::mul_mat_vec_f32_f32_f32, vulkan::StaticShaderId::mul_mat_vec_f32_f32_f32_subgroup, - vulkan::StaticShaderId::mul_mat_vec_f32_f32_f32_subgroup_no_shmem, - }; + vulkan::StaticShaderId::mul_mat_vec_f32_f32_f32_subgroup_no_shmem); case vulkan::StaticShaderId::mul_mat_vec_f16_f32_f32: - return { + return order( vulkan::StaticShaderId::mul_mat_vec_f16_f32_f32, vulkan::StaticShaderId::mul_mat_vec_f16_f32_f32_subgroup, - vulkan::StaticShaderId::mul_mat_vec_f16_f32_f32_subgroup_no_shmem, - }; + vulkan::StaticShaderId::mul_mat_vec_f16_f32_f32_subgroup_no_shmem); case vulkan::StaticShaderId::mul_mat_vec_bf16_f32_f32: - return { + return order( vulkan::StaticShaderId::mul_mat_vec_bf16_f32_f32, vulkan::StaticShaderId::mul_mat_vec_bf16_f32_f32_subgroup, - vulkan::StaticShaderId::mul_mat_vec_bf16_f32_f32_subgroup_no_shmem, - }; + vulkan::StaticShaderId::mul_mat_vec_bf16_f32_f32_subgroup_no_shmem); case vulkan::StaticShaderId::mul_mat_vec_f32_f16_f32: - return { + return order( vulkan::StaticShaderId::mul_mat_vec_f32_f16_f32, vulkan::StaticShaderId::mul_mat_vec_f32_f16_f32_subgroup, - vulkan::StaticShaderId::mul_mat_vec_f32_f16_f32_subgroup_no_shmem, - }; + vulkan::StaticShaderId::mul_mat_vec_f32_f16_f32_subgroup_no_shmem); case vulkan::StaticShaderId::mul_mat_vec_f16_f16_f32: - return { + return order( vulkan::StaticShaderId::mul_mat_vec_f16_f16_f32, vulkan::StaticShaderId::mul_mat_vec_f16_f16_f32_subgroup, - vulkan::StaticShaderId::mul_mat_vec_f16_f16_f32_subgroup_no_shmem, - }; + vulkan::StaticShaderId::mul_mat_vec_f16_f16_f32_subgroup_no_shmem); case vulkan::StaticShaderId::mul_mat_vec_bf16_f16_f32: - return { + return order( vulkan::StaticShaderId::mul_mat_vec_bf16_f16_f32, vulkan::StaticShaderId::mul_mat_vec_bf16_f16_f32_subgroup, - vulkan::StaticShaderId::mul_mat_vec_bf16_f16_f32_subgroup_no_shmem, - }; + vulkan::StaticShaderId::mul_mat_vec_bf16_f16_f32_subgroup_no_shmem); case vulkan::StaticShaderId::Count: break; } @@ -456,6 +479,7 @@ choose_split_k(uint32_t k, uint32_t num_batches, uint32_t split_k_threshold) { MatmulDispatchTuning select_matmul_dispatch_tuning(Dtype dtype, uint32_t m, uint32_t n, uint32_t k) { const auto profile = matmul_profile_for_device(); + const auto& ctx = vulkan::VulkanContext::get(); const bool aligned = matmul_inputs_aligned(m, n, k); const MatmulFamily family = classify_matmul_family(m, n, k); const size_t family_index = static_cast(family); @@ -471,14 +495,36 @@ select_matmul_dispatch_tuning(Dtype dtype, uint32_t m, uint32_t n, uint32_t k) { tuning.prefer_fp32_accum = dtype == float32 || (dtype != float32 && k >= family_spec.fp32_accum_k_threshold); - if (const auto& ctx = vulkan::VulkanContext::get(); - !ctx.cooperative_matrix_supported() && + if (!ctx.cooperative_matrix_supported() && ctx.architecture() == vulkan::GpuArchitecture::AmdRdna && tuning.specialization_constants.size() > 0) { tuning.specialization_constants[0] = std::min(tuning.specialization_constants[0], 64u); } + if (tuning.specialization_constants.size() > 10 && + ctx.subgroup_size_control_supported()) { + const uint32_t preferred = std::clamp( + profile.preferred_subgroup_size, + std::max(ctx.subgroup_min_size(), 1u), + std::max(ctx.subgroup_max_size(), 1u)); + tuning.specialization_constants[10] = preferred; + } + + if (ctx.shader_core_count() > 0) { + const uint32_t core_scale = std::clamp(ctx.shader_core_count() / 8u, 1u, 4u); + tuning.split_k_threshold = + std::max(tuning.split_k_threshold, 512u * core_scale); + if (dtype != float32) { + tuning.prefer_fp32_accum = + k >= std::max(tuning.split_k_threshold, 1024u); + } + } + + if (ctx.architecture() == vulkan::GpuArchitecture::AmdRdna && dtype == bfloat16) { + tuning.prefer_fp32_accum = true; + } + return tuning; } @@ -519,7 +565,7 @@ TensorLayout4D make_tensor_layout_4d(const array& arr) { bool has_vulkan_buffer(const array& arr) { auto data = arr.data_shared_ptr(); - return data != nullptr && data->buffer.ptr() != nullptr; + return data != nullptr && vulkan::is_vulkan_buffer(data->buffer); } array cast_to_float16_scratch(const array& arr, Stream s, const char* lane) { @@ -558,6 +604,9 @@ bool try_eval_scores_v_matvec_vulkan( if (scores.shape(0) != values.shape(0) || scores.shape(0) != out.shape(0)) { return false; } + if (scores.shape(-3) != out.shape(-3)) { + return false; + } if (!ensure_vulkan_buffer(scores, s) || !ensure_vulkan_buffer(values, s)) { return false; @@ -700,7 +749,7 @@ bool ensure_vulkan_buffer(array& arr, Stream s) { } auto data = arr.data_shared_ptr(); - if (data == nullptr || data->buffer.ptr() == nullptr) { + if (data == nullptr || !vulkan::is_vulkan_buffer(data->buffer)) { return false; } @@ -709,14 +758,31 @@ bool ensure_vulkan_buffer(array& arr, Stream s) { } void zero_initialize_output(array& out, Stream s) { - out.set_data(allocator::malloc(out.nbytes())); - if (out.nbytes() == 0) { + if (out.size() == 0) { + return; + } + + auto zero_contiguous = [&](array& target) { + target.set_data(allocator::malloc(target.nbytes())); + if (target.nbytes() == 0) { + return; + } + auto* target_buf = static_cast(target.buffer().ptr()); + auto cmd_buffer = vulkan::begin_command_recording(s.index); + cmd_buffer.fillBuffer(target_buf->buffer, 0, target.nbytes(), 0); + vulkan::end_command_recording(s.index); + }; + + if (is_row_contiguous_zero_offset(out)) { + zero_contiguous(out); return; } - auto* out_buf = static_cast(out.buffer().ptr()); - auto cmd_buffer = vulkan::begin_command_recording(s.index); - cmd_buffer.fillBuffer(out_buf->buffer, 0, out.nbytes(), 0); - vulkan::end_command_recording(s.index); + + array scratch = + vulkan::acquire_scratch_array(s, kMatmulZeroScratchLane, out.shape(), out.dtype()); + zero_contiguous(scratch); + vulkan::mark_scratch_array_written(s, kMatmulZeroScratchLane); + copy_gpu(scratch, out, CopyType::General, s); } bool try_eval_matvec_vulkan( @@ -891,7 +957,13 @@ bool try_eval_mul_mm_vulkan( return true; }; - if (!materialize_broadcast_input(a) || !materialize_broadcast_input(b)) { + const bool can_keep_a_broadcast_view = a.ndim() == 4 && + a.shape(-2) == out.shape(-2) && a.shape(-1) == b.shape(-2) && + (a.shape(-3) == 1 || a.shape(-3) == out.shape(-3)) && + (a.shape(-4) == 1 || a.shape(-4) == out.shape(-4)); + + if ((!can_keep_a_broadcast_view && !materialize_broadcast_input(a)) || + !materialize_broadcast_input(b)) { return false; } @@ -1016,6 +1088,20 @@ bool try_eval_mul_mm_vulkan( } } + const uint32_t a_heads = static_cast( + a.ndim() >= 3 ? a.shape(-3) : 1); + const uint32_t out_heads = static_cast( + out_work.ndim() >= 3 ? out_work.shape(-3) : 1); + const uint32_t a_batches_outer = static_cast( + a.ndim() >= 4 ? a.shape(-4) : 1); + const uint32_t out_batches_outer = static_cast( + out_work.ndim() >= 4 ? out_work.shape(-4) : 1); + if (a_heads == 0 || out_heads == 0 || a_batches_outer == 0 || + out_batches_outer == 0 || (out_heads % a_heads) != 0 || + (out_batches_outer % a_batches_outer) != 0) { + return false; + } + vulkan::MatmulPushConstants push_constants{}; push_constants.M = m; push_constants.N = n; @@ -1028,10 +1114,10 @@ bool try_eval_mul_mm_vulkan( push_constants.batch_stride_d = batch_stride_d; push_constants.num_batches = num_batches; push_constants.k_split = round_up_div(k, split_k); - push_constants.ne02 = num_batches; - push_constants.ne12 = num_batches; - push_constants.broadcast2 = 1; - push_constants.broadcast3 = 1; + push_constants.ne02 = a_heads; + push_constants.ne12 = out_heads; + push_constants.broadcast2 = out_heads / a_heads; + push_constants.broadcast3 = out_batches_outer / a_batches_outer; push_constants.padded_N = n; if (matmul_debug_enabled()) { @@ -1182,25 +1268,41 @@ bool try_eval_mul_mm_vulkan( return try_dispatch(shader_candidates, out_work, needs_out_copy); } -} // namespace - -bool try_eval_matmul_vulkan( +bool try_eval_matmul_vulkan_impl( const std::vector& inputs, array& out, - Stream s) { + Stream s, + bool* used_matvec_fast_path) { + if (used_matvec_fast_path) { + *used_matvec_fast_path = false; + } if (inputs.size() == 2 && (inputs[0].size() == 0 || inputs[1].size() == 0)) { zero_initialize_output(out, s); return true; } if (matvec_enabled() && try_eval_matvec_vulkan(inputs, out, s)) { + if (used_matvec_fast_path) { + *used_matvec_fast_path = true; + } return true; } return try_eval_mul_mm_vulkan(inputs, out, s); } +} // namespace + +bool try_eval_matmul_vulkan( + const std::vector& inputs, + array& out, + Stream s) { + return try_eval_matmul_vulkan_impl(inputs, out, s, nullptr); +} + void Matmul::eval_gpu(const std::vector& inputs, array& out) { - if (try_eval_matmul_vulkan(inputs, out, stream())) { - log_matmul_path(inputs, "mul_mm"); + bool used_matvec_fast_path = false; + if (try_eval_matmul_vulkan_impl( + inputs, out, stream(), &used_matvec_fast_path)) { + log_matmul_path(inputs, used_matvec_fast_path ? "matvec" : "mul_mm"); return; } throw std::runtime_error( diff --git a/mlx/backend/vulkan/primitives.cpp b/mlx/backend/vulkan/primitives.cpp index 82f30ce891..75131b02da 100644 --- a/mlx/backend/vulkan/primitives.cpp +++ b/mlx/backend/vulkan/primitives.cpp @@ -1,10 +1,13 @@ // Copyright © 2024 Apple Inc. #include "mlx/distributed/primitives.h" +#include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/vulkan/allocator.h" +#include "mlx/backend/vulkan/kernels.h" #include "mlx/backend/vulkan/primitives_utils.h" +#include "mlx/backend/vulkan/shader_compiler.h" namespace mlx::core { @@ -63,6 +66,214 @@ namespace mlx::core { namespace { +array collapse_power_leading_dims(const array& arr, Stream s) { + if (arr.ndim() <= 4) { + return arr; + } + return flatten_in_eval(arr, 0, arr.ndim() - 4, s); +} + +bool ensure_vulkan_buffer_power(array& arr, Stream s) { + auto data = arr.data_shared_ptr(); + if (data != nullptr && vulkan::is_vulkan_buffer(data->buffer)) { + return true; + } + if (arr.has_primitive()) { + arr = contiguous_copy_gpu(arr, s); + data = arr.data_shared_ptr(); + return data != nullptr && vulkan::is_vulkan_buffer(data->buffer); + } + arr.wait(); + data = arr.data_shared_ptr(); + if (data != nullptr && vulkan::is_vulkan_buffer(data->buffer)) { + return true; + } + arr = contiguous_copy_gpu(arr, s); + data = arr.data_shared_ptr(); + return data != nullptr && vulkan::is_vulkan_buffer(data->buffer); +} + +std::string emit_bf16_power_helpers() { + return R"( +uint fp32_to_bf16(float f) { + uint u = floatBitsToUint(f); + u = (u + (0x7fffu + ((u >> 16) & 1u))) >> 16; + return u; +} + +float bf16_to_fp32(uint u) { + return uintBitsToFloat(u << 16); +} + +)"; +} + +std::string power_input_expr(Dtype dtype, const char* buffer_name) { + if (dtype == bfloat16) { + return std::string("bf16_to_fp32(uint(") + buffer_name + ".data[idx]))"; + } + if (dtype == float16) { + return std::string("float(") + buffer_name + ".data[idx])"; + } + return std::string(buffer_name) + ".data[idx]"; +} + +std::string power_output_expr(Dtype out_dtype, const std::string& expr) { + if (out_dtype == bfloat16) { + return "uint16_t(fp32_to_bf16(" + expr + "))"; + } + if (out_dtype == float16) { + return "float16_t(" + expr + ")"; + } + return expr; +} + +std::string build_power_shader(Dtype a_dtype, Dtype b_dtype, Dtype out_dtype) { + std::ostringstream os; + os << vulkan::emit_dynamic_shader_preamble(a_dtype, out_dtype, false); + if (a_dtype == bfloat16 || b_dtype == bfloat16 || out_dtype == bfloat16) { + os << emit_bf16_power_helpers(); + } + os << "layout(push_constant) uniform PushConstants { uint a_offset; uint b_offset; uint out_offset; uint total_elements; } pc;\n"; + os << "layout(set = 0, binding = 0) readonly buffer InputA {" + << vulkan::dtype_to_glsl_storage_type(a_dtype) << " data[];} a_buf;\n"; + os << "layout(set = 0, binding = 1) readonly buffer InputB {" + << vulkan::dtype_to_glsl_storage_type(b_dtype) << " data[];} b_buf;\n"; + os << "layout(set = 0, binding = 2) buffer Output {" + << vulkan::dtype_to_glsl_storage_type(out_dtype) << " data[];} out_buf;\n\n"; + os << "void main() {\n"; + os << " uint linear_idx = gl_GlobalInvocationID.x;\n"; + os << " if (linear_idx >= pc.total_elements) return;\n"; + os << " uint idx = linear_idx;\n"; + os << " uint a_idx = idx + pc.a_offset;\n"; + os << " uint b_idx = idx + pc.b_offset;\n"; + os << " float lhs = " << power_input_expr(a_dtype, "a_buf") << ";\n"; + os << " float rhs = " << power_input_expr(b_dtype, "b_buf") << ";\n"; + os << " out_buf.data[idx + pc.out_offset] = " + << power_output_expr(out_dtype, "pow(lhs, rhs)") << ";\n"; + os << "}\n"; + return os.str(); +} + +bool try_eval_power_vulkan( + const std::vector& inputs, + array& out, + Stream s) { + if (inputs.size() != 2) { + return false; + } + array a = inputs[0]; + array b = inputs[1]; + + auto is_supported_dtype = [](Dtype dtype) { + return dtype == float16 || dtype == float32 || dtype == bfloat16; + }; + if (!is_supported_dtype(a.dtype()) || !is_supported_dtype(b.dtype()) || + !is_supported_dtype(out.dtype())) { + return false; + } + + auto materialize_broadcast_input = [&](array& in) { + if (in.shape() == out.shape()) { + return true; + } + if (broadcast_shapes(in.shape(), out.shape()) != out.shape()) { + return false; + } + if (!ensure_vulkan_buffer_power(in, s)) { + return false; + } + array view(out.shape(), in.dtype(), nullptr, {}); + broadcast(in, view); + in = view; + return true; + }; + + if (!materialize_broadcast_input(a) || !materialize_broadcast_input(b)) { + return false; + } + + if (!is_supported_elementwise_layout(a)) { + a = contiguous_copy_gpu(a, s); + } + if (!is_supported_elementwise_layout(b)) { + b = contiguous_copy_gpu(b, s); + } + + a = collapse_power_leading_dims(a, s); + b = collapse_power_leading_dims(b, s); + + const bool staged_output = !is_supported_elementwise_layout(out); + array out_work = staged_output ? array(out.shape(), out.dtype(), nullptr, {}) : out; + out_work.set_data(allocator::malloc(out_work.nbytes())); + out_work = collapse_power_leading_dims(out_work, s); + + if (!is_supported_elementwise_layout(a) || !is_supported_elementwise_layout(b) || + !is_supported_elementwise_layout(out_work)) { + return false; + } + if (!ensure_vulkan_buffer_power(a, s) || !ensure_vulkan_buffer_power(b, s) || + !ensure_vulkan_buffer_power(out_work, s)) { + return false; + } + if (out_work.size() == 0) { + if (staged_output) { + copy_gpu(out_work, out, CopyType::General, s); + } + return true; + } + + const auto a_offset = static_cast(a.offset() / size_of(a.dtype())); + const auto b_offset = static_cast(b.offset() / size_of(b.dtype())); + const auto out_offset = static_cast(out_work.offset() / size_of(out_work.dtype())); + const auto total = static_cast(out_work.data_size()); + if (a_offset > std::numeric_limits::max() || + b_offset > std::numeric_limits::max() || + out_offset > std::numeric_limits::max() || + total > std::numeric_limits::max()) { + return false; + } + + const std::string shader_name = "dynamic_power_" + + std::to_string(static_cast(a.dtype().val())) + "_" + + std::to_string(static_cast(b.dtype().val())) + "_" + + std::to_string(static_cast(out_work.dtype().val())); + const std::string glsl_source = + build_power_shader(a.dtype(), b.dtype(), out_work.dtype()); + vulkan::DynamicArrayRef arrays[] = {{&a, 0}, {&b, 1}, {&out_work, 2}}; + constexpr uint32_t kPushConstantSize = sizeof(uint32_t) * 4; + auto dispatch = vulkan::dispatch_dynamic_compute_begin( + shader_name, glsl_source, 3, arrays, kPushConstantSize, s); + + struct PushConstants { + uint32_t a_offset; + uint32_t b_offset; + uint32_t out_offset; + uint32_t total_elements; + } pc{ + static_cast(a_offset), + static_cast(b_offset), + static_cast(out_offset), + static_cast(total), + }; + vkCmdPushConstants( + dispatch.command_buffer, + dispatch.pipeline->layout, + VK_SHADER_STAGE_COMPUTE_BIT, + 0, + kPushConstantSize, + &pc); + const uint32_t workgroups = + std::max((static_cast(total) + 255u) / 256u, 1u); + vkCmdDispatch(dispatch.command_buffer, workgroups, 1, 1); + vulkan::end_command_recording(s.index); + + if (staged_output) { + copy_gpu(out_work, out, CopyType::General, s); + } + return true; +} + bool is_supported_select_layout(const array& arr) { return arr.flags().contiguous && arr.offset() == 0 && arr.ndim() > 0 && arr.strides().back() == 1; @@ -235,7 +446,11 @@ NO_GPU_MULTI_STATE(SVD) CPU_FALLBACK(NotEqual) CPU_FALLBACK_STATE(Partition) -CPU_FALLBACK(Power) +void Power::eval_gpu(const std::vector& inputs, array& out) { + if (!try_eval_power_vulkan(inputs, out, stream())) { + eval_cpu_fallback_on_stream(inputs, out, stream()); + } +} // QuantizedMatmul and QQMatmul are implemented in quantized.cpp. CPU_FALLBACK(Real) diff --git a/mlx/backend/vulkan/reduce.cpp b/mlx/backend/vulkan/reduce.cpp index 59ea7bed2a..f80bffd329 100644 --- a/mlx/backend/vulkan/reduce.cpp +++ b/mlx/backend/vulkan/reduce.cpp @@ -104,7 +104,12 @@ bool try_eval_reduce_sum_rows_vulkan( } std::vector host_values(out.size(), fill_value); vulkan::enqueue_owned_staging_upload( - s, host_values.data(), host_values.size(), out_buf->buffer, 0); + s, + host_values.data(), + host_values.size(), + out_buf->buffer, + 0, + out.data_shared_ptr()); vulkan::retain_array_for_stream(s, out); return true; } diff --git a/mlx/backend/vulkan/vulkan.cpp b/mlx/backend/vulkan/vulkan.cpp index 5a9537c951..62e0419de2 100644 --- a/mlx/backend/vulkan/vulkan.cpp +++ b/mlx/backend/vulkan/vulkan.cpp @@ -9,8 +9,11 @@ #include #include +#include #include +#include #include +#include #include #include #include @@ -29,6 +32,71 @@ struct QueueFamilyIndices { bool has_separate_transfer{false}; }; +std::optional parse_env_uint32(const char* env_name) { + const char* env = std::getenv(env_name); + if (env == nullptr || *env == '\0') { + return std::nullopt; + } + + errno = 0; + char* end = nullptr; + unsigned long parsed = std::strtoul(env, &end, 0); + if (errno != 0 || end == env || (end != nullptr && *end != '\0') || + parsed > std::numeric_limits::max()) { + throw std::runtime_error( + std::string("[vulkan::init] Invalid ") + env_name + "='" + env + + "'. Expected unsigned integer."); + } + return static_cast(parsed); +} + +uint32_t device_type_rank(vk::PhysicalDeviceType type) { + switch (type) { + case vk::PhysicalDeviceType::eDiscreteGpu: + return 5; + case vk::PhysicalDeviceType::eIntegratedGpu: + return 4; + case vk::PhysicalDeviceType::eVirtualGpu: + return 3; + case vk::PhysicalDeviceType::eCpu: + return 2; + default: + return 1; + } +} + +uint64_t total_device_local_memory(vk::PhysicalDevice physical_device) { + const auto mem = physical_device.getMemoryProperties(); + uint64_t total = 0; + for (uint32_t i = 0; i < mem.memoryHeapCount; ++i) { + if ((mem.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) != + vk::MemoryHeapFlagBits{}) { + total += mem.memoryHeaps[i].size; + } + } + return total; +} + +uint64_t score_physical_device( + vk::PhysicalDevice physical_device, + const QueueFamilyIndices& indices, + std::optional preferred_vendor_id) { + const auto properties = physical_device.getProperties(); + const uint64_t type_score = + static_cast(device_type_rank(properties.deviceType)) << 60; + const uint64_t local_mem_score = + std::min(total_device_local_memory(physical_device), (1ull << 56) - 1) + << 4; + const uint64_t queue_topology_score = + indices.has_separate_transfer ? (1ull << 3) : 0ull; + const uint64_t vendor_score = + (preferred_vendor_id.has_value() && + properties.vendorID == preferred_vendor_id.value()) + ? (1ull << 2) + : 0ull; + return type_score + local_mem_score + queue_topology_score + vendor_score; +} + uint32_t find_queue_family( const std::vector& queue_families, const vk::QueueFlags& required, @@ -417,6 +485,9 @@ void VulkanContext::init() { bool has_separate_transfer_queue = false; bool is_unified_memory = false; bool shader_float16_supported = false; + bool shader_int8_supported = false; + bool storage_buffer_8bit_supported = false; + bool scalar_block_layout_supported = false; bool shader_bfloat16_supported = false; bool subgroup_size_control_supported = false; bool subgroup_require_full_support = false; @@ -456,8 +527,16 @@ void VulkanContext::init() { "[vulkan::init] Failed to find GPUs with Vulkan support."); } + const auto forced_device_index = parse_env_uint32("MLX_VULKAN_DEVICE_INDEX"); + const auto preferred_vendor_id = + parse_env_uint32("MLX_VULKAN_PREFERRED_VENDOR_ID"); + + std::optional> best_candidate; bool found_compute_device = false; - for (auto candidate : available_devices) { + for (uint32_t candidate_index = 0; + candidate_index < available_devices.size(); + ++candidate_index) { + auto candidate = available_devices[candidate_index]; auto queue_families = candidate.getQueueFamilyProperties(); bool has_compute = false; for (const auto& qf : queue_families) { @@ -468,19 +547,36 @@ void VulkanContext::init() { } } if (has_compute) { + if (forced_device_index.has_value() && + forced_device_index.value() != candidate_index) { + continue; + } auto indices = find_queue_families(candidate); - physical_device = candidate; - compute_queue_family_index = indices.compute_family; - compute_queue_index = indices.compute_queue_index; - transfer_queue_family_index = indices.transfer_family; - transfer_queue_index = indices.transfer_queue_index; - has_separate_transfer_queue = indices.has_separate_transfer; - found_compute_device = true; - break; + const uint64_t score = + score_physical_device(candidate, indices, preferred_vendor_id); + if (!best_candidate.has_value() || score > best_candidate->first) { + best_candidate = std::make_pair(score, candidate_index); + } } } + if (best_candidate.has_value()) { + const uint32_t selected_index = best_candidate->second; + physical_device = available_devices[selected_index]; + auto indices = find_queue_families(physical_device); + compute_queue_family_index = indices.compute_family; + compute_queue_index = indices.compute_queue_index; + transfer_queue_family_index = indices.transfer_family; + transfer_queue_index = indices.transfer_queue_index; + has_separate_transfer_queue = indices.has_separate_transfer; + found_compute_device = true; + } + if (!found_compute_device) { + if (forced_device_index.has_value()) { + throw std::runtime_error( + "[vulkan::init] Forced MLX_VULKAN_DEVICE_INDEX does not refer to a compute-capable physical device."); + } throw std::runtime_error( "[vulkan::init] Failed to find a compute-capable physical device."); } @@ -567,11 +663,15 @@ void VulkanContext::init() { vk::PhysicalDeviceFeatures2 supported_features; vk::PhysicalDeviceVulkan11Features supported_vulkan11_features; vk::PhysicalDeviceShaderFloat16Int8Features supported_shader_float16_int8; + vk::PhysicalDevice8BitStorageFeatures supported_storage_8bit; + vk::PhysicalDeviceScalarBlockLayoutFeatures supported_scalar_block_layout; supported_features.pNext = &supported_vulkan11_features; supported_vulkan11_features.pNext = &supported_shader_float16_int8; + supported_shader_float16_int8.pNext = &supported_storage_8bit; + supported_storage_8bit.pNext = &supported_scalar_block_layout; vk::PhysicalDeviceShaderIntegerDotProductFeatures supported_shader_integer_dot_product{}; - supported_shader_float16_int8.pNext = &supported_shader_integer_dot_product; + supported_scalar_block_layout.pNext = &supported_shader_integer_dot_product; vk::PhysicalDeviceSubgroupSizeControlFeatures supported_subgroup_size_control{}; @@ -641,11 +741,15 @@ void VulkanContext::init() { vk::PhysicalDeviceFeatures2 enabled_features; vk::PhysicalDeviceVulkan11Features enabled_vulkan11_features; vk::PhysicalDeviceShaderFloat16Int8Features enabled_shader_float16_int8; + vk::PhysicalDevice8BitStorageFeatures enabled_storage_8bit; + vk::PhysicalDeviceScalarBlockLayoutFeatures enabled_scalar_block_layout; enabled_features.pNext = &enabled_vulkan11_features; enabled_vulkan11_features.pNext = &enabled_shader_float16_int8; + enabled_shader_float16_int8.pNext = &enabled_storage_8bit; + enabled_storage_8bit.pNext = &enabled_scalar_block_layout; vk::PhysicalDeviceShaderIntegerDotProductFeatures enabled_shader_integer_dot_product{}; - enabled_shader_float16_int8.pNext = &enabled_shader_integer_dot_product; + enabled_scalar_block_layout.pNext = &enabled_shader_integer_dot_product; vk::PhysicalDeviceSubgroupSizeControlFeatures enabled_subgroup_size_control{}; @@ -698,9 +802,21 @@ void VulkanContext::init() { if (supported_vulkan11_features.storageBuffer16BitAccess) { enabled_vulkan11_features.storageBuffer16BitAccess = VK_TRUE; } + if (supported_storage_8bit.storageBuffer8BitAccess) { + enabled_storage_8bit.storageBuffer8BitAccess = VK_TRUE; + storage_buffer_8bit_supported = true; + } + if (supported_scalar_block_layout.scalarBlockLayout) { + enabled_scalar_block_layout.scalarBlockLayout = VK_TRUE; + scalar_block_layout_supported = true; + } if (supported_features.features.shaderInt16) { enabled_features.features.shaderInt16 = VK_TRUE; } + if (supported_shader_float16_int8.shaderInt8) { + enabled_shader_float16_int8.shaderInt8 = VK_TRUE; + shader_int8_supported = true; + } if (supported_shader_float16_int8.shaderFloat16) { enabled_shader_float16_int8.shaderFloat16 = VK_TRUE; shader_float16_supported = true; @@ -842,6 +958,9 @@ void VulkanContext::init() { if (shader_bfloat16_supported) { device_extensions.push_back(VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME); } + if (has_device_extension(extensions, VK_EXT_SCALAR_BLOCK_LAYOUT_EXTENSION_NAME)) { + device_extensions.push_back(VK_EXT_SCALAR_BLOCK_LAYOUT_EXTENSION_NAME); + } if (coopmat2_conv2d_supported) { device_extensions.push_back(VK_NV_COOPERATIVE_MATRIX_2_EXTENSION_NAME); } @@ -903,6 +1022,9 @@ void VulkanContext::init() { mem_properties_ = mem_properties; is_unified_memory_ = is_unified_memory; this->shader_float16_supported_ = shader_float16_supported; + this->shader_int8_supported_ = shader_int8_supported; + this->storage_buffer_8bit_supported_ = storage_buffer_8bit_supported; + this->scalar_block_layout_supported_ = scalar_block_layout_supported; this->shader_bfloat16_extension_present_ = has_shader_bfloat16_ext; this->shader_bfloat16_reported_supported_ = shader_bfloat16_supported; this->shader_bfloat16_supported_ = false; @@ -973,6 +1095,9 @@ void VulkanContext::cleanup() { timeline_value_ = 0; is_unified_memory_ = false; shader_float16_supported_ = false; + shader_int8_supported_ = false; + storage_buffer_8bit_supported_ = false; + scalar_block_layout_supported_ = false; shader_bfloat16_extension_present_ = false; shader_bfloat16_reported_supported_ = false; shader_bfloat16_supported_ = false; diff --git a/mlx/backend/vulkan/vulkan.h b/mlx/backend/vulkan/vulkan.h index c518d55cda..ede0fac40b 100644 --- a/mlx/backend/vulkan/vulkan.h +++ b/mlx/backend/vulkan/vulkan.h @@ -77,6 +77,15 @@ class VulkanContext { bool shader_float16_supported() const { return shader_float16_supported_; } + bool shader_int8_supported() const { + return shader_int8_supported_; + } + bool storage_buffer_8bit_supported() const { + return storage_buffer_8bit_supported_; + } + bool scalar_block_layout_supported() const { + return scalar_block_layout_supported_; + } bool shader_bfloat16_supported() const; bool subgroup_size_control_supported() const { return subgroup_size_control_supported_; @@ -161,6 +170,9 @@ class VulkanContext { bool initialized_{false}; bool is_unified_memory_{false}; bool shader_float16_supported_{false}; + bool shader_int8_supported_{false}; + bool storage_buffer_8bit_supported_{false}; + bool scalar_block_layout_supported_{false}; bool shader_bfloat16_extension_present_{false}; bool shader_bfloat16_reported_supported_{false}; mutable std::once_flag shader_bfloat16_probe_once_; diff --git a/python/tests/test_vulkan_ops.py b/python/tests/test_vulkan_ops.py index 1610a421d1..166b54d1d5 100644 --- a/python/tests/test_vulkan_ops.py +++ b/python/tests/test_vulkan_ops.py @@ -727,6 +727,37 @@ def submit(host_src, heavy=False): for gpu_out, host_src in zip(outs, host_srcs): self._assert_outputs_close(gpu_out, host_src, atol=0.0, rtol=0.0) + def test_large_low_precision_gather_regression(self): + for dtype, atol in ((mx.float16, 5e-3), (mx.bfloat16, 5e-2)): + with self.subTest(dtype=str(dtype)): + self._assert_cpu_gpu_same( + lambda dtype=dtype: mx.arange( + 100 * 8960, dtype=mx.float32 + ).reshape(100, 8960).astype(dtype)[mx.array([[23]], dtype=mx.int32)].astype( + mx.float32 + ), + atol=atol, + rtol=atol, + ) + + def test_compiled_gelu_approx_negative_power_regression(self): + x = mx.linspace(-6.0, 6.0, 6144, dtype=mx.float32).reshape(1, 1, 6144) + + def gelu_approx(x): + return 0.5 * x * (1.0 + mx.tanh(0.7978845608 * (x + 0.044715 * mx.power(x, 3.0)))) + + @mx.compile + def compiled(gate, value): + return gelu_approx(gate) * value + + expected = self._run_on_device(mx.cpu, lambda: gelu_approx(x) * x) + actual = self._run_on_device(mx.gpu, lambda: compiled(x, x)) + self._assert_outputs_close( + actual.astype(mx.float32), + expected.astype(mx.float32), + atol=1e-4, + rtol=1e-4, + ) def _cases(): return [