diff --git a/linear_operator/utils/interpolation.py b/linear_operator/utils/interpolation.py index 73faee18..92f8cce9 100644 --- a/linear_operator/utils/interpolation.py +++ b/linear_operator/utils/interpolation.py @@ -63,12 +63,13 @@ def left_t_interp(interp_indices, interp_values, rhs, output_dim): device=interp_values.device, ) size = torch.Size((batch_size, output_dim, num_data * num_interp)) - type_name = summing_matrix_values.type().split(".")[-1] # e.g. FloatTensor - if interp_values.is_cuda: - cls = getattr(torch.cuda.sparse, type_name) - else: - cls = getattr(torch.sparse, type_name) - summing_matrix = cls(summing_matrix_indices, summing_matrix_values, size) + summing_matrix = torch.sparse_coo_tensor( + summing_matrix_indices, + summing_matrix_values, + size, + dtype=summing_matrix_values.dtype, + device=summing_matrix_values.device, + ) # Sum up the values appropriately by performing sparse matrix multiplication values = values.reshape(batch_size, num_data * num_interp, num_cols)