Skip to content

Commit 8e81436

Browse files
Update
[ghstack-poisoned]
1 parent bab43f6 commit 8e81436

1 file changed

Lines changed: 16 additions & 2 deletions

File tree

backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
6464
in_numel *= static_cast<uint64_t>(d);
6565
}
6666
const uint32_t M = static_cast<uint32_t>(in_numel / K);
67+
if (in_numel % K != 0) {
68+
throw std::runtime_error(
69+
"WebGPU linear_q4gsw: input numel not a multiple of K");
70+
}
6771
const uint32_t N = static_cast<uint32_t>(weight.dims[0]);
6872
const uint32_t K_packed = static_cast<uint32_t>(weight.dims[1]);
6973
const uint32_t num_groups = static_cast<uint32_t>(scales.dims[0]);
@@ -75,6 +79,11 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
7579
if (K_packed != (K + 1) / 2) {
7680
throw std::runtime_error("WebGPU linear_q4gsw: K_packed must be ceil(K/2)");
7781
}
82+
// Weight is read as array<u32>; a non-multiple-of-4 byte count over-reads.
83+
if ((static_cast<uint64_t>(N) * K_packed) % 4u != 0u) {
84+
throw std::runtime_error(
85+
"WebGPU linear_q4gsw: N*K_packed must be a multiple of 4 (u32-packed)");
86+
}
7887

7988
// One workgroup per output row (M); validate dispatch before any alloc.
8089
const uint32_t workgroup_count =
@@ -100,6 +109,12 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
100109
if (group_size <= 0) {
101110
throw std::runtime_error("WebGPU linear_q4gsw: group_size <= 0");
102111
}
112+
// scales is indexed [(k/group_size)*padded_N + n]; guard the table bounds.
113+
const uint32_t gs = static_cast<uint32_t>(group_size);
114+
if (num_groups < (K + gs - 1u) / gs || padded_N < N) {
115+
throw std::runtime_error(
116+
"WebGPU linear_q4gsw: scales dims too small for K/N");
117+
}
103118

104119
// Optional bias: real buffer if present, else a dummy for the fixed layout.
105120
uint32_t has_bias = 0;
@@ -117,15 +132,14 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
117132
}
118133
if (bias_buffer == nullptr) {
119134
bias_buffer = graph.create_scratch_buffer(4);
120-
bias_size = 4;
121135
}
122136

123137
Q4gswParams params = {};
124138
params.M = M;
125139
params.N = N;
126140
params.K = K;
127141
params.K_packed = K_packed;
128-
params.group_size = static_cast<uint32_t>(group_size);
142+
params.group_size = gs;
129143
params.padded_N = padded_N;
130144
params.has_bias = has_bias;
131145

0 commit comments

Comments
 (0)