diff --git a/msccl/language/collectives.py b/msccl/language/collectives.py index 2e4fa73..6550ccb 100755 --- a/msccl/language/collectives.py +++ b/msccl/language/collectives.py @@ -105,6 +105,10 @@ def init_buffers(self): output_buffer = [None] * (self.num_ranks * self.chunk_factor) for ch in range(self.chunk_factor): input_buffer[ch] = Chunk(r, ch, -1, r * self.chunk_factor + ch) + if self.create_all_chunks: + for rank in range(self.num_ranks): + for ch in range(self.chunk_factor): + output_buffer[rank * self.chunk_factor + ch] = Chunk(rank, ch, -1, rank * self.chunk_factor + ch) buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer} rank_buffers.append(buffers) return rank_buffers