Skip to content

Commit a464b37

Browse files
Update
[ghstack-poisoned]
2 parents ca90cb2 + 1098b31 commit a464b37

7 files changed

Lines changed: 67 additions & 51 deletions

File tree

backends/webgpu/runtime/WebGPUBackend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ Result<DelegateHandle*> WebGPUBackend::init(
7676
}
7777

7878
try {
79-
graph->build(flatbuffer_data, constant_data, context.get_named_data_map());
79+
graph->build(flatbuffer_data, constant_data);
8080
} catch (const std::exception& e) {
8181
ET_LOG(Error, "WebGPU graph build failed: %s", e.what());
8282
graph->~WebGPUGraph();

backends/webgpu/runtime/WebGPUGraph.cpp

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
1111

1212
#include <executorch/backends/vulkan/serialization/schema_generated.h>
13-
#include <executorch/runtime/core/named_data_map.h>
1413

1514
#include <executorch/backends/webgpu/runtime/WebGPUDevice.h>
1615
#include <webgpu/wgpu.h>
@@ -94,8 +93,7 @@ WebGPUGraph::~WebGPUGraph() {
9493

9594
void WebGPUGraph::build(
9695
const void* flatbuffer_data,
97-
const uint8_t* constant_data,
98-
const executorch::runtime::NamedDataMap* named_data_map) {
96+
const uint8_t* constant_data) {
9997
if (!device_) {
10098
auto* ctx = get_default_webgpu_context();
10199
if (ctx) {
@@ -167,31 +165,6 @@ void WebGPUGraph::build(
167165
const uint8_t* src = constant_data + vk_bytes->offset();
168166
wgpuQueueWriteBuffer(
169167
queue_, tensor.buffer, 0, src, tensor.nbytes);
170-
} else if (
171-
vk_bytes->named_key() != nullptr &&
172-
named_data_map != nullptr) {
173-
// Constant stored in the PTE named-data map.
174-
auto buf =
175-
named_data_map->get_data(vk_bytes->named_key()->c_str());
176-
if (!buf.ok()) {
177-
throw std::runtime_error(
178-
std::string("WebGPU: named constant '") +
179-
vk_bytes->named_key()->c_str() +
180-
"' not found in NamedDataMap");
181-
}
182-
if (buf->size() < tensor.nbytes) {
183-
throw std::runtime_error(
184-
std::string("WebGPU: named constant '") +
185-
vk_bytes->named_key()->c_str() + "' undersized: have " +
186-
std::to_string(buf->size()) + " bytes, need " +
187-
std::to_string(tensor.nbytes));
188-
}
189-
wgpuQueueWriteBuffer(
190-
queue_, tensor.buffer, 0, buf->data(), tensor.nbytes);
191-
buf->Free();
192-
} else {
193-
throw std::runtime_error(
194-
"WebGPU: constant has no inline offset and no named-data key");
195168
}
196169
}
197170
}

backends/webgpu/runtime/WebGPUGraph.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
#include <unordered_map>
1616
#include <vector>
1717

18-
#include <executorch/runtime/core/named_data_map.h>
19-
2018
namespace executorch {
2119
namespace backends {
2220
namespace webgpu {
@@ -68,10 +66,7 @@ class WebGPUGraph {
6866

6967
// Build the graph from a deserialized VkGraph flatbuffer and constant data.
7068
// The flatbuffer_data pointer must remain valid during build().
71-
void build(
72-
const void* flatbuffer_data,
73-
const uint8_t* constant_data,
74-
const executorch::runtime::NamedDataMap* named_data_map = nullptr);
69+
void build(const void* flatbuffer_data, const uint8_t* constant_data);
7570

7671
// Copy input tensor data from host pointers into GPU buffers.
7772
void copy_inputs(const std::vector<std::pair<const void*, size_t>>& inputs);
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <webgpu/webgpu.h>
12+
13+
#include <algorithm>
14+
#include <cstdint>
15+
#include <stdexcept>
16+
#include <string>
17+
18+
namespace executorch::backends::webgpu::utils {
19+
20+
// Clamp workgroup size to device limit (SwiftShader caps at 128).
21+
inline uint32_t clamp_workgroup_size(WGPUDevice device, uint32_t desired) {
22+
WGPULimits limits = {};
23+
if (wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success &&
24+
limits.maxComputeInvocationsPerWorkgroup > 0) {
25+
return std::min(desired, limits.maxComputeInvocationsPerWorkgroup);
26+
}
27+
return desired;
28+
}
29+
30+
// 1D dispatch count (mirrors Vulkan div_up); throws if > device limit.
31+
inline uint32_t compute_1d_workgroup_count(
32+
WGPUDevice device,
33+
uint32_t num_threads,
34+
uint32_t workgroup_size,
35+
const char* op_name) {
36+
uint32_t count = (num_threads + workgroup_size - 1) / workgroup_size;
37+
WGPULimits limits = {};
38+
uint32_t max_count =
39+
wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success &&
40+
limits.maxComputeWorkgroupsPerDimension > 0
41+
? limits.maxComputeWorkgroupsPerDimension
42+
: 65535u; // WebGPU spec-default floor
43+
if (count > max_count) {
44+
throw std::runtime_error(
45+
std::string("WebGPU ") + op_name +
46+
": workgroup count exceeds the 1D dispatch limit");
47+
}
48+
return count;
49+
}
50+
51+
} // namespace executorch::backends::webgpu::utils

backends/webgpu/runtime/ops/add/BinaryOp.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
*/
88

99
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
10+
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
1011
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
1112
#include <executorch/backends/webgpu/runtime/ops/add/binary_add_wgsl.h>
1213

1314
#include <webgpu/webgpu.h>
1415

15-
#include <algorithm>
1616
#include <cmath>
1717
#include <cstring>
1818

@@ -51,21 +51,10 @@ void add_impl(WebGPUGraph& graph, const std::vector<int>& args) {
5151
uint32_t num_elements =
5252
static_cast<uint32_t>(out_tensor.nbytes / sizeof(float));
5353

54-
// Clamp the workgroup size to the device limit (SwiftShader caps at 128).
55-
WGPULimits limits = {};
56-
uint32_t device_max =
57-
wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success &&
58-
limits.maxComputeInvocationsPerWorkgroup > 0
59-
? limits.maxComputeInvocationsPerWorkgroup
60-
: kBinaryAddWorkgroupSize;
61-
uint32_t wg_size = std::min(kBinaryAddWorkgroupSize, device_max);
62-
uint32_t workgroup_count = (num_elements + wg_size - 1) / wg_size;
63-
64-
// Validate the 1D dispatch limit before allocating any GPU objects.
65-
if (workgroup_count > 65535u) {
66-
throw std::runtime_error(
67-
"WebGPU add: workgroup count exceeds the 1D dispatch limit (65535)");
68-
}
54+
uint32_t wg_size =
55+
utils::clamp_workgroup_size(device, kBinaryAddWorkgroupSize);
56+
uint32_t workgroup_count =
57+
utils::compute_1d_workgroup_count(device, num_elements, wg_size, "add");
6958

7059
WGPUConstantEntry wg_size_constant = {};
7160
wg_size_constant.key = {"wg_size", WGPU_STRLEN};

backends/webgpu/runtime/ops/rms_norm/rms_norm.wgsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// NOTE: This file is for editor/tooling support only. The runtime consumes the
2+
// inline copy of this shader in `rms_norm_wgsl.h` (kRmsNormWGSL). Keep the two
3+
// in sync by hand — any edit here must be mirrored there.
14
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
25
@group(0) @binding(1) var<storage, read> t_in: array<f32>;
36
@group(0) @binding(2) var<storage, read> t_weight: array<f32>;

backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
namespace executorch::backends::webgpu {
1414

1515
// WGSL shader source for rms_norm: y = x * w * rsqrt(mean(x^2) + eps)
16+
//
17+
// NOTE: This inline string is the runtime source of truth — it is what gets
18+
// passed to wgpuDeviceCreateShaderModule. The sibling `rms_norm.wgsl` file
19+
// exists only for editor/tooling support and must be kept identical to this
20+
// string by hand; there is no build-time sync.
1621
inline constexpr const char* kRmsNormWGSL = R"(
1722
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
1823
@group(0) @binding(1) var<storage, read> t_in: array<f32>;

0 commit comments

Comments
 (0)