@@ -64,6 +64,10 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
6464 in_numel *= static_cast <uint64_t >(d);
6565 }
6666 const uint32_t M = static_cast <uint32_t >(in_numel / K);
67+ if (in_numel % K != 0 ) {
68+ throw std::runtime_error (
69+ " WebGPU linear_q4gsw: input numel not a multiple of K" );
70+ }
6771 const uint32_t N = static_cast <uint32_t >(weight.dims [0 ]);
6872 const uint32_t K_packed = static_cast <uint32_t >(weight.dims [1 ]);
6973 const uint32_t num_groups = static_cast <uint32_t >(scales.dims [0 ]);
@@ -75,6 +79,11 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
7579 if (K_packed != (K + 1 ) / 2 ) {
7680 throw std::runtime_error (" WebGPU linear_q4gsw: K_packed must be ceil(K/2)" );
7781 }
82+ // Weight is read as array<u32>; a non-multiple-of-4 byte count over-reads.
83+ if ((static_cast <uint64_t >(N) * K_packed) % 4u != 0u ) {
84+ throw std::runtime_error (
85+ " WebGPU linear_q4gsw: N*K_packed must be a multiple of 4 (u32-packed)" );
86+ }
7887
7988 // One workgroup per output row (M); validate dispatch before any alloc.
8089 const uint32_t workgroup_count =
@@ -100,6 +109,12 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
100109 if (group_size <= 0 ) {
101110 throw std::runtime_error (" WebGPU linear_q4gsw: group_size <= 0" );
102111 }
112+ // scales is indexed [(k/group_size)*padded_N + n]; guard the table bounds.
113+ const uint32_t gs = static_cast <uint32_t >(group_size);
114+ if (num_groups < (K + gs - 1u ) / gs || padded_N < N) {
115+ throw std::runtime_error (
116+ " WebGPU linear_q4gsw: scales dims too small for K/N" );
117+ }
103118
104119 // Optional bias: real buffer if present, else a dummy for the fixed layout.
105120 uint32_t has_bias = 0 ;
@@ -117,15 +132,14 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
117132 }
118133 if (bias_buffer == nullptr ) {
119134 bias_buffer = graph.create_scratch_buffer (4 );
120- bias_size = 4 ;
121135 }
122136
123137 Q4gswParams params = {};
124138 params.M = M;
125139 params.N = N;
126140 params.K = K;
127141 params.K_packed = K_packed;
128- params.group_size = static_cast < uint32_t >(group_size) ;
142+ params.group_size = gs ;
129143 params.padded_N = padded_N;
130144 params.has_bias = has_bias;
131145
0 commit comments