diff --git a/graph_net/agent/parallel_extract.py b/graph_net/agent/parallel_extract.py index 2c6b9fc93..42eea8bd0 100644 --- a/graph_net/agent/parallel_extract.py +++ b/graph_net/agent/parallel_extract.py @@ -383,13 +383,13 @@ def _parse_args() -> argparse.Namespace: "--gpus", type=str, default=None, - help="Comma-separated GPU indices to use (GPU mode; if set, ignores --num-workers)", + help="Comma-separated GPU indices to use (default: auto-detect all available GPUs)", ) parser.add_argument( - "--num-workers", + "--cpu-workers", type=int, default=None, - help="Number of worker processes in CPU mode (default: CPU count)", + help="Number of worker processes in CPU-only mode (default: half of CPU cores)", ) parser.add_argument( "--output", @@ -435,22 +435,29 @@ def _resolve_config(args: argparse.Namespace): ) print(f"[INFO] Workspace: {workspace}") - if get_device_type() == "cuda": + # Decide GPU vs CPU mode: if --cpu-workers is set, force CPU-only mode. + # If no CUDA available, also fall back to CPU mode. + if (args.cpu_workers and args.cpu_workers > 0) or get_device_type() == "cpu": + # CPU-only mode. Default to half of CPU cores to avoid overloading the + # system, since each worker is a heavy process (model loading + graph + # extraction). + gpus = [] + num_workers = ( + args.cpu_workers if args.cpu_workers else max(1, (os.cpu_count() or 2) // 2) + ) + print(f"[INFO] CPU-only mode: {num_workers} workers") + extract_timeout = ( + args.extract_timeout if args.extract_timeout is not None else 2000 + ) + verify_timeout = args.verify_timeout if args.verify_timeout is not None else 600 + else: gpus = get_gpu_ids(args) num_workers = len(gpus) - print(f"[INFO] GPU mode (torch fallback): {gpus}") + print(f"[INFO] GPU mode: {num_workers} workers on GPUs {gpus}") extract_timeout = ( args.extract_timeout if args.extract_timeout is not None else 1000 ) verify_timeout = args.verify_timeout if args.verify_timeout is not None else 300 - else: - gpus = [] - num_workers = args.num_workers if args.num_workers else 1 - print(f"[INFO] CPU mode: {num_workers} workers") - extract_timeout = ( - args.extract_timeout if args.extract_timeout is not None else 2000 - ) - verify_timeout = args.verify_timeout if args.verify_timeout is not None else 600 return workspace, gpus, num_workers, extract_timeout, verify_timeout @@ -468,7 +475,10 @@ def main() -> int: print("[ERROR] Empty model list, nothing to do") return 1 - print(f"[INFO] Total models: {len(model_ids)}, workers: {num_workers}") + if gpus: + print(f"[INFO] Total models: {len(model_ids)}, GPU workers: {num_workers}") + else: + print(f"[INFO] Total models: {len(model_ids)}, CPU workers: {num_workers}") # --- Populate shared task queue --- task_queue: multiprocessing.Queue = multiprocessing.Queue() @@ -479,8 +489,9 @@ def main() -> int: result_queue: multiprocessing.Queue = multiprocessing.Queue() start_time = datetime.now() + worker_type = "GPU" if gpus else "CPU" print( - f"\n[START] {start_time.strftime('%Y-%m-%d %H:%M:%S')} — launching {num_workers} workers\n" + f"\n[START] {start_time.strftime('%Y-%m-%d %H:%M:%S')} — launching {num_workers} {worker_type} workers\n" ) processes = []