1- import os
2- import signal
3- import sys
4-
51import torch
62import torch .distributed as dist
73import torch .multiprocessing as mp
8-
4+ from multiprocessing import Pool
5+ import os
6+ import signal
7+ import sys
98
109def timeout_handler (signum , frame ):
11- print (" ✗ TIMEOUT: Process hung" )
10+ print (' ✗ TIMEOUT: Process hung' )
1211 sys .exit (1 )
1312
14-
15- def test_worker ( rank , world_size , master_port ):
13+ def test_worker ( args ):
14+ rank , world_size , master_port = args
1615 try :
17- os .environ [" MASTER_ADDR" ] = " 127.0.0.1"
18- os .environ [" MASTER_PORT" ] = str (master_port )
19- os .environ [" RANK" ] = str (rank )
20- os .environ [" WORLD_SIZE" ] = str (world_size )
21-
16+ os .environ [' MASTER_ADDR' ] = ' 127.0.0.1'
17+ os .environ [' MASTER_PORT' ] = str (master_port )
18+ os .environ [' RANK' ] = str (rank )
19+ os .environ [' WORLD_SIZE' ] = str (world_size )
20+
2221 signal .signal (signal .SIGALRM , timeout_handler )
2322 signal .alarm (30 )
24-
25- print (f" Rank { rank } : Init NCCL..." )
26- dist .init_process_group ("nccl" , rank = rank , world_size = world_size )
23+
24+ print (f' Rank { rank } : Init NCCL...' )
25+ dist .init_process_group ("nccl" , init_method = "env://" , rank = rank , world_size = world_size , device_id = torch . device ( f'cuda: { rank } ' ) )
2726 signal .alarm (0 )
28-
29- device = torch .device (f" cuda:{ rank } " )
27+
28+ device = torch .device (f' cuda:{ rank } ' )
3029 tensor = torch .ones (100 , device = device ) * rank
31-
30+
3231 signal .alarm (15 )
3332 dist .all_reduce (tensor )
3433 signal .alarm (0 )
35-
36- print (f" ✓ Rank { rank } : sum = { tensor [0 ].item ()} " )
34+
35+ print (f' ✓ Rank { rank } : sum = { tensor [0 ].item ()} ' )
3736 dist .destroy_process_group ()
38-
37+ return True
38+
3939 except Exception as e :
4040 signal .alarm (0 )
41- print (f"✗ Rank { rank } : { e } " )
42- sys .exit (1 )
43-
41+ print (f'✗ Rank { rank } : { e } ' )
42+ return False
4443
4544def main ():
4645 num_gpus = torch .cuda .device_count ()
47- print (f" Testing { num_gpus } GPUs - 4 rounds" )
48-
46+ print (f' Testing { num_gpus } GPUs - 4 rounds' )
47+
4948 for round_num in range (4 ):
50- print (f" === ROUND { round_num + 1 } ===" )
49+ print (f' === ROUND { round_num + 1 } ===' )
5150 master_port = 29500 + round_num
52-
53- mp .set_start_method ("spawn" , force = True )
54- processes = []
55-
56- for rank in range (num_gpus ):
57- p = mp .Process (target = test_worker , args = (rank , num_gpus , master_port ))
58- p .start ()
59- processes .append (p )
60-
61- for _ , p in enumerate (processes ):
62- p .join (timeout = 60 )
63- if p .exitcode != 0 :
64- print (f"✗ ROUND { round_num + 1 } FAILED" )
65- for rp in processes :
66- if rp .is_alive ():
67- rp .terminate ()
51+
52+ mp .set_start_method ('spawn' , force = True )
53+
54+ # Prepare worker arguments
55+ worker_args = [(rank , num_gpus , master_port ) for rank in range (num_gpus )]
56+
57+ with Pool (processes = num_gpus ) as pool :
58+ try :
59+ # Use map_async with timeout
60+ result = pool .map_async (test_worker , worker_args )
61+ results = result .get (timeout = 60 )
62+
63+ # Check if all workers succeeded
64+ if not all (results ):
65+ print (f'✗ ROUND { round_num + 1 } FAILED' )
66+ sys .exit (1 )
67+
68+ except mp .TimeoutError :
69+ print (f'✗ ROUND { round_num + 1 } HUNG' )
70+ pool .terminate ()
71+ pool .join ()
6872 sys .exit (1 )
69- elif p .is_alive ():
70- print (f"✗ ROUND { round_num + 1 } HUNG" )
71- p .terminate ()
73+ except Exception as e :
74+ print (f'✗ ROUND { round_num + 1 } ERROR: { e } ' )
7275 sys .exit (1 )
76+
77+ print (f'✓ ROUND { round_num + 1 } PASSED' )
78+
79+ print ('✓ ALL ROUNDS PASSED' )
7380
74- print (f"✓ ROUND { round_num + 1 } PASSED" )
75-
76- print ("✓ ALL ROUNDS PASSED" )
77-
78-
79- if __name__ == "__main__" :
80- main ()
81+ if __name__ == '__main__' :
82+ main ()
0 commit comments