diff --git a/backends/cuda/tests/test_int4_dispatch.py b/backends/cuda/tests/test_int4_dispatch.py index fd748ae8584..ecf1a53e48e 100644 --- a/backends/cuda/tests/test_int4_dispatch.py +++ b/backends/cuda/tests/test_int4_dispatch.py @@ -59,7 +59,10 @@ def _make_int4_linear(N, K, group_size=128, symmetric=False, bias=False): ) int4_w = quantize_weight(w_bf16, config) - module = nn.Linear(K, N, bias=bias, dtype=torch.bfloat16) + # device="cuda" so the random init draws from the CUDA RNG to match the + # same random weight as regular int4 dispatch and fit the same numerical + # error tolerance. + module = nn.Linear(K, N, bias=bias, dtype=torch.bfloat16, device="cuda") pack_linear_for_cuda(module, {"weight": int4_w}) module.cuda() return module, w_bf16.cuda()