diff --git a/examples/mscclang/allreduce_a100_pcie_hierarchical.py b/examples/mscclang/allreduce_a100_pcie_hierarchical.py index 3ea7460..ae81780 100644 --- a/examples/mscclang/allreduce_a100_pcie_hierarchical.py +++ b/examples/mscclang/allreduce_a100_pcie_hierarchical.py @@ -12,14 +12,16 @@ def allpairs_reduce_scatter(gpuIds, size, offset): if gpuIds[r1] != gpuIds[r2]: index = offset + r2 * size c = chunk(gpuIds[r1], Buffer.input, index, size=size) - c.copy(gpuIds[r2], 'scratch', sendtb=gpuIds[r2], recvtb=gpuIds[r1]) + c.copy(gpuIds[r2], 'scratch', index=r1*size, sendtb=gpuIds[r2], recvtb=gpuIds[r1]) # Each rank performs a local reduction on the nth chunk # Utilize 8 threadblocks for this reduction for better parallelism for r in range(ngpus): - for index in range(0, size * (ngpus-1)): + for index in range(0, size): c = chunk(gpuIds[r], Buffer.input, offset + r*size + (index % size)) - c.reduce(chunk(gpuIds[r], 'scratch', index), sendtb=(index % size)) + for r2 in range(ngpus): + if gpuIds[r] != gpuIds[r2]: + c.reduce(chunk(gpuIds[r], 'scratch', (r2 * size) + index), sendtb=(index % size)) def allpairs_all_gather(gpuIds, size, offset):