From df563a528b559efe82806eb6cbf7e73a1ec7de24 Mon Sep 17 00:00:00 2001 From: Gabriel Mongaras Date: Sat, 5 Aug 2023 10:14:34 -0700 Subject: [PATCH] Balancer for different loss inputs --- encodec/balancer.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/encodec/balancer.py b/encodec/balancer.py index 0cb70e1..96202fc 100644 --- a/encodec/balancer.py +++ b/encodec/balancer.py @@ -45,6 +45,15 @@ class Balancer: losses['loss_b'] = compute_loss_b(x, y) if model.training(): balancer.backward(losses, x) + + Expected usage for different loss inputs + weights = {'loss_a': 1, 'loss_b': 4} + balancer = Balancer(weights, ...) + losses: dict = {} + losses['loss_a'] = compute_loss_a(x1, y1) + losses['loss_b'] = compute_loss_b(x2, y2) + if model.training(): + balancer.backward(losses, [x1, x2]) ..Warning:: It is unclear how this will interact with DistributedDataParallel, in particular if you have some losses not handled by the balancer. In that case @@ -80,10 +89,15 @@ def __init__(self, weights: tp.Dict[str, float], rescale_grads: bool = True, tot def metrics(self): return self._metrics - def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor): + def backward(self, losses: tp.Dict[str, torch.Tensor], inputs: tp.List[torch.Tensor]): + if type(inputs) != list: + inputs = [inputs] * len(losses) + elif len(inputs) == 1: + inputs = inputs * len(losses) + norms = {} grads = {} - for name, loss in losses.items(): + for name, loss, input in zip(losses.keys(), losses.values(), inputs): grad, = autograd.grad(loss, [input], retain_graph=True) if self.per_batch_item: dims = tuple(range(1, grad.dim())) @@ -107,15 +121,18 @@ def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor): total_weights = sum([self.weights[k] for k in avg_norms]) ratios = {k: w / total_weights for k, w in self.weights.items()} - out_grad: tp.Any = 0 - for name, avg_norm in avg_norms.items(): + i = 0 + for name, avg_norm, input in zip(avg_norms.keys(), avg_norms.values(), inputs): if self.rescale_grads: scale = ratios[name] * self.total_norm / (self.epsilon + avg_norm) grad = grads[name] * scale else: grad = self.weights[name] * grads[name] - out_grad += grad - input.backward(out_grad) + if i != len(inputs) - 1: + input.backward(grad, retain_graph=True) + else: + input.backward(grad) + i += 1 def test():