Skip to content

Commit c36f9fa

Browse files
JulianCloudNTHfacebook-github-bot
authored andcommitted
Generate *_wgsl.h embedded shaders from *.wgsl (#19981)
Summary: Adds `backends/webgpu/scripts/gen_wgsl_headers.py` to generate each `runtime/ops/<op>/<shader>_wgsl.h` from its `<shader>.wgsl`, so each WGSL shader has a single canonical source instead of a hand-maintained embedded copy that can silently drift. Each header embeds the shader verbatim (`inline constexpr const char* k<Op>WGSL = R"(...)";` plus the workgroup-size constants) and a `// wgsl-sha256:` of the source; `--check` (wired into `test_build_webgpu.sh` and the `webgpu_backend` CMake build) and the unit tests fail the build if any committed header drifts. `workgroup_size` is parsed for all three dims (WGSL allows 1-3; y and z default to 1), emitting `k<Op>WorkgroupSizeX/Y/Z` so future 2D/3D shaders need no codegen change; the two current 1D consumers read `...X`. The X/Y/Z naming and `uint32_t`-per-axis mirror Vulkan's `utils::WorkgroupSize` (`backends/vulkan/runtime/utils/VecUtils.h`); WGSL `workgroup_size` is compile-time, so the value is parsed from the shader rather than set via runtime spec-constants as in Vulkan. The drift check compares the full rendered header (not just the shader sha), so a generator-logic change is also detected/regenerated. The parser accepts the spaced form `workgroup_size (n)` and suffix-typed literals (`64u`). Regenerates the two existing committed op headers: `binary_add_wgsl.h` and `rms_norm_wgsl.h` gain the `...X/Y/Z` constants (X = the 1D size, Y=Z=1); `rms_norm.wgsl` also drops its now-obsolete 3-line "keep in sync by hand" note (codegen + `--check` make it false). The shader code itself is unchanged. This change was authored with assistance from Claude. Reviewed By: SS-JIA Differential Revision: D107403275
1 parent e56be3e commit c36f9fa

9 files changed

Lines changed: 408 additions & 21 deletions

File tree

backends/webgpu/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ set(WEBGPU_SRCS
3737

3838
add_library(webgpu_backend ${WEBGPU_SRCS})
3939

40+
# Verify committed *_wgsl.h match their *.wgsl (drift fails the build).
41+
resolve_python_executable()
42+
add_custom_target(
43+
webgpu_wgsl_headers_check ALL
44+
COMMAND "${PYTHON_EXECUTABLE}"
45+
"${CMAKE_CURRENT_SOURCE_DIR}/scripts/gen_wgsl_headers.py" --check
46+
WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}"
47+
COMMENT "Checking WebGPU embedded-WGSL headers are in sync"
48+
)
49+
add_dependencies(webgpu_backend webgpu_wgsl_headers_check)
50+
4051
target_include_directories(
4152
webgpu_backend PRIVATE $<BUILD_INTERFACE:${EXECUTORCH_ROOT}/..>
4253
)

backends/webgpu/runtime/ops/add/BinaryOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ void add_impl(WebGPUGraph& graph, const std::vector<int>& args) {
5252
static_cast<uint32_t>(out_tensor.nbytes / sizeof(float));
5353

5454
uint32_t wg_size =
55-
utils::clamp_workgroup_size(device, kBinaryAddWorkgroupSize);
55+
utils::clamp_workgroup_size(device, kBinaryAddWorkgroupSizeX);
5656
uint32_t workgroup_count =
5757
utils::compute_1d_workgroup_count(device, num_elements, wg_size, "add");
5858

backends/webgpu/runtime/ops/add/binary_add_wgsl.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88

99
#pragma once
1010

11-
namespace executorch {
12-
namespace backends {
13-
namespace webgpu {
11+
#include <cstdint>
1412

15-
// WGSL shader source for element-wise add: output = input1 + alpha * input2
13+
namespace executorch::backends::webgpu {
14+
15+
// @generated from binary_add.wgsl - DO NOT EDIT.
16+
// wgsl-sha256: c1ceec80c8d4d3d56986ad91ce0d7f9a57cd8467b8c3aa07a28da70e51d141d9
1617
inline constexpr const char* kBinaryAddWGSL = R"(
1718
@group(0) @binding(0) var<storage, read> input1: array<f32>;
1819
@group(0) @binding(1) var<storage, read> input2: array<f32>;
@@ -36,8 +37,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
3637
}
3738
)";
3839

39-
inline constexpr uint32_t kBinaryAddWorkgroupSize = 256;
40+
inline constexpr uint32_t kBinaryAddWorkgroupSizeX = 256;
41+
inline constexpr uint32_t kBinaryAddWorkgroupSizeY = 1;
42+
inline constexpr uint32_t kBinaryAddWorkgroupSizeZ = 1;
4043

41-
} // namespace webgpu
42-
} // namespace backends
43-
} // namespace executorch
44+
} // namespace executorch::backends::webgpu

backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,9 @@ void rms_norm_impl(WebGPUGraph& graph, const std::vector<int>& args) {
172172
bg_desc.entries = bg_entries;
173173
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
174174

175-
// One workgroup per row (kRmsNormWorkgroupSize threads cooperate per row)
175+
// One workgroup per row (kRmsNormWorkgroupSizeX threads cooperate per row)
176176
static_assert(
177-
kRmsNormWorkgroupSize == 64,
177+
kRmsNormWorkgroupSizeX == 64,
178178
"must match @workgroup_size and WG_SIZE in rms_norm.wgsl");
179179
graph.add_dispatch({pipeline, bind_group, num_rows});
180180

backends/webgpu/runtime/ops/rms_norm/rms_norm.wgsl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
// NOTE: This file is for editor/tooling support only. The runtime consumes the
2-
// inline copy of this shader in `rms_norm_wgsl.h` (kRmsNormWGSL). Keep the two
3-
// in sync by hand — any edit here must be mirrored there.
41
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
52
@group(0) @binding(1) var<storage, read> t_in: array<f32>;
63
@group(0) @binding(2) var<storage, read> t_weight: array<f32>;

backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,8 @@
1212

1313
namespace executorch::backends::webgpu {
1414

15-
// WGSL shader source for rms_norm: y = x * w * rsqrt(mean(x^2) + eps)
16-
//
17-
// NOTE: This inline string is the runtime source of truth — it is what gets
18-
// passed to wgpuDeviceCreateShaderModule. The sibling `rms_norm.wgsl` file
19-
// exists only for editor/tooling support and must be kept identical to this
20-
// string by hand; there is no build-time sync.
15+
// @generated from rms_norm.wgsl - DO NOT EDIT.
16+
// wgsl-sha256: 340dcbf3c06dc311e70bef953c1e9cbbdf4121fe177eedd3253549e614b55069
2117
inline constexpr const char* kRmsNormWGSL = R"(
2218
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
2319
@group(0) @binding(1) var<storage, read> t_in: array<f32>;
@@ -93,6 +89,8 @@ fn main(
9389
}
9490
)";
9591

96-
inline constexpr uint32_t kRmsNormWorkgroupSize = 64;
92+
inline constexpr uint32_t kRmsNormWorkgroupSizeX = 64;
93+
inline constexpr uint32_t kRmsNormWorkgroupSizeY = 1;
94+
inline constexpr uint32_t kRmsNormWorkgroupSizeZ = 1;
9795

9896
} // namespace executorch::backends::webgpu
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""Generate runtime/ops/<op>/<stem>_wgsl.h from each <stem>.wgsl.
9+
10+
Each header embeds the shader verbatim as `inline constexpr const char*
11+
k<Pascal>WGSL` plus `k<Pascal>WorkgroupSize` (parsed from @workgroup_size).
12+
13+
Usage:
14+
gen_wgsl_headers.py # (re)write all <stem>_wgsl.h
15+
gen_wgsl_headers.py --check # exit 1 if any committed header is stale
16+
17+
Stdlib only (the devserver has no third-party pip).
18+
"""
19+
20+
import argparse
21+
import hashlib
22+
import re
23+
import sys
24+
from pathlib import Path
25+
26+
BACKEND_ROOT = Path(__file__).resolve().parents[1]
27+
28+
_SHA_RE = re.compile(r"// wgsl-sha256: ([0-9a-f]{64})")
29+
30+
_BSD_HEADER = """\
31+
/*
32+
* Copyright (c) Meta Platforms, Inc. and affiliates.
33+
* All rights reserved.
34+
*
35+
* This source code is licensed under the BSD-style license found in the
36+
* LICENSE file in the root directory of this source tree.
37+
*/"""
38+
39+
40+
def symbol_base(stem: str) -> str:
41+
"""snake_case shader stem -> PascalCase symbol base (binary_add -> BinaryAdd)."""
42+
return "".join(part.capitalize() for part in stem.split("_"))
43+
44+
45+
_INT_LITERAL_RE = re.compile(r"^(\d+)[uUiI]?$")
46+
47+
48+
def _resolve_dim(tok: str, src: str) -> int:
49+
"""Resolve one @workgroup_size dim token: a literal or an override/const ident.
50+
51+
Accepts WGSL suffix-typed integer literals (e.g. `64u`, `64i`) both as the
52+
token and on the right-hand side of an `override`/`const` (type optional).
53+
"""
54+
lit = _INT_LITERAL_RE.match(tok)
55+
if lit:
56+
return int(lit.group(1))
57+
m = re.search(
58+
r"(?:override|const)\s+"
59+
+ re.escape(tok)
60+
+ r"\s*(?::\s*u32\s*)?=\s*(\d+)[uUiI]?",
61+
src,
62+
)
63+
if not m:
64+
raise ValueError(f"cannot resolve @workgroup_size identifier '{tok}'")
65+
return int(m.group(1))
66+
67+
68+
def parse_workgroup_size(src: str) -> tuple[int, int, int]:
69+
"""Resolve the (x, y, z) dims of @workgroup_size; y and z default to 1."""
70+
m = re.search(r"@workgroup_size\s*\(([^)]*)\)", src)
71+
if not m:
72+
raise ValueError("no @workgroup_size found")
73+
toks = [t.strip() for t in m.group(1).split(",") if t.strip()]
74+
if not toks or len(toks) > 3:
75+
raise ValueError(f"@workgroup_size takes 1-3 dims, got {len(toks)}")
76+
dims = [_resolve_dim(t, src) for t in toks]
77+
while len(dims) < 3:
78+
dims.append(1)
79+
return (dims[0], dims[1], dims[2])
80+
81+
82+
def wgsl_sha256(wgsl_text: str) -> str:
83+
return hashlib.sha256(wgsl_text.encode("utf-8")).hexdigest()
84+
85+
86+
def embedded_sha256(header_text: str) -> str:
87+
m = _SHA_RE.search(header_text)
88+
return m.group(1) if m else ""
89+
90+
91+
def render_header(wgsl_path, wgsl_text: str) -> str:
92+
"""Render the full <stem>_wgsl.h text for a shader (shader embedded verbatim)."""
93+
if ')"' in wgsl_text:
94+
raise ValueError('shader contains )" which would close the R"( literal')
95+
stem = Path(wgsl_path).stem
96+
base = symbol_base(stem)
97+
x, y, z = parse_workgroup_size(wgsl_text)
98+
99+
head = [
100+
_BSD_HEADER,
101+
"",
102+
"#pragma once",
103+
"",
104+
"#include <cstdint>",
105+
"",
106+
"namespace executorch::backends::webgpu {",
107+
"",
108+
f"// @generated from {stem}.wgsl - DO NOT EDIT.",
109+
f"// wgsl-sha256: {wgsl_sha256(wgsl_text)}",
110+
f'inline constexpr const char* k{base}WGSL = R"(',
111+
]
112+
return (
113+
"\n".join(head)
114+
+ "\n"
115+
+ wgsl_text
116+
+ ')";'
117+
+ "\n\n"
118+
+ f"inline constexpr uint32_t k{base}WorkgroupSizeX = {x};\n"
119+
+ f"inline constexpr uint32_t k{base}WorkgroupSizeY = {y};\n"
120+
+ f"inline constexpr uint32_t k{base}WorkgroupSizeZ = {z};\n\n"
121+
+ "} // namespace executorch::backends::webgpu\n"
122+
)
123+
124+
125+
def discover():
126+
"""All shader sources under runtime/ops, sorted."""
127+
return sorted((BACKEND_ROOT / "runtime/ops").glob("**/*.wgsl"))
128+
129+
130+
def _report_drift(missing, stale) -> None:
131+
"""Print the --check report for missing/stale committed headers."""
132+
if missing:
133+
print("Missing embedded WGSL headers (run scripts/gen_wgsl_headers.py):")
134+
for h in missing:
135+
print(f" {h.relative_to(BACKEND_ROOT)}")
136+
if stale:
137+
print("Stale embedded WGSL headers (run scripts/gen_wgsl_headers.py):")
138+
for h in stale:
139+
print(f" {h.relative_to(BACKEND_ROOT)}")
140+
141+
142+
def main(argv=None) -> int:
143+
parser = argparse.ArgumentParser(description=__doc__)
144+
parser.add_argument(
145+
"--check",
146+
action="store_true",
147+
help="verify committed headers match (exit 1 on drift)",
148+
)
149+
args = parser.parse_args(argv)
150+
151+
stale = []
152+
missing = []
153+
errors = []
154+
for wgsl in discover():
155+
wgsl_text = wgsl.read_text()
156+
try:
157+
want = render_header(wgsl, wgsl_text)
158+
except ValueError as e:
159+
errors.append(f"{wgsl.relative_to(BACKEND_ROOT)}: {e}")
160+
continue
161+
header = wgsl.with_name(wgsl.stem + "_wgsl.h")
162+
# Full-content compare (not just the sha) catches generator-logic drift too.
163+
if header.exists() and header.read_text() == want:
164+
continue
165+
if args.check:
166+
(missing if not header.exists() else stale).append(header)
167+
else:
168+
header.write_text(want)
169+
170+
if errors:
171+
print("Cannot generate header (malformed shader):")
172+
for e in errors:
173+
print(f" {e}")
174+
return 1
175+
if args.check and (stale or missing):
176+
_report_drift(missing, stale)
177+
return 1
178+
return 0
179+
180+
181+
if __name__ == "__main__":
182+
sys.exit(main())

backends/webgpu/test/test_build_webgpu.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ EXECUTORCH_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
1515
PYTHON_EXECUTABLE="${PYTHON_EXECUTABLE:-python3}"
1616
NPROC=$(nproc 2>/dev/null || sysctl -n hw.ncpu)
1717

18+
echo "=== Check embedded WGSL headers are up to date ==="
19+
"${PYTHON_EXECUTABLE}" "${SCRIPT_DIR}/../scripts/gen_wgsl_headers.py" --check \
20+
|| { echo "ERROR: *_wgsl.h out of sync with .wgsl; run scripts/gen_wgsl_headers.py"; exit 1; }
21+
22+
# Unit tests for the WGSL header generator itself
23+
$PYTHON_EXECUTABLE -m pytest "${SCRIPT_DIR}/test_wgsl_codegen.py" -v
24+
1825
# ── Step 1: Python export tests ──────────────────────────────────────────────
1926

2027
echo "=== Step 1: Run Python export tests ==="

0 commit comments

Comments
 (0)