-
Notifications
You must be signed in to change notification settings - Fork 1k
Add rms_norm op (#19893) #19893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rms_norm op (#19893) #19893
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| #include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h> | ||
|
|
||
| #include <executorch/backends/vulkan/serialization/schema_generated.h> | ||
| #include <executorch/runtime/core/named_data_map.h> | ||
|
|
||
| #include <executorch/backends/webgpu/runtime/WebGPUDevice.h> | ||
| #include <webgpu/wgpu.h> | ||
|
|
@@ -93,7 +94,8 @@ WebGPUGraph::~WebGPUGraph() { | |
|
|
||
| void WebGPUGraph::build( | ||
| const void* flatbuffer_data, | ||
| const uint8_t* constant_data) { | ||
| const uint8_t* constant_data, | ||
| const executorch::runtime::NamedDataMap* named_data_map) { | ||
| if (!device_) { | ||
| auto* ctx = get_default_webgpu_context(); | ||
| if (ctx) { | ||
|
|
@@ -165,6 +167,25 @@ void WebGPUGraph::build( | |
| const uint8_t* src = constant_data + vk_bytes->offset(); | ||
| wgpuQueueWriteBuffer( | ||
| queue_, tensor.buffer, 0, src, tensor.nbytes); | ||
| } else if ( | ||
| vk_bytes->named_key() != nullptr && | ||
| named_data_map != nullptr) { | ||
| // Constant stored in the PTE named-data map. | ||
| auto buf = | ||
| named_data_map->get_data(vk_bytes->named_key()->c_str()); | ||
| if (buf.ok() && buf->size() >= tensor.nbytes) { | ||
| wgpuQueueWriteBuffer( | ||
| queue_, tensor.buffer, 0, buf->data(), tensor.nbytes); | ||
| buf->Free(); | ||
| } else { | ||
| throw std::runtime_error( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nits: The error message conflates "key not found" and "buffer undersized" |
||
| std::string("WebGPU: named constant '") + | ||
| vk_bytes->named_key()->c_str() + | ||
| "' missing or undersized in NamedDataMap"); | ||
| } | ||
| } else { | ||
| throw std::runtime_error( | ||
| "WebGPU: constant has no inline offset and no named-data key"); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,192 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| #include <executorch/backends/webgpu/runtime/WebGPUGraph.h> | ||
| #include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h> | ||
| #include <executorch/backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h> | ||
|
|
||
| #include <webgpu/webgpu.h> | ||
|
|
||
| #include <cstdint> | ||
| #include <cstring> | ||
| #include <stdexcept> | ||
|
|
||
| namespace executorch::backends::webgpu { | ||
|
|
||
| namespace { | ||
|
|
||
| // Uniform layout matching the WGSL Params struct (16-byte aligned). | ||
| struct RmsNormParams { | ||
| uint32_t num_rows; | ||
| uint32_t row_width; | ||
| float epsilon; | ||
| uint32_t _pad; | ||
| }; | ||
| static_assert(sizeof(RmsNormParams) == 16, "RmsNormParams must be 16 bytes"); | ||
|
|
||
| void rms_norm_impl(WebGPUGraph& graph, const std::vector<int>& args) { | ||
| // et_vk.rms_norm.default args: [in, weight, eps, out] | ||
| const int in_id = args.at(0); | ||
| const int weight_id = args.at(1); | ||
| const int eps_id = args.at(2); | ||
| const int out_id = args.at(3); | ||
|
|
||
| WGPUDevice device = graph.device(); | ||
|
|
||
| // Get epsilon (Double from a Python float; defaults to float32 eps) | ||
| float epsilon = 1.1920928955078125e-07f; | ||
| if (graph.get_value_type(eps_id) == WebGPUGraph::ValueType::Double) { | ||
| epsilon = static_cast<float>(graph.get_double(eps_id)); | ||
| } else if (graph.get_value_type(eps_id) == WebGPUGraph::ValueType::Int) { | ||
| epsilon = static_cast<float>(graph.get_int(eps_id)); | ||
| } | ||
|
|
||
| // row_width = last dim; num_rows = product of the rest (PyTorch NCHW order) | ||
| const auto& in_tensor = graph.get_tensor(in_id); | ||
| if (in_tensor.dims.empty() || in_tensor.nbytes == 0) { | ||
| throw std::runtime_error("WebGPU rms_norm: empty input"); | ||
| } | ||
| const uint32_t row_width = static_cast<uint32_t>(in_tensor.dims.back()); | ||
| if (row_width == 0) { | ||
| throw std::runtime_error("WebGPU rms_norm: zero row width"); | ||
| } | ||
| uint64_t in_numel = 1; | ||
| for (int64_t d : in_tensor.dims) { | ||
| in_numel *= static_cast<uint64_t>(d); | ||
| } | ||
| // fp32-only shader: bail if the bytes don't match an fp32 element count. | ||
| if (in_tensor.nbytes != in_numel * sizeof(float)) { | ||
| throw std::runtime_error("WebGPU rms_norm: fp32-only (byte-size mismatch)"); | ||
| } | ||
| const uint32_t num_rows = static_cast<uint32_t>(in_numel / row_width); | ||
| if (num_rows == 0) { | ||
| throw std::runtime_error("WebGPU rms_norm: zero rows"); | ||
| } | ||
|
|
||
| // Create uniform buffer for params | ||
| RmsNormParams params = {}; | ||
| params.num_rows = num_rows; | ||
| params.row_width = row_width; | ||
| params.epsilon = epsilon; | ||
|
|
||
| WGPUBufferDescriptor uniform_desc = {}; | ||
| uniform_desc.size = sizeof(RmsNormParams); | ||
| uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst; | ||
| uniform_desc.mappedAtCreation = true; | ||
| WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would this buffer be freed eventually ? When should this be released ? |
||
| void* mapped = | ||
| wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(RmsNormParams)); | ||
| std::memcpy(mapped, ¶ms, sizeof(RmsNormParams)); | ||
| wgpuBufferUnmap(uniform_buffer); | ||
|
|
||
| graph.add_uniform_buffer_bytes(sizeof(RmsNormParams)); | ||
|
|
||
| // Create shader module from built-in WGSL source | ||
| WGPUShaderSourceWGSL wgsl_desc = {}; | ||
| wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; | ||
| wgsl_desc.code = {kRmsNormWGSL, WGPU_STRLEN}; | ||
|
|
||
| WGPUShaderModuleDescriptor shader_desc = {}; | ||
| shader_desc.nextInChain = &wgsl_desc.chain; | ||
| WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); | ||
|
|
||
| // Create bind group layout: out (rw) + in/weight (ro storage) + params | ||
| WGPUBindGroupLayoutEntry entries[4] = {}; | ||
|
|
||
| // t_out - storage buffer, read-write | ||
| entries[0].binding = 0; | ||
| entries[0].visibility = WGPUShaderStage_Compute; | ||
| entries[0].buffer.type = WGPUBufferBindingType_Storage; | ||
|
|
||
| // t_in - storage buffer, read-only | ||
| entries[1].binding = 1; | ||
| entries[1].visibility = WGPUShaderStage_Compute; | ||
| entries[1].buffer.type = WGPUBufferBindingType_ReadOnlyStorage; | ||
|
|
||
| // t_weight - storage buffer, read-only | ||
| entries[2].binding = 2; | ||
| entries[2].visibility = WGPUShaderStage_Compute; | ||
| entries[2].buffer.type = WGPUBufferBindingType_ReadOnlyStorage; | ||
|
|
||
| // params - uniform buffer | ||
| entries[3].binding = 3; | ||
| entries[3].visibility = WGPUShaderStage_Compute; | ||
| entries[3].buffer.type = WGPUBufferBindingType_Uniform; | ||
|
|
||
| WGPUBindGroupLayoutDescriptor bgl_desc = {}; | ||
| bgl_desc.entryCount = 4; | ||
| bgl_desc.entries = entries; | ||
| WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc); | ||
|
|
||
| // Create pipeline layout | ||
| WGPUPipelineLayoutDescriptor pl_desc = {}; | ||
| pl_desc.bindGroupLayoutCount = 1; | ||
| pl_desc.bindGroupLayouts = &bgl; | ||
| WGPUPipelineLayout pipeline_layout = | ||
| wgpuDeviceCreatePipelineLayout(device, &pl_desc); | ||
|
|
||
| // Create compute pipeline | ||
| WGPUComputePipelineDescriptor pipeline_desc = {}; | ||
| pipeline_desc.layout = pipeline_layout; | ||
| pipeline_desc.compute.module = shader; | ||
| pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; | ||
| WGPUComputePipeline pipeline = | ||
| wgpuDeviceCreateComputePipeline(device, &pipeline_desc); | ||
|
|
||
| // Create bind group with actual buffers | ||
| const auto& out_tensor = graph.get_tensor(out_id); | ||
| const auto& weight_tensor = graph.get_tensor(weight_id); | ||
|
|
||
| WGPUBindGroupEntry bg_entries[4] = {}; | ||
|
|
||
| bg_entries[0].binding = 0; | ||
| bg_entries[0].buffer = out_tensor.buffer; | ||
| bg_entries[0].size = out_tensor.nbytes; | ||
|
|
||
| bg_entries[1].binding = 1; | ||
| bg_entries[1].buffer = in_tensor.buffer; | ||
| bg_entries[1].size = in_tensor.nbytes; | ||
|
|
||
| bg_entries[2].binding = 2; | ||
| bg_entries[2].buffer = weight_tensor.buffer; | ||
| bg_entries[2].size = weight_tensor.nbytes; | ||
|
|
||
| bg_entries[3].binding = 3; | ||
| bg_entries[3].buffer = uniform_buffer; | ||
| bg_entries[3].size = sizeof(RmsNormParams); | ||
|
|
||
| WGPUBindGroupDescriptor bg_desc = {}; | ||
| bg_desc.layout = bgl; | ||
| bg_desc.entryCount = 4; | ||
| bg_desc.entries = bg_entries; | ||
| WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); | ||
|
|
||
| // One workgroup per row (kRmsNormWorkgroupSize threads cooperate per row) | ||
| static_assert( | ||
| kRmsNormWorkgroupSize == 64, | ||
| "must match @workgroup_size and WG_SIZE in rms_norm.wgsl"); | ||
| if (num_rows > 65535u) { | ||
| throw std::runtime_error( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we are throwing exception here, should we release / free all objects that have been allocated thus far such as (uniform buffer, shader module, etc...) |
||
| "WebGPU rms_norm: num_rows exceeds the 1D dispatch limit (65535)"); | ||
| } | ||
| graph.add_dispatch({pipeline, bind_group, num_rows}); | ||
|
|
||
| // Release intermediate objects (pipeline and bind_group are kept by dispatch) | ||
| wgpuShaderModuleRelease(shader); | ||
| wgpuBindGroupLayoutRelease(bgl); | ||
| wgpuPipelineLayoutRelease(pipeline_layout); | ||
| // uniform_buffer is kept alive by the bind group | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| WEBGPU_REGISTER_OPERATORS { | ||
| WEBGPU_REGISTER_OP(et_vk.rms_norm.default, rms_norm_impl); | ||
| } | ||
|
|
||
| } // namespace executorch::backends::webgpu | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| @group(0) @binding(0) var<storage, read_write> t_out: array<f32>; | ||
| @group(0) @binding(1) var<storage, read> t_in: array<f32>; | ||
| @group(0) @binding(2) var<storage, read> t_weight: array<f32>; | ||
|
|
||
| struct Params { | ||
| num_rows: u32, | ||
| row_width: u32, | ||
| epsilon: f32, | ||
| _pad: u32, | ||
| } | ||
| @group(0) @binding(3) var<uniform> params: Params; | ||
|
|
||
| const WG_SIZE: u32 = 64u; | ||
|
|
||
| var<workgroup> shared_sum: array<f32, WG_SIZE>; | ||
|
|
||
| fn reduce_shared(worker_id: u32) { | ||
| workgroupBarrier(); | ||
| var stride: u32 = WG_SIZE / 2u; | ||
| loop { | ||
| if (stride == 0u) { | ||
| break; | ||
| } | ||
| if (worker_id < stride) { | ||
| shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride]; | ||
| } | ||
| workgroupBarrier(); | ||
| stride = stride >> 1u; | ||
| } | ||
| } | ||
|
|
||
| @compute @workgroup_size(64, 1, 1) | ||
| fn main( | ||
| @builtin(workgroup_id) wid: vec3<u32>, | ||
| @builtin(local_invocation_id) lid: vec3<u32>) { | ||
| let row_idx = wid.x; | ||
| let worker_id = lid.x; | ||
|
|
||
| if (row_idx >= params.num_rows) { | ||
| return; | ||
| } | ||
|
|
||
| let base = row_idx * params.row_width; | ||
|
|
||
| var local_sq_sum: f32 = 0.0; | ||
| var x: u32 = worker_id; | ||
| loop { | ||
| if (x >= params.row_width) { | ||
| break; | ||
| } | ||
| let v = t_in[base + x]; | ||
| local_sq_sum = local_sq_sum + v * v; | ||
| x = x + WG_SIZE; | ||
| } | ||
|
|
||
| shared_sum[worker_id] = local_sq_sum; | ||
| reduce_shared(worker_id); | ||
|
|
||
| let mean_sq = shared_sum[0] / f32(params.row_width); | ||
| let rstd = inverseSqrt(mean_sq + params.epsilon); | ||
|
|
||
| x = worker_id; | ||
| loop { | ||
| if (x >= params.row_width) { | ||
| break; | ||
| } | ||
| let v = t_in[base + x]; | ||
| let w = t_weight[x]; | ||
| t_out[base + x] = v * rstd * w; | ||
| x = x + WG_SIZE; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we free this here ? does this step "named_data_map->get_data" allocate new memory at Ln174 ?