Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ if(NOT TARGET vulkan_schema)
endif()

set(WEBGPU_SRCS
runtime/WebGPUBackend.cpp runtime/WebGPUGraph.cpp
runtime/WebGPUDelegateHeader.cpp runtime/WebGPUDevice.cpp
runtime/ops/OperatorRegistry.cpp runtime/ops/add/BinaryOp.cpp
runtime/WebGPUBackend.cpp
runtime/WebGPUGraph.cpp
runtime/WebGPUDelegateHeader.cpp
runtime/WebGPUDevice.cpp
runtime/ops/OperatorRegistry.cpp
runtime/ops/add/BinaryOp.cpp
runtime/ops/rms_norm/RmsNorm.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down Expand Up @@ -116,4 +120,35 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST)

target_compile_options(webgpu_native_test PRIVATE -fexceptions)
set_property(TARGET webgpu_native_test PROPERTY CXX_STANDARD 17)

add_executable(webgpu_rms_norm_test test/native/test_rms_norm.cpp)

target_include_directories(
webgpu_rms_norm_test PRIVATE $<BUILD_INTERFACE:${EXECUTORCH_ROOT}/..>
"${WGPU_NATIVE_DIR}/include"
)

target_link_libraries(
webgpu_rms_norm_test
PRIVATE webgpu_backend
wgpu_native
executorch_core
extension_module_static
extension_data_loader
extension_tensor
portable_kernels
portable_ops_lib
)

if(APPLE)
target_link_libraries(
webgpu_rms_norm_test PRIVATE "-framework Metal" "-framework QuartzCore"
"-framework CoreGraphics"
)
else()
target_link_libraries(webgpu_rms_norm_test PRIVATE dl m pthread)
endif()

target_compile_options(webgpu_rms_norm_test PRIVATE -fexceptions)
set_property(TARGET webgpu_rms_norm_test PROPERTY CXX_STANDARD 17)
endif()
2 changes: 1 addition & 1 deletion backends/webgpu/runtime/WebGPUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Result<DelegateHandle*> WebGPUBackend::init(
}

