diff --git a/encodec/distrib.py b/encodec/distrib.py index b1662d8..695a1c2 100644 --- a/encodec/distrib.py +++ b/encodec/distrib.py @@ -87,7 +87,7 @@ def sync_buffer(buffers, average=True): for buffer, handle in handles: handle.wait() if average: - buffer.data /= world_size + buffer.data /= world_size() def sync_grad(params):