From 530e835f774b1b4c7e5ade919b4495de2b7759db Mon Sep 17 00:00:00 2001 From: Caio Date: Thu, 12 Dec 2024 01:52:29 +0000 Subject: [PATCH 1/2] adding option for create all chuncks in outplace case for allgather collective --- msccl/language/collectives.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/msccl/language/collectives.py b/msccl/language/collectives.py index 2e4fa73..6ea0ab3 100755 --- a/msccl/language/collectives.py +++ b/msccl/language/collectives.py @@ -103,8 +103,13 @@ def init_buffers(self): for r in range(self.num_ranks): input_buffer = [None] * self.chunk_factor output_buffer = [None] * (self.num_ranks * self.chunk_factor) - for ch in range(self.chunk_factor): - input_buffer[ch] = Chunk(r, ch, -1, r * self.chunk_factor + ch) + if not self.create_all_chunks: + for ch in range(self.chunk_factor): + input_buffer[ch] = Chunk(r, ch, -1, r * self.chunk_factor + ch) + else: + for rank in range(self.num_ranks): + for ch in range(self.chunk_factor): + output_buffer[rank * self.chunk_factor + ch] = Chunk(rank, ch, -1, rank * self.chunk_factor + ch) buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer} rank_buffers.append(buffers) return rank_buffers From 23ac27ec6b30276d6079092d68430c38db75197c Mon Sep 17 00:00:00 2001 From: Caio Date: Thu, 12 Dec 2024 18:15:34 +0000 Subject: [PATCH 2/2] WIP --- msccl/language/collectives.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/msccl/language/collectives.py b/msccl/language/collectives.py index 6ea0ab3..6550ccb 100755 --- a/msccl/language/collectives.py +++ b/msccl/language/collectives.py @@ -103,10 +103,9 @@ def init_buffers(self): for r in range(self.num_ranks): input_buffer = [None] * self.chunk_factor output_buffer = [None] * (self.num_ranks * self.chunk_factor) - if not self.create_all_chunks: - for ch in range(self.chunk_factor): - input_buffer[ch] = Chunk(r, ch, -1, r * self.chunk_factor + ch) - else: + for ch in range(self.chunk_factor): + input_buffer[ch] = Chunk(r, ch, -1, r * self.chunk_factor + ch) + if self.create_all_chunks: for rank in range(self.num_ranks): for ch in range(self.chunk_factor): output_buffer[rank * self.chunk_factor + ch] = Chunk(rank, ch, -1, rank * self.chunk_factor + ch)