|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +"""Unit + drift tests for the embedded-WGSL-header generator. |
| 8 | +
|
| 9 | +Loads the generator by file path (no package/namespace dependency). |
| 10 | +""" |
| 11 | + |
| 12 | +import hashlib |
| 13 | +import importlib.util |
| 14 | +import unittest |
| 15 | +from pathlib import Path |
| 16 | + |
| 17 | +_GEN = Path(__file__).resolve().parents[1] / "scripts" / "gen_wgsl_headers.py" |
| 18 | +_spec = importlib.util.spec_from_file_location("gen_wgsl_headers", _GEN) |
| 19 | +g = importlib.util.module_from_spec(_spec) |
| 20 | +_spec.loader.exec_module(g) |
| 21 | + |
| 22 | + |
| 23 | +class WgslCodegenTest(unittest.TestCase): |
| 24 | + def test_symbol_base(self) -> None: |
| 25 | + self.assertEqual(g.symbol_base("binary_add"), "BinaryAdd") |
| 26 | + self.assertEqual( |
| 27 | + g.symbol_base("sdpa_compute_attn_weights"), "SdpaComputeAttnWeights" |
| 28 | + ) |
| 29 | + self.assertEqual(g.symbol_base("update_cache"), "UpdateCache") |
| 30 | + self.assertEqual(g.symbol_base("rms_norm"), "RmsNorm") |
| 31 | + |
| 32 | + def test_parse_workgroup_literal(self) -> None: |
| 33 | + self.assertEqual( |
| 34 | + g.parse_workgroup_size("@compute @workgroup_size(64, 1, 1)\nfn main(){}"), |
| 35 | + 64, |
| 36 | + ) |
| 37 | + |
| 38 | + def test_parse_workgroup_override_indirection(self) -> None: |
| 39 | + src = "override wg_size: u32 = 256;\n@compute @workgroup_size(wg_size)\nfn main(){}" |
| 40 | + self.assertEqual(g.parse_workgroup_size(src), 256) |
| 41 | + |
| 42 | + def test_parse_workgroup_suffix_typed_literal(self) -> None: |
| 43 | + self.assertEqual( |
| 44 | + g.parse_workgroup_size("@compute @workgroup_size(64u, 1, 1)\nfn main(){}"), |
| 45 | + 64, |
| 46 | + ) |
| 47 | + |
| 48 | + def test_parse_workgroup_const_without_type_annotation(self) -> None: |
| 49 | + src = "const WG = 128u;\n@compute @workgroup_size(WG)\nfn main(){}" |
| 50 | + self.assertEqual(g.parse_workgroup_size(src), 128) |
| 51 | + |
| 52 | + def test_parse_workgroup_not_fooled_by_const(self) -> None: |
| 53 | + # rms_norm/softmax shape: a sibling `const WG_SIZE` beside a LITERAL size. |
| 54 | + src = ( |
| 55 | + "const WG_SIZE: u32 = 64u;\n@compute @workgroup_size(64, 1, 1)\nfn main(){}" |
| 56 | + ) |
| 57 | + self.assertEqual(g.parse_workgroup_size(src), 64) |
| 58 | + |
| 59 | + def test_render_header_shape(self) -> None: |
| 60 | + wgsl = "@compute @workgroup_size(64, 1, 1)\nfn main(){}\n" |
| 61 | + h = g.render_header(Path("runtime/ops/update_cache/update_cache.wgsl"), wgsl) |
| 62 | + self.assertIn("#pragma once", h) |
| 63 | + self.assertIn("#include <cstdint>", h) |
| 64 | + self.assertIn("namespace executorch::backends::webgpu {", h) |
| 65 | + self.assertIn("// @generated from update_cache.wgsl - DO NOT EDIT.", h) |
| 66 | + self.assertIn('inline constexpr const char* kUpdateCacheWGSL = R"(', h) |
| 67 | + self.assertIn("inline constexpr uint32_t kUpdateCacheWorkgroupSize = 64;", h) |
| 68 | + self.assertNotIn("Confidential", h) |
| 69 | + # the shader is embedded verbatim: |
| 70 | + body = h.split('R"(', 1)[1].split(')";', 1)[0] |
| 71 | + self.assertEqual(body, "\n" + wgsl) |
| 72 | + self.assertTrue(h.endswith("\n")) |
| 73 | + |
| 74 | + def test_render_header_embeds_sha256(self) -> None: |
| 75 | + wgsl = "@compute @workgroup_size(64, 1, 1)\nfn main(){}\n" |
| 76 | + h = g.render_header(Path("runtime/ops/update_cache/update_cache.wgsl"), wgsl) |
| 77 | + want = hashlib.sha256(wgsl.encode("utf-8")).hexdigest() |
| 78 | + self.assertIn(f"// wgsl-sha256: {want}", h) |
| 79 | + self.assertEqual(g.embedded_sha256(h), want) |
| 80 | + self.assertEqual(g.wgsl_sha256(wgsl), want) |
| 81 | + |
| 82 | + def test_embedded_sha256_missing_returns_empty(self) -> None: |
| 83 | + self.assertEqual(g.embedded_sha256("no sha line here\n"), "") |
| 84 | + |
| 85 | + def test_sha256_changes_with_shader(self) -> None: |
| 86 | + a = g.wgsl_sha256("@compute @workgroup_size(64, 1, 1)\nfn main(){}\n") |
| 87 | + b = g.wgsl_sha256("@compute @workgroup_size(256)\nfn main(){}\n") |
| 88 | + self.assertNotEqual(a, b) |
| 89 | + |
| 90 | + def test_committed_headers_match_generator(self) -> None: |
| 91 | + wgsls = g.discover() |
| 92 | + self.assertGreater(len(wgsls), 0, "no .wgsl shaders discovered") |
| 93 | + for wgsl in wgsls: |
| 94 | + want = g.render_header(wgsl, wgsl.read_text()) |
| 95 | + got = wgsl.with_name(wgsl.stem + "_wgsl.h").read_text() |
| 96 | + self.assertEqual( |
| 97 | + got, want, f"{wgsl.stem}_wgsl.h stale; run scripts/gen_wgsl_headers.py" |
| 98 | + ) |
0 commit comments