From baa44cfd7565a4e15c151e1cb1dc9616cbce8301 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Tue, 11 Mar 2025 22:24:23 -0700 Subject: [PATCH] Update deprecated sparse tensor construction Previously this code would raise the following: ``` linear_operator\utils\interpolation.py:71: UserWarning: torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated. Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=). (Triggered internally at ..\torch\csrc\utils\tensor_new.cpp:620.) summing_matrix = cls(summing_matrix_indices, summing_matrix_values, size) ``` This updates the code to avoid this. --- linear_operator/utils/interpolation.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/linear_operator/utils/interpolation.py b/linear_operator/utils/interpolation.py index 73faee18..92f8cce9 100644 --- a/linear_operator/utils/interpolation.py +++ b/linear_operator/utils/interpolation.py @@ -63,12 +63,13 @@ def left_t_interp(interp_indices, interp_values, rhs, output_dim): device=interp_values.device, ) size = torch.Size((batch_size, output_dim, num_data * num_interp)) - type_name = summing_matrix_values.type().split(".")[-1] # e.g. FloatTensor - if interp_values.is_cuda: - cls = getattr(torch.cuda.sparse, type_name) - else: - cls = getattr(torch.sparse, type_name) - summing_matrix = cls(summing_matrix_indices, summing_matrix_values, size) + summing_matrix = torch.sparse_coo_tensor( + summing_matrix_indices, + summing_matrix_values, + size, + dtype=summing_matrix_values.dtype, + device=summing_matrix_values.device, + ) # Sum up the values appropriately by performing sparse matrix multiplication values = values.reshape(batch_size, num_data * num_interp, num_cols)