Skip to content

Commit 02f7d4e

Browse files
Update
[ghstack-poisoned]
1 parent 16bbbba commit 02f7d4e

4 files changed

Lines changed: 95 additions & 0 deletions

File tree

backends/webgpu/test/op_tests/cases.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
AddModule,
3131
AddSelfModule,
3232
)
33+
from executorch.backends.webgpu.test.ops.mul.test_mul import (
34+
CONFIGS as _MUL_CONFIGS,
35+
MulModule,
36+
)
3337
from executorch.backends.webgpu.test.ops.rms_norm.test_rms_norm import (
3438
_CASES,
3539
_linspace_weight,
@@ -106,3 +110,14 @@ def _rms_norm_suite() -> WebGPUTestSuite:
106110
)
107111
)
108112
return WebGPUTestSuite(module_factory=_rms_norm_factory, cases=cases)
113+
114+
115+
@register_op_test("mul")
116+
def _mul_suite() -> WebGPUTestSuite:
117+
# Full numeric coverage incl. broadcast (binary_mul.wgsl over a TensorMeta UBO); fp64 golden.
118+
return WebGPUTestSuite(
119+
module_factory=lambda: MulModule(),
120+
cases=[
121+
Case(name=name, inputs=(sa, sb)) for name, (sa, sb) in _MUL_CONFIGS.items()
122+
],
123+
)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""`aten.mul.Tensor` (full broadcast) module + configs for the WebGPU op-test framework.
8+
9+
`MulModule` + `CONFIGS` are imported by `cases.py` to drive the declarative op-test
10+
suite (export via VulkanPartitioner + fp64 torch golden, run on Dawn). `TestMul` is
11+
the export-delegation + eager-correctness smoke test. Configs span the same-shape
12+
fast path (SwiGLU), last-dim broadcast at LLM width, and a mixed-rank left-pad case.
13+
"""
14+
15+
import unittest
16+
17+
import torch
18+
19+
from executorch.backends.vulkan import VulkanPartitioner
20+
from executorch.exir import to_edge_transform_and_lower
21+
22+
# name -> (shape_a, shape_b). Output shape is the broadcast of the two.
23+
CONFIGS = {
24+
"same": ((8, 32), (8, 32)), # fast path (SwiGLU same-shape)
25+
"bcast_lastdim": ((1, 1, 7, 896), (1, 1, 7, 1)), # last-dim broadcast, LLM width
26+
"mixedrank": ((4,), (3, 4)), # right-aligned left-pad (in.ndim < out.ndim)
27+
}
28+
29+
30+
class MulModule(torch.nn.Module):
31+
def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
32+
return a * b
33+
34+
35+
def _det_inputs(shape_a, shape_b):
36+
"""Deterministic fp32 inputs (fixed seed) for a config."""
37+
g = torch.Generator().manual_seed(0)
38+
a = torch.randn(*shape_a, generator=g, dtype=torch.float32)
39+
b = torch.randn(*shape_b, generator=g, dtype=torch.float32)
40+
return a, b
41+
42+
43+
def _export(a: torch.Tensor, b: torch.Tensor):
44+
ep = torch.export.export(MulModule().eval(), (a, b))
45+
return to_edge_transform_and_lower(
46+
ep, partitioner=[VulkanPartitioner()]
47+
).to_executorch()
48+
49+
50+
def _delegated(et) -> bool:
51+
return any(
52+
d.id == "VulkanBackend"
53+
for plan in et.executorch_program.execution_plan
54+
for d in plan.delegates
55+
)
56+
57+
58+
class TestMul(unittest.TestCase):
59+
def test_export_delegates(self) -> None:
60+
for name, (sa, sb) in CONFIGS.items():
61+
a, b = _det_inputs(sa, sb)
62+
et = _export(a, b)
63+
self.assertTrue(
64+
_delegated(et), f"Expected a VulkanBackend delegate (mul {name})"
65+
)
66+
67+
def test_golden_matches_eager(self) -> None:
68+
for _, (sa, sb) in CONFIGS.items():
69+
a, b = _det_inputs(sa, sb)
70+
torch.testing.assert_close(MulModule()(a, b), a * b)
71+
72+
73+
if __name__ == "__main__":
74+
unittest.main()

backends/webgpu/test/tester.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
WEBGPU_SUPPORTED_OPS = [
2222
exir_ops.edge.aten.add.Tensor,
2323
exir_ops.edge.et_vk.rms_norm.default,
24+
exir_ops.edge.aten.mul.Tensor,
2425
]
2526

2627

0 commit comments

Comments
 (0)