Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions encodec/balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ class Balancer:
losses['loss_b'] = compute_loss_b(x, y)
if model.training():
balancer.backward(losses, x)

Expected usage for different loss inputs
weights = {'loss_a': 1, 'loss_b': 4}
balancer = Balancer(weights, ...)
losses: dict = {}
losses['loss_a'] = compute_loss_a(x1, y1)
losses['loss_b'] = compute_loss_b(x2, y2)
if model.training():
balancer.backward(losses, [x1, x2])

..Warning:: It is unclear how this will interact with DistributedDataParallel,
in particular if you have some losses not handled by the balancer. In that case
Expand Down Expand Up @@ -80,10 +89,15 @@ def __init__(self, weights: tp.Dict[str, float], rescale_grads: bool = True, tot
def metrics(self):
return self._metrics

def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor):
def backward(self, losses: tp.Dict[str, torch.Tensor], inputs: tp.List[torch.Tensor]):
if type(inputs) != list:
inputs = [inputs] * len(losses)
elif len(inputs) == 1:
inputs = inputs * len(losses)

norms = {}
grads = {}
for name, loss in losses.items():
for name, loss, input in zip(losses.keys(), losses.values(), inputs):
grad, = autograd.grad(loss, [input], retain_graph=True)
if self.per_batch_item:
dims = tuple(range(1, grad.dim()))
Expand All @@ -107,15 +121,18 @@ def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor):
total_weights = sum([self.weights[k] for k in avg_norms])
ratios = {k: w / total_weights for k, w in self.weights.items()}

out_grad: tp.Any = 0
for name, avg_norm in avg_norms.items():
i = 0
for name, avg_norm, input in zip(avg_norms.keys(), avg_norms.values(), inputs):
if self.rescale_grads:
scale = ratios[name] * self.total_norm / (self.epsilon + avg_norm)
grad = grads[name] * scale
else:
grad = self.weights[name] * grads[name]
out_grad += grad
input.backward(out_grad)
if i != len(inputs) - 1:
input.backward(grad, retain_graph=True)
else:
input.backward(grad)
i += 1


def test():
Expand Down