Skip to content

Commit fc0ba68

Browse files
JulianCloudNTHfacebook-github-bot
authored andcommitted
Generate *_wgsl.h embedded shaders from *.wgsl (#19981)
Summary: Adds `backends/webgpu/scripts/gen_wgsl_headers.py` to generate each `runtime/ops/<op>/<shader>_wgsl.h` from its `<shader>.wgsl`, so each WGSL shader has a single canonical source instead of a hand-maintained embedded copy that can silently drift. Each header embeds the shader verbatim (`inline constexpr const char* k<Op>WGSL = R"(...)";` plus `k<Op>WorkgroupSize` parsed from `workgroup_size`) and a `// wgsl-sha256:` of the source; `--check` and the unit test verify each committed header against that embedded sha. Regenerates the two existing committed op headers — `binary_add_wgsl.h` and `rms_norm_wgsl.h` — into this canonical form (embedded shader bodies byte-identical). Adds `test/test_wgsl_codegen.py` (unit + drift test) and a `--check` mode wired into `test_build_webgpu.sh` that fails if any committed header is stale. This change was authored with assistance from Claude. Differential Revision: D107403275
1 parent 4c9c444 commit fc0ba68

5 files changed

Lines changed: 280 additions & 13 deletions

File tree

backends/webgpu/runtime/ops/add/binary_add_wgsl.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88

99
#pragma once
1010

11-
namespace executorch {
12-
namespace backends {
13-
namespace webgpu {
11+
#include <cstdint>
1412

15-
// WGSL shader source for element-wise add: output = input1 + alpha * input2
13+
namespace executorch::backends::webgpu {
14+
15+
// @generated from binary_add.wgsl - DO NOT EDIT.
16+
// wgsl-sha256: c1ceec80c8d4d3d56986ad91ce0d7f9a57cd8467b8c3aa07a28da70e51d141d9
1617
inline constexpr const char* kBinaryAddWGSL = R"(
1718
@group(0) @binding(0) var<storage, read> input1: array<f32>;
1819
@group(0) @binding(1) var<storage, read> input2: array<f32>;
@@ -38,6 +39,4 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
3839

3940
inline constexpr uint32_t kBinaryAddWorkgroupSize = 256;
4041

41-
} // namespace webgpu
42-
} // namespace backends
43-
} // namespace executorch
42+
} // namespace executorch::backends::webgpu

backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@
1212

1313
namespace executorch::backends::webgpu {
1414

15-
// WGSL shader source for rms_norm: y = x * w * rsqrt(mean(x^2) + eps)
16-
//
17-
// NOTE: This inline string is the runtime source of truth — it is what gets
18-
// passed to wgpuDeviceCreateShaderModule. The sibling `rms_norm.wgsl` file
19-
// exists only for editor/tooling support and must be kept identical to this
20-
// string by hand; there is no build-time sync.
15+
// @generated from rms_norm.wgsl - DO NOT EDIT.
16+
// wgsl-sha256: 41ea66b52f9d205e84daf7f67e489197c4242d8b78435b2248d97d547449b95f
2117
inline constexpr const char* kRmsNormWGSL = R"(
18+
// NOTE: This file is for editor/tooling support only. The runtime consumes the
19+
// inline copy of this shader in `rms_norm_wgsl.h` (kRmsNormWGSL). Keep the two
20+
// in sync by hand — any edit here must be mirrored there.
2221
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
2322
@group(0) @binding(1) var<storage, read> t_in: array<f32>;
2423
@group(0) @binding(2) var<storage, read> t_weight: array<f32>;
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""Generate runtime/ops/<op>/<stem>_wgsl.h from each <stem>.wgsl.
9+
10+
Each header embeds the shader verbatim as `inline constexpr const char*
11+
k<Pascal>WGSL` plus `k<Pascal>WorkgroupSize` (parsed from @workgroup_size).
12+
13+
Usage:
14+
gen_wgsl_headers.py # (re)write all <stem>_wgsl.h
15+
gen_wgsl_headers.py --check # exit 1 if any committed header is stale
16+
17+
Stdlib only (the devserver has no third-party pip).
18+
"""
19+
20+
import argparse
21+
import hashlib
22+
import re
23+
import sys
24+
from pathlib import Path
25+
26+
BACKEND_ROOT = Path(__file__).resolve().parents[1]
27+
28+
_SHA_RE = re.compile(r"// wgsl-sha256: ([0-9a-f]{64})")
29+
30+
_BSD_HEADER = """\
31+
/*
32+
* Copyright (c) Meta Platforms, Inc. and affiliates.
33+
* All rights reserved.
34+
*
35+
* This source code is licensed under the BSD-style license found in the
36+
* LICENSE file in the root directory of this source tree.
37+
*/"""
38+
39+
40+
def symbol_base(stem: str) -> str:
41+
"""snake_case shader stem -> PascalCase symbol base (binary_add -> BinaryAdd)."""
42+
return "".join(part.capitalize() for part in stem.split("_"))
43+
44+
45+
_INT_LITERAL_RE = re.compile(r"^(\d+)[uUiI]?$")
46+
47+
48+
def parse_workgroup_size(src: str) -> int:
49+
"""Resolve the x dim of @workgroup_size: a literal, or an override/const ident.
50+
51+
Accepts WGSL suffix-typed integer literals (e.g. `64u`, `64i`) both inside
52+
`@workgroup_size(...)` and on the right-hand side of an `override`/`const`.
53+
The type annotation on `override`/`const` is optional.
54+
"""
55+
m = re.search(r"@workgroup_size\(\s*([A-Za-z0-9_]+)", src)
56+
if not m:
57+
raise ValueError("no @workgroup_size found")
58+
tok = m.group(1)
59+
lit = _INT_LITERAL_RE.match(tok)
60+
if lit:
61+
return int(lit.group(1))
62+
m2 = re.search(
63+
r"(?:override|const)\s+"
64+
+ re.escape(tok)
65+
+ r"\s*(?::\s*u32\s*)?=\s*(\d+)[uUiI]?",
66+
src,
67+
)
68+
if not m2:
69+
raise ValueError(f"cannot resolve @workgroup_size identifier '{tok}'")
70+
return int(m2.group(1))
71+
72+
73+
def wgsl_sha256(wgsl_text: str) -> str:
74+
return hashlib.sha256(wgsl_text.encode("utf-8")).hexdigest()
75+
76+
77+
def embedded_sha256(header_text: str) -> str:
78+
m = _SHA_RE.search(header_text)
79+
return m.group(1) if m else ""
80+
81+
82+
def render_header(wgsl_path, wgsl_text: str) -> str:
83+
"""Render the full <stem>_wgsl.h text for a shader (shader embedded verbatim)."""
84+
if ')"' in wgsl_text:
85+
raise ValueError('shader contains )" which would close the R"( literal')
86+
stem = Path(wgsl_path).stem
87+
base = symbol_base(stem)
88+
n = parse_workgroup_size(wgsl_text)
89+
90+
head = [
91+
_BSD_HEADER,
92+
"",
93+
"#pragma once",
94+
"",
95+
"#include <cstdint>",
96+
"",
97+
"namespace executorch::backends::webgpu {",
98+
"",
99+
f"// @generated from {stem}.wgsl - DO NOT EDIT.",
100+
f"// wgsl-sha256: {wgsl_sha256(wgsl_text)}",
101+
f'inline constexpr const char* k{base}WGSL = R"(',
102+
]
103+
return (
104+
"\n".join(head)
105+
+ "\n"
106+
+ wgsl_text
107+
+ ')";'
108+
+ "\n\n"
109+
+ f"inline constexpr uint32_t k{base}WorkgroupSize = {n};\n\n"
110+
+ "} // namespace executorch::backends::webgpu\n"
111+
)
112+
113+
114+
def discover():
115+
"""All shader sources under runtime/ops, sorted."""
116+
return sorted((BACKEND_ROOT / "runtime/ops").glob("**/*.wgsl"))
117+
118+
119+
def main(argv=None) -> int:
120+
parser = argparse.ArgumentParser(description=__doc__)
121+
parser.add_argument(
122+
"--check",
123+
action="store_true",
124+
help="verify committed headers match (exit 1 on drift)",
125+
)
126+
args = parser.parse_args(argv)
127+
128+
stale = []
129+
missing = []
130+
errors = []
131+
for wgsl in discover():
132+
wgsl_text = wgsl.read_text()
133+
try:
134+
want = render_header(wgsl, wgsl_text)
135+
except ValueError as e:
136+
errors.append(f"{wgsl.relative_to(BACKEND_ROOT)}: {e}")
137+
continue
138+
header = wgsl.with_name(wgsl.stem + "_wgsl.h")
139+
if header.exists() and embedded_sha256(header.read_text()) == wgsl_sha256(
140+
wgsl_text
141+
):
142+
continue
143+
if args.check:
144+
(missing if not header.exists() else stale).append(header)
145+
else:
146+
header.write_text(want)
147+
148+
if errors:
149+
print("Cannot generate header (malformed shader):")
150+
for e in errors:
151+
print(f" {e}")
152+
return 1
153+
if args.check and (stale or missing):
154+
if missing:
155+
print("Missing embedded WGSL headers (run scripts/gen_wgsl_headers.py):")
156+
for h in missing:
157+
print(f" {h.relative_to(BACKEND_ROOT)}")
158+
if stale:
159+
print("Stale embedded WGSL headers (run scripts/gen_wgsl_headers.py):")
160+
for h in stale:
161+
print(f" {h.relative_to(BACKEND_ROOT)}")
162+
return 1
163+
return 0
164+
165+
166+
if __name__ == "__main__":
167+
sys.exit(main())

backends/webgpu/test/test_build_webgpu.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ EXECUTORCH_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
1515
PYTHON_EXECUTABLE="${PYTHON_EXECUTABLE:-python3}"
1616
NPROC=$(nproc 2>/dev/null || sysctl -n hw.ncpu)
1717

18+
echo "=== Check embedded WGSL headers are up to date ==="
19+
"${PYTHON_EXECUTABLE}" "${SCRIPT_DIR}/../scripts/gen_wgsl_headers.py" --check \
20+
|| { echo "ERROR: *_wgsl.h out of sync with .wgsl; run scripts/gen_wgsl_headers.py"; exit 1; }
21+
1822
# ── Step 1: Python export tests ──────────────────────────────────────────────
1923

2024
echo "=== Step 1: Run Python export tests ==="
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Unit + drift tests for the embedded-WGSL-header generator.
8+
9+
Loads the generator by file path (no package/namespace dependency).
10+
"""
11+
12+
import hashlib
13+
import importlib.util
14+
import unittest
15+
from pathlib import Path
16+
17+
_GEN = Path(__file__).resolve().parents[1] / "scripts" / "gen_wgsl_headers.py"
18+
_spec = importlib.util.spec_from_file_location("gen_wgsl_headers", _GEN)
19+
g = importlib.util.module_from_spec(_spec)
20+
_spec.loader.exec_module(g)
21+
22+
23+
class WgslCodegenTest(unittest.TestCase):
24+
def test_symbol_base(self) -> None:
25+
self.assertEqual(g.symbol_base("binary_add"), "BinaryAdd")
26+
self.assertEqual(
27+
g.symbol_base("sdpa_compute_attn_weights"), "SdpaComputeAttnWeights"
28+
)
29+
self.assertEqual(g.symbol_base("update_cache"), "UpdateCache")
30+
self.assertEqual(g.symbol_base("rms_norm"), "RmsNorm")
31+
32+
def test_parse_workgroup_literal(self) -> None:
33+
self.assertEqual(
34+
g.parse_workgroup_size("@compute @workgroup_size(64, 1, 1)\nfn main(){}"),
35+
64,
36+
)
37+
38+
def test_parse_workgroup_override_indirection(self) -> None:
39+
src = "override wg_size: u32 = 256;\n@compute @workgroup_size(wg_size)\nfn main(){}"
40+
self.assertEqual(g.parse_workgroup_size(src), 256)
41+
42+
def test_parse_workgroup_suffix_typed_literal(self) -> None:
43+
self.assertEqual(
44+
g.parse_workgroup_size("@compute @workgroup_size(64u, 1, 1)\nfn main(){}"),
45+
64,
46+
)
47+
48+
def test_parse_workgroup_const_without_type_annotation(self) -> None:
49+
src = "const WG = 128u;\n@compute @workgroup_size(WG)\nfn main(){}"
50+
self.assertEqual(g.parse_workgroup_size(src), 128)
51+
52+
def test_parse_workgroup_not_fooled_by_const(self) -> None:
53+
# rms_norm/softmax shape: a sibling `const WG_SIZE` beside a LITERAL size.
54+
src = (
55+
"const WG_SIZE: u32 = 64u;\n@compute @workgroup_size(64, 1, 1)\nfn main(){}"
56+
)
57+
self.assertEqual(g.parse_workgroup_size(src), 64)
58+
59+
def test_render_header_shape(self) -> None:
60+
wgsl = "@compute @workgroup_size(64, 1, 1)\nfn main(){}\n"
61+
h = g.render_header(Path("runtime/ops/update_cache/update_cache.wgsl"), wgsl)
62+
self.assertIn("#pragma once", h)
63+
self.assertIn("#include <cstdint>", h)
64+
self.assertIn("namespace executorch::backends::webgpu {", h)
65+
self.assertIn("// @generated from update_cache.wgsl - DO NOT EDIT.", h)
66+
self.assertIn('inline constexpr const char* kUpdateCacheWGSL = R"(', h)
67+
self.assertIn("inline constexpr uint32_t kUpdateCacheWorkgroupSize = 64;", h)
68+
self.assertNotIn("Confidential", h)
69+
# the shader is embedded verbatim:
70+
body = h.split('R"(', 1)[1].split(')";', 1)[0]
71+
self.assertEqual(body, "\n" + wgsl)
72+
self.assertTrue(h.endswith("\n"))
73+
74+
def test_render_header_embeds_sha256(self) -> None:
75+
wgsl = "@compute @workgroup_size(64, 1, 1)\nfn main(){}\n"
76+
h = g.render_header(Path("runtime/ops/update_cache/update_cache.wgsl"), wgsl)
77+
want = hashlib.sha256(wgsl.encode("utf-8")).hexdigest()
78+
self.assertIn(f"// wgsl-sha256: {want}", h)
79+
self.assertEqual(g.embedded_sha256(h), want)
80+
self.assertEqual(g.wgsl_sha256(wgsl), want)
81+
82+
def test_embedded_sha256_missing_returns_empty(self) -> None:
83+
self.assertEqual(g.embedded_sha256("no sha line here\n"), "")
84+
85+
def test_sha256_changes_with_shader(self) -> None:
86+
a = g.wgsl_sha256("@compute @workgroup_size(64, 1, 1)\nfn main(){}\n")
87+
b = g.wgsl_sha256("@compute @workgroup_size(256)\nfn main(){}\n")
88+
self.assertNotEqual(a, b)
89+
90+
def test_committed_headers_match_generator(self) -> None:
91+
wgsls = g.discover()
92+
self.assertGreater(len(wgsls), 0, "no .wgsl shaders discovered")
93+
for wgsl in wgsls:
94+
want = g.render_header(wgsl, wgsl.read_text())
95+
got = wgsl.with_name(wgsl.stem + "_wgsl.h").read_text()
96+
self.assertEqual(
97+
got, want, f"{wgsl.stem}_wgsl.h stale; run scripts/gen_wgsl_headers.py"
98+
)

0 commit comments

Comments
 (0)