Skip to content

Commit e005de4

Browse files
author
Mark Saroufim
committed
update
1 parent 55a291d commit e005de4

1 file changed

Lines changed: 49 additions & 38 deletions

File tree

scripts/test_distributed.py

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,93 @@
1-
import torch
2-
import torch.distributed as dist
3-
import torch.multiprocessing as mp
4-
from multiprocessing import Pool
51
import os
62
import signal
73
import sys
4+
from multiprocessing import Pool
5+
6+
import torch
7+
import torch.distributed as dist
8+
import torch.multiprocessing as mp
9+
810

911
def timeout_handler(signum, frame):
10-
print('✗ TIMEOUT: Process hung')
12+
print("✗ TIMEOUT: Process hung")
1113
sys.exit(1)
1214

15+
1316
def test_worker(args):
1417
rank, world_size, master_port = args
1518
try:
16-
os.environ['MASTER_ADDR'] = '127.0.0.1'
17-
os.environ['MASTER_PORT'] = str(master_port)
18-
os.environ['RANK'] = str(rank)
19-
os.environ['WORLD_SIZE'] = str(world_size)
20-
19+
os.environ["MASTER_ADDR"] = "127.0.0.1"
20+
os.environ["MASTER_PORT"] = str(master_port)
21+
os.environ["RANK"] = str(rank)
22+
os.environ["WORLD_SIZE"] = str(world_size)
23+
2124
signal.signal(signal.SIGALRM, timeout_handler)
2225
signal.alarm(30)
23-
24-
print(f'Rank {rank}: Init NCCL...')
25-
dist.init_process_group("nccl", init_method="env://", rank=rank, world_size=world_size, device_id=torch.device(f'cuda:{rank}'))
26+
27+
print(f"Rank {rank}: Init NCCL...")
28+
dist.init_process_group(
29+
"nccl",
30+
init_method="env://",
31+
rank=rank,
32+
world_size=world_size,
33+
device_id=torch.device(f"cuda:{rank}"),
34+
)
2635
signal.alarm(0)
27-
28-
device = torch.device(f'cuda:{rank}')
36+
37+
device = torch.device(f"cuda:{rank}")
2938
tensor = torch.ones(100, device=device) * rank
30-
39+
3140
signal.alarm(15)
3241
dist.all_reduce(tensor)
3342
signal.alarm(0)
34-
35-
print(f'✓ Rank {rank}: sum = {tensor[0].item()}')
43+
44+
print(f"✓ Rank {rank}: sum = {tensor[0].item()}")
3645
dist.destroy_process_group()
3746
return True
38-
47+
3948
except Exception as e:
4049
signal.alarm(0)
41-
print(f'✗ Rank {rank}: {e}')
50+
print(f"✗ Rank {rank}: {e}")
4251
return False
4352

53+
4454
def main():
4555
num_gpus = torch.cuda.device_count()
46-
print(f'Testing {num_gpus} GPUs - 4 rounds')
47-
56+
print(f"Testing {num_gpus} GPUs - 4 rounds")
57+
4858
for round_num in range(4):
49-
print(f'=== ROUND {round_num + 1} ===')
59+
print(f"=== ROUND {round_num + 1} ===")
5060
master_port = 29500 + round_num
51-
52-
mp.set_start_method('spawn', force=True)
53-
61+
62+
mp.set_start_method("spawn", force=True)
63+
5464
# Prepare worker arguments
5565
worker_args = [(rank, num_gpus, master_port) for rank in range(num_gpus)]
56-
66+
5767
with Pool(processes=num_gpus) as pool:
5868
try:
5969
# Use map_async with timeout
6070
result = pool.map_async(test_worker, worker_args)
6171
results = result.get(timeout=60)
62-
72+
6373
# Check if all workers succeeded
6474
if not all(results):
65-
print(f'✗ ROUND {round_num + 1} FAILED')
75+
print(f"✗ ROUND {round_num + 1} FAILED")
6676
sys.exit(1)
67-
77+
6878
except mp.TimeoutError:
69-
print(f'✗ ROUND {round_num + 1} HUNG')
79+
print(f"✗ ROUND {round_num + 1} HUNG")
7080
pool.terminate()
7181
pool.join()
7282
sys.exit(1)
7383
except Exception as e:
74-
print(f'✗ ROUND {round_num + 1} ERROR: {e}')
84+
print(f"✗ ROUND {round_num + 1} ERROR: {e}")
7585
sys.exit(1)
76-
77-
print(f'✓ ROUND {round_num + 1} PASSED')
78-
79-
print('✓ ALL ROUNDS PASSED')
8086

81-
if __name__ == '__main__':
82-
main()
87+
print(f"✓ ROUND {round_num + 1} PASSED")
88+
89+
print("✓ ALL ROUNDS PASSED")
90+
91+
92+
if __name__ == "__main__":
93+
main()

0 commit comments

Comments
 (0)