diff --git a/src_denoising/models/non_local.py b/src_denoising/models/non_local.py index 7b5adcc..36474ae 100644 --- a/src_denoising/models/non_local.py +++ b/src_denoising/models/non_local.py @@ -100,11 +100,11 @@ def log1mexp(x, expm1_guard = 1e-7): # for x close to 0 we need expm1 for numerically stable computation # we furtmermore modify the backward pass to avoid instable gradients, # ie situations where the incoming output gradient is close to 0 and the gradient of expm1 is very large - expxm1 = torch.expm1(x[1 - t]) + expxm1 = torch.expm1(x[~t]) log1mexp_fw = (-expxm1).log() log1mexp_bw = (-expxm1+expm1_guard).log() # limits magnitude of gradient - y[1 - t] = log1mexp_fw.detach() + (log1mexp_bw - log1mexp_bw.detach()) + y[~t] = log1mexp_fw.detach() + (log1mexp_bw - log1mexp_bw.detach()) return y diff --git a/src_denoising/ops.py b/src_denoising/ops.py index ae40b88..d109d5b 100644 --- a/src_denoising/ops.py +++ b/src_denoising/ops.py @@ -63,7 +63,7 @@ def forward(ctx, x, y, I, chunk_size=64): If = I_chunk.view(b,1,this_chunk_size,o).expand(b,k,this_chunk_size,o) y_full = torch.cuda.FloatTensor(b,k,this_chunk_size,n).fill_(0) - y_full = y_full.scatter_add(source=y_chunk.permute(0,3,1,2), index=If, dim=3) + y_full.scatter_add_(3, If.long(), y_chunk.permute(0,3,1,2)) z_interm = torch.cat([torch.matmul(y_full[:,i_k:i_k+1,:,:], x_interm) for i_k in range(k)], 1) z_chunk = z_interm.permute(0,2,3,1) z_chunks.append(z_chunk) @@ -90,7 +90,7 @@ def backward(ctx, grad): If = I_chunk.view(b,1,this_chunk_size,o).expand(b,k,this_chunk_size,o) del I_chunk y_full = torch.cuda.FloatTensor(b,k,this_chunk_size,n).fill_(0) - y_full = y_full.scatter_add(source=y_chunk.permute(0,3,1,2), index=If, dim=3) + y_full.scatter_add_(3, If.long(), y_chunk.permute(0,3,1,2)) del y_chunk for i_k in range(k): @@ -206,4 +206,4 @@ def _finfo(tensor): def clamp_probs(probs): eps = _finfo(probs).eps - return probs.clamp(min=eps, max=1 - eps) \ No newline at end of file + return probs.clamp(min=eps, max=1 - eps)