11import base64
2+ import copy
23import dataclasses
34import multiprocessing
45import re
@@ -65,7 +66,7 @@ def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]:
6566
6667 tests = []
6768 lines = content .splitlines ()
68- match = r"\s*([a-zA-Z ]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*"
69+ match = r"\s*([a-zA-Z_ ]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*"
6970 for line in lines :
7071 parts = line .split (";" )
7172 case = {}
@@ -123,18 +124,19 @@ def calculate_stats(durations: list[int]):
123124 worst = float (worst ))
124125
125126
126- def _clone_data (data ):
127+ def _clone_data (data , rank : int ):
127128 """
128129 Recursively goes through data and clones all tensors.
129130 """
130131 if isinstance (data , tuple ):
131- return tuple (_clone_data (x ) for x in data )
132+ return tuple (_clone_data (x , rank ) for x in data )
132133 elif isinstance (data , list ):
133- return [_clone_data (x ) for x in data ]
134+ return [_clone_data (x , rank ) for x in data ]
134135 elif isinstance (data , dict ):
135- return {k : _clone_data (v ) for k , v in data .items ()}
136+ return {k : _clone_data (v , rank ) for k , v in data .items ()}
136137 elif isinstance (data , torch .Tensor ):
137- return data .clone ()
138+ device = f"cuda:{ rank } "
139+ return data .clone ().to (device )
138140 else :
139141 return data
140142
@@ -157,16 +159,60 @@ def _run_single_test(test: TestCase):
157159 from submission import custom_kernel
158160 data = generate_input (** test .args )
159161 torch .cuda .synchronize ()
160- submission_output = custom_kernel (_clone_data (data ))
162+ submission_output = custom_kernel (_clone_data (data , 0 ))
161163 torch .cuda .synchronize ()
162164 return wrap_check_implementation (data , submission_output )
163165
164166
167+ def _run_distributed_test (test : TestCase , rank : int ):
168+ """
169+ Runs a single test case. Do not call directly
170+ """
171+ from submission import custom_kernel
172+ import torch .distributed as dist
173+ world_size = test .args ["world_size" ]
174+ os .environ ["MASTER_ADDR" ] = "127.0.0.1"
175+ os .environ ["MASTER_PORT" ] = "12356"
176+ dist .init_process_group ("nccl" , init_method = "env://" , rank = rank , world_size = world_size )
177+ try :
178+ data = generate_input (** test .args , rank = rank )
179+ torch .cuda .synchronize ()
180+ submission_output = custom_kernel (_clone_data (data , rank ))
181+ torch .cuda .synchronize ()
182+ return wrap_check_implementation (data , submission_output )
183+ finally :
184+ dist .destroy_process_group ()
185+
186+
187+ def run_multi_gpu_test (pool : multiprocessing .Pool , test : TestCase , world_size : int ):
188+ """
189+ Runs a single test in another process.
190+ """
191+ rets = []
192+ # world_size is a mandatory argument for multi-gpu tests
193+ for i in range (world_size ):
194+ rets .append (
195+ pool .apply_async (
196+ _run_distributed_test ,
197+ args = (test , i ),
198+ )
199+ )
200+ rets = [el .get () for el in rets ]
201+
202+ correct = all (ret [0 ] for ret in rets )
203+ error_messages = str .join ("\n " , [f"rank { rank } : { ret [1 ]} " for rank , ret in enumerate (rets ) if not ret [0 ]])
204+ return correct , error_messages
205+
206+
165207def run_single_test (pool : multiprocessing .Pool , test : TestCase ):
166208 """
167209 Runs a single test in another process.
168210 """
169- return pool .apply (_run_single_test , (test ,))
211+ world_size = test .args .get ("world_size" , None )
212+ if world_size is None :
213+ return pool .apply (_run_single_test , (test , 0 , 0 ))
214+ else :
215+ return run_multi_gpu_test (pool , test , world_size )
170216
171217
172218def run_testing (logger : PopcornOutput , pool : multiprocessing .Pool , tests : list [TestCase ]):
@@ -345,14 +391,15 @@ def main():
345391 mode = sys .argv [1 ]
346392 seed = os .getenv ("POPCORN_SEED" )
347393 os .unsetenv ("POPCORN_SEED" )
394+ n_gpus = int (os .getenv ("POPCORN_GPUS" , "1" ))
348395 seed = int (seed ) if seed else None
349396 set_seed (seed or 42 )
350397 tests = get_test_cases (sys .argv [2 ], seed )
351398
352399 with PopcornOutput (int (fd )) as logger :
353400 import multiprocessing
354401 mp_context = multiprocessing .get_context ('spawn' )
355- with mp_context .Pool (1 ) as pool :
402+ with mp_context .Pool (n_gpus ) as pool :
356403 if mode == "test" :
357404 return run_testing (logger , pool , tests )
358405 if mode == "benchmark" :
0 commit comments