Skip to content

Commit 9629cfe

Browse files
author
Mark Saroufim
committed
update
1 parent 825decb commit 9629cfe

1 file changed

Lines changed: 56 additions & 54 deletions

File tree

scripts/test_distributed.py

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,82 @@
1-
import os
2-
import signal
3-
import sys
4-
51
import torch
62
import torch.distributed as dist
73
import torch.multiprocessing as mp
8-
4+
from multiprocessing import Pool
5+
import os
6+
import signal
7+
import sys
98

109
def timeout_handler(signum, frame):
11-
print("✗ TIMEOUT: Process hung")
10+
print('✗ TIMEOUT: Process hung')
1211
sys.exit(1)
1312

14-
15-
def test_worker(rank, world_size, master_port):
13+
def test_worker(args):
14+
rank, world_size, master_port = args
1615
try:
17-
os.environ["MASTER_ADDR"] = "127.0.0.1"
18-
os.environ["MASTER_PORT"] = str(master_port)
19-
os.environ["RANK"] = str(rank)
20-
os.environ["WORLD_SIZE"] = str(world_size)
21-
16+
os.environ['MASTER_ADDR'] = '127.0.0.1'
17+
os.environ['MASTER_PORT'] = str(master_port)
18+
os.environ['RANK'] = str(rank)
19+
os.environ['WORLD_SIZE'] = str(world_size)
20+
2221
signal.signal(signal.SIGALRM, timeout_handler)
2322
signal.alarm(30)
24-
25-
print(f"Rank {rank}: Init NCCL...")
26-
dist.init_process_group("nccl", rank=rank, world_size=world_size)
23+
24+
print(f'Rank {rank}: Init NCCL...')
25+
dist.init_process_group("nccl", init_method="env://", rank=rank, world_size=world_size, device_id=torch.device(f'cuda:{rank}'))
2726
signal.alarm(0)
28-
29-
device = torch.device(f"cuda:{rank}")
27+
28+
device = torch.device(f'cuda:{rank}')
3029
tensor = torch.ones(100, device=device) * rank
31-
30+
3231
signal.alarm(15)
3332
dist.all_reduce(tensor)
3433
signal.alarm(0)
35-
36-
print(f"✓ Rank {rank}: sum = {tensor[0].item()}")
34+
35+
print(f'✓ Rank {rank}: sum = {tensor[0].item()}')
3736
dist.destroy_process_group()
38-
37+
return True
38+
3939
except Exception as e:
4040
signal.alarm(0)
41-
print(f"✗ Rank {rank}: {e}")
42-
sys.exit(1)
43-
41+
print(f'✗ Rank {rank}: {e}')
42+
return False
4443

4544
def main():
4645
num_gpus = torch.cuda.device_count()
47-
print(f"Testing {num_gpus} GPUs - 4 rounds")
48-
46+
print(f'Testing {num_gpus} GPUs - 4 rounds')
47+
4948
for round_num in range(4):
50-
print(f"=== ROUND {round_num + 1} ===")
49+
print(f'=== ROUND {round_num + 1} ===')
5150
master_port = 29500 + round_num
52-
53-
mp.set_start_method("spawn", force=True)
54-
processes = []
55-
56-
for rank in range(num_gpus):
57-
p = mp.Process(target=test_worker, args=(rank, num_gpus, master_port))
58-
p.start()
59-
processes.append(p)
60-
61-
for _, p in enumerate(processes):
62-
p.join(timeout=60)
63-
if p.exitcode != 0:
64-
print(f"✗ ROUND {round_num + 1} FAILED")
65-
for rp in processes:
66-
if rp.is_alive():
67-
rp.terminate()
51+
52+
mp.set_start_method('spawn', force=True)
53+
54+
# Prepare worker arguments
55+
worker_args = [(rank, num_gpus, master_port) for rank in range(num_gpus)]
56+
57+
with Pool(processes=num_gpus) as pool:
58+
try:
59+
# Use map_async with timeout
60+
result = pool.map_async(test_worker, worker_args)
61+
results = result.get(timeout=60)
62+
63+
# Check if all workers succeeded
64+
if not all(results):
65+
print(f'✗ ROUND {round_num + 1} FAILED')
66+
sys.exit(1)
67+
68+
except mp.TimeoutError:
69+
print(f'✗ ROUND {round_num + 1} HUNG')
70+
pool.terminate()
71+
pool.join()
6872
sys.exit(1)
69-
elif p.is_alive():
70-
print(f"✗ ROUND {round_num + 1} HUNG")
71-
p.terminate()
73+
except Exception as e:
74+
print(f'✗ ROUND {round_num + 1} ERROR: {e}')
7275
sys.exit(1)
76+
77+
print(f'✓ ROUND {round_num + 1} PASSED')
78+
79+
print('✓ ALL ROUNDS PASSED')
7380

74-
print(f"✓ ROUND {round_num + 1} PASSED")
75-
76-
print("✓ ALL ROUNDS PASSED")
77-
78-
79-
if __name__ == "__main__":
80-
main()
81+
if __name__ == '__main__':
82+
main()

0 commit comments

Comments
 (0)