Skip to content

Commit f6945a4

Browse files
[ExecuTorch][WebGPU] Add mul op with full broadcast (aten.mul.Tensor)
Pull Request resolved: #20358 Adds `aten.mul.Tensor` to the WebGPU delegate with full PyTorch broadcast, plus the shared `runtime/ops/TensorMeta.h` per-tensor uniform that broadcast ops reuse. Mul is on the Llama critical path — `F.silu` decomposes to `sigmoid` + `mul`, and SwiGLU multiplies two same-shape activations (the fast path). Composition (single dispatch): - `TensorMeta.h` (NEW) — 48-byte std140 `{ndim, numel, sizes[4], strides[4]}` UBO mirroring Vulkan's per-tensor `BufferMetadata`; `fill_tensor_meta_broadcast` right-aligns operand dims (rank>4 throws); `static_assert(sizeof==48)`. - `mul/BinaryOp.cpp` — builds 3 `TensorMeta` UBOs (out/in1/in2 at bindings 3/4/5), guards fp32 + rank≤4, 1D-dispatches over `compute_1d_workgroup_count(numel)`, releases all uniforms after the bind group. - `mul/binary_mul.wgsl` — same-shape fast path + a broadcast path (delinearize output index, clamp each input coord per-dim to size-1, relinearize on input strides). - `WebGPUUtils.h` — adds the shared `utils::make_uniform` helper (first use). ghstack-source-id: 394848336 @exported-using-ghexport Differential Revision: [D108793167](https://our.internmc.facebook.com/intern/diff/D108793167/)
1 parent 044e16a commit f6945a4

6 files changed

Lines changed: 421 additions & 0 deletions

File tree

backends/webgpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ set(WEBGPU_SRCS
3838
runtime/ops/sdpa/Sdpa.cpp
3939
runtime/ops/select_as_symint/SelectAsSymint.cpp
4040
runtime/ops/quantized_linear/QuantizedLinear.cpp
41+
runtime/ops/mul/BinaryOp.cpp
4142
)
4243

4344
add_library(webgpu_backend ${WEBGPU_SRCS})

backends/webgpu/runtime/WebGPUUtils.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <algorithm>
1414
#include <cstdint>
15+
#include <cstring>
1516
#include <stdexcept>
1617
#include <string>
1718

@@ -48,4 +49,25 @@ inline uint32_t compute_1d_workgroup_count(
4849
return count;
4950
}
5051

