1- import torch
2- import torch .distributed as dist
3- import torch .multiprocessing as mp
4- from multiprocessing import Pool
51import os
62import signal
73import sys
4+ from multiprocessing import Pool
5+
6+ import torch
7+ import torch .distributed as dist
8+ import torch .multiprocessing as mp
9+
810
911def timeout_handler (signum , frame ):
10- print (' ✗ TIMEOUT: Process hung' )
12+ print (" ✗ TIMEOUT: Process hung" )
1113 sys .exit (1 )
1214
15+
1316def test_worker (args ):
1417 rank , world_size , master_port = args
1518 try :
16- os .environ [' MASTER_ADDR' ] = ' 127.0.0.1'
17- os .environ [' MASTER_PORT' ] = str (master_port )
18- os .environ [' RANK' ] = str (rank )
19- os .environ [' WORLD_SIZE' ] = str (world_size )
20-
19+ os .environ [" MASTER_ADDR" ] = " 127.0.0.1"
20+ os .environ [" MASTER_PORT" ] = str (master_port )
21+ os .environ [" RANK" ] = str (rank )
22+ os .environ [" WORLD_SIZE" ] = str (world_size )
23+
2124 signal .signal (signal .SIGALRM , timeout_handler )
2225 signal .alarm (30 )
23-
24- print (f'Rank { rank } : Init NCCL...' )
25- dist .init_process_group ("nccl" , init_method = "env://" , rank = rank , world_size = world_size , device_id = torch .device (f'cuda:{ rank } ' ))
26+
27+ print (f"Rank { rank } : Init NCCL..." )
28+ dist .init_process_group (
29+ "nccl" ,
30+ init_method = "env://" ,
31+ rank = rank ,
32+ world_size = world_size ,
33+ device_id = torch .device (f"cuda:{ rank } " ),
34+ )
2635 signal .alarm (0 )
27-
28- device = torch .device (f' cuda:{ rank } ' )
36+
37+ device = torch .device (f" cuda:{ rank } " )
2938 tensor = torch .ones (100 , device = device ) * rank
30-
39+
3140 signal .alarm (15 )
3241 dist .all_reduce (tensor )
3342 signal .alarm (0 )
34-
35- print (f' ✓ Rank { rank } : sum = { tensor [0 ].item ()} ' )
43+
44+ print (f" ✓ Rank { rank } : sum = { tensor [0 ].item ()} " )
3645 dist .destroy_process_group ()
3746 return True
38-
47+
3948 except Exception as e :
4049 signal .alarm (0 )
41- print (f' ✗ Rank { rank } : { e } ' )
50+ print (f" ✗ Rank { rank } : { e } " )
4251 return False
4352
53+
4454def main ():
4555 num_gpus = torch .cuda .device_count ()
46- print (f' Testing { num_gpus } GPUs - 4 rounds' )
47-
56+ print (f" Testing { num_gpus } GPUs - 4 rounds" )
57+
4858 for round_num in range (4 ):
49- print (f' === ROUND { round_num + 1 } ===' )
59+ print (f" === ROUND { round_num + 1 } ===" )
5060 master_port = 29500 + round_num
51-
52- mp .set_start_method (' spawn' , force = True )
53-
61+
62+ mp .set_start_method (" spawn" , force = True )
63+
5464 # Prepare worker arguments
5565 worker_args = [(rank , num_gpus , master_port ) for rank in range (num_gpus )]
56-
66+
5767 with Pool (processes = num_gpus ) as pool :
5868 try :
5969 # Use map_async with timeout
6070 result = pool .map_async (test_worker , worker_args )
6171 results = result .get (timeout = 60 )
62-
72+
6373 # Check if all workers succeeded
6474 if not all (results ):
65- print (f' ✗ ROUND { round_num + 1 } FAILED' )
75+ print (f" ✗ ROUND { round_num + 1 } FAILED" )
6676 sys .exit (1 )
67-
77+
6878 except mp .TimeoutError :
69- print (f' ✗ ROUND { round_num + 1 } HUNG' )
79+ print (f" ✗ ROUND { round_num + 1 } HUNG" )
7080 pool .terminate ()
7181 pool .join ()
7282 sys .exit (1 )
7383 except Exception as e :
74- print (f' ✗ ROUND { round_num + 1 } ERROR: { e } ' )
84+ print (f" ✗ ROUND { round_num + 1 } ERROR: { e } " )
7585 sys .exit (1 )
76-
77- print (f'✓ ROUND { round_num + 1 } PASSED' )
78-
79- print ('✓ ALL ROUNDS PASSED' )
8086
81- if __name__ == '__main__' :
82- main ()
87+ print (f"✓ ROUND { round_num + 1 } PASSED" )
88+
89+ print ("✓ ALL ROUNDS PASSED" )
90+
91+
92+ if __name__ == "__main__" :
93+ main ()
0 commit comments