try {
graph->build(flatbuffer_data, constant_data);
graph->build(flatbuffer_data, constant_data, context.get_named_data_map());
} catch (const std::exception& e) {
ET_LOG(Error, "WebGPU graph build failed: %s", e.what());
graph->~WebGPUGraph();
Expand Down
23 changes: 22 additions & 1 deletion backends/webgpu/runtime/WebGPUGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/serialization/schema_generated.h>
#include <executorch/runtime/core/named_data_map.h>

#include <executorch/backends/webgpu/runtime/WebGPUDevice.h>
#include <webgpu/wgpu.h>
Expand Down Expand Up @@ -93,7 +94,8 @@ WebGPUGraph::~WebGPUGraph() {

void WebGPUGraph::build(
const void* flatbuffer_data,
const uint8_t* constant_data) {
const uint8_t* constant_data,
const executorch::runtime::NamedDataMap* named_data_map) {
if (!device_) {
auto* ctx = get_default_webgpu_context();
if (ctx) {
Expand Down Expand Up @@ -165,6 +167,25 @@ void WebGPUGraph::build(
const uint8_t* src = constant_data + vk_bytes->offset();
wgpuQueueWriteBuffer(
queue_, tensor.buffer, 0, src, tensor.nbytes);
} else if (
vk_bytes->named_key() != nullptr &&
named_data_map != nullptr) {
// Constant stored in the PTE named-data map.
auto buf =
named_data_map->get_data(vk_bytes->named_key()->c_str());
if (buf.ok() && buf->size() >= tensor.nbytes) {
wgpuQueueWriteBuffer(
queue_, tensor.buffer, 0, buf->data(), tensor.nbytes);
buf->Free();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we free this here ? does this step "named_data_map->get_data" allocate new memory at Ln174 ?

} else {
throw std::runtime_error(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nits: The error message conflates "key not found" and "buffer undersized"
into one string. Splitting into two branches with sizes/error codes would make debugging ?

std::string("WebGPU: named constant '") +
vk_bytes->named_key()->c_str() +
"' missing or undersized in NamedDataMap");
}
} else {
throw std::runtime_error(
"WebGPU: constant has no inline offset and no named-data key");
}
}
}
Expand Down
9 changes: 8 additions & 1 deletion backends/webgpu/runtime/WebGPUGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
#include <unordered_map>
#include <vector>

#include <executorch/runtime/core/named_data_map.h>

namespace executorch {
namespace backends {
namespace webgpu {

using executorch::runtime::NamedDataMap;

struct WebGPUTensor {
WGPUBuffer buffer = nullptr;
std::vector<int64_t> dims;
Expand Down Expand Up @@ -66,7 +70,10 @@ class WebGPUGraph {

// Build the graph from a deserialized VkGraph flatbuffer and constant data.
// The flatbuffer_data pointer must remain valid during build().
void build(const void* flatbuffer_data, const uint8_t* constant_data);
void build(
const void* flatbuffer_data,
const uint8_t* constant_data,
const NamedDataMap* named_data_map = nullptr);

// Copy input tensor data from host pointers into GPU buffers.
void copy_inputs(const std::vector<std::pair<const void*, size_t>>& inputs);
Expand Down
192 changes: 192 additions & 0 deletions backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h>

#include <webgpu/webgpu.h>

#include <cstdint>
#include <cstring>
#include <stdexcept>

namespace executorch::backends::webgpu {

namespace {

// Uniform layout matching the WGSL Params struct (16-byte aligned).
struct RmsNormParams {
uint32_t num_rows;
uint32_t row_width;
float epsilon;
uint32_t _pad;
};
static_assert(sizeof(RmsNormParams) == 16, "RmsNormParams must be 16 bytes");

void rms_norm_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// et_vk.rms_norm.default args: [in, weight, eps, out]
const int in_id = args.at(0);
const int weight_id = args.at(1);
const int eps_id = args.at(2);
const int out_id = args.at(3);

WGPUDevice device = graph.device();

// Get epsilon (Double from a Python float; defaults to float32 eps)
float epsilon = 1.1920928955078125e-07f;
if (graph.get_value_type(eps_id) == WebGPUGraph::ValueType::Double) {
epsilon = static_cast<float>(graph.get_double(eps_id));
} else if (graph.get_value_type(eps_id) == WebGPUGraph::ValueType::Int) {
epsilon = static_cast<float>(graph.get_int(eps_id));
}

// row_width = last dim; num_rows = product of the rest (PyTorch NCHW order)
const auto& in_tensor = graph.get_tensor(in_id);
if (in_tensor.dims.empty() || in_tensor.nbytes == 0) {
throw std::runtime_error("WebGPU rms_norm: empty input");
}
const uint32_t row_width = static_cast<uint32_t>(in_tensor.dims.back());
if (row_width == 0) {
throw std::runtime_error("WebGPU rms_norm: zero row width");
}
uint64_t in_numel = 1;
for (int64_t d : in_tensor.dims) {
in_numel *= static_cast<uint64_t>(d);
}
// fp32-only shader: bail if the bytes don't match an fp32 element count.
if (in_tensor.nbytes != in_numel * sizeof(float)) {
throw std::runtime_error("WebGPU rms_norm: fp32-only (byte-size mismatch)");
}
const uint32_t num_rows = static_cast<uint32_t>(in_numel / row_width);
if (num_rows == 0) {
throw std::runtime_error("WebGPU rms_norm: zero rows");
}

// Create uniform buffer for params
RmsNormParams params = {};
params.num_rows = num_rows;
params.row_width = row_width;
params.epsilon = epsilon;

WGPUBufferDescriptor uniform_desc = {};
uniform_desc.size = sizeof(RmsNormParams);
uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
uniform_desc.mappedAtCreation = true;
WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc);

@psiddh psiddh Jun 2, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this buffer be freed eventually ? When should this be released ?

void* mapped =
wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(RmsNormParams));
std::memcpy(mapped, &params, sizeof(RmsNormParams));
wgpuBufferUnmap(uniform_buffer);

graph.add_uniform_buffer_bytes(sizeof(RmsNormParams));

// Create shader module from built-in WGSL source
WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kRmsNormWGSL, WGPU_STRLEN};

WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);

// Create bind group layout: out (rw) + in/weight (ro storage) + params
WGPUBindGroupLayoutEntry entries[4] = {};

// t_out - storage buffer, read-write
entries[0].binding = 0;
entries[0].visibility = WGPUShaderStage_Compute;
entries[0].buffer.type = WGPUBufferBindingType_Storage;

// t_in - storage buffer, read-only
entries[1].binding = 1;
entries[1].visibility = WGPUShaderStage_Compute;
entries[1].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;

// t_weight - storage buffer, read-only
entries[2].binding = 2;
entries[2].visibility = WGPUShaderStage_Compute;
entries[2].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;

// params - uniform buffer
entries[3].binding = 3;
entries[3].visibility = WGPUShaderStage_Compute;
entries[3].buffer.type = WGPUBufferBindingType_Uniform;

WGPUBindGroupLayoutDescriptor bgl_desc = {};
bgl_desc.entryCount = 4;
bgl_desc.entries = entries;
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);

// Create pipeline layout
WGPUPipelineLayoutDescriptor pl_desc = {};
pl_desc.bindGroupLayoutCount = 1;
pl_desc.bindGroupLayouts = &bgl;
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

// Create compute pipeline
WGPUComputePipelineDescriptor pipeline_desc = {};
pipeline_desc.layout = pipeline_layout;
pipeline_desc.compute.module = shader;
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
WGPUComputePipeline pipeline =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

// Create bind group with actual buffers
const auto& out_tensor = graph.get_tensor(out_id);
const auto& weight_tensor = graph.get_tensor(weight_id);

WGPUBindGroupEntry bg_entries[4] = {};

bg_entries[0].binding = 0;
bg_entries[0].buffer = out_tensor.buffer;
bg_entries[0].size = out_tensor.nbytes;

bg_entries[1].binding = 1;
bg_entries[1].buffer = in_tensor.buffer;
bg_entries[1].size = in_tensor.nbytes;

bg_entries[2].binding = 2;
bg_entries[2].buffer = weight_tensor.buffer;
bg_entries[2].size = weight_tensor.nbytes;

bg_entries[3].binding = 3;
bg_entries[3].buffer = uniform_buffer;
bg_entries[3].size = sizeof(RmsNormParams);

WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = bgl;
bg_desc.entryCount = 4;
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

// One workgroup per row (kRmsNormWorkgroupSize threads cooperate per row)
static_assert(
kRmsNormWorkgroupSize == 64,
"must match @workgroup_size and WG_SIZE in rms_norm.wgsl");
if (num_rows > 65535u) {
throw std::runtime_error(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are throwing exception here, should we release / free all objects that have been allocated thus far such as (uniform buffer, shader module, etc...)

"WebGPU rms_norm: num_rows exceeds the 1D dispatch limit (65535)");
}
graph.add_dispatch({pipeline, bind_group, num_rows});

// Release intermediate objects (pipeline and bind_group are kept by dispatch)
wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// uniform_buffer is kept alive by the bind group
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(et_vk.rms_norm.default, rms_norm_impl);
}

} // namespace executorch::backends::webgpu
72 changes: 72 additions & 0 deletions backends/webgpu/runtime/ops/rms_norm/rms_norm.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
@group(0) @binding(1) var<storage, read> t_in: array<f32>;
@group(0) @binding(2) var<storage, read> t_weight: array<f32>;

struct Params {
num_rows: u32,
row_width: u32,
epsilon: f32,
_pad: u32,
}
@group(0) @binding(3) var<uniform> params: Params;

const WG_SIZE: u32 = 64u;

var<workgroup> shared_sum: array<f32, WG_SIZE>;

fn reduce_shared(worker_id: u32) {
workgroupBarrier();
var stride: u32 = WG_SIZE / 2u;
loop {
if (stride == 0u) {
break;
}
if (worker_id < stride) {
shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride];
}
workgroupBarrier();
stride = stride >> 1u;
}
}

@compute @workgroup_size(64, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let row_idx = wid.x;
let worker_id = lid.x;

if (row_idx >= params.num_rows) {
return;
}

let base = row_idx * params.row_width;

var local_sq_sum: f32 = 0.0;
var x: u32 = worker_id;
loop {
if (x >= params.row_width) {
break;
}
let v = t_in[base + x];
local_sq_sum = local_sq_sum + v * v;
x = x + WG_SIZE;
}

shared_sum[worker_id] = local_sq_sum;
reduce_shared(worker_id);

let mean_sq = shared_sum[0] / f32(params.row_width);
let rstd = inverseSqrt(mean_sq + params.epsilon);

x = worker_id;
loop {
if (x >= params.row_width) {
break;
}
let v = t_in[base + x];
let w = t_weight[x];
t_out[base + x] = v * rstd * w;
x = x + WG_SIZE;
}
}
Loading
Loading