52+
// Create a uniform buffer mapped-at-creation, copy `size` bytes in, and unmap.
53+
inline WGPUBuffer
54+
make_uniform(WGPUDevice device, const void* data, size_t size) {
55+
WGPUBufferDescriptor desc = {};
56+
desc.size = size;
57+
desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
58+
desc.mappedAtCreation = true;
59+
WGPUBuffer buf = wgpuDeviceCreateBuffer(device, &desc);
60+
if (!buf) {
61+
throw std::runtime_error("make_uniform: buffer creation failed");
62+
}
63+
void* ptr = wgpuBufferGetMappedRange(buf, 0, size);
64+
if (!ptr) {
65+
wgpuBufferRelease(buf);
66+
throw std::runtime_error("make_uniform: mapped range is null");
67+
}
68+
std::memcpy(ptr, data, size);
69+
wgpuBufferUnmap(buf);
70+
return buf;
71+
}
72+
5173
} // namespace executorch::backends::webgpu::utils
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
12+
13+
#include <cstddef>
14+
#include <cstdint>
15+
#include <stdexcept>
16+
17+
namespace executorch::backends::webgpu {
18+
19+
constexpr uint32_t kTensorMetaMaxNdim = 4;
20+
21+
// Per-tensor metadata UBO; mirrors Vulkan BufferMetadata (4-dim NCHW, std140).
22+
struct TensorMeta {
23+
uint32_t ndim;
24+
uint32_t numel;
25+
uint32_t _pad[2];
26+
uint32_t sizes[kTensorMetaMaxNdim];
27+
uint32_t strides[kTensorMetaMaxNdim];
28+
};
29+
30+
static_assert(
31+
sizeof(TensorMeta) == 48,
32+
"TensorMeta std140 layout must be 48 bytes to match the WGSL uniform");
33+
// Lock the std140 field offsets the WGSL uniform reads, not just total size.
34+
static_assert(offsetof(TensorMeta, ndim) == 0);
35+
static_assert(offsetof(TensorMeta, numel) == 4);
36+
static_assert(offsetof(TensorMeta, sizes) == 16);
37+
static_assert(offsetof(TensorMeta, strides) == 32);
38+
39+
// Fill TensorMeta from NCHW dims: contiguous strides, padded trailing slots.
40+
inline void fill_tensor_meta(const WebGPUTensor& t, TensorMeta* m) {
41+
const uint32_t ndim = static_cast<uint32_t>(t.dims.size());
42+
if (ndim > kTensorMetaMaxNdim) {
43+
throw std::runtime_error("TensorMeta: tensor rank exceeds 4 (MAX_NDIM)");
44+
}
45+
*m = {};
46+
for (uint32_t d = 0; d < kTensorMetaMaxNdim; d++) {
47+
m->sizes[d] = 1u;
48+
m->strides[d] = 0u;
49+
}
50+
m->ndim = ndim;
51+
uint32_t numel = 1u;
52+
uint32_t acc = 1u;
53+
for (int i = static_cast<int>(ndim) - 1; i >= 0; i--) {
54+
const uint32_t sz = static_cast<uint32_t>(t.dims[i]);
55+
m->sizes[i] = sz;
56+
m->strides[i] = acc;
57+
acc *= sz;
58+
numel *= sz;
59+
}
60+
m->numel = numel;
61+
}
62+
63+
// Broadcast variant: right-align operand dims into out rank (PyTorch trailing).
64+
inline void fill_tensor_meta_broadcast(
65+
const WebGPUTensor& t,
66+
uint32_t out_ndim,
67+
TensorMeta* m) {
68+
const uint32_t rank = static_cast<uint32_t>(t.dims.size());
69+
if (out_ndim > kTensorMetaMaxNdim) {
70+
throw std::runtime_error("TensorMeta: out_ndim exceeds 4 (MAX_NDIM)");
71+
}
72+
if (rank > out_ndim) {
73+
throw std::runtime_error("TensorMeta: operand rank exceeds out_ndim");
74+
}
75+
*m = {};
76+
for (uint32_t d = 0; d < kTensorMetaMaxNdim; d++) {
77+
m->sizes[d] = 1u;
78+
m->strides[d] = 0u;
79+
}
80+
m->ndim = out_ndim;
81+
uint32_t acc = 1u;
82+
uint32_t numel = 1u;
83+
for (int i = static_cast<int>(rank) - 1; i >= 0; i--) {
84+
const uint32_t slot = out_ndim - rank + static_cast<uint32_t>(i);
85+
const uint32_t sz = static_cast<uint32_t>(t.dims[i]);
86+
m->sizes[slot] = sz;
87+
m->strides[slot] = acc;
88+
acc *= sz;
89+
numel *= sz;
90+
}
91+
m->numel = numel;
92+
}
93+
94+
} // namespace executorch::backends::webgpu
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
10+
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
11+
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
12+
#include <executorch/backends/webgpu/runtime/ops/TensorMeta.h>
13+
#include <executorch/backends/webgpu/runtime/ops/mul/binary_mul_wgsl.h>
14+
15+
#include <webgpu/webgpu.h>
16+
17+
#include <stdexcept>
18+
#include <vector>
19+
20+
namespace executorch::backends::webgpu {
21+
22+
namespace {
23+
24+
void mul_impl(WebGPUGraph& graph, const std::vector<int>& args) {
25+
// aten.mul.Tensor args: [in1, in2, out] (self, other; no alpha)
26+
const int in1_id = args.at(0);
27+
const int in2_id = args.at(1);
28+
const int out_id = args.at(2);
29+
30+
WGPUDevice device = graph.device();
31+
32+
const auto& in1_tensor = graph.get_tensor(in1_id);
33+
const auto& in2_tensor = graph.get_tensor(in2_id);
34+
const auto& out_tensor = graph.get_tensor(out_id);
35+
36+
// Rank guard (NCHW backend is <= 4 dims; 1D dispatch only).
37+
if (out_tensor.dims.size() > kTensorMetaMaxNdim ||
38+
in1_tensor.dims.size() > kTensorMetaMaxNdim ||
39+
in2_tensor.dims.size() > kTensorMetaMaxNdim) {
40+
throw std::runtime_error("mul: tensor rank exceeds 4 (MAX_NDIM)");
41+
}
42+
43+
const uint32_t out_ndim = static_cast<uint32_t>(out_tensor.dims.size());
44+
45+
// 3 per-tensor meta uniforms (mirror Vulkan); inputs broadcast-aligned.
46+
TensorMeta out_meta;
47+
TensorMeta in1_meta;
48+
TensorMeta in2_meta;
49+
fill_tensor_meta_broadcast(out_tensor, out_ndim, &out_meta);
50+
fill_tensor_meta_broadcast(in1_tensor, out_ndim, &in1_meta);
51+
fill_tensor_meta_broadcast(in2_tensor, out_ndim, &in2_meta);
52+
53+
// fp32-only: nbytes must equal numel * 4 for every operand.
54+
if (out_tensor.nbytes !=
55+
static_cast<size_t>(out_meta.numel) * sizeof(float) ||
56+
in1_tensor.nbytes !=
57+
static_cast<size_t>(in1_meta.numel) * sizeof(float) ||
58+
in2_tensor.nbytes !=
59+
static_cast<size_t>(in2_meta.numel) * sizeof(float)) {
60+
throw std::runtime_error("mul: non-fp32 operand (nbytes != numel * 4)");
61+
}
62+
63+
uint32_t wg_size =
64+
utils::clamp_workgroup_size(device, kBinaryMulWorkgroupSizeX);
65+
uint32_t workgroup_count =
66+
utils::compute_1d_workgroup_count(device, out_meta.numel, wg_size, "mul");
67+
68+
WGPUConstantEntry wg_size_constant = {};
69+
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
70+
wg_size_constant.value = static_cast<double>(wg_size);
71+
72+
WGPUBuffer out_meta_buf =
73+
utils::make_uniform(device, &out_meta, sizeof(TensorMeta));
74+
WGPUBuffer in1_meta_buf =
75+
utils::make_uniform(device, &in1_meta, sizeof(TensorMeta));
76+
WGPUBuffer in2_meta_buf =
77+
utils::make_uniform(device, &in2_meta, sizeof(TensorMeta));
78+
graph.add_uniform_buffer_bytes(3 * sizeof(TensorMeta));
79+
80+
WGPUShaderSourceWGSL wgsl_desc = {};
81+
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
82+
wgsl_desc.code = {kBinaryMulWGSL, WGPU_STRLEN};
83+
84+
WGPUShaderModuleDescriptor shader_desc = {};
85+
shader_desc.nextInChain = &wgsl_desc.chain;
86+
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);
87+
88+
// Bind group: in1, in2, out (rw), out_meta, in1_meta, in2_meta (3 uniforms).
89+
WGPUBindGroupLayoutEntry entries[6] = {};
90+
91+
entries[0].binding = 0;
92+
entries[0].visibility = WGPUShaderStage_Compute;
93+
entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
94+
95+
entries[1].binding = 1;
96+
entries[1].visibility = WGPUShaderStage_Compute;
97+
entries[1].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
98+
99+
entries[2].binding = 2;
100+
entries[2].visibility = WGPUShaderStage_Compute;
101+
entries[2].buffer.type = WGPUBufferBindingType_Storage;
102+
103+
entries[3].binding = 3;
104+
entries[3].visibility = WGPUShaderStage_Compute;
105+
entries[3].buffer.type = WGPUBufferBindingType_Uniform;
106+
107+
entries[4].binding = 4;
108+
entries[4].visibility = WGPUShaderStage_Compute;
109+
entries[4].buffer.type = WGPUBufferBindingType_Uniform;
110+
111+
entries[5].binding = 5;
112+
entries[5].visibility = WGPUShaderStage_Compute;
113+
entries[5].buffer.type = WGPUBufferBindingType_Uniform;
114+
115+
WGPUBindGroupLayoutDescriptor bgl_desc = {};
116+
bgl_desc.entryCount = 6;
117+
bgl_desc.entries = entries;
118+
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);
119+
120+
WGPUPipelineLayoutDescriptor pl_desc = {};
121+
pl_desc.bindGroupLayoutCount = 1;
122+
pl_desc.bindGroupLayouts = &bgl;
123+
WGPUPipelineLayout pipeline_layout =
124+
wgpuDeviceCreatePipelineLayout(device, &pl_desc);
125+
126+
WGPUComputePipelineDescriptor pipeline_desc = {};
127+
pipeline_desc.layout = pipeline_layout;
128+
pipeline_desc.compute.module = shader;
129+
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
130+
pipeline_desc.compute.constantCount = 1;
131+
pipeline_desc.compute.constants = &wg_size_constant;
132+
WGPUComputePipeline pipeline =
133+
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);
134+
135+
WGPUBindGroupEntry bg_entries[6] = {};
136+
137+
bg_entries[0].binding = 0;
138+
bg_entries[0].buffer = in1_tensor.buffer;
139+
bg_entries[0].size = in1_tensor.nbytes;
140+
141+
bg_entries[1].binding = 1;
142+
bg_entries[1].buffer = in2_tensor.buffer;
143+
bg_entries[1].size = in2_tensor.nbytes;
144+
145+
bg_entries[2].binding = 2;
146+
bg_entries[2].buffer = out_tensor.buffer;
147+
bg_entries[2].size = out_tensor.nbytes;
148+
149+
bg_entries[3].binding = 3;
150+
bg_entries[3].buffer = out_meta_buf;
151+
bg_entries[3].size = sizeof(TensorMeta);
152+
153+
bg_entries[4].binding = 4;
154+
bg_entries[4].buffer = in1_meta_buf;
155+
bg_entries[4].size = sizeof(TensorMeta);
156+
157+
bg_entries[5].binding = 5;
158+
bg_entries[5].buffer = in2_meta_buf;
159+
bg_entries[5].size = sizeof(TensorMeta);
160+
161+
WGPUBindGroupDescriptor bg_desc = {};
162+
bg_desc.layout = bgl;
163+
bg_desc.entryCount = 6;
164+
bg_desc.entries = bg_entries;
165+
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
166+
167+
graph.add_dispatch({pipeline, bind_group, workgroup_count});
168+
169+
wgpuShaderModuleRelease(shader);
170+
wgpuBindGroupLayoutRelease(bgl);
171+
wgpuPipelineLayoutRelease(pipeline_layout);
172+
// Drop our refs; the bind group keeps the uniforms alive until release.
173+
wgpuBufferRelease(out_meta_buf);
174+
wgpuBufferRelease(in1_meta_buf);
175+
wgpuBufferRelease(in2_meta_buf);
176+
}
177+
178+
} // namespace
179+
180+
WEBGPU_REGISTER_OPERATORS {
181+
WEBGPU_REGISTER_OP(aten.mul.Tensor, mul_impl);
182+
}
183+
184+
} // namespace executorch::backends::webgpu
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
@group(0) @binding(0) var<storage, read> input1: array<f32>;
2+
@group(0) @binding(1) var<storage, read> input2: array<f32>;
3+
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
4+
5+
struct TensorMeta {
6+
ndim: u32,
7+
numel: u32,
8+
sizes: vec4<u32>,
9+
strides: vec4<u32>,
10+
}
11+
@group(0) @binding(3) var<uniform> out_meta: TensorMeta;
12+
@group(0) @binding(4) var<uniform> in1_meta: TensorMeta;
13+
@group(0) @binding(5) var<uniform> in2_meta: TensorMeta;
14+
15+
override wg_size: u32 = 64u;
16+
17+
@compute @workgroup_size(wg_size, 1, 1)
18+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
19+
let idx = gid.x;
20+
if (idx >= out_meta.numel) {
21+
return;
22+
}
23+
24+
// Fast path: every input dim matches the output dim -> elementwise.
25+
var same = true;
26+
for (var d: u32 = 0u; d < out_meta.ndim; d = d + 1u) {
27+
if (in1_meta.sizes[d] != out_meta.sizes[d] ||
28+
in2_meta.sizes[d] != out_meta.sizes[d]) {
29+
same = false;
30+
}
31+
}
32+
if (same) {
33+
output[idx] = input1[idx] * input2[idx];
34+
return;
35+
}
36+
37+
// Broadcast: out idx -> per-input coord (clamp size-1 dims), relinearize.
38+
var rem = idx;
39+
var l1: u32 = 0u;
40+
var l2: u32 = 0u;
41+
for (var d: u32 = 0u; d < out_meta.ndim; d = d + 1u) {
42+
let coord = rem / out_meta.strides[d];
43+
rem = rem % out_meta.strides[d];
44+
l1 = l1 + min(coord, in1_meta.sizes[d] - 1u) * in1_meta.strides[d];
45+
l2 = l2 + min(coord, in2_meta.sizes[d] - 1u) * in2_meta.strides[d];
46+
}
47+
output[idx] = input1[l1] * input2[l2];
48+
}

0 commit comments

Comments
 (0)