Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src_denoising/models/non_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ def log1mexp(x, expm1_guard = 1e-7):
# for x close to 0 we need expm1 for numerically stable computation
# we furtmermore modify the backward pass to avoid instable gradients,
# ie situations where the incoming output gradient is close to 0 and the gradient of expm1 is very large
expxm1 = torch.expm1(x[1 - t])
expxm1 = torch.expm1(x[~t])
log1mexp_fw = (-expxm1).log()
log1mexp_bw = (-expxm1+expm1_guard).log() # limits magnitude of gradient

y[1 - t] = log1mexp_fw.detach() + (log1mexp_bw - log1mexp_bw.detach())
y[~t] = log1mexp_fw.detach() + (log1mexp_bw - log1mexp_bw.detach())
return y


Expand Down
6 changes: 3 additions & 3 deletions src_denoising/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def forward(ctx, x, y, I, chunk_size=64):

If = I_chunk.view(b,1,this_chunk_size,o).expand(b,k,this_chunk_size,o)
y_full = torch.cuda.FloatTensor(b,k,this_chunk_size,n).fill_(0)
y_full = y_full.scatter_add(source=y_chunk.permute(0,3,1,2), index=If, dim=3)
y_full.scatter_add_(3, If.long(), y_chunk.permute(0,3,1,2))
z_interm = torch.cat([torch.matmul(y_full[:,i_k:i_k+1,:,:], x_interm) for i_k in range(k)], 1)
z_chunk = z_interm.permute(0,2,3,1)
z_chunks.append(z_chunk)
Expand All @@ -90,7 +90,7 @@ def backward(ctx, grad):
If = I_chunk.view(b,1,this_chunk_size,o).expand(b,k,this_chunk_size,o)
del I_chunk
y_full = torch.cuda.FloatTensor(b,k,this_chunk_size,n).fill_(0)
y_full = y_full.scatter_add(source=y_chunk.permute(0,3,1,2), index=If, dim=3)
y_full.scatter_add_(3, If.long(), y_chunk.permute(0,3,1,2))
del y_chunk

for i_k in range(k):
Expand Down Expand Up @@ -206,4 +206,4 @@ def _finfo(tensor):

def clamp_probs(probs):
eps = _finfo(probs).eps
return probs.clamp(min=eps, max=1 - eps)
return probs.clamp(min=eps, max=1 - eps)