|
12 | 12 |
|
13 | 13 | #include <webgpu/webgpu.h> |
14 | 14 |
|
| 15 | +#include <algorithm> |
15 | 16 | #include <cmath> |
16 | 17 | #include <cstring> |
17 | 18 |
|
@@ -50,6 +51,26 @@ void add_impl(WebGPUGraph& graph, const std::vector<int>& args) { |
50 | 51 | uint32_t num_elements = |
51 | 52 | static_cast<uint32_t>(out_tensor.nbytes / sizeof(float)); |
52 | 53 |
|
| 54 | + // Clamp the workgroup size to the device limit (SwiftShader caps at 128). |
| 55 | + WGPULimits limits = {}; |
| 56 | + uint32_t device_max = |
| 57 | + wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success && |
| 58 | + limits.maxComputeInvocationsPerWorkgroup > 0 |
| 59 | + ? limits.maxComputeInvocationsPerWorkgroup |
| 60 | + : kBinaryAddWorkgroupSize; |
| 61 | + uint32_t wg_size = std::min(kBinaryAddWorkgroupSize, device_max); |
| 62 | + uint32_t workgroup_count = (num_elements + wg_size - 1) / wg_size; |
| 63 | + |
| 64 | + // Validate the 1D dispatch limit before allocating any GPU objects. |
| 65 | + if (workgroup_count > 65535u) { |
| 66 | + throw std::runtime_error( |
| 67 | + "WebGPU add: workgroup count exceeds the 1D dispatch limit (65535)"); |
| 68 | + } |
| 69 | + |
| 70 | + WGPUConstantEntry wg_size_constant = {}; |
| 71 | + wg_size_constant.key = {"wg_size", WGPU_STRLEN}; |
| 72 | + wg_size_constant.value = static_cast<double>(wg_size); |
| 73 | + |
53 | 74 | // Create uniform buffer for params |
54 | 75 | AddParams params = {}; |
55 | 76 | params.num_elements = num_elements; |
@@ -115,6 +136,8 @@ void add_impl(WebGPUGraph& graph, const std::vector<int>& args) { |
115 | 136 | pipeline_desc.layout = pipeline_layout; |
116 | 137 | pipeline_desc.compute.module = shader; |
117 | 138 | pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; |
| 139 | + pipeline_desc.compute.constantCount = 1; |
| 140 | + pipeline_desc.compute.constants = &wg_size_constant; |
118 | 141 | WGPUComputePipeline pipeline = |
119 | 142 | wgpuDeviceCreateComputePipeline(device, &pipeline_desc); |
120 | 143 |
|
@@ -146,16 +169,14 @@ void add_impl(WebGPUGraph& graph, const std::vector<int>& args) { |
146 | 169 | bg_desc.entries = bg_entries; |
147 | 170 | WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); |
148 | 171 |
|
149 | | - uint32_t workgroup_count = |
150 | | - (num_elements + kBinaryAddWorkgroupSize - 1) / kBinaryAddWorkgroupSize; |
151 | | - |
152 | 172 | graph.add_dispatch({pipeline, bind_group, workgroup_count}); |
153 | 173 |
|
154 | 174 | // Release intermediate objects (pipeline and bind_group are kept by dispatch) |
155 | 175 | wgpuShaderModuleRelease(shader); |
156 | 176 | wgpuBindGroupLayoutRelease(bgl); |
157 | 177 | wgpuPipelineLayoutRelease(pipeline_layout); |
158 | | - // uniform_buffer is kept alive by the bind group |
| 178 | + // Drop our ref; the bind group keeps the uniform buffer alive until release. |
| 179 | + wgpuBufferRelease(uniform_buffer); |
159 | 180 | } |
160 | 181 |
|
161 | 182 | } // namespace |
|
0 commit comments