diff --git a/.github/scripts/run_new_examples.sh b/.github/scripts/run_new_examples.sh index 076e7de46..b54724d82 100755 --- a/.github/scripts/run_new_examples.sh +++ b/.github/scripts/run_new_examples.sh @@ -56,7 +56,7 @@ EXIT_CODE=0 fi fi echo \"Running: \$example_file with $NUM_RANKS ranks\" - torchrun --nproc_per_node=$NUM_RANKS --standalone \"\$example_file\" + torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:0 --nnodes=1 --nproc_per_node=$NUM_RANKS \"\$example_file\" fi done " || { EXIT_CODE=$?; } diff --git a/.github/scripts/run_perf_benchmark.sh b/.github/scripts/run_perf_benchmark.sh index 85e580a86..26abd6254 100755 --- a/.github/scripts/run_perf_benchmark.sh +++ b/.github/scripts/run_perf_benchmark.sh @@ -30,7 +30,7 @@ echo "[PERF-BENCHMARK] Using GPUs: $GPU_DEVICES" cd /iris_workspace pip install -e . - torchrun --nproc_per_node=8 examples/${EXAMPLE_PATH}/benchmark.py \ + torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:0 --nnodes=1 --nproc_per_node=8 examples/${EXAMPLE_PATH}/benchmark.py \ --benchmark \ --validate \ ${BENCHMARK_ARGS} \ diff --git a/.github/scripts/run_tests.sh b/.github/scripts/run_tests.sh index bbcba6585..c126df7ca 100755 --- a/.github/scripts/run_tests.sh +++ b/.github/scripts/run_tests.sh @@ -82,7 +82,7 @@ EXIT_CODE=0 for test_file in tests/$TEST_DIR/test_*.py; do if [ -f \"\$test_file\" ]; then echo \"Testing: \$test_file with $NUM_RANKS ranks (install: $INSTALL_METHOD)\" - torchrun --nproc_per_node=$NUM_RANKS --standalone tests/run_tests_distributed.py \"\$test_file\" -v --tb=short --durations=10 + torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:0 --nnodes=1 --nproc_per_node=$NUM_RANKS tests/run_tests_distributed.py \"\$test_file\" -v --tb=short --durations=10 fi done " || { EXIT_CODE=$?; } diff --git a/.github/workflows/iris-external-validation-test.yml b/.github/workflows/iris-external-validation-test.yml index 1330e8d3c..b6af5b1c7 100644 --- a/.github/workflows/iris-external-validation-test.yml +++ b/.github/workflows/iris-external-validation-test.yml @@ -54,7 +54,7 @@ jobs: cd /iris_workspace pip install git+https://github.com/${{ github.repository }}.git@${{ github.sha }} wget -O test_iris_distributed.py https://gist.githubusercontent.com/mawad-amd/6375dc078e39e256828f379e03310ec7/raw/0827d023eaf8e9755b17cbe8ab06f2ce258e746a/test_iris_distributed.py - torchrun --nproc_per_node=2 test_iris_distributed.py + torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:0 --nnodes=1 --nproc_per_node=2 test_iris_distributed.py " echo "::endgroup::" @@ -103,7 +103,7 @@ jobs: cd /iris_workspace pip install git+https://github.com/${{ github.repository }}.git@${{ github.sha }} wget -O test_iris_gluon_distributed.py https://gist.githubusercontent.com/mawad-amd/2666dde8ebe2755eb0c4f2108709fcd5/raw/c5544943e2832c75252160bd9084600bf01a6b06/test_iris_gluon_distributed.py - torchrun --nproc_per_node=2 test_iris_gluon_distributed.py + torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:0 --nnodes=1 --nproc_per_node=2 test_iris_gluon_distributed.py " echo "::endgroup::" diff --git a/.github/workflows/iris-performance-regression-test.yml b/.github/workflows/iris-performance-regression-test.yml index ebde87df3..fc7081b80 100644 --- a/.github/workflows/iris-performance-regression-test.yml +++ b/.github/workflows/iris-performance-regression-test.yml @@ -24,11 +24,10 @@ jobs: matrix: # Performance baselines measured on AMD Instinct MI325X (8 GPUs) include: - # Disabled https://github.com/ROCm/iris/issues/238 - #- example_name: "GEMM All-Scatter WG Specialization" - # example_path: "10_gemm_all_scatter_wg_specialization" - # tflops_threshold: 1600 # Actual: ~2182 TFLOPs - # benchmark_args: "-m 16384 -n 16384 -k 16384 --BLK_M 128 --BLK_N 128 --BLK_K 64 --gsize_m 6 --gemm_sms 256" + - example_name: "GEMM All-Scatter WG Specialization" + example_path: "10_gemm_all_scatter_wg_specialization" + tflops_threshold: 1440 # Actual: ~1802 TFLOPs (80% regression threshold) + benchmark_args: "-m 16384 -n 16384 -k 16384 --BLK_M 128 --BLK_N 128 --BLK_K 64 --gsize_m 6 --gemm_sms 256" - example_name: "GEMM All-Scatter" example_path: "07_gemm_all_scatter" diff --git a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py index 655c892f3..910ebdd6f 100755 --- a/examples/10_gemm_all_scatter_wg_specialization/benchmark.py +++ b/examples/10_gemm_all_scatter_wg_specialization/benchmark.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +import os + import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -132,7 +134,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) total_tiles = total_blocks_M * total_blocks_N - locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8) + locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32) bias = None @@ -153,13 +155,18 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): # Allocate Timestamps timestamps = Timestamps(num_tiles=total_tiles) + def preamble(): + # Barrier 1: ensure all ranks finish previous iteration before clearing locks + shmem.barrier() + locks.zero_() + # Barrier 2: ensure all ranks see zeroed locks before any rank starts the kernel + shmem.barrier() + def run_experiment(): nonlocal local_C nonlocal global_C nonlocal kernel_timing - shmem.barrier() - if args["trace_tiles"]: timestamps.reset() shmem.barrier() @@ -215,6 +222,16 @@ def run_experiment(): kernel_timing[k]["experiments"] = 0 if args["validate"]: + # Run a dedicated validation kernel to ensure all cross-GPU writes are fully + # propagated before checking results. The warmup above may leave some + # iris.put stores in-flight on the xGMI interconnect; the extra + # preamble + run + barrier cycle guarantees all ranks have flushed their + # GPU caches and that rank-0 sees every scattered tile before we call + # validate_gemm. + preamble() + run_experiment() + shmem.barrier() + shmem.info("Validating...") matmul.set_debug(True) # Validate global result @@ -241,7 +258,7 @@ def run_experiment(): matmul.set_debug(False) shmem.info("Benchmarking...") perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) - triton_ms = iris.do_bench(run_experiment, shmem.barrier) + triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble) triton_tflops = perf(triton_ms) algo_string = "all_scatter" shmem.info( @@ -275,15 +292,24 @@ def run_experiment(): def main(): args = parse_args() - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) + # Check if running with torchrun (detected by environment variables) + if "RANK" in os.environ and "LOCAL_RANK" in os.environ: + # torchrun handles process spawning, so call _worker directly + print("Detected torchrun execution mode") + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + init_url = os.environ.get("MASTER_ADDR", "127.0.0.1") + ":" + os.environ.get("MASTER_PORT", "29500") + _worker(rank, world_size, f"tcp://{init_url}", args) + else: + # Use multiprocessing spawn for backward compatibility + num_ranks = args["num_ranks"] + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) if __name__ == "__main__": diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index 4d9c28255..643e84f90 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -140,8 +140,7 @@ def persistent_gemm_all_scatter_wg_specialization( tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) tl.store(c_global + global_offset, c, mask=sub_mask, cache_modifier=".wt") - tl.debug_barrier() - tl.store(locks + tile_id, 1, cache_modifier=".wt") + tl.atomic_xchg(locks + tile_id, 1, sem="release", scope="gpu") else: # pid >= GEMM_SMS COMM_SMS = NUM_SMS - GEMM_SMS @@ -163,8 +162,11 @@ def persistent_gemm_all_scatter_wg_specialization( global_offset = rm[:, None] * stride_cm_global + (rn[None, :] + cur_rank * N) * stride_cn_global # End: masks/offset calculations. + # Spin-wait: first check with a cheap volatile load, then acquire-CAS to + # ensure memory ordering once the lock is observed set. while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1: pass + tl.atomic_cas(locks + tile_id, 1, 1, sem="acquire", scope="gpu") for remote_rank in range(world_size): if remote_rank != cur_rank: diff --git a/examples/25_ccl_all_gather/example.py b/examples/25_ccl_all_gather/example.py index 2c15f59e1..18266d905 100644 --- a/examples/25_ccl_all_gather/example.py +++ b/examples/25_ccl_all_gather/example.py @@ -18,6 +18,7 @@ import torch.distributed as dist import iris +from iris.ccl import Config def parse_args(): @@ -30,6 +31,12 @@ def parse_args(): parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference") + parser.add_argument("--block_size_m", type=int, default=32, help="Block size for M dimension tiling") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension tiling") + parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for all-gather kernel") + parser.add_argument("--num_stages", type=int, default=1, help="Number of stages") + parser.add_argument("--num_warps", type=int, default=4, help="Number of warps") + parser.add_argument("--waves_per_eu", type=int, default=0, help="Number of waves per EU") return vars(parser.parse_args()) @@ -53,8 +60,18 @@ def main(): input_tensor.fill_(float(rank + 1)) output_tensor = ctx.zeros((world_size * M, N), dtype=dtype) + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "comm_sms": args["comm_sms"], + "num_stages": args["num_stages"], + "num_warps": args["num_warps"], + "waves_per_eu": args["waves_per_eu"], + } + config = Config(**config_kwargs) + ctx.barrier() - ctx.ccl.all_gather(output_tensor, input_tensor) + ctx.ccl.all_gather(output_tensor, input_tensor, config=config) torch.cuda.synchronize() if rank == 0: diff --git a/examples/26_ccl_all_to_all/example.py b/examples/26_ccl_all_to_all/example.py index d24fbd909..f55bd7403 100644 --- a/examples/26_ccl_all_to_all/example.py +++ b/examples/26_ccl_all_to_all/example.py @@ -18,6 +18,7 @@ import torch.distributed as dist import iris +from iris.ccl import Config def parse_args(): @@ -28,6 +29,12 @@ def parse_args(): parser.add_argument("-m", type=int, default=512, help="Number of rows") parser.add_argument("-n", type=int, default=128, help="Number of columns per rank slice") parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") + parser.add_argument("--block_size_m", type=int, default=32, help="Block size for M dimension tiling") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension tiling") + parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for all-to-all kernel") + parser.add_argument("--num_stages", type=int, default=1, help="Number of stages") + parser.add_argument("--num_warps", type=int, default=4, help="Number of warps") + parser.add_argument("--waves_per_eu", type=int, default=0, help="Number of waves per EU") parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference") return vars(parser.parse_args()) @@ -54,8 +61,18 @@ def main(): input_tensor[:, target_rank * N : (target_rank + 1) * N] = float(rank * 10 + target_rank + 1) output_tensor = ctx.zeros((M, N * world_size), dtype=dtype) + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "comm_sms": args["comm_sms"], + "num_stages": args["num_stages"], + "num_warps": args["num_warps"], + "waves_per_eu": args["waves_per_eu"], + } + config = Config(**config_kwargs) + ctx.barrier() - ctx.ccl.all_to_all(output_tensor, input_tensor) + ctx.ccl.all_to_all(output_tensor, input_tensor, config=config) torch.cuda.synchronize() if rank == 0: diff --git a/examples/27_ccl_reduce_scatter/example.py b/examples/27_ccl_reduce_scatter/example.py new file mode 100644 index 000000000..1a5a80975 --- /dev/null +++ b/examples/27_ccl_reduce_scatter/example.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Example: iris.ccl.reduce_scatter + +Each rank has input (M, N); each rank reduces its assigned tiles from all ranks +and stores the result only to its own output (same shape (M, N)). + +Run with: + torchrun --nproc_per_node= --standalone example.py [--validate] +""" + +import argparse +import os + +import torch +import torch.distributed as dist + +import iris +from iris.ccl import Config + + +def parse_args(): + parser = argparse.ArgumentParser( + description="CCL reduce-scatter example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=1024, help="Number of rows") + parser.add_argument("-n", type=int, default=512, help="Number of columns") + parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") + parser.add_argument("--block_size_m", type=int, default=32, help="Block size for M dimension tiling") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension tiling") + parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for reduce-scatter kernel") + parser.add_argument("--num_stages", type=int, default=1, help="Number of stages") + parser.add_argument("--num_warps", type=int, default=4, help="Number of warps") + parser.add_argument("--waves_per_eu", type=int, default=0, help="Number of waves per EU") + parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") + parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference") + return vars(parser.parse_args()) + + +def main(): + args = parse_args() + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="gloo") + + ctx = iris.iris(heap_size=args["heap_size"]) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + dtype = dtype_map[args["datatype"]] + M, N = args["m"], args["n"] + + # Each rank fills its input with (rank + 1) + input_tensor = ctx.zeros((M, N), dtype=dtype) + input_tensor.fill_(float(rank + 1)) + output_tensor = ctx.zeros((M, N), dtype=dtype) + + config_kwargs = { + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "comm_sms": args["comm_sms"], + "num_stages": args["num_stages"], + "num_warps": args["num_warps"], + "waves_per_eu": args["waves_per_eu"], + "all_reduce_distribution": 1, + } + config = Config(**config_kwargs) + + ctx.barrier() + ctx.ccl.reduce_scatter(output_tensor, input_tensor, config=config) + torch.cuda.synchronize() + + if rank == 0: + ctx.info(f"reduce_scatter: world_size={world_size}, shape=({M},{N}), dtype={dtype}") + + if args["validate"]: + # Reference: gather all inputs, sum, then each rank checks its assigned tiles + ref_list = [torch.empty(M, N, dtype=dtype, device=input_tensor.device) for _ in range(world_size)] + dist.all_gather(ref_list, input_tensor) + full_reduced = sum(ref_list).float() + + block_size_m = args["block_size_m"] + block_size_n = args["block_size_n"] + num_pid_m = (M + block_size_m - 1) // block_size_m + num_pid_n = (N + block_size_n - 1) // block_size_n + total_tiles = num_pid_m * num_pid_n + tiles_per_rank = (total_tiles + world_size - 1) // world_size + start_tile = rank * tiles_per_rank + + # Build mask of (i,j) belonging to this rank's tiles (block distribution) + pid_m = torch.arange(M, device=output_tensor.device) // block_size_m + pid_n = torch.arange(N, device=output_tensor.device) // block_size_n + tile_id = pid_m[:, None] * num_pid_n + pid_n[None, :] + mask = (tile_id >= start_tile) & (tile_id < start_tile + tiles_per_rank) + + out_float = output_tensor.float() + expected_where = full_reduced[mask] + actual_where = out_float[mask] + assert torch.allclose(actual_where, expected_where, atol=0.6), f"Rank {rank}: output mismatch on assigned tiles" + if rank == 0: + ctx.info("Validation passed: output matches reference") + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/iris/_distributed_helpers.py b/iris/_distributed_helpers.py index 9b222375e..ac237cb5e 100644 --- a/iris/_distributed_helpers.py +++ b/iris/_distributed_helpers.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + import torch import torch.distributed as dist import numpy as np +import triton +import triton.language as tl def _infer_device(): @@ -207,6 +210,76 @@ def distributed_broadcast_tensor(value_to_broadcast=None, root=0): return obj[0] +def extract_group_info(group, rank, num_ranks): + """ + Extract rank and stride information for a process group. + + Args: + group: ProcessGroup or None. If None, uses the provided rank/num_ranks + as the default (all-ranks) group. + rank: Global rank of the current process. + num_ranks: Total number of ranks in the default group. + + Returns: + Tuple of (rank_in_group, rank_global, world_size, rank_start, rank_stride): + - rank_in_group: Rank within the group (0-indexed) + - rank_global: Global rank of this process + - world_size: Number of ranks in the group + - rank_start: Starting global rank of the group + - rank_stride: Stride between consecutive ranks in the group + + Examples: + >>> # group=None: all ranks [0,1,2,3], current global rank is 2 + >>> extract_group_info(None, 2, 4) + (2, 2, 4, 0, 1) + + >>> # DP group: strided ranks [0,4,8,12], current global rank is 8 + >>> extract_group_info(dp_group, 8, 16) + (2, 8, 4, 0, 4) + """ + if group is None: + return rank, rank, num_ranks, 0, 1 + + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed must be initialized to use ProcessGroup. " + "Call torch.distributed.init_process_group() first." + ) + + group_ranks = dist.get_process_group_ranks(group) + world_size = len(group_ranks) + rank_global = rank + + if rank_global not in group_ranks: + raise RuntimeError( + f"Rank {rank_global} is not part of the specified process group. Group contains ranks: {group_ranks}" + ) + + rank_in_group = group_ranks.index(rank_global) + + if len(group_ranks) > 1: + strides = [group_ranks[i] - group_ranks[i - 1] for i in range(1, len(group_ranks))] + if not all(s == strides[0] for s in strides): + raise NotImplementedError( + f"Non-strided process groups are not yet supported. " + f"Group ranks: {group_ranks}. " + f"Please use groups with uniform stride (e.g., [0,1,2,3] or [0,4,8,12])." + ) + rank_start = group_ranks[0] + rank_stride = strides[0] + if rank_stride == 0: + raise ValueError( + f"Invalid process group: rank_stride is 0, indicating duplicate ranks. " + f"Group ranks: {group_ranks}. " + f"Each rank must appear exactly once in a process group." + ) + else: + rank_start = group_ranks[0] + rank_stride = 1 + + return rank_in_group, rank_global, world_size, rank_start, rank_stride + + def distributed_barrier(group=None): """ Synchronization barrier using PyTorch distributed. @@ -220,6 +293,98 @@ def distributed_barrier(group=None): dist.barrier(group=group) +@triton.jit +def _translate_ptr(ptr, from_rank, to_rank, heap_bases): + """Translate a pointer from one rank's address space to another's.""" + from_base = tl.load(heap_bases + from_rank) + to_base = tl.load(heap_bases + to_rank) + offset = tl.cast(ptr, tl.uint64) - from_base + translated_ptr = tl.cast(tl.cast(to_base, tl.pointer_type(tl.int8)) + offset, ptr.dtype) + return translated_ptr + + +@triton.jit +def _device_barrier_kernel( + flags_ptr, + iris_rank, + world_size: tl.constexpr, + rank_start, + rank_stride, + heap_bases, + MAX_SPINS: tl.constexpr = 1_000_000_000, +): + """ + Device-side barrier using atomic operations on the symmetric heap. + CUDA graph capturable. + + Stateless w.r.t. host-side epoch tracking: there is no CPU-side epoch + counter. Each rank's flag on the heap serves as its own epoch counter, + managed entirely by the GPU via atomic_add. A persistent per-group flags + tensor is cached in ``_device_barrier_state``. + + Launched with grid=(1,). A single CTA: + 1. Atomically increments its own flag (atomic_add, release) + 2. Serially polls each remote rank's flag for the same value (acquire) + """ + # Increment own flag and determine target + own_flag_ptr = flags_ptr + iris_rank + own_translated = _translate_ptr(own_flag_ptr, iris_rank, iris_rank, heap_bases) + old = tl.atomic_add(own_translated, 1, sem="release", scope="sys") + target = old + 1 + + # Poll each remote rank serially + for i in range(world_size): + remote_rank = rank_start + i * rank_stride + if remote_rank != iris_rank: + remote_flag_ptr = flags_ptr + remote_rank + remote_translated = _translate_ptr(remote_flag_ptr, iris_rank, remote_rank, heap_bases) + spin_count = 0 + while ( + tl.atomic_cas( + remote_translated, + target, + target, + sem="acquire", + scope="sys", + ) + < target + ): + spin_count += 1 + tl.device_assert(spin_count < MAX_SPINS, "device_barrier: timeout") + + +def distributed_device_barrier(flags, group, rank, num_ranks, heap_bases): + """ + Device-side barrier using atomic operations on the symmetric heap. + CUDA graph capturable. + + Unlike ``distributed_barrier`` which uses host-side ``torch.distributed.barrier()``, + this launches a single-CTA Triton kernel that synchronizes via + device-side atomics, making it safe to use during CUDA graph capture. + + Stateless w.r.t. host-side epoch tracking: each rank's flag on the + symmetric heap serves as its own epoch counter, managed entirely by + the GPU via atomic_add. A persistent per-group flags tensor is cached + in ``_device_barrier_state``. + + Args: + flags: int32 tensor on symmetric heap, one element per rank. + group: ProcessGroup or None. If None, uses all ranks. + rank: Global rank of this process. + num_ranks: Total number of ranks in the default group. + heap_bases: Tensor of heap base addresses for all ranks. + """ + _, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, rank, num_ranks) + _device_barrier_kernel[(1,)]( + flags, + rank_global, + world_size, + rank_start, + rank_stride, + heap_bases, + ) + + def init_distributed(): """ Initialize PyTorch distributed and return communicator info. diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index 190c96072..2093fb3bd 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -377,6 +377,9 @@ def all_gather( config.comm_sms, config.num_xcds, config.chunk_size, + num_stages=config.num_stages, + num_warps=config.num_warps, + waves_per_eu=config.waves_per_eu, ) if not async_op: diff --git a/iris/ccl/all_to_all.py b/iris/ccl/all_to_all.py index 9ff16a1bd..332d102b3 100644 --- a/iris/ccl/all_to_all.py +++ b/iris/ccl/all_to_all.py @@ -401,6 +401,9 @@ def all_to_all( config.comm_sms, config.num_xcds, config.chunk_size, + num_stages=config.num_stages, + num_warps=config.num_warps, + waves_per_eu=config.waves_per_eu, ) else: # Use Triton implementation @@ -428,6 +431,9 @@ def all_to_all( config.comm_sms, config.num_xcds, config.chunk_size, + num_stages=config.num_stages, + num_warps=config.num_warps, + waves_per_eu=config.waves_per_eu, ) if not async_op: diff --git a/iris/ccl/config.py b/iris/ccl/config.py index bb84c2ea2..5f1f8b9c3 100644 --- a/iris/ccl/config.py +++ b/iris/ccl/config.py @@ -80,6 +80,9 @@ class Config: all_reduce_num_rings: int = 1 all_reduce_ring_slice_n: int | None = None reduce_scatter_variant: str = "two_shot" + num_stages: int = 1 + num_warps: int = 4 + waves_per_eu: int = 0 def __post_init__(self): """Validate and auto-detect num_xcds if not set.""" diff --git a/iris/ccl/reduce_scatter.py b/iris/ccl/reduce_scatter.py index 2581a1b55..85e915a85 100644 --- a/iris/ccl/reduce_scatter.py +++ b/iris/ccl/reduce_scatter.py @@ -106,11 +106,11 @@ def persistent_reduce_scatter_two_shot( if is_full: start_rank_idx = pid % world_size start_rank_global = rank_start + start_rank_idx * rank_stride - acc = iris.load(base_ptr, iris_rank, start_rank_global, heap_bases).to(acc_dtype) + acc = iris.load(base_ptr, iris_rank, start_rank_global, heap_bases, hint=(1, BLOCK_SIZE_N)).to(acc_dtype) for i in tl.static_range(1, world_size): remote_rank_idx = (start_rank_idx + i) % world_size remote_rank = rank_start + remote_rank_idx * rank_stride - acc += iris.load(base_ptr, iris_rank, remote_rank, heap_bases).to(acc_dtype) + acc += iris.load(base_ptr, iris_rank, remote_rank, heap_bases, hint=(1, BLOCK_SIZE_N)).to(acc_dtype) reduced = acc.to(output_ptr.type.element_ty) @@ -124,11 +124,15 @@ def persistent_reduce_scatter_two_shot( start_rank_idx = pid % world_size start_rank_global = rank_start + start_rank_idx * rank_stride - acc = iris.load(base_ptr, iris_rank, start_rank_global, heap_bases, mask=mask).to(acc_dtype) + acc = iris.load(base_ptr, iris_rank, start_rank_global, heap_bases, mask=mask, hint=(1, BLOCK_SIZE_N)).to( + acc_dtype + ) for i in tl.static_range(1, world_size): remote_rank_idx = (start_rank_idx + i) % world_size remote_rank = rank_start + remote_rank_idx * rank_stride - acc += iris.load(base_ptr, iris_rank, remote_rank, heap_bases, mask=mask).to(acc_dtype) + acc += iris.load(base_ptr, iris_rank, remote_rank, heap_bases, mask=mask, hint=(1, BLOCK_SIZE_N)).to( + acc_dtype + ) reduced = acc.to(output_ptr.type.element_ty) @@ -247,6 +251,9 @@ def reduce_scatter( config.num_xcds, config.chunk_size, distribution, + num_stages=config.num_stages, + num_warps=config.num_warps, + waves_per_eu=config.waves_per_eu, ) if not async_op: diff --git a/iris/ccl/utils.py b/iris/ccl/utils.py index 4f90ac09a..eeff2781a 100644 --- a/iris/ccl/utils.py +++ b/iris/ccl/utils.py @@ -9,6 +9,7 @@ from typing import Tuple import triton import triton.language as tl +from iris._distributed_helpers import extract_group_info as _extract_group_info @triton.jit() @@ -67,83 +68,11 @@ def extract_group_info(group, shmem) -> Tuple[int, int, int, int, int]: Returns: Tuple of (rank_in_group, rank_global, world_size, rank_start, rank_stride) - - rank_in_group: Rank within the group (0-indexed), used for tile assignment and comparisons - - rank_global: Global rank of this process, used for iris RMA operations (heap_bases indexing) + - rank_in_group: Rank within the group (0-indexed) + - rank_global: Global rank of this process - world_size: Number of ranks in the group - rank_start: Starting global rank of the group - rank_stride: Stride between consecutive ranks in the group - - Examples: - >>> # group=None: all ranks [0,1,2,3], current global rank is 2 - >>> extract_group_info(None, shmem) - (2, 2, 4, 0, 1) # rank_in_group=2, rank_global=2, world_size=4, start=0, stride=1 - - >>> # TP group: consecutive ranks [0,1,2,3], current global rank is 2 - >>> extract_group_info(tp_group, shmem) - (2, 2, 4, 0, 1) # rank_in_group=2, rank_global=2, world_size=4, start=0, stride=1 - - >>> # DP group: strided ranks [0,4,8,12], current global rank is 8 - >>> extract_group_info(dp_group, shmem) - (2, 8, 4, 0, 4) # rank_in_group=2, rank_global=8, world_size=4, start=0, stride=4 """ - if group is None: - # Use all ranks in shmem context - # When group is None, rank_in_group equals rank_global - rank_global = shmem.get_rank() - rank_in_group = rank_global - world_size = shmem.get_num_ranks() - rank_start = 0 - rank_stride = 1 - return rank_in_group, rank_global, world_size, rank_start, rank_stride - - # Extract from ProcessGroup - import torch.distributed as dist - - if not dist.is_initialized(): - raise RuntimeError( - "torch.distributed must be initialized to use ProcessGroup. " - "Call torch.distributed.init_process_group() first." - ) - - group_ranks = dist.get_process_group_ranks(group) - world_size = len(group_ranks) - rank_global = dist.get_rank() - - if rank_global not in group_ranks: - raise RuntimeError( - f"Current rank {rank_global} is not part of the specified process group. " - f"Group contains ranks: {group_ranks}" - ) - - rank_in_group = group_ranks.index(rank_global) - - # Detect stride pattern - if len(group_ranks) > 1: - # Check if all consecutive pairs have the same stride - strides = [group_ranks[i] - group_ranks[i - 1] for i in range(1, len(group_ranks))] - is_strided = all(s == strides[0] for s in strides) - - if is_strided: - rank_start = group_ranks[0] - rank_stride = strides[0] - - # Validate rank_stride is not zero (would indicate duplicate ranks) - if rank_stride == 0: - raise ValueError( - f"Invalid process group: rank_stride is 0, indicating duplicate ranks. " - f"Group ranks: {group_ranks}. " - f"Each rank must appear exactly once in a process group." - ) - else: - # Non-strided group - not supported yet - raise NotImplementedError( - f"Non-strided process groups are not yet supported. " - f"Group ranks: {group_ranks}. " - f"Please use groups with uniform stride (e.g., [0,1,2,3] or [0,4,8,12])." - ) - else: - # Single rank group - rank_start = group_ranks[0] - rank_stride = 1 - - return rank_in_group, rank_global, world_size, rank_start, rank_stride + + return _extract_group_info(group, shmem.get_rank(), shmem.get_num_ranks()) diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 8aead7c41..1a06f284e 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -516,6 +516,9 @@ def __init__(self, heap_size=1 << 30): # Initialize CCL interface self.ccl = self.CCL(self) + # Pre-build the device context tensor + self._build_device_context() + class CCL: """ Collective Communication Library (CCL) interface for IrisGluon. @@ -661,6 +664,19 @@ def error(self, message): """Log an error message with rank information.""" self._log_with_rank(logging.ERROR, message) + def _build_device_context(self): + """ + Build and cache the device context tensor. + + Called during __init__ to pre-build the tensor once. + """ + # Convert heap_bases to a list for concatenation + heap_bases_list = self.heap_bases.tolist() + + # Create context tensor: [cur_rank, num_ranks, heap_base_0, heap_base_1, ...] + context_data = [self.cur_rank, self.num_ranks] + heap_bases_list + self._device_context = torch.tensor(context_data, dtype=torch.int64, device=self.device) + def get_device_context(self): """ Get the device context tensor for Gluon kernels. @@ -679,14 +695,7 @@ def get_device_context(self): >>> ctx = IrisDeviceCtx.initialize(context_tensor) >>> data = ctx.load(buffer, 1) """ - # Convert heap_bases to a list for concatenation - heap_bases_list = self.heap_bases.tolist() - - # Create context tensor: [cur_rank, num_ranks, heap_base_0, heap_base_1, ...] - context_data = [self.cur_rank, self.num_ranks] + heap_bases_list - context_tensor = torch.tensor(context_data, dtype=torch.int64, device=self.device) - - return context_tensor + return self._device_context def get_backend(self): """ diff --git a/iris/iris.py b/iris/iris.py index d88a0e426..7516e3eaa 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -38,6 +38,7 @@ >>> data = device_ctx.load(buffer, from_rank=remote_rank) """ +import os import triton import triton.language as tl from triton.language.core import _aggregate as aggregate @@ -45,6 +46,7 @@ from iris._distributed_helpers import ( init_distributed, distributed_barrier, + distributed_device_barrier, distributed_broadcast_scalar, distributed_broadcast_tensor, ) @@ -55,6 +57,7 @@ ) from iris.symmetric_heap import SymmetricHeap import numpy as np +from typing import Any import torch import logging @@ -115,7 +118,8 @@ def __init__(self, heap_size=1 << 30, allocator_type="torch"): import json heap_bases_list = [int(self.heap_bases[r].item()) for r in range(self.num_ranks)] - out_path = f"iris_rank_{self.cur_rank}_heap_bases.json" + prefix = os.environ.get("IRIS_HEAP_BASES_PREFIX", "iris") + out_path = f"{prefix}_rank_{self.cur_rank}_heap_bases.json" with open(out_path, "w") as f: json.dump( { @@ -135,9 +139,15 @@ def __init__(self, heap_size=1 << 30, allocator_type="torch"): # Lazy initialization for ops interface self._ops = None + # Device-side barrier state, keyed by process group (None = all ranks). + self._device_barrier_state: dict[Any, torch.Tensor] = {} + # Initialize tracing self.tracing = Tracing(self) + # Pre-build the device context tensor (rebuilt when tracing is enabled) + self._build_device_context() + def __del__(self): """Cleanup resources on deletion.""" try: @@ -899,31 +909,11 @@ def get_heap_bases(self): """ return self.heap_bases - def get_device_context(self): + def _build_device_context(self): """ - Get the device context tensor for DeviceContext initialization. - - Returns a tensor encoding: [cur_rank, world_size, heap_base_0, heap_base_1, ...] - If tracing is enabled, also includes: [trace_enabled, max_events, trace_counter_ptr, trace_buffer_ptrs...] - - This opaque format allows future extension without breaking the API. + Build and cache the device context tensor. - Returns: - torch.Tensor: Encoded context data as int64 tensor on device - - Example: - >>> import iris - >>> from iris import DeviceContext - >>> import triton - >>> import triton.language as tl - >>> - >>> ctx = iris.iris() - >>> context_tensor = shmem.get_device_context() - >>> - >>> @triton.jit - >>> def my_kernel(context_tensor, rank: tl.constexpr, world_size: tl.constexpr, ...): - >>> ctx = DeviceContext.initialize(context_tensor, rank, world_size) - >>> data = ctx.load(buffer, from_rank=1) + Called during __init__ and again after tracing.enable() to include tracing fields. """ # Convert heap_bases to a list for concatenation heap_bases_list = self.heap_bases.tolist() @@ -958,9 +948,35 @@ def get_device_context(self): else: context_data += [0] # trace_enabled = 0 (false) - context_tensor = torch.tensor(context_data, dtype=torch.int64, device=self.device) + self._device_context = torch.tensor(context_data, dtype=torch.int64, device=self.device) + + def get_device_context(self): + """ + Get the device context tensor for DeviceContext initialization. + + Returns a tensor encoding: [cur_rank, world_size, heap_base_0, heap_base_1, ...] + If tracing is enabled, also includes: [trace_enabled, max_events, trace_counter_ptr, trace_buffer_ptrs...] + + This opaque format allows future extension without breaking the API. + + Returns: + torch.Tensor: Encoded context data as int64 tensor on device - return context_tensor + Example: + >>> import iris + >>> from iris import DeviceContext + >>> import triton + >>> import triton.language as tl + >>> + >>> ctx = iris.iris() + >>> context_tensor = ctx.get_device_context() + >>> + >>> @triton.jit + >>> def my_kernel(context_tensor, rank: tl.constexpr, world_size: tl.constexpr, ...): + >>> ctx = DeviceContext.initialize(context_tensor, rank, world_size) + >>> data = ctx.load(buffer, from_rank=1) + """ + return self._device_context def barrier(self, stream=None, group=None): """ @@ -989,6 +1005,36 @@ def barrier(self, stream=None, group=None): # Distributed barrier distributed_barrier(group=group) + def device_barrier(self, group=None): + """ + Device-side barrier that is CUDA graph capturable. + + Unlike ``barrier()`` which uses host-side ``torch.distributed.barrier()``, + this uses device-side atomic operations on the symmetric heap to synchronize + ranks. Stateless w.r.t. host-side epoch tracking: each rank's flag on + the heap serves as its own epoch counter, managed entirely by the GPU + via atomic_add. A persistent per-group flags tensor is cached in + ``_device_barrier_state``. + + Args: + group (ProcessGroup, optional): The process group to synchronize. + If None, uses all ranks in the shmem context. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> ctx.device_barrier() # Synchronize all ranks on device + """ + if group not in self._device_barrier_state: + self._device_barrier_state[group] = self.zeros((self.num_ranks,), dtype=torch.int32) + + distributed_device_barrier( + self._device_barrier_state[group], + group, + self.cur_rank, + self.num_ranks, + self.get_heap_bases(), + ) + def get_device(self): """ Get the underlying device where the Iris symmetric heap resides. @@ -1865,13 +1911,14 @@ def load( This function performs a memory read operation by translating the pointer from the `from_rank`'s address space to the `to_rank`'s address space and loading - data from the target memory location. If the `from_rank` and `to_rank` are the same, - this function performs a local load operation. + data from the target memory location. The load is **local** when + ``to_rank == from_rank``, and **remote** (cross-GPU) otherwise. - The `cache_modifier` parameter controls instruction-level cache behavior - by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits - in the global load instruction. These affect cache usage across the CU, - L2, and last-level caches. + The `cache_modifier` is passed through to the underlying ``tl.load()`` call + unconditionally — it is the caller's responsibility to choose an appropriate + modifier. Cache modifiers control instruction-level cache behavior by setting + the appropriate scope (``SC0``, ``SC1``) and non-temporal (``NT``) bits in + the load instruction, following the CDNA ISA. Args: pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. @@ -1880,7 +1927,7 @@ def load( heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. other (Block, optional): Value to return for masked-out elements. If not provided, the result for masked-out elements is undefined. Defaults to None. - cache_modifier (str, optional): Controls cache behavior of the load. + cache_modifier (str, optional): Controls cache behavior of the load. It is the caller's responsibility to use modifiers appropriately. Supported values: - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. @@ -1926,13 +1973,14 @@ def store( This function performs a memory write operation by translating the pointer from the `from_rank`'s address space to the `to_rank`'s address space and storing - the provided data to the target memory location. If the `from_rank` and `to_rank` are the same, - this function performs a local store operation. + the provided data to the target memory location. The store is **local** when + ``from_rank == to_rank``, and **remote** (cross-GPU) otherwise. - The `cache_modifier` parameter controls instruction-level cache behavior - by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits - in the global store instruction. These affect cache usage across the CU (L1), - L2, and last-level cache (LLC), following the CDNA ISA. + The `cache_modifier` is always passed through to the underlying ``tl.store()`` + call unconditionally — it is the caller's responsibility to choose an appropriate + modifier for the operation (local vs. remote). Cache modifiers control instruction-level + cache behavior by setting the appropriate scope (``SC0``, ``SC1``) and non-temporal + (``NT``) bits in the store instruction, following the CDNA ISA. Args: pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. @@ -1942,7 +1990,7 @@ def store( heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). - cache_modifier (str, optional): Controls cache behavior of the store. Ignored for remote stores (when `from_rank != to_rank`) as cache modifiers are not supported for cross-GPU memory operations. Supported values are: + cache_modifier (str, optional): Controls cache behavior of the store. It is the caller's responsibility to use modifiers appropriately. Supported values are: - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. @@ -1963,10 +2011,7 @@ def store( >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) - if from_rank == to_rank: - tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier) - else: - tl.store(translated_ptr, value, mask=mask) + tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier) @triton.jit @@ -1988,9 +2033,13 @@ def copy( This function performs the transfer by translating `src_ptr` from the `from_rank`'s address space to the `to_rank`'s address space, performing a masked load from the translated source, and storing the loaded data to `dst_ptr` in the `to_rank` memory location. - If `from_rank` and `to_rank` are the same, this function performs a local copy operation. It is undefined behaviour if neither `from_rank` nor `to_rank` is the `cur_rank`. + The load is from ``from_rank`` (remote if ``from_rank != cur_rank``) and the store is to + ``to_rank`` (remote if ``to_rank != cur_rank``). Both ``load_cache_modifier`` and + ``store_cache_modifier`` are passed through unconditionally — it is the caller's + responsibility to choose appropriate modifiers for local vs. remote operations. + Args: src_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s local memory from which to read data. dst_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `to_rank`'s local memory where the data will be written. @@ -2000,13 +2049,13 @@ def copy( heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. - load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + load_cache_modifier (str, optional): Controls cache behavior of the load. It is the caller's responsibility to use modifiers appropriately. Supported values are: - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. - store_cache_modifier (str, optional): Controls cache behavior of the store. Only effective for local stores (when `to_rank == cur_rank`). Supported values are: + store_cache_modifier (str, optional): Controls cache behavior of the store. It is the caller's responsibility to use modifiers appropriately. Supported values are: - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. @@ -2047,10 +2096,7 @@ def copy( translated_dst = tl.max_contiguous(tl.multiple_of(translated_dst, hint), hint) data = tl.load(translated_src, mask=mask, other=other, cache_modifier=load_cache_modifier) - if to_rank == cur_rank: - tl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier) - else: - tl.store(translated_dst, data, mask=mask) + tl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit @@ -2071,8 +2117,12 @@ def get( This function performs a memory read operation by translating the `from_ptr` from the current rank's address space to the `from_rank`'s address space, loading data - from the `from_rank` memory location, and storing it to the local `to_ptr`. - If the `from_rank` is the same as the current rank, this function performs a local copy operation. + from the `from_rank`'s memory location, and storing it to the local `to_ptr`. + + The load is **remote** when ``from_rank != to_rank`` (reading from a peer GPU), while the + store is **always local** (writing to `to_ptr` in the current rank's own memory). Both + ``load_cache_modifier`` and ``store_cache_modifier`` are passed through unconditionally — + it is the caller's responsibility to choose appropriate modifiers. Args: from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `from_rank`'s address space. Must be the current rank where the pointer is local. @@ -2082,13 +2132,13 @@ def get( heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. - load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + load_cache_modifier (str, optional): Controls cache behavior of the load (remote when ``from_rank != to_rank``). It is the caller's responsibility to use modifiers appropriately. Supported values are: - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. - store_cache_modifier (str, optional): Controls cache behavior of the store. The store is always to local memory (`to_ptr`), so this is always applied. Supported values are: + store_cache_modifier (str, optional): Controls cache behavior of the store. The store is always to local memory (``to_ptr``). Supported values are: - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. @@ -2131,7 +2181,11 @@ def put( This function performs a memory write operation by loading data from the current rank's `from_ptr`, translating the `to_ptr` from the current rank's address space to the `to_rank`'s address space, and storing the data to the `to_rank` memory location. - If the `to_rank` is the same as the current rank, this function performs a local copy operation. + + The load is **always local** (reading from the current rank's own ``from_ptr``), while the + store is **remote** when ``from_rank != to_rank`` (writing to a peer GPU). Both + ``load_cache_modifier`` and ``store_cache_modifier`` are passed through unconditionally — + it is the caller's responsibility to choose appropriate modifiers. Args: from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's local memory from which to read data. @@ -2142,13 +2196,13 @@ def put( mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. - load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + load_cache_modifier (str, optional): Controls cache behavior of the load (always local). Supported values are: - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. - store_cache_modifier (str, optional): Controls cache behavior of the store. Only effective for local stores (when `from_rank == to_rank`). Supported values are: + store_cache_modifier (str, optional): Controls cache behavior of the store (remote when ``from_rank != to_rank``). It is the caller's responsibility to use modifiers appropriately. Supported values are: - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. @@ -2170,10 +2224,7 @@ def put( data = tl.load(from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) - if from_rank == to_rank: - tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) - else: - tl.store(translated_to_ptr, data, mask=mask) + tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit diff --git a/iris/topology.py b/iris/topology.py new file mode 100644 index 000000000..42a83ae0e --- /dev/null +++ b/iris/topology.py @@ -0,0 +1,1373 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +from __future__ import annotations + +import json +import logging +import os +import re +import socket +from dataclasses import dataclass, field +from enum import IntEnum +from typing import Any, Dict, List, Optional, Set, Tuple + +import torch +import torch.distributed as dist + +logger = logging.getLogger("iris.topology") + + +class InterconnectLevel(IntEnum): + """Hierarchical interconnect tiers.""" + + INTRA_NODE = 0 + INTRA_RACK_FABRIC = 1 + INTER_NODE_RDMA = 2 + + def __str__(self) -> str: + return self.name + + +class IntraNodeLinkType(IntEnum): + """Link type between two GPUs.""" + + SELF = -1 # Same GPU (diagonal) + NVLINK = 0 # NVIDIA NVLink or AMD xGMI + NVSWITCH = 1 # Connected through NVSwitch (all-to-all NVLink) + PCIE_SWITCH = 2 # Same PCIe switch (PIX/PXB) + PCIE_HOST_BRIDGE = 3 # Same CPU socket / PCIe host bridge (PHB) + PCIE_NUMA = 4 # Crosses NUMA boundary (NODE) + PCIE_SYSTEM = 5 # Crosses QPI/UPI between sockets (SYS) + UNKNOWN = 99 + + def __str__(self) -> str: + return self.name + + +@dataclass +class FabricInfo: + """ + GPU fabric domain identification. + + This is a vendor-agnostic representation of where a GPU sits in the + physical fabric topology. Two GPUs with the same (cluster_uuid, clique_id) + are in the same high-speed fabric domain and can communicate via direct + GPU-to-GPU links (NVLink via NVSwitch, or xGMI) without RDMA. + + AMD mapping: + cluster_uuid <-> ppod_id (physical pod identifier, uint64) + clique_id <-> vpod_id (virtual pod identifier, uint32) + + NVIDIA mapping: + cluster_uuid <-> clusterUuid[16] (NVLink domain UUID, bytes) + clique_id <-> cliqueId (fabric clique ID, uint32) + + If both fields are empty/zero, fabric info is unavailable (e.g. no + NVSwitch, no xGMI hive, single-node PCIe-only system). + """ + + cluster_uuid: str = "" # Domain identifier (ppod_id hex / clusterUuid hex) + clique_id: int = 0 # Sub-domain identifier (vpod_id / cliqueId) + + @property + def is_valid(self) -> bool: + """True if fabric info was successfully retrieved.""" + return bool(self.cluster_uuid) + + @property + def domain_key(self) -> str: + """ + Combined key for domain comparison. + + Two GPUs with the same domain_key are in the same fabric domain. + """ + if not self.is_valid: + return "" + return f"{self.cluster_uuid}:{self.clique_id}" + + def to_dict(self) -> dict: + return {"cluster_uuid": self.cluster_uuid, "clique_id": self.clique_id} + + @classmethod + def from_dict(cls, d: dict) -> "FabricInfo": + return cls(cluster_uuid=d.get("cluster_uuid", ""), clique_id=d.get("clique_id", 0)) + + +# --------------------------------------------------------------------------- +# Logical-to-physical GPU index translation +# --------------------------------------------------------------------------- + + +def _logical_to_physical_gpu_index(logical_idx: int, vendor: str) -> int: + """ + Translate a logical (PyTorch) GPU index to the physical index used by + vendor libraries (NVML / AMDSMI). + + PyTorch (CUDA runtime) respects CUDA_VISIBLE_DEVICES / HIP_VISIBLE_DEVICES + and remaps device indices so that logical 0 = first *visible* GPU. + NVML and AMDSMI always enumerate *all* physical GPUs starting from 0. + + This function parses the relevant visibility env vars to recover the + physical index. If no env var is set or the entry is not a plain + integer (e.g. GPU UUIDs in CUDA_VISIBLE_DEVICES), it falls back to + returning the logical index unchanged — callers that need robustness + against UUID-style entries should prefer PCI-bus-ID-based handle + resolution instead (see _get_nvml_handle_by_pci / _get_amdsmi_handle_by_pci). + """ + if logical_idx < 0: + return logical_idx + + if vendor == "nvidia": + env_vars = ["CUDA_VISIBLE_DEVICES"] + elif vendor == "amd": + # ROCm checks HIP_VISIBLE_DEVICES first, then ROCR_VISIBLE_DEVICES + env_vars = ["HIP_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES"] + else: + return logical_idx + + for env_var in env_vars: + visible = os.environ.get(env_var) + if not visible: + continue + + entries = [e.strip() for e in re.split(r"[,\s]+", visible) if e.strip()] + if logical_idx >= len(entries): + continue + + entry = entries[logical_idx] + + # Skip UUID-style entries (e.g. "GPU-xxxxxxxx-...") — can't map + # to a plain integer index. Caller should use PCI-based resolution. + if not entry.isdigit() and not (entry.startswith("-") and entry[1:].isdigit()): + logger.debug( + "%s entry '%s' is not a plain index; falling back to logical index %d", + env_var, + entry, + logical_idx, + ) + return logical_idx + + try: + return int(entry) + except ValueError: + return logical_idx + + return logical_idx + + +# --------------------------------------------------------------------------- +# PCI-bus-ID-based vendor handle resolution +# --------------------------------------------------------------------------- + + +def _get_nvml_handle_by_pci(pci_bus_id: str): + """ + Get an NVML device handle by PCI bus ID. + + This bypasses index-based lookup entirely and is immune to + CUDA_VISIBLE_DEVICES remapping. Returns None on failure. + """ + try: + import pynvml + except ImportError: + return None + + norm = _normalize_pci_bus_id(pci_bus_id) + # nvmlDeviceGetHandleByPciBusId expects full domain:bus:dev.fn + if len(norm.split(":")) == 2: + norm = f"0000:{norm}" + + try: + return pynvml.nvmlDeviceGetHandleByPciBusId(norm.encode()) + except Exception as e: + logger.debug("nvmlDeviceGetHandleByPciBusId(%s) failed: %s", norm, e) + return None + + +def _get_amdsmi_handle_by_pci(pci_bus_id: str, all_handles=None): + """ + Get an amdsmi processor handle by PCI bus ID (BDF). + + If *all_handles* is provided, it is used directly; otherwise + amdsmi_get_processor_handles() is called (caller must have + already called amdsmi_init). + + Returns None if no match is found. + """ + try: + import amdsmi + except ImportError: + return None + + if all_handles is None: + all_handles = amdsmi.amdsmi_get_processor_handles() + + norm = _normalize_pci_bus_id(pci_bus_id) + + for handle in all_handles: + try: + bdf = amdsmi.amdsmi_get_gpu_device_bdf(handle) + if bdf and _normalize_pci_bus_id(str(bdf)) == norm: + return handle + except Exception: + continue + return None + + +# --------------------------------------------------------------------------- +# Fabric info +# --------------------------------------------------------------------------- + + +def _amd_get_gpu_fabric_info(gpu_id: int, pci_bus_id: str = "") -> FabricInfo: + """ + Get GPU fabric info from AMD's AMDSMI library. + + GPUs with matching (ppod_id, vpod_id) are in the same fabric + domain and can communicate via direct GPU links without RDMA. + + Args: + gpu_id: Local GPU device index (0-based, logical). + pci_bus_id: PCI bus ID for handle resolution (preferred over index). + + Returns: + FabricInfo with cluster_uuid = hex(ppod_id), clique_id = vpod_id. + Returns empty FabricInfo if the call fails or is not available. + """ + # Default placeholder: no fabric info available (single-node behavior) + return FabricInfo() + + +def _nvidia_get_gpu_fabric_info(gpu_id: int, pci_bus_id: str = "") -> FabricInfo: + """ + Get GPU fabric info from NVIDIA's NVML library. + + When *pci_bus_id* is provided, the NVML handle is resolved via PCI + address — immune to CUDA_VISIBLE_DEVICES remapping. Falls back to + physical-index-based lookup if PCI resolution fails. + + Args: + gpu_id: Local GPU device index (0-based, logical). + pci_bus_id: PCI bus ID for handle resolution (preferred). + """ + try: + import pynvml + + pynvml.nvmlInit() + try: + # --- Resolve handle: prefer PCI bus ID, fall back to physical index --- + handle = None + if pci_bus_id and pci_bus_id != "unknown": + handle = _get_nvml_handle_by_pci(pci_bus_id) + + if handle is None: + physical_idx = _logical_to_physical_gpu_index(gpu_id, "nvidia") + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_idx) + + fabric_info = None + try: + info_struct = pynvml.c_nvmlGpuFabricInfo_v2_t() + pynvml.nvmlDeviceGetGpuFabricInfoV(handle, info_struct) + fabric_info = info_struct + except (AttributeError, TypeError, pynvml.NVMLError): + # GPU doesn't support fabric + return FabricInfo() + + if fabric_info is None: + return FabricInfo() + + # Check registration state — must be COMPLETED (value 3) + state = getattr(fabric_info, "state", None) + if state is not None and state != 3: + return FabricInfo() + + # Check status — must be SUCCESS (value 0) + status = getattr(fabric_info, "status", None) + if status is not None and status != 0: + return FabricInfo() + + # Extract clusterUuid + cluster_uuid_raw = getattr(fabric_info, "clusterUuid", None) + if cluster_uuid_raw is None: + return FabricInfo() + + if isinstance(cluster_uuid_raw, bytes): + cluster_uuid_hex = cluster_uuid_raw.hex() + elif isinstance(cluster_uuid_raw, (list, tuple)): + cluster_uuid_hex = bytes(cluster_uuid_raw).hex() + else: + cluster_uuid_hex = str(cluster_uuid_raw) + + if all(c == "0" for c in cluster_uuid_hex): + return FabricInfo() + + clique_id = getattr(fabric_info, "cliqueId", 0) + + return FabricInfo( + cluster_uuid=cluster_uuid_hex, + clique_id=int(clique_id), + ) + finally: + pynvml.nvmlShutdown() + + except ImportError: + logger.debug("pynvml not available, skipping NVML fabric info") + except Exception as e: + logger.debug("NVML fabric info query failed for GPU %d: %s", gpu_id, e) + + return FabricInfo() + + +def get_gpu_fabric_info(gpu_id: int, vendor: str, pci_bus_id: str = "") -> FabricInfo: + """ + Get GPU fabric domain info for the given device. + + Dispatches to the appropriate vendor-specific implementation: + AMD: amdsmi_get_gpu_fabric_info (ppod_id, vpod_id) + NVIDIA: nvmlDeviceGetGpuFabricInfoV (clusterUuid, cliqueId) + + Args: + gpu_id: Local GPU device index. + vendor: "amd" or "nvidia". + pci_bus_id: PCI bus ID for PCI-based handle resolution (preferred). + + Returns: + FabricInfo identifying the fabric domain this GPU belongs to. + """ + if vendor == "amd": + return _amd_get_gpu_fabric_info(gpu_id, pci_bus_id=pci_bus_id) + elif vendor == "nvidia": + return _nvidia_get_gpu_fabric_info(gpu_id, pci_bus_id=pci_bus_id) + else: + return FabricInfo() + + +def _normalize_pci_bus_id(bus_id: str) -> str: + """ + Normalize a PCI bus ID to a canonical lowercase form for comparison. + + Handles formats like: + "0000:41:00.0" -> "0000:41:00.0" + "00000000:41:00.0" -> "0000:41:00.0" (nvidia sometimes uses 8-char domain) + "GPU 0000:41:00.0" -> "0000:41:00.0" (strip prefix junk) + """ + bus_id = bus_id.strip().lower() + # Extract the BDF pattern (domain:bus:device.function) + match = re.search(r"([0-9a-f]+:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9a-f])", bus_id) + if not match: + return bus_id + bdf = match.group(1) + # Normalize domain to 4 hex chars + parts = bdf.split(":") + if len(parts) == 3: + domain = parts[0] + # Truncate or pad domain to 4 chars + if len(domain) > 4: + domain = domain[-4:] # Take last 4 (nvidia-smi uses 8-char 00000000) + elif len(domain) < 4: + domain = domain.zfill(4) + return f"{domain}:{parts[1]}:{parts[2]}" + return bdf + + +@dataclass +class GPUInfo: + """Information about a single GPU, gathered from the rank that owns it.""" + + global_rank: int + local_rank: int # rank index within the node (0..ranks_per_node-1) + hostname: str + gpu_id: int # CUDA/HIP device index (local to process visibility) + pci_bus_id: str # e.g. "0000:41:00.0" — physical PCI address + device_name: str # e.g. "NVIDIA A100-SXM4-80GB" or "AMD Instinct MI300X" + total_memory_mb: int + numa_node: int # NUMA node affinity (-1 if unknown) + vendor: str # "nvidia" or "amd" + uuid: str # GPU UUID + fabric_info: FabricInfo = field(default_factory=FabricInfo) + + def to_dict(self) -> dict: + d = {k: v for k, v in self.__dict__.items() if k != "fabric_info"} + d["fabric_info"] = self.fabric_info.to_dict() + return d + + @classmethod + def from_dict(cls, d: dict) -> "GPUInfo": + # Don't mutate the input dict — use .get() and filter instead + fabric_data = d.get("fabric_info", {}) + filtered = {k: v for k, v in d.items() if k != "fabric_info"} + info = cls(**filtered) + info.fabric_info = FabricInfo.from_dict(fabric_data) + return info + + +@dataclass +class NodeInfo: + """Aggregated information about a single physical host (node).""" + + hostname: str + ranks: List[int] = field(default_factory=list) + gpu_ids: List[int] = field(default_factory=list) # gpu_id per rank (may have dups) + unique_gpu_ids: List[int] = field(default_factory=list) # deduplicated, sorted + unique_pci_ids: List[str] = field(default_factory=list) # deduplicated physical PCI bus IDs + num_gpus: int = 0 # count of unique physical GPUs (by PCI bus ID), NOT ranks + num_ranks: int = 0 # count of ranks on this node + has_infiniband: bool = False + ib_devices: List[str] = field(default_factory=list) + # Intra-node link matrix: link_types[gpu_i][gpu_j] indexed by gpu_id (device index) + link_types: Optional[List[List[int]]] = None + # P2P access matrix: p2p_access[gpu_i][gpu_j] indexed by gpu_id (device index) + p2p_access: Optional[List[List[bool]]] = None + # Fabric domain keys for GPUs on this node + fabric_domain_key: str = "" # primary (for backward compat) + fabric_domain_keys: List[str] = field(default_factory=list) # all unique + + def get_link_type(self, gpu_id_a: int, gpu_id_b: int) -> IntraNodeLinkType: + """ + Look up the link type between two GPUs by their device index (gpu_id). + + This is the safe accessor that avoids the oversubscription IndexError + where local_rank > len(link_types matrix). + """ + if self.link_types is None: + return IntraNodeLinkType.UNKNOWN + if gpu_id_a == gpu_id_b: + return IntraNodeLinkType.SELF + if gpu_id_a >= len(self.link_types) or gpu_id_b >= len(self.link_types[0]): + return IntraNodeLinkType.UNKNOWN + return IntraNodeLinkType(self.link_types[gpu_id_a][gpu_id_b]) + + def can_p2p_access(self, gpu_id_a: int, gpu_id_b: int) -> bool: + """ + Look up P2P accessibility between two GPUs by their device index (gpu_id). + """ + if self.p2p_access is None: + return gpu_id_a == gpu_id_b + if gpu_id_a == gpu_id_b: + return True + if gpu_id_a >= len(self.p2p_access) or gpu_id_b >= len(self.p2p_access[0]): + return False + return self.p2p_access[gpu_id_a][gpu_id_b] + + +@dataclass +class TopologyMap: + """ + Complete cluster topology, built from all-gathered GPU information. + + This is the primary output of topology discovery and is used by the + hierarchical memory manager to decide communication strategies. + """ + + world_size: int + num_nodes: int + gpu_info: Dict[int, GPUInfo] # rank -> GPUInfo + nodes: Dict[str, NodeInfo] # hostname -> NodeInfo + fabric_domains: Dict[str, List[str]] # domain_key -> [hostname, ...] + # Precomputed peer groups (lazily populated) + _node_peers: Dict[int, Set[int]] = field(default_factory=dict) + _fabric_domain_peers: Dict[int, Set[int]] = field(default_factory=dict) + + def get_interconnect_level(self, rank_a: int, rank_b: int) -> InterconnectLevel: + """ + Determine the interconnect tier between two ranks. + + Decision tree: + 1. Same hostname -> INTRA_NODE (IPC handles) + 2. Same fabric domain_key -> INTRA_RACK_FABRIC (NVLink/xGMI fabric) + 3. Otherwise -> INTER_NODE_RDMA (InfiniBand) + """ + if rank_a == rank_b: + return InterconnectLevel.INTRA_NODE + + info_a = self.gpu_info[rank_a] + info_b = self.gpu_info[rank_b] + + # Same hostname -> same physical node + if info_a.hostname == info_b.hostname: + return InterconnectLevel.INTRA_NODE + + # Same fabric domain -> intra-rack fabric (NVLink domain / xGMI hive) + key_a = info_a.fabric_info.domain_key + key_b = info_b.fabric_info.domain_key + if key_a and key_a == key_b: + return InterconnectLevel.INTRA_RACK_FABRIC + + # Everything else -> RDMA + return InterconnectLevel.INTER_NODE_RDMA + + def get_node_peers(self, rank: int) -> Set[int]: + """Return all ranks on the same node as `rank` (excluding self).""" + if rank not in self._node_peers: + hostname = self.gpu_info[rank].hostname + self._node_peers[rank] = {r for r, info in self.gpu_info.items() if info.hostname == hostname and r != rank} + return self._node_peers[rank] + + def get_fabric_domain_peers(self, rank: int) -> Set[int]: + """Return all ranks in the same fabric domain (excluding self).""" + if rank not in self._fabric_domain_peers: + domain_key = self.gpu_info[rank].fabric_info.domain_key + if not domain_key: + self._fabric_domain_peers[rank] = set() + else: + self._fabric_domain_peers[rank] = { + r for r, info in self.gpu_info.items() if info.fabric_info.domain_key == domain_key and r != rank + } + return self._fabric_domain_peers[rank] + + def get_rdma_peers(self, rank: int) -> Set[int]: + """Return all ranks reachable only via RDMA.""" + all_ranks = set(self.gpu_info.keys()) + node_peers = self.get_node_peers(rank) + fabric_peers = self.get_fabric_domain_peers(rank) + return all_ranks - node_peers - fabric_peers - {rank} + + def get_ranks_for_node(self, hostname: str) -> List[int]: + """Return sorted list of ranks on a given node.""" + if hostname in self.nodes: + return sorted(self.nodes[hostname].ranks) + return [] + + def get_ranks_for_fabric_domain(self, domain_key: str) -> List[int]: + """Return sorted list of all ranks in a fabric domain.""" + if domain_key not in self.fabric_domains: + return [] + ranks = [] + for hostname in self.fabric_domains[domain_key]: + ranks.extend(self.get_ranks_for_node(hostname)) + return sorted(ranks) + + def summary(self) -> str: + """Human-readable summary of the topology.""" + lines = [ + "=== Iris Cluster Topology ===", + f"World size: {self.world_size} | Nodes: {self.num_nodes} | Fabric domains: {len(self.fabric_domains)}", + "", + ] + + for hostname, node in sorted(self.nodes.items()): + ib_str = f"IB: {', '.join(node.ib_devices)}" if node.has_infiniband else "IB: none" + fabric_str = f"fabric: {node.fabric_domain_key}" if node.fabric_domain_key else "" + oversubscribed = "" + if node.num_ranks > node.num_gpus: + oversubscribed = f" [oversubscribed: {node.num_ranks} ranks on {node.num_gpus} GPUs]" + lines.append( + f" Node '{hostname}': {node.num_gpus} GPUs, ranks {node.ranks}, {ib_str} {fabric_str}{oversubscribed}" + ) + for rank in sorted(node.ranks): + info = self.gpu_info[rank] + lines.append( + f" rank {rank}: GPU{info.gpu_id} " + f"({info.device_name}, {info.total_memory_mb}MB) " + f"PCI={info.pci_bus_id} NUMA={info.numa_node}" + ) + + if self.fabric_domains: + lines.append("") + lines.append("Fabric Domains:") + for domain_key, hostnames in sorted(self.fabric_domains.items()): + total_gpus = sum(self.nodes[h].num_gpus for h in hostnames if h in self.nodes) + lines.append(f" {domain_key}: {len(hostnames)} nodes, {total_gpus} GPUs, hosts={hostnames}") + + return "\n".join(lines) + + def __str__(self) -> str: + return self.summary() + + +def _detect_vendor() -> str: + """Detect whether we're running on NVIDIA or AMD GPUs.""" + if hasattr(torch.version, "hip") and torch.version.hip is not None: + return "amd" + if torch.cuda.is_available(): + return "nvidia" + return "unknown" + + +def _get_total_memory_mb(gpu_id: int) -> int: + """ + Get total GPU memory in MB, compatible across PyTorch versions. + + Handles both `total_memory` (newer PyTorch) and `total_mem` (older) + attribute names on device properties. + """ + props = torch.cuda.get_device_properties(gpu_id) + total_bytes = getattr(props, "total_memory", None) + if total_bytes is None: + total_bytes = getattr(props, "total_mem", 0) + return total_bytes // (1024 * 1024) + + +def _get_pci_bus_id(device_idx: int, vendor: str) -> str: + """ + Get the PCI bus ID for a GPU device. + + Uses _logical_to_physical_gpu_index() to translate the PyTorch + logical device index to the physical index expected by NVML/AMDSMI, + then queries NVML/AMDSMI by that physical index to obtain and + normalize the busId/BDF string. + """ + if device_idx < 0: + logger.debug("Invalid device index: %d", device_idx) + return "unknown" + + physical_idx = _logical_to_physical_gpu_index(device_idx, vendor) + + if vendor == "nvidia": + try: + import pynvml + + pynvml.nvmlInit() + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_idx) + pci_info = pynvml.nvmlDeviceGetPciInfo(handle) + bus_id = pci_info.busId + if isinstance(bus_id, bytes): + bus_id = bus_id.decode("utf-8") + return _normalize_pci_bus_id(bus_id) + finally: + pynvml.nvmlShutdown() + except ImportError: + logger.debug("pynvml not available") + except Exception as e: + logger.debug( + "NVML query failed for device %d (physical %d): %s", + device_idx, + physical_idx, + e, + ) + + elif vendor == "amd": + try: + import amdsmi + + amdsmi.amdsmi_init() + try: + handles = amdsmi.amdsmi_get_processor_handles() + if 0 <= physical_idx < len(handles): + bus_id = amdsmi.amdsmi_get_gpu_device_bdf(handles[physical_idx]) + if bus_id: + return _normalize_pci_bus_id(str(bus_id)) + else: + logger.debug( + "physical_idx %d out of range (%d GPUs)", + physical_idx, + len(handles), + ) + finally: + amdsmi.amdsmi_shut_down() + except ImportError: + logger.debug("amdsmi not available") + except Exception as e: + logger.debug( + "amdsmi query failed for device %d (physical %d): %s", + device_idx, + physical_idx, + e, + ) + + else: + logger.debug("Unknown vendor: %s", vendor) + + return "unknown" + + +def _get_gpu_uuid(device_idx: int, vendor: str, pci_bus_id: str = "") -> str: + """ + Get the unique UUID for a GPU device. + + When *pci_bus_id* is available, resolves the vendor handle via PCI + address (immune to visibility env var issues). Otherwise falls back + to physical-index-based lookup via _logical_to_physical_gpu_index(). + """ + if device_idx < 0: + logger.debug("Invalid device index: %d", device_idx) + return f"gpu-{socket.gethostname()}-{device_idx}" + + physical_idx = _logical_to_physical_gpu_index(device_idx, vendor) + + if vendor == "nvidia": + try: + import pynvml + + pynvml.nvmlInit() + try: + # Prefer PCI-based resolution — immune to CUDA_VISIBLE_DEVICES + handle = None + if pci_bus_id and pci_bus_id != "unknown": + handle = _get_nvml_handle_by_pci(pci_bus_id) + if handle is None: + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_idx) + + uuid = pynvml.nvmlDeviceGetUUID(handle) + if isinstance(uuid, bytes): + uuid = uuid.decode("utf-8") + return uuid + finally: + pynvml.nvmlShutdown() + except ImportError: + logger.debug("pynvml not available") + except Exception as e: + logger.debug( + "NVML UUID query failed for device %d (physical %d): %s", + device_idx, + physical_idx, + e, + ) + + elif vendor == "amd": + try: + import amdsmi + + amdsmi.amdsmi_init() + try: + # Prefer PCI-based resolution + handle = None + if pci_bus_id and pci_bus_id != "unknown": + handle = _get_amdsmi_handle_by_pci(pci_bus_id) + if handle is None: + handles = amdsmi.amdsmi_get_processor_handles() + if 0 <= physical_idx < len(handles): + handle = handles[physical_idx] + + if handle is not None: + uuid = amdsmi.amdsmi_get_gpu_device_uuid(handle) + if uuid: + return str(uuid) + else: + logger.debug( + "physical_idx %d out of range or PCI resolve failed", + physical_idx, + ) + finally: + amdsmi.amdsmi_shut_down() + except ImportError: + logger.debug("amdsmi not available") + except Exception as e: + logger.debug( + "amdsmi UUID query failed for device %d (physical %d): %s", + device_idx, + physical_idx, + e, + ) + + else: + logger.debug("Unknown vendor: %s", vendor) + + return f"gpu-{socket.gethostname()}-{device_idx}" + + +def _get_numa_node(pci_bus_id: str) -> int: + """ + Detect the NUMA node affinity for a GPU via sysfs. + Uses the PCI bus ID to read the NUMA node from + the kernel's PCI sysfs interface. + Returns -1 if unknown. + """ + if not pci_bus_id or pci_bus_id == "unknown": + return -1 + + pci_addr = pci_bus_id.lower() + # sysfs always uses full domain prefix (0000:xx:xx.x) + if len(pci_addr.split(":")) == 2: + pci_addr = f"0000:{pci_addr}" + + numa_path = f"/sys/bus/pci/devices/{pci_addr}/numa_node" + try: + with open(numa_path) as f: + node = int(f.read().strip()) + return node + except (FileNotFoundError, ValueError, OSError) as e: + logger.debug("Failed to read NUMA node from %s: %s", numa_path, e) + return -1 + + +def _detect_infiniband() -> Tuple[bool, List[str]]: + """Detect InfiniBand/RDMA devices on this node via the kernel sysfs interface.""" + ib_class_path = "/sys/class/infiniband" + try: + if os.path.isdir(ib_class_path): + devices = os.listdir(ib_class_path) + if devices: + return True, devices + except OSError as e: + logger.debug("Failed to read %s: %s", ib_class_path, e) + + return False, [] + + +def _detect_intra_node_topology( + local_pci_bus_ids: List[str], + vendor: str, +) -> Tuple[Optional[List[List[int]]], Optional[List[List[bool]]]]: + """ + Detect intra-node GPU-to-GPU topology (link types and P2P accessibility). + + Args: + local_pci_bus_ids: PCI bus IDs of the GPUs visible to this process, + indexed by local device index (matching torch.cuda device ordering). + Used to correctly map physical CLI tool output to local indices. + vendor: "nvidia" or "amd". + + Returns: + (link_types, p2p_access) — both indexed by local device index (gpu_id). + link_types may be None if CLI tools are unavailable. + """ + num_gpus = len(local_pci_bus_ids) + link_types = None + p2p_access = None + + # P2P access detection uses PyTorch device indices, which already respect + # visibility environment variables (e.g., CUDA_VISIBLE_DEVICES / HIP_VISIBLE_DEVICES) + # and any logical device remapping applied by the runtime. + try: + p2p_access = [] + for i in range(num_gpus): + row = [] + for j in range(num_gpus): + if i == j: + row.append(True) + else: + row.append(torch.cuda.can_device_access_peer(i, j)) + p2p_access.append(row) + except Exception as e: + logger.debug(f"P2P access detection failed: {e}") + p2p_access = None + + # Link type detection parses CLI tools that report physical topology. + # We must map physical GPU indices to our local device indices via PCI bus IDs. + if vendor == "nvidia": + link_types = _parse_nvidia_topo(local_pci_bus_ids) + elif vendor == "amd": + link_types = _parse_amd_topo(local_pci_bus_ids) + + return link_types, p2p_access + + +# Map NVML topology levels to IntraNodeLinkType +_NVML_TOPO_TO_LINK = { + 0: IntraNodeLinkType.SELF, # NVML_TOPOLOGY_INTERNAL + 10: IntraNodeLinkType.PCIE_SWITCH, # NVML_TOPOLOGY_SINGLE + 20: IntraNodeLinkType.PCIE_SWITCH, # NVML_TOPOLOGY_MULTIPLE + 30: IntraNodeLinkType.PCIE_HOST_BRIDGE, # NVML_TOPOLOGY_HOSTBRIDGE + 40: IntraNodeLinkType.PCIE_NUMA, # NVML_TOPOLOGY_NODE + 50: IntraNodeLinkType.PCIE_SYSTEM, # NVML_TOPOLOGY_SYSTEM +} + +# Max NVLink connections per GPU (18 for B200) +_MAX_NVLINK_LINKS = 18 + + +def _nvml_check_nvlink(handle_src, handle_dst, pynvml) -> bool: + """Check if any active NVLink connects two GPU handles.""" + try: + pci_dst = pynvml.nvmlDeviceGetPciInfo(handle_dst) + target_bus = pci_dst.busId + if isinstance(target_bus, bytes): + target_bus = target_bus.decode("utf-8") + target_bus = target_bus.lower() + except Exception: + return False + + for link in range(_MAX_NVLINK_LINKS): + try: + state = pynvml.nvmlDeviceGetNvLinkState(handle_src, link) + if not state: + continue + remote_pci = pynvml.nvmlDeviceGetNvLinkRemotePciInfo(handle_src, link) + remote_bus = remote_pci.busId + if isinstance(remote_bus, bytes): + remote_bus = remote_bus.decode("utf-8") + if remote_bus.lower() == target_bus: + return True + except Exception: + break # Link index doesn't exist; indices are contiguous from 0 + + return False + + +def _parse_nvidia_topo(local_pci_bus_ids: List[str]) -> Optional[List[List[int]]]: + """ + Build GPU-GPU link type matrix using pynvml topology queries. + + Args: + local_pci_bus_ids: PCI bus IDs for each local PyTorch device index. + + Returns: + Link type matrix indexed by local device index, or None on failure. + """ + num_local = len(local_pci_bus_ids) + + try: + import pynvml + + pynvml.nvmlInit() + try: + # Resolve local PCI bus IDs directly to NVML handles + handles = [] + for pci in local_pci_bus_ids: + norm = _normalize_pci_bus_id(pci) + if len(norm.split(":")) == 2: + norm = f"0000:{norm}" + try: + handle = pynvml.nvmlDeviceGetHandleByPciBusId(norm.encode()) + handles.append(handle) + except pynvml.NVMLError as e: + logger.warning("Could not get NVML handle for PCI %s: %s", pci, e) + return None + + # Build the matrix + matrix = [[IntraNodeLinkType.UNKNOWN] * num_local for _ in range(num_local)] + for i in range(num_local): + for j in range(num_local): + if i == j: + matrix[i][j] = IntraNodeLinkType.SELF + continue + + # NVLink takes priority over PCIe topology level + if _nvml_check_nvlink(handles[i], handles[j], pynvml): + matrix[i][j] = IntraNodeLinkType.NVLINK + continue + + try: + level = pynvml.nvmlDeviceGetTopologyCommonAncestor(handles[i], handles[j]) + matrix[i][j] = _NVML_TOPO_TO_LINK.get(level, IntraNodeLinkType.UNKNOWN) + except pynvml.NVMLError: + matrix[i][j] = IntraNodeLinkType.UNKNOWN + + return matrix + finally: + pynvml.nvmlShutdown() + + except ImportError: + logger.debug("pynvml not available for topology query") + return None + except Exception as e: + logger.debug("NVML topology query failed: %s", e) + return None + + +def _parse_amd_topo(local_pci_bus_ids: List[str]) -> Optional[List[List[int]]]: + """ + Build GPU-GPU link type matrix using amdsmi topology queries. + + Args: + local_pci_bus_ids: PCI bus IDs for each local PyTorch device index. + + Returns: + Link type matrix indexed by local device index, or None on failure. + """ + num_local = len(local_pci_bus_ids) + + try: + import amdsmi + + amdsmi.amdsmi_init() + try: + all_handles = amdsmi.amdsmi_get_processor_handles() + + # Build BDF -> handle map for all physical GPUs + bdf_to_handle = {} + for handle in all_handles: + try: + bdf = amdsmi.amdsmi_get_gpu_device_bdf(handle) + if bdf: + bdf_to_handle[_normalize_pci_bus_id(str(bdf))] = handle + except Exception: + continue + + # Resolve local PCI bus IDs to amdsmi handles + handles = [] + for pci in local_pci_bus_ids: + norm = _normalize_pci_bus_id(pci) + handle = bdf_to_handle.get(norm) + if handle is None: + logger.warning( + "Could not find amdsmi handle for PCI %s. Known BDFs: %s", + pci, + list(bdf_to_handle.keys()), + ) + return None + handles.append(handle) + + # Pre-compute XGMI neighbor sets for each local GPU. + # amdsmi_get_link_topology_nearest returns all GPUs reachable + # via a given link type from a source GPU. + xgmi_neighbors: List[Set[str]] = [] + for handle in handles: + neighbors: Set[str] = set() + try: + result = amdsmi.amdsmi_get_link_topology_nearest( + handle, amdsmi.AmdSmiLinkType.AMDSMI_LINK_TYPE_XGMI + ) + for peer in result.get("processor_list", []): + try: + peer_bdf = amdsmi.amdsmi_get_gpu_device_bdf(peer) + if peer_bdf: + neighbors.add(_normalize_pci_bus_id(str(peer_bdf))) + except Exception: + continue + except Exception: + pass # No XGMI on this GPU (PCIe-only system) + xgmi_neighbors.append(neighbors) + + # Normalized BDFs for each local device index + local_bdfs = [_normalize_pci_bus_id(pci) for pci in local_pci_bus_ids] + + # Build the matrix + matrix = [[IntraNodeLinkType.UNKNOWN] * num_local for _ in range(num_local)] + for i in range(num_local): + for j in range(num_local): + if i == j: + matrix[i][j] = IntraNodeLinkType.SELF + continue + + # Check if GPU j's BDF is in GPU i's XGMI neighbor set + if local_bdfs[j] in xgmi_neighbors[i]: + matrix[i][j] = IntraNodeLinkType.NVLINK # XGMI maps to NVLINK enum + else: + matrix[i][j] = IntraNodeLinkType.PCIE_SWITCH + + return matrix + finally: + amdsmi.amdsmi_shut_down() + + except ImportError: + logger.debug("amdsmi not available for topology query") + return None + except Exception as e: + logger.debug("amdsmi topology query failed: %s", e) + return None + + +def _all_gather_strings(local_string: str, world_size: int) -> List[str]: + """ + All-gather a string from each rank using PyTorch distributed. + + Compatible with Iris's use of torch.distributed (NCCL/RCCL backend). + """ + local_bytes = local_string.encode("utf-8") + local_len = len(local_bytes) + + # All-gather lengths + len_tensor = torch.tensor([local_len], dtype=torch.long, device="cuda") + len_list = [torch.zeros(1, dtype=torch.long, device="cuda") for _ in range(world_size)] + dist.all_gather(len_list, len_tensor) + max_len = max(t.item() for t in len_list) + + if max_len == 0: + return [""] * world_size + + # All-gather padded byte tensors + # Use frombuffer to avoid iterating every byte through Python space. + # bytearray is required because frombuffer needs a writable buffer, + # and copy=True ensures NCCL gets an owned contiguous CUDA tensor. + padded = bytearray(local_bytes) + bytearray(max_len - local_len) + local_tensor = torch.frombuffer(padded, dtype=torch.uint8).to("cuda", copy=True) + gathered = [torch.zeros(max_len, dtype=torch.uint8, device="cuda") for _ in range(world_size)] + dist.all_gather(gathered, local_tensor) + + results = [] + for t, length_t in zip(gathered, len_list): + length = int(length_t.item()) + results.append(bytes(t[:length].cpu().tolist()).decode("utf-8")) + return results + + +class TopologyDiscovery: + """ + Multi-node multi-GPU topology discovery for Iris. + + Integrates with Iris's existing PyTorch distributed setup and produces + a TopologyMap that classifies every GPU pair into one of three tiers + based on hostname (intra-node), fabric info (intra-rack), or neither (RDMA). + + The fabric domain detection uses: + AMD: amdsmi_get_gpu_fabric_info -> (ppod_id, vpod_id) + NVIDIA: nvmlDeviceGetGpuFabricInfoV -> (clusterUuid, cliqueId) + """ + + def __init__(self, iris_ctx=None): + self._iris_ctx = iris_ctx + self._topology: Optional[TopologyMap] = None + + if iris_ctx is not None: + self.rank = iris_ctx.cur_rank + self.world_size = iris_ctx.num_ranks + self.gpu_id = iris_ctx.gpu_id + else: + num_gpus = torch.cuda.device_count() + if num_gpus <= 0: + raise RuntimeError("TopologyDiscovery requires at least one GPU") + + # Use LOCAL_RANK (set by torchrun/SLURM) for per-node GPU assignment. + # This is more robust than global_rank % num_gpus, which breaks when + # ranks aren't distributed in a way that aligns with device_count + # (e.g., 2 nodes with 8 GPUs each but only 4 ranks per node). + # The % num_gpus clamp handles isolation (LOCAL_RANK=3, device_count=1). + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + self.gpu_id = local_rank % num_gpus + # MUST set device BEFORE init_process_group — NCCL needs a CUDA + # device assigned to this process, otherwise all ranks fight over + # GPU 0 and init either fails or produces world_size=1. + torch.cuda.set_device(self.gpu_id) + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + raise RuntimeError("TopologyDiscovery requires an initialized distributed process group.") + + @property + def topology(self) -> Optional[TopologyMap]: + """The discovered topology, or None if discover() hasn't been called.""" + return self._topology + + def discover(self) -> TopologyMap: + """ + Perform topology discovery across the cluster. + + This is a collective operation — all ranks must call it. + + Steps: + 1. Each rank probes its local GPU (device name, PCI, NUMA, UUID) + 2. Each rank queries fabric info (ppod/vpod on AMD, clusterUuid/cliqueId on NVIDIA) + 3. Each rank probes node-level info (IB devices, intra-node topology) + 4. All-gather GPU info + node info across all ranks + 5. Build the global TopologyMap with fabric domain grouping + """ + vendor = _detect_vendor() + hostname = socket.gethostname() + + logger.debug( + f"[Rank {self.rank}] Starting topology discovery on {hostname}, GPU {self.gpu_id}, vendor={vendor}" + ) + + # Probe local GPU info + device_name = torch.cuda.get_device_name(self.gpu_id) + total_memory_mb = _get_total_memory_mb(self.gpu_id) + pci_bus_id = _get_pci_bus_id(self.gpu_id, vendor) + # Pass PCI bus ID to UUID query for PCI-based handle resolution + gpu_uuid = _get_gpu_uuid(self.gpu_id, vendor, pci_bus_id=pci_bus_id) + # Pass PCI bus ID + numa_node = _get_numa_node(pci_bus_id) + + # Query fabric info — pass PCI bus ID for PCI-based handle resolution + fabric_info = get_gpu_fabric_info(self.gpu_id, vendor, pci_bus_id=pci_bus_id) + logger.debug( + f"[Rank {self.rank}] Fabric info: cluster_uuid={fabric_info.cluster_uuid}, " + f"clique_id={fabric_info.clique_id}, domain_key={fabric_info.domain_key}" + ) + + local_gpu_info = GPUInfo( + global_rank=self.rank, + local_rank=self.gpu_id, + hostname=hostname, + gpu_id=self.gpu_id, + pci_bus_id=pci_bus_id, + device_name=device_name, + total_memory_mb=total_memory_mb, + numa_node=numa_node, + vendor=vendor, + uuid=gpu_uuid, + fabric_info=fabric_info, + ) + + # Probe node-level info + # Gather PCI bus IDs for ALL visible GPUs on this node (not just this rank's). + # This is needed for correct CLI tool output remapping. + num_local_gpus = torch.cuda.device_count() + if num_local_gpus <= 1 and self.world_size > 1: + logger.warning( + f"[Rank {self.rank}] torch.cuda.device_count() = {num_local_gpus}. " + f"CUDA_VISIBLE_DEVICES may be restricting GPU visibility. " + f"Intra-node topology detection will be limited." + ) + + local_pci_bus_ids = [] + for dev_idx in range(num_local_gpus): + local_pci_bus_ids.append(_get_pci_bus_id(dev_idx, vendor)) + + has_ib, ib_devices = _detect_infiniband() + link_types, p2p_access = _detect_intra_node_topology(local_pci_bus_ids, vendor) + + # All-gather + local_gpu_json = json.dumps(local_gpu_info.to_dict()) + all_gpu_jsons = _all_gather_strings(local_gpu_json, self.world_size) + + node_info_json = json.dumps( + { + "link_types": link_types, + "p2p_access": p2p_access, + "has_ib": has_ib, + "ib_devices": ib_devices, + "num_visible_gpus": num_local_gpus, + } + ) + all_node_jsons = _all_gather_strings(node_info_json, self.world_size) + + # Build global topology + gpu_info_map: Dict[int, GPUInfo] = {} + for gpu_json in all_gpu_jsons: + info = GPUInfo.from_dict(json.loads(gpu_json)) + gpu_info_map[info.global_rank] = info + + all_node_infos = [json.loads(s) for s in all_node_jsons] + + # Group ranks by hostname + hostname_to_ranks: Dict[str, List[int]] = {} + for rank, info in gpu_info_map.items(): + hostname_to_ranks.setdefault(info.hostname, []).append(rank) + + # local_rank ordering: assign sequential index within each node + for hostname, ranks in hostname_to_ranks.items(): + for local_idx, rank in enumerate(sorted(ranks)): + gpu_info_map[rank].local_rank = local_idx + + # Build NodeInfo + nodes: Dict[str, NodeInfo] = {} + for hostname, ranks in hostname_to_ranks.items(): + sorted_ranks = sorted(ranks) + representative = sorted_ranks[0] + nd = all_node_infos[representative] + + # Collect per-rank gpu_ids and deduplicate physical GPUs by PCI bus ID. + gpu_ids_per_rank = [gpu_info_map[r].gpu_id for r in sorted_ranks] + pci_ids_per_rank = [gpu_info_map[r].pci_bus_id for r in sorted_ranks] + unique_pci_ids = sorted(set(p for p in pci_ids_per_rank if p != "unknown")) + unique_gpu_ids = sorted(set(gpu_ids_per_rank)) + num_physical_gpus = len(unique_pci_ids) if unique_pci_ids else len(unique_gpu_ids) + + # Collect all unique fabric domain keys and warn if mixed + domain_keys: Set[str] = set() + for r in sorted_ranks: + dk = gpu_info_map[r].fabric_info.domain_key + if dk: + domain_keys.add(dk) + + if len(domain_keys) > 1: + logger.warning( + f"Node '{hostname}' has GPUs in multiple fabric domains: " + f"{domain_keys}. This may indicate a misconfiguration." + ) + + sorted_domain_keys = sorted(domain_keys) + node_domain_key = sorted_domain_keys[0] if sorted_domain_keys else "" + + # Log oversubscription + if len(sorted_ranks) > num_physical_gpus: + logger.info( + f"Node '{hostname}': {len(sorted_ranks)} ranks oversubscribed " + f"on {num_physical_gpus} physical GPUs " + f"(pci_ids={unique_pci_ids or unique_gpu_ids})" + ) + + nodes[hostname] = NodeInfo( + hostname=hostname, + ranks=sorted_ranks, + gpu_ids=gpu_ids_per_rank, + unique_gpu_ids=unique_gpu_ids, + unique_pci_ids=unique_pci_ids, + num_gpus=num_physical_gpus, + num_ranks=len(sorted_ranks), + has_infiniband=nd["has_ib"], + ib_devices=nd["ib_devices"], + link_types=nd["link_types"], + p2p_access=nd["p2p_access"], + fabric_domain_key=node_domain_key, + fabric_domain_keys=sorted_domain_keys, + ) + + # Build fabric domain map by registering each node under ALL its domains + fabric_domains: Dict[str, List[str]] = {} + for hostname, node in nodes.items(): + for dk in node.fabric_domain_keys: + fabric_domains.setdefault(dk, []).append(hostname) + + self._topology = TopologyMap( + world_size=self.world_size, + num_nodes=len(nodes), + gpu_info=gpu_info_map, + nodes=nodes, + fabric_domains=fabric_domains, + ) + + logger.debug( + f"[Rank {self.rank}] Topology discovery complete: {len(nodes)} nodes, {len(fabric_domains)} fabric domains" + ) + return self._topology + + def get_communication_groups(self) -> Dict[InterconnectLevel, List[List[int]]]: + """ + Build communication groups for each interconnect level. + """ + if self._topology is None: + raise RuntimeError("Must call discover() before get_communication_groups()") + + topo = self._topology + + # Intra-node: one group per physical node (sorted by hostname for determinism) + intra_node_groups = [ + sorted(node.ranks) for hostname, node in sorted(topo.nodes.items(), key=lambda item: item[0]) + ] + + # Fabric-level: group by fabric domain, plus standalone nodes + if topo.fabric_domains: + # Sort fabric domains for deterministic ordering of fabric groups + fabric_domain_keys = sorted(topo.fabric_domains) + fabric_groups = [sorted(topo.get_ranks_for_fabric_domain(dk)) for dk in fabric_domain_keys] + + # Include standalone nodes (no fabric domain) as their own groups + # so they aren't orphaned from the fabric tier. + nodes_in_fabric = set() + for hostnames in topo.fabric_domains.values(): + nodes_in_fabric.update(hostnames) + for hostname, node in sorted(topo.nodes.items(), key=lambda item: item[0]): + if hostname not in nodes_in_fabric: + fabric_groups.append(sorted(node.ranks)) + else: + # No fabric at all + fabric_groups = [] + + # RDMA: everyone + rdma_groups = [sorted(topo.gpu_info.keys())] + + return { + InterconnectLevel.INTRA_NODE: intra_node_groups, + InterconnectLevel.INTRA_RACK_FABRIC: fabric_groups, + InterconnectLevel.INTER_NODE_RDMA: rdma_groups, + } + + def get_heap_distribution_plan(self) -> Dict[int, Dict[str, Any]]: + """ + Generate a plan for distributing symmetric heap bases across the cluster. + + For each rank, classifies every peer into one of three tiers: + - ipc_peers: Same node -> use cudaIpcMemHandle / hipIpcMemHandle + - fabric_peers: Same fabric domain, different node -> use fabric + memory handles (cuMemExportToShareableHandle on NVIDIA, + or equivalent on AMD) + - rdma_peers: Different fabric domain -> use RDMA + """ + if self._topology is None: + raise RuntimeError("Must call discover() before get_heap_distribution_plan()") + + topo = self._topology + plan: Dict[int, Dict[str, Any]] = {} + + for rank, info in topo.gpu_info.items(): + ipc_peers = sorted(topo.get_node_peers(rank)) + fabric_peers = sorted(topo.get_fabric_domain_peers(rank) - topo.get_node_peers(rank)) + rdma_peers = sorted(topo.get_rdma_peers(rank)) + + plan[rank] = { + "ipc_peers": ipc_peers, + "fabric_peers": fabric_peers, + "rdma_peers": rdma_peers, + "node": info.hostname, + "local_rank": info.local_rank, + "gpu_id": info.gpu_id, + "fabric_domain": info.fabric_info.domain_key, + } + + return plan diff --git a/iris/tracing/core.py b/iris/tracing/core.py index 9b3ca9a62..317fc0bbf 100644 --- a/iris/tracing/core.py +++ b/iris/tracing/core.py @@ -71,6 +71,9 @@ def enable(self, max_events=1_000_000): self.iris.info(f"Device tracing enabled with max {max_events} events") + # Rebuild the cached device context to include tracing fields + self.iris._build_device_context() + def reset(self): """ Reset trace counter to start a new trace capture. diff --git a/pyproject.toml b/pyproject.toml index 18e71badb..4a8f1916c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,13 @@ Repository = "https://github.com/ROCm/iris" Documentation = "https://rocm.github.io/iris/" [project.optional-dependencies] +nvidia = [ + "nvidia-ml-py", # Install via `pip install iris[nvidia]` on NVIDIA system. +] +amd = [ + # amdsmi is provided by the ROCm system package + # Listed here for completeness +] dev = [ "pytest", "black", diff --git a/scripts/roccap_wrapper.py b/scripts/roccap_wrapper.py index a3f76fb5c..f82540675 100644 --- a/scripts/roccap_wrapper.py +++ b/scripts/roccap_wrapper.py @@ -41,6 +41,8 @@ # Set simulation env so Iris uses torch allocator os.environ["IRIS_SIMULATION"] = "1" +# Pass kernel name so Iris can name heap_bases output to match -k (e.g. persistent_all_gather_heap_bases.json) +os.environ["IRIS_HEAP_BASES_PREFIX"] = parsed.kernel # Disable PyTorch caching allocator for simple allocations in simulation os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1" diff --git a/tests/unittests/test_barriers.py b/tests/unittests/test_barriers.py new file mode 100644 index 000000000..79f6e8351 --- /dev/null +++ b/tests/unittests/test_barriers.py @@ -0,0 +1,339 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import gc +from typing import Literal + +import pytest +import torch +import triton +import triton.language as tl +import iris + + +BarrierType = Literal["host", "device"] +BARRIER_TYPES: list[BarrierType] = ["host", "device"] + + +def _call_barrier(shmem: iris.Iris, barrier_type: BarrierType) -> None: + if barrier_type == "host": + shmem.barrier() + else: + shmem.device_barrier() + + +@triton.jit +def _read_remote_kernel( + buf_ptr, + result_ptr, + cur_rank: tl.constexpr, + remote_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, +): + offsets = tl.arange(0, BLOCK_SIZE) + data = iris.load(buf_ptr + offsets, cur_rank, remote_rank, heap_bases) + tl.store(result_ptr + offsets, data) + + +@triton.jit +def _write_remote_kernel( + buf_ptr, + value, + cur_rank: tl.constexpr, + remote_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, +): + offsets = tl.arange(0, BLOCK_SIZE) + data = tl.full([BLOCK_SIZE], value, dtype=tl.float32) + iris.store(buf_ptr + offsets, data, cur_rank, remote_rank, heap_bases) + + +@pytest.mark.parametrize("n", [1, 10]) +@pytest.mark.parametrize("barrier_type", BARRIER_TYPES) +def test_barrier_basic(barrier_type, n): + shmem = iris.iris(1 << 20) + _call_barrier(shmem, barrier_type) + + try: + for _ in range(n): + _call_barrier(shmem, barrier_type) + finally: + _call_barrier(shmem, barrier_type) + del shmem + gc.collect() + + +@pytest.mark.parametrize("n", [1, 2, 5, 10]) +@pytest.mark.parametrize("barrier_type", BARRIER_TYPES) +def test_barrier_state_reuse(barrier_type, n): + """Verify device barrier reuses the same flags tensor across calls.""" + shmem = iris.iris(1 << 20) + _call_barrier(shmem, barrier_type) + + try: + shmem.device_barrier() + assert None in shmem._device_barrier_state + flags = shmem._device_barrier_state[None] + flags_ptr = flags.data_ptr() + + for _ in range(n): + shmem.device_barrier() + assert shmem._device_barrier_state[None].data_ptr() == flags_ptr + finally: + _call_barrier(shmem, barrier_type) + del shmem + gc.collect() + + +def _cross_rank_eager( + shmem, + barrier_type, + op, + num_barriers, + rounds, + N, + rank, + neighbor, + writer, + heap_bases, + buf, + result, +): + if op == "load": + for i in range(rounds): + buf.fill_(float(rank + i * 100)) + + for _ in range(num_barriers): + _call_barrier(shmem, barrier_type) + + _read_remote_kernel[(1,)]( + buf, + result, + rank, + neighbor, + N, + heap_bases, + ) + + for _ in range(num_barriers): + _call_barrier(shmem, barrier_type) + + expected_val = float(neighbor + i * 100) + expected = torch.full((N,), expected_val, dtype=torch.float32, device="cuda") + torch.testing.assert_close(result, expected, rtol=0, atol=0) + else: + for i in range(rounds): + buf.fill_(0.0) + + for _ in range(num_barriers): + _call_barrier(shmem, barrier_type) + + write_val = float(rank + i * 100) + _write_remote_kernel[(1,)]( + buf, + write_val, + rank, + neighbor, + N, + heap_bases, + ) + + for _ in range(num_barriers): + _call_barrier(shmem, barrier_type) + + expected_val = float(writer + i * 100) + expected = torch.full((N,), expected_val, dtype=torch.float32, device="cuda") + torch.testing.assert_close(buf, expected, rtol=0, atol=0) + + +def _cross_rank_graph( + shmem, + op, + num_barriers, + rounds, + N, + rank, + neighbor, + writer, + heap_bases, + buf, + result, +): + capture_stream = torch.cuda.Stream() + + if op == "load": + buf.fill_(float(rank)) + + # Warmup on capture stream. + with torch.cuda.stream(capture_stream): + for _ in range(num_barriers): + shmem.device_barrier() + _read_remote_kernel[(1,)]( + buf, + result, + rank, + neighbor, + N, + heap_bases, + ) + for _ in range(num_barriers): + shmem.device_barrier() + capture_stream.synchronize() + + # Capture. + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=capture_stream): + for _ in range(num_barriers): + shmem.device_barrier() + _read_remote_kernel[(1,)]( + buf, + result, + rank, + neighbor, + N, + heap_bases, + ) + for _ in range(num_barriers): + shmem.device_barrier() + + # Replay with fresh data. + for i in range(rounds): + val = float(rank + (i + 1) * 10) + with torch.cuda.stream(capture_stream): + buf.fill_(val) + shmem.device_barrier() + graph.replay() + capture_stream.synchronize() + + expected = torch.full( + (N,), + float(neighbor + (i + 1) * 10), + dtype=torch.float32, + device="cuda", + ) + torch.testing.assert_close(result, expected, rtol=0, atol=0) + else: + buf.fill_(0.0) + + # Warmup on capture stream. + with torch.cuda.stream(capture_stream): + for _ in range(num_barriers): + shmem.device_barrier() + _write_remote_kernel[(1,)]( + buf, + float(rank), + rank, + neighbor, + N, + heap_bases, + ) + for _ in range(num_barriers): + shmem.device_barrier() + capture_stream.synchronize() + + # Capture. + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=capture_stream): + for _ in range(num_barriers): + shmem.device_barrier() + _write_remote_kernel[(1,)]( + buf, + float(rank), + rank, + neighbor, + N, + heap_bases, + ) + for _ in range(num_barriers): + shmem.device_barrier() + + # Replay and verify. + for _ in range(rounds): + with torch.cuda.stream(capture_stream): + buf.fill_(0.0) + shmem.device_barrier() + graph.replay() + capture_stream.synchronize() + + with torch.cuda.stream(capture_stream): + shmem.device_barrier() + capture_stream.synchronize() + expected = torch.full((N,), float(writer), dtype=torch.float32, device="cuda") + torch.testing.assert_close(buf, expected, rtol=0, atol=0) + + +# Host barrier is not graph-capturable (uses NCCL which crashes with +# hipErrorStreamCaptureUnsupported on ROCm). Skip host+graph combos. +@pytest.mark.parametrize("N", [1, 64, 256, 1024]) +@pytest.mark.parametrize("num_barriers", [1, 2, 4]) +@pytest.mark.parametrize("mode", ["eager", "graph"]) +@pytest.mark.parametrize("op", ["load", "store", "both"]) +@pytest.mark.parametrize("barrier_type", BARRIER_TYPES) +def test_barrier_cross_rank(barrier_type, op, mode, num_barriers, N, rounds=3): + """Verify cross-rank data visibility after barrier. + + - op: load (iris.load from neighbor), store (iris.store to neighbor), or both + - mode: eager (direct calls) or graph (CUDA graph capture + replay) + - num_barriers: consecutive barriers to test idempotency + - N: number of elements (must be power of 2 for Triton BLOCK_SIZE) + - rounds: number of iterations with changing data (default 3) + + Each mode runs multiple rounds with changing data to stress correctness. + Graph mode captures barrier + kernel into a CUDA graph, then replays + with fresh data to verify correctness through the captured graph. + """ + if mode == "graph" and barrier_type == "host": + pytest.skip( + "Host barrier uses NCCL which is not graph-capturable on ROCm. See https://github.com/ROCm/HIP/issues/3876" + ) + + shmem = iris.iris(1 << 20) + _call_barrier(shmem, barrier_type) + rank = shmem.get_rank() + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + neighbor = (rank + 1) % num_ranks + writer = (rank - 1 + num_ranks) % num_ranks + + buf = shmem.zeros((N,), dtype=torch.float32) + result = shmem.zeros((N,), dtype=torch.float32) + + ops = ["load", "store"] if op == "both" else [op] + + try: + for single_op in ops: + if mode == "eager": + _cross_rank_eager( + shmem, + barrier_type, + single_op, + num_barriers, + rounds, + N, + rank, + neighbor, + writer, + heap_bases, + buf, + result, + ) + else: + _cross_rank_graph( + shmem, + single_op, + num_barriers, + rounds, + N, + rank, + neighbor, + writer, + heap_bases, + buf, + result, + ) + finally: + _call_barrier(shmem, barrier_type) + del shmem + gc.collect() diff --git a/tests/unittests/test_copy_cache_modifiers.py b/tests/unittests/test_copy_cache_modifiers.py index b7c278ea2..892b4e0d9 100644 --- a/tests/unittests/test_copy_cache_modifiers.py +++ b/tests/unittests/test_copy_cache_modifiers.py @@ -26,46 +26,22 @@ def copy_kernel_local_read_remote_write( offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < BLOCK_SIZE - # Copy from current rank to other ranks + # Copy from current rank to other ranks. + # Both load and store cache modifiers are passed unconditionally. for target_rank in range(num_ranks): src_data = data + BLOCK_SIZE * cur_rank dest_data = results + BLOCK_SIZE * cur_rank - if load_cache_modifier is None and store_cache_modifier is None: - iris.copy(src_data + offsets, dest_data + offsets, cur_rank, target_rank, cur_rank, heap_bases, mask=mask) - elif load_cache_modifier is None: - iris.copy( - src_data + offsets, - dest_data + offsets, - cur_rank, - target_rank, - cur_rank, - heap_bases, - mask=mask, - store_cache_modifier=store_cache_modifier, - ) - elif store_cache_modifier is None: - iris.copy( - src_data + offsets, - dest_data + offsets, - cur_rank, - target_rank, - cur_rank, - heap_bases, - mask=mask, - load_cache_modifier=load_cache_modifier, - ) - else: - iris.copy( - src_data + offsets, - dest_data + offsets, - cur_rank, - target_rank, - cur_rank, - heap_bases, - mask=mask, - load_cache_modifier=load_cache_modifier, - store_cache_modifier=store_cache_modifier, - ) + iris.copy( + src_data + offsets, + dest_data + offsets, + cur_rank, + target_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) @triton.jit @@ -85,70 +61,43 @@ def copy_kernel_remote_read_local_write( offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < BLOCK_SIZE - # Copy from other ranks to current rank + # Copy from other ranks to current rank. + # Both load and store cache modifiers are passed unconditionally. for source_rank in range(num_ranks): src_data = data + BLOCK_SIZE * source_rank dest_data = results + BLOCK_SIZE * source_rank - if load_cache_modifier is None and store_cache_modifier is None: - iris.copy(src_data + offsets, dest_data + offsets, source_rank, cur_rank, cur_rank, heap_bases, mask=mask) - elif load_cache_modifier is None: - iris.copy( - src_data + offsets, - dest_data + offsets, - source_rank, - cur_rank, - cur_rank, - heap_bases, - mask=mask, - store_cache_modifier=store_cache_modifier, - ) - elif store_cache_modifier is None: - iris.copy( - src_data + offsets, - dest_data + offsets, - source_rank, - cur_rank, - cur_rank, - heap_bases, - mask=mask, - load_cache_modifier=load_cache_modifier, - ) - else: - iris.copy( - src_data + offsets, - dest_data + offsets, - source_rank, - cur_rank, - cur_rank, - heap_bases, - mask=mask, - load_cache_modifier=load_cache_modifier, - store_cache_modifier=store_cache_modifier, - ) - - -# Define cache modifiers for load and store operations + iris.copy( + src_data + offsets, + dest_data + offsets, + source_rank, + cur_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +# Define cache modifiers for load and store operations. +# Both load and store modifiers are passed unconditionally to tl.load()/tl.store(). +# It is the caller's responsibility to use appropriate modifiers for local vs. remote ops. LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] -# Remote stores (cross-GPU IPC) cannot use cache modifier bits -# Only default (None or empty string) works - cache bits break coherency -STORE_CACHE_MODIFIERS_REMOTE_WRITE = [None, ""] -# For testing remote reads (which work with all load modifiers), -# we can use all store modifiers since the store is local -LOAD_CACHE_MODIFIERS_REMOTE_READ = [None, "", ".ca", ".cg", ".cv"] -STORE_CACHE_MODIFIERS_LOCAL_WRITE = [None, "", ".wb", ".cg", ".cs", ".wt"] +STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] @pytest.mark.parametrize( - "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS_REMOTE_WRITE)) + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) ) def test_copy_local_read_remote_write(load_cache_modifier, store_cache_modifier): """Test copy: local read → remote write Direction: from_rank=cur_rank (local), to_rank=other (remote) - - Load: from LOCAL memory (all cache modifiers should work) - - Store: to REMOTE memory (only None/"" work, cache bits break coherency) + - Load: from LOCAL memory + - Store: to REMOTE memory - This tests that load cache modifiers work for local reads. + store_cache_modifier is passed unconditionally to the remote tl.store(). It is the + caller's responsibility to use modifiers appropriately. """ shmem = iris.iris(1 << 20) num_ranks = shmem.get_num_ranks() @@ -162,6 +111,11 @@ def test_copy_local_read_remote_write(load_cache_modifier, store_cache_modifier) data[i, :] = base * (i + 1) results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + + # Barrier to ensure all ranks have initialized their data before any rank launches + # the kernel (which reads remote data in the remote-read case). + shmem.barrier() + grid = lambda meta: (1,) copy_kernel_local_read_remote_write[grid]( data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier @@ -181,16 +135,17 @@ def test_copy_local_read_remote_write(load_cache_modifier, store_cache_modifier) @pytest.mark.parametrize( "load_cache_modifier,store_cache_modifier", - list(product(LOAD_CACHE_MODIFIERS_REMOTE_READ, STORE_CACHE_MODIFIERS_LOCAL_WRITE)), + list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)), ) def test_copy_remote_read_local_write(load_cache_modifier, store_cache_modifier): """Test copy: remote read → local write Direction: from_rank=other (remote), to_rank=cur_rank (local) - - Load: from REMOTE memory (test if cache modifiers work for remote reads) - - Store: to LOCAL memory (all cache modifiers should work) + - Load: from REMOTE memory + - Store: to LOCAL memory - This tests whether load cache modifiers work for remote reads. + Both cache modifiers are passed unconditionally. It is the caller's responsibility + to use appropriate modifiers for local vs. remote operations. """ shmem = iris.iris(1 << 20) num_ranks = shmem.get_num_ranks() @@ -204,6 +159,11 @@ def test_copy_remote_read_local_write(load_cache_modifier, store_cache_modifier) data[i, :] = base * (i + 1) results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + + # Barrier to ensure all ranks have initialized their data before any rank launches + # the kernel (which reads remote data in the remote-read case). + shmem.barrier() + grid = lambda meta: (1,) copy_kernel_remote_read_local_write[grid]( data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier diff --git a/tests/unittests/test_get_cache_modifiers.py b/tests/unittests/test_get_cache_modifiers.py index 58cb9d485..4b35a32d0 100644 --- a/tests/unittests/test_get_cache_modifiers.py +++ b/tests/unittests/test_get_cache_modifiers.py @@ -27,42 +27,19 @@ def get_kernel( acc = tl.zeros([BLOCK_SIZE], dtype=data.type.element_ty) - # Loop over all ranks, get the stored data with cache modifiers - # We test default values set by the function when parameters are None + # Loop over all ranks and get data with cache modifiers applied unconditionally. + # The load is remote when from_rank != cur_rank; the store to results is always local. for target_rank in range(num_ranks): - if load_cache_modifier is None and store_cache_modifier is None: - iris.get(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask) - elif load_cache_modifier is None: - iris.get( - data + offsets, - results + offsets, - cur_rank, - target_rank, - heap_bases, - mask=mask, - store_cache_modifier=store_cache_modifier, - ) - elif store_cache_modifier is None: - iris.get( - data + offsets, - results + offsets, - cur_rank, - target_rank, - heap_bases, - mask=mask, - load_cache_modifier=load_cache_modifier, - ) - else: - iris.get( - data + offsets, - results + offsets, - cur_rank, - target_rank, - heap_bases, - mask=mask, - load_cache_modifier=load_cache_modifier, - store_cache_modifier=store_cache_modifier, - ) + iris.get( + data + offsets, + results + offsets, + cur_rank, + target_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) acc += tl.load(results + offsets, mask=mask) # Store the accumulated value back to the output @@ -78,7 +55,12 @@ def get_kernel( "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) ) def test_get_cache_modifiers(load_cache_modifier, store_cache_modifier): - """Test get (copy from other rank) with various cache modifiers.""" + """Test get (copy from other rank) with various cache modifiers. + + load_cache_modifier is passed unconditionally; it applies to the remote load when + from_rank != to_rank. store_cache_modifier applies to the always-local store to to_ptr. + It is the caller's responsibility to use modifiers appropriately. + """ shmem = iris.iris(1 << 20) num_ranks = shmem.get_num_ranks() heap_bases = shmem.get_heap_bases() diff --git a/tests/unittests/test_get_other_triton.py b/tests/unittests/test_get_other_triton.py index 501710dc6..412d9710a 100644 --- a/tests/unittests/test_get_other_triton.py +++ b/tests/unittests/test_get_other_triton.py @@ -74,10 +74,11 @@ def test_get_other_api(dtype, BLOCK_SIZE): # Verify the results # First half should contain loaded values accumulated from all ranks (num_ranks * 1.0) - # Second half should contain accumulated "other" values (num_ranks * -1.0) + # Second half stays at 0.0 because iris.get stores with mask, so masked-out positions + # in `results` are never written; tl.load(results + offsets) reads 0.0 from them. expected = torch.zeros(BLOCK_SIZE, dtype=dtype, device="cuda") expected[: BLOCK_SIZE // 2] = num_ranks * 1.0 - expected[BLOCK_SIZE // 2 :] = num_ranks * other_value + expected[BLOCK_SIZE // 2 :] = 0.0 try: torch.testing.assert_close(results, expected, rtol=0, atol=0) diff --git a/tests/unittests/test_load_cache_modifiers.py b/tests/unittests/test_load_cache_modifiers.py index 5c1473002..cf9c0fc74 100644 --- a/tests/unittests/test_load_cache_modifiers.py +++ b/tests/unittests/test_load_cache_modifiers.py @@ -10,7 +10,7 @@ @triton.jit -def kernel( +def load_kernel( data, results, source_rank: tl.constexpr, @@ -23,25 +23,21 @@ def kernel( pid = tl.program_id(0) partner = int((source_rank + num_ranks // 2) % num_ranks) - # Compute start index of this block block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) - - # Guard for out-of-bounds accesses mask = offsets < BLOCK_SIZE - if cache_modifier is None: - result = iris.load(data + offsets, source_rank, partner, heap_bases, mask=mask, volatile=volatile) - else: - result = iris.load( - data + offsets, - source_rank, - partner, - heap_bases, - mask=mask, - cache_modifier=cache_modifier, - volatile=volatile, - ) + # cache_modifier is passed unconditionally; it is the caller's responsibility + # to use an appropriate modifier for local vs. remote loads. + result = iris.load( + data + offsets, + source_rank, + partner, + heap_bases, + mask=mask, + cache_modifier=cache_modifier, + volatile=volatile, + ) tl.store(results + offsets, result, mask=mask) @@ -53,7 +49,11 @@ def kernel( @pytest.mark.parametrize("cache_modifier,volatile", list(product(CACHE_MODIFIERS, VOLATILE_OPTIONS))) def test_load_cache_modifiers(cache_modifier, volatile): - """Test load with various cache modifiers and volatile settings.""" + """Test load with various cache modifiers and volatile settings. + + cache_modifier is passed unconditionally to tl.load(). It is the caller's + responsibility to use modifiers appropriately for local vs. remote loads. + """ shmem = iris.iris(1 << 20) num_ranks = shmem.get_num_ranks() heap_bases = shmem.get_heap_bases() @@ -67,10 +67,10 @@ def test_load_cache_modifiers(cache_modifier, volatile): shmem.barrier() grid = lambda meta: (1,) - kernel[grid](data, results, source_rank, num_ranks, BLOCK_SIZE, heap_bases, cache_modifier, volatile) + load_kernel[grid](data, results, source_rank, num_ranks, BLOCK_SIZE, heap_bases, cache_modifier, volatile) shmem.barrier() - # Verify the result + # Verify the result - should have loaded from partner rank expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * partner try: diff --git a/tests/unittests/test_load_other_triton.py b/tests/unittests/test_load_other_triton.py index f9fe0e00c..e7db690b2 100644 --- a/tests/unittests/test_load_other_triton.py +++ b/tests/unittests/test_load_other_triton.py @@ -57,8 +57,9 @@ def test_load_other_api(dtype, BLOCK_SIZE): source_rank = shmem.get_rank() partner = int((source_rank + num_ranks // 2) % num_ranks) - # Fill data with partner rank value - data = shmem.full((BLOCK_SIZE,), partner, dtype=dtype) + # Fill data with source rank value so remote reads match expected values: + # each rank's data[i] = source_rank, so loading from partner gives partner's rank value + data = shmem.full((BLOCK_SIZE,), source_rank, dtype=dtype) results = shmem.zeros_like(data) # Use -1 as the "other" value for masked-out elements diff --git a/tests/unittests/test_put_cache_modifiers.py b/tests/unittests/test_put_cache_modifiers.py index 25d169047..3c4083602 100644 --- a/tests/unittests/test_put_cache_modifiers.py +++ b/tests/unittests/test_put_cache_modifiers.py @@ -13,7 +13,8 @@ def put_kernel( data, results, - cur_rank: tl.constexpr, + from_rank: tl.constexpr, + to_rank: tl.constexpr, BLOCK_SIZE: tl.constexpr, heap_bases: tl.tensor, load_cache_modifier: tl.constexpr, @@ -23,61 +24,30 @@ def put_kernel( block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < BLOCK_SIZE - - # Put data locally (same rank) with cache modifiers. - # store_cache_modifier only applies to local stores (from_rank == to_rank). - # Remote stores do not support cache modifiers. - if load_cache_modifier is None and store_cache_modifier is None: - iris.put(data + offsets, results + offsets, cur_rank, cur_rank, heap_bases, mask=mask) - elif load_cache_modifier is None: - iris.put( - data + offsets, - results + offsets, - cur_rank, - cur_rank, - heap_bases, - mask=mask, - store_cache_modifier=store_cache_modifier, - ) - elif store_cache_modifier is None: - iris.put( - data + offsets, - results + offsets, - cur_rank, - cur_rank, - heap_bases, - mask=mask, - load_cache_modifier=load_cache_modifier, - ) - else: - iris.put( - data + offsets, - results + offsets, - cur_rank, - cur_rank, - heap_bases, - mask=mask, - load_cache_modifier=load_cache_modifier, - store_cache_modifier=store_cache_modifier, - ) + iris.put( + data + offsets, + results + offsets, + from_rank, + to_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) # Define cache modifiers for load and store operations LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] -# store_cache_modifier is only effective for local stores (from_rank == to_rank) +# store_cache_modifier is passed unconditionally; it is the caller's responsibility +# to choose appropriate modifiers for local vs. remote stores. STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] @pytest.mark.parametrize( "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) ) -def test_put_cache_modifiers(load_cache_modifier, store_cache_modifier): - """Test put (local copy) with various cache modifiers. - - store_cache_modifier is only effective for local stores (from_rank == to_rank). - Remote stores do not support cache modifiers. - This test exercises only local puts to verify cache modifier behavior. - """ +def test_put_cache_modifiers_local(load_cache_modifier, store_cache_modifier): + """Test local put (from_rank == to_rank) with various cache modifiers.""" shmem = iris.iris(1 << 20) heap_bases = shmem.get_heap_bases() cur_rank = shmem.get_rank() @@ -89,19 +59,63 @@ def test_put_cache_modifiers(load_cache_modifier, store_cache_modifier): shmem.barrier() grid = lambda meta: (1,) - put_kernel[grid](data, results, cur_rank, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier) + put_kernel[grid]( + data, results, cur_rank, cur_rank, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) shmem.barrier() - # Verify the result - should have the data that was put (local copy) expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") - try: torch.testing.assert_close(results, expected, rtol=0, atol=0) except AssertionError as e: print( - f"PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + f"LOCAL PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" ) print(e) - print("Expected:", expected) - print("Actual:", results) raise + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_put_cache_modifiers_remote(load_cache_modifier, store_cache_modifier): + """Test remote put (from_rank != to_rank) with various cache modifiers. + + store_cache_modifier is passed unconditionally to the remote tl.store(). It is the + caller's responsibility to use modifiers appropriately for remote operations. + """ + shmem = iris.iris(1 << 20) + heap_bases = shmem.get_heap_bases() + num_ranks = shmem.get_num_ranks() + cur_rank = shmem.get_rank() + + if num_ranks < 2: + pytest.skip("Remote put test requires at least 2 ranks") + + BLOCK_SIZE = 16 + data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros(BLOCK_SIZE, dtype=torch.float32) + + shmem.barrier() + + # rank 0 puts to rank 1 + remote_rank = (cur_rank + 1) % num_ranks + grid = lambda meta: (1,) + if cur_rank == 0: + put_kernel[grid]( + data, results, cur_rank, remote_rank, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + + shmem.barrier() + + # rank 1 checks the data it received from rank 0 + if cur_rank == 1: + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"REMOTE PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + raise diff --git a/tests/unittests/test_put_other_triton.py b/tests/unittests/test_put_other_triton.py index db78bfc3d..51db50f85 100644 --- a/tests/unittests/test_put_other_triton.py +++ b/tests/unittests/test_put_other_triton.py @@ -68,11 +68,12 @@ def test_put_other_api(dtype, BLOCK_SIZE): shmem.barrier() # Verify the results - # First half should contain the value 1.0 (from data) - # Second half should contain the "other" value (-1.0) since mask was False + # First half should contain the value 1.0 (from data, written via masked put) + # Second half stays at 0.0 because iris.put stores with mask, so masked-out positions + # in results are never written. expected = torch.zeros(BLOCK_SIZE, dtype=dtype, device="cuda") expected[: BLOCK_SIZE // 2] = 1.0 - expected[BLOCK_SIZE // 2 :] = other_value + expected[BLOCK_SIZE // 2 :] = 0.0 try: torch.testing.assert_close(results, expected, rtol=0, atol=0) diff --git a/tests/unittests/test_store_cache_modifiers.py b/tests/unittests/test_store_cache_modifiers.py index 7b459012e..4c46feac0 100644 --- a/tests/unittests/test_store_cache_modifiers.py +++ b/tests/unittests/test_store_cache_modifiers.py @@ -9,38 +9,40 @@ @triton.jit -def kernel( +def local_store_kernel( data, results, - destination_rank: tl.constexpr, + cur_rank: tl.constexpr, BLOCK_SIZE: tl.constexpr, heap_bases: tl.tensor, cache_modifier: tl.constexpr, ): pid = tl.program_id(0) - block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < BLOCK_SIZE - - # Load the data from src for this block value = tl.load(data + offsets, mask=mask) + # Local store: from_rank == to_rank == cur_rank + iris.store(results + offsets, value, cur_rank, cur_rank, heap_bases, mask=mask, cache_modifier=cache_modifier) - # Store data locally (same rank) with the specified cache modifier. - # Cache modifiers only apply to local stores; remote stores do not support them. - if cache_modifier is None: - iris.store(results + offsets, value, destination_rank, destination_rank, heap_bases, mask=mask) - else: - iris.store( - results + offsets, - value, - destination_rank, - destination_rank, - heap_bases, - mask=mask, - cache_modifier=cache_modifier, - ) + +@triton.jit +def remote_store_kernel( + data, + results, + from_rank: tl.constexpr, + to_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + cache_modifier: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + value = tl.load(data + offsets, mask=mask) + # Remote store: from_rank != to_rank + iris.store(results + offsets, value, from_rank, to_rank, heap_bases, mask=mask, cache_modifier=cache_modifier) # Define cache modifiers for store operations @@ -48,15 +50,11 @@ def kernel( @pytest.mark.parametrize("cache_modifier", CACHE_MODIFIERS) -def test_store_cache_modifiers(cache_modifier): - """Test local store with various cache modifiers. - - Cache modifiers are only effective for local stores (from_rank == to_rank). - Remote stores do not support cache modifiers. - """ +def test_store_cache_modifiers_local(cache_modifier): + """Test local store (from_rank == to_rank) with various cache modifiers.""" shmem = iris.iris(1 << 20) heap_bases = shmem.get_heap_bases() - destination_rank = shmem.get_rank() + cur_rank = shmem.get_rank() BLOCK_SIZE = 16 src = shmem.ones(BLOCK_SIZE, dtype=torch.float32) @@ -65,16 +63,53 @@ def test_store_cache_modifiers(cache_modifier): shmem.barrier() grid = lambda meta: (1,) - kernel[grid](src, results, destination_rank, BLOCK_SIZE, heap_bases, cache_modifier) + local_store_kernel[grid](src, results, cur_rank, BLOCK_SIZE, heap_bases, cache_modifier) shmem.barrier() - # Verify the result expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") - try: torch.testing.assert_close(results, expected, rtol=0, atol=0) except AssertionError as e: + print(f"LOCAL STORE test failed with cache_modifier={cache_modifier}") print(e) - print("Expected:", expected) - print("Actual:", results) raise + + +@pytest.mark.parametrize("cache_modifier", CACHE_MODIFIERS) +def test_store_cache_modifiers_remote(cache_modifier): + """Test remote store (from_rank != to_rank) with various cache modifiers. + + Cache modifiers are passed through unconditionally to tl.store(). It is the + caller's responsibility to use them appropriately for remote operations. + """ + shmem = iris.iris(1 << 20) + heap_bases = shmem.get_heap_bases() + num_ranks = shmem.get_num_ranks() + cur_rank = shmem.get_rank() + + if num_ranks < 2: + pytest.skip("Remote store test requires at least 2 ranks") + + BLOCK_SIZE = 16 + src = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros(BLOCK_SIZE, dtype=torch.float32) + + shmem.barrier() + + # rank 0 stores to rank 1 + remote_rank = (cur_rank + 1) % num_ranks + grid = lambda meta: (1,) + if cur_rank == 0: + remote_store_kernel[grid](src, results, cur_rank, remote_rank, BLOCK_SIZE, heap_bases, cache_modifier) + + shmem.barrier() + + # rank 1 checks the data it received from rank 0 + if cur_rank == 1: + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(f"REMOTE STORE test failed with cache_modifier={cache_modifier}") + print(e) + raise diff --git a/tests/unittests/test_topology.py b/tests/unittests/test_topology.py new file mode 100644 index 000000000..22836180e --- /dev/null +++ b/tests/unittests/test_topology.py @@ -0,0 +1,971 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tests for iris.topology — Multi-GPU topology discovery. + +Some tests are pure unit tests (no GPU needed), others require a distributed +process group with real GPUs. The distributed tests are marked and will skip +gracefully if the environment isn't set up. +""" + +import json +import socket + +import pytest +import torch +import torch.distributed as dist + +from iris.topology import ( + FabricInfo, + GPUInfo, + IntraNodeLinkType, + InterconnectLevel, + NodeInfo, + TopologyDiscovery, + TopologyMap, + _all_gather_strings, + _normalize_pci_bus_id, + _logical_to_physical_gpu_index, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def get_rank(): + return dist.get_rank() if dist.is_initialized() else 0 + + +def get_world_size(): + return dist.get_world_size() if dist.is_initialized() else 1 + + +# --------------------------------------------------------------------------- +# Unit tests — no GPU or distributed required +# --------------------------------------------------------------------------- + + +class TestFabricInfo: + """Tests for FabricInfo data class.""" + + def test_empty_fabric_info(self): + fi = FabricInfo() + assert fi.cluster_uuid == "" + assert fi.clique_id == 0 + assert fi.is_valid is False + assert fi.domain_key == "" + + def test_valid_fabric_info(self): + fi = FabricInfo(cluster_uuid="00aabbccdd112233", clique_id=1) + assert fi.is_valid is True + assert fi.domain_key == "00aabbccdd112233:1" + + def test_domain_key_comparison(self): + fi_a = FabricInfo(cluster_uuid="aabb", clique_id=0) + fi_b = FabricInfo(cluster_uuid="aabb", clique_id=0) + fi_c = FabricInfo(cluster_uuid="aabb", clique_id=1) + fi_d = FabricInfo(cluster_uuid="ccdd", clique_id=0) + assert fi_a.domain_key == fi_b.domain_key + assert fi_a.domain_key != fi_c.domain_key # different clique + assert fi_a.domain_key != fi_d.domain_key # different cluster + + def test_empty_domain_keys_do_not_match(self): + """Empty domain keys are equal but falsy, so topology code won't match them.""" + fi_a = FabricInfo() + fi_b = FabricInfo() + assert fi_a.domain_key == fi_b.domain_key # "" == "" is True + assert not fi_a.domain_key + + def test_serialization_roundtrip(self): + fi = FabricInfo(cluster_uuid="deadbeef", clique_id=42) + d = fi.to_dict() + fi2 = FabricInfo.from_dict(d) + assert fi2.cluster_uuid == fi.cluster_uuid + assert fi2.clique_id == fi.clique_id + + def test_from_dict_missing_keys(self): + fi = FabricInfo.from_dict({}) + assert fi.cluster_uuid == "" + assert fi.clique_id == 0 + + +class TestGPUInfo: + """Tests for GPUInfo data class.""" + + def _make_gpu_info( + self, + rank=0, + hostname="node-a", + gpu_id=0, + pci_bus_id="0000:41:00.0", + fabric=None, + ): + return GPUInfo( + global_rank=rank, + local_rank=gpu_id, + hostname=hostname, + gpu_id=gpu_id, + pci_bus_id=pci_bus_id, + device_name="Test GPU", + total_memory_mb=81920, + numa_node=0, + vendor="amd", + uuid=f"gpu-{hostname}-{gpu_id}", + fabric_info=fabric or FabricInfo(), + ) + + def test_serialization_roundtrip(self): + fi = FabricInfo(cluster_uuid="aabb", clique_id=1) + gpu = self._make_gpu_info(fabric=fi) + d = gpu.to_dict() + gpu2 = GPUInfo.from_dict(d) + assert gpu2.global_rank == gpu.global_rank + assert gpu2.hostname == gpu.hostname + assert gpu2.pci_bus_id == gpu.pci_bus_id + assert gpu2.fabric_info.cluster_uuid == "aabb" + assert gpu2.fabric_info.clique_id == 1 + + def test_from_dict_does_not_mutate_input(self): + """Regression: old code used dict.pop() which mutated the input.""" + d = { + "global_rank": 0, + "local_rank": 0, + "hostname": "h", + "gpu_id": 0, + "pci_bus_id": "x", + "device_name": "x", + "total_memory_mb": 0, + "numa_node": 0, + "vendor": "amd", + "uuid": "x", + "fabric_info": {"cluster_uuid": "abc", "clique_id": 1}, + } + original_keys = set(d.keys()) + GPUInfo.from_dict(d) + assert set(d.keys()) == original_keys + assert "fabric_info" in d # must not have been popped + + def test_from_dict_missing_fabric(self): + d = { + "global_rank": 0, + "local_rank": 0, + "hostname": "h", + "gpu_id": 0, + "pci_bus_id": "x", + "device_name": "x", + "total_memory_mb": 0, + "numa_node": 0, + "vendor": "amd", + "uuid": "x", + } + gpu = GPUInfo.from_dict(d) + assert gpu.fabric_info.is_valid is False + + +class TestNormalizePCIBusId: + """Tests for PCI bus ID normalization.""" + + def test_standard_format(self): + assert _normalize_pci_bus_id("0000:41:00.0") == "0000:41:00.0" + + def test_uppercase(self): + assert _normalize_pci_bus_id("0000:4A:00.0") == "0000:4a:00.0" + + def test_nvidia_8char_domain(self): + """nvidia-smi sometimes uses 8-char domain like 00000000:41:00.0""" + result = _normalize_pci_bus_id("00000000:41:00.0") + assert result == "0000:41:00.0" + + def test_prefix_junk(self): + result = _normalize_pci_bus_id("GPU 0000:41:00.0") + assert result == "0000:41:00.0" + + def test_no_match(self): + result = _normalize_pci_bus_id("garbage") + assert result == "garbage" + + +class TestNodeInfo: + """Tests for NodeInfo safe accessors.""" + + def test_get_link_type_self(self): + node = NodeInfo( + hostname="h", + link_types=[ + [IntraNodeLinkType.SELF, IntraNodeLinkType.NVLINK], + [IntraNodeLinkType.NVLINK, IntraNodeLinkType.SELF], + ], + ) + assert node.get_link_type(0, 0) == IntraNodeLinkType.SELF + assert node.get_link_type(0, 1) == IntraNodeLinkType.NVLINK + + def test_get_link_type_out_of_bounds(self): + """Oversubscription safety: local_rank=3 on a 2-GPU node.""" + node = NodeInfo( + hostname="h", + link_types=[ + [IntraNodeLinkType.SELF, IntraNodeLinkType.NVLINK], + [IntraNodeLinkType.NVLINK, IntraNodeLinkType.SELF], + ], + ) + # gpu_id 3 is out of bounds for a 2x2 matrix + assert node.get_link_type(3, 0) == IntraNodeLinkType.UNKNOWN + assert node.get_link_type(0, 3) == IntraNodeLinkType.UNKNOWN + + def test_get_link_type_no_matrix(self): + node = NodeInfo(hostname="h", link_types=None) + assert node.get_link_type(0, 1) == IntraNodeLinkType.UNKNOWN + + def test_p2p_access_out_of_bounds(self): + node = NodeInfo( + hostname="h", + p2p_access=[[True, True], [True, True]], + ) + assert node.can_p2p_access(0, 1) is True + assert node.can_p2p_access(3, 0) is False # out of bounds -> False + + def test_p2p_access_self_always_true(self): + node = NodeInfo(hostname="h", p2p_access=None) + assert node.can_p2p_access(0, 0) is True + + +class TestTopologyMap: + """Tests for TopologyMap with synthetic data.""" + + def _make_topology(self): + """ + Build a synthetic 8-rank topology: + node-a: ranks 0,1,2,3 (GPU0-3), fabric "aabb:0" + node-b: ranks 4,5 (GPU0-1), fabric "aabb:0" + node-c: ranks 6,7 (GPU0-1), no fabric + """ + fabric_ab = FabricInfo(cluster_uuid="aabb", clique_id=0) + gpus = {} + for r in range(4): + gpus[r] = GPUInfo( + global_rank=r, + local_rank=r, + hostname="node-a", + gpu_id=r, + pci_bus_id=f"0000:4{r}:00.0", + device_name="MI300X", + total_memory_mb=81920, + numa_node=0, + vendor="amd", + uuid=f"gpu-a-{r}", + fabric_info=fabric_ab, + ) + for r in range(4, 6): + gpus[r] = GPUInfo( + global_rank=r, + local_rank=r - 4, + hostname="node-b", + gpu_id=r - 4, + pci_bus_id=f"0000:8{r - 4}:00.0", + device_name="MI300X", + total_memory_mb=81920, + numa_node=0, + vendor="amd", + uuid=f"gpu-b-{r - 4}", + fabric_info=fabric_ab, + ) + for r in range(6, 8): + gpus[r] = GPUInfo( + global_rank=r, + local_rank=r - 6, + hostname="node-c", + gpu_id=r - 6, + pci_bus_id=f"0000:c{r - 6}:00.0", + device_name="A100", + total_memory_mb=81920, + numa_node=0, + vendor="nvidia", + uuid=f"gpu-c-{r - 6}", + fabric_info=FabricInfo(), + ) + + nodes = { + "node-a": NodeInfo( + hostname="node-a", + ranks=[0, 1, 2, 3], + gpu_ids=[0, 1, 2, 3], + unique_gpu_ids=[0, 1, 2, 3], + unique_pci_ids=[ + "0000:40:00.0", + "0000:41:00.0", + "0000:42:00.0", + "0000:43:00.0", + ], + num_gpus=4, + num_ranks=4, + fabric_domain_key="aabb:0", + fabric_domain_keys=["aabb:0"], + ), + "node-b": NodeInfo( + hostname="node-b", + ranks=[4, 5], + gpu_ids=[0, 1], + unique_gpu_ids=[0, 1], + unique_pci_ids=["0000:80:00.0", "0000:81:00.0"], + num_gpus=2, + num_ranks=2, + fabric_domain_key="aabb:0", + fabric_domain_keys=["aabb:0"], + ), + "node-c": NodeInfo( + hostname="node-c", + ranks=[6, 7], + gpu_ids=[0, 1], + unique_gpu_ids=[0, 1], + unique_pci_ids=["0000:c0:00.0", "0000:c1:00.0"], + num_gpus=2, + num_ranks=2, + fabric_domain_key="", + fabric_domain_keys=[], + ), + } + + fabric_domains = {"aabb:0": ["node-a", "node-b"]} + + return TopologyMap( + world_size=8, + num_nodes=3, + gpu_info=gpus, + nodes=nodes, + fabric_domains=fabric_domains, + ) + + # --- Interconnect level classification --- + + def test_same_rank_is_intra_node(self): + topo = self._make_topology() + assert topo.get_interconnect_level(0, 0) == InterconnectLevel.INTRA_NODE + + def test_same_node_is_intra_node(self): + topo = self._make_topology() + assert topo.get_interconnect_level(0, 3) == InterconnectLevel.INTRA_NODE + assert topo.get_interconnect_level(1, 2) == InterconnectLevel.INTRA_NODE + + def test_same_fabric_different_node_is_fabric(self): + topo = self._make_topology() + # rank 0 (node-a) <-> rank 4 (node-b): both in fabric "aabb:0" + assert topo.get_interconnect_level(0, 4) == InterconnectLevel.INTRA_RACK_FABRIC + assert topo.get_interconnect_level(3, 5) == InterconnectLevel.INTRA_RACK_FABRIC + + def test_no_fabric_is_rdma(self): + topo = self._make_topology() + # rank 0 (node-a, fabric) <-> rank 6 (node-c, no fabric) + assert topo.get_interconnect_level(0, 6) == InterconnectLevel.INTER_NODE_RDMA + assert topo.get_interconnect_level(4, 7) == InterconnectLevel.INTER_NODE_RDMA + + # --- Peer groups --- + + def test_node_peers(self): + topo = self._make_topology() + assert topo.get_node_peers(0) == {1, 2, 3} + assert topo.get_node_peers(4) == {5} + assert topo.get_node_peers(6) == {7} + + def test_fabric_domain_peers(self): + topo = self._make_topology() + # rank 0: fabric peers = all of (node-a + node-b) minus self + assert topo.get_fabric_domain_peers(0) == {1, 2, 3, 4, 5} + # rank 4: fabric peers = all of (node-a + node-b) minus self + assert topo.get_fabric_domain_peers(4) == {0, 1, 2, 3, 5} + # rank 6: no fabric -> empty + assert topo.get_fabric_domain_peers(6) == set() + + def test_rdma_peers(self): + topo = self._make_topology() + # rank 0: RDMA peers = everyone not in same node or fabric = {6, 7} + assert topo.get_rdma_peers(0) == {6, 7} + # rank 6: RDMA peers = everyone not on node-c = {0,1,2,3,4,5} + assert topo.get_rdma_peers(6) == {0, 1, 2, 3, 4, 5} + + def test_peer_groups_partition_world(self): + """Node peers + fabric-only peers + RDMA peers + self = world.""" + topo = self._make_topology() + for rank in range(8): + node = topo.get_node_peers(rank) + fabric_only = topo.get_fabric_domain_peers(rank) - node + rdma = topo.get_rdma_peers(rank) + all_peers = node | fabric_only | rdma | {rank} + assert all_peers == set(range(8)), ( + f"Rank {rank}: partition incomplete. Missing: {set(range(8)) - all_peers}" + ) + # No overlaps + assert not (node & rdma), f"Rank {rank}: node∩rdma overlap" + assert not (fabric_only & rdma), f"Rank {rank}: fabric∩rdma overlap" + + # --- Communication groups --- + + def test_comm_groups_intra_node(self): + topo = self._make_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + groups = td.get_communication_groups() + intra = groups[InterconnectLevel.INTRA_NODE] + assert sorted(intra, key=lambda g: g[0]) == [ + [0, 1, 2, 3], + [4, 5], + [6, 7], + ] + + def test_comm_groups_fabric_includes_standalone(self): + """ + When fabric domains exist, standalone nodes (no fabric) must still + appear in the fabric tier so they aren't orphaned. + """ + topo = self._make_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + groups = td.get_communication_groups() + fabric = groups[InterconnectLevel.INTRA_RACK_FABRIC] + # Should have 2 groups: fabric domain + standalone + all_ranks_in_fabric = set() + for g in fabric: + all_ranks_in_fabric.update(g) + assert all_ranks_in_fabric == set(range(8)), ( + f"Ranks missing from fabric groups: {set(range(8)) - all_ranks_in_fabric}" + ) + + def test_comm_groups_fabric_domain_group_content(self): + """Fabric domain group should contain exactly the ranks in that domain.""" + topo = self._make_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + groups = td.get_communication_groups() + fabric = groups[InterconnectLevel.INTRA_RACK_FABRIC] + # First group should be the "aabb:0" domain (node-a + node-b) + assert [0, 1, 2, 3, 4, 5] in fabric + # Second group should be the standalone node-c + assert [6, 7] in fabric + + def test_comm_groups_fabric_is_not_empty_when_domains_exist(self): + """When fabric domains exist, fabric groups must be non-empty.""" + topo = self._make_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + groups = td.get_communication_groups() + assert len(groups[InterconnectLevel.INTRA_RACK_FABRIC]) > 0 + + def test_comm_groups_rdma_is_world(self): + topo = self._make_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + groups = td.get_communication_groups() + assert groups[InterconnectLevel.INTER_NODE_RDMA] == [list(range(8))] + + # --- Heap distribution plan --- + + def test_heap_plan_completeness(self): + topo = self._make_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + plan = td.get_heap_distribution_plan() + assert set(plan.keys()) == set(range(8)) + + def test_heap_plan_rank4(self): + topo = self._make_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + plan = td.get_heap_distribution_plan() + p4 = plan[4] + assert p4["ipc_peers"] == [5] # same node + assert p4["fabric_peers"] == [0, 1, 2, 3] # same fabric, diff node + assert p4["rdma_peers"] == [6, 7] # no fabric + + def test_heap_plan_no_peer_overlap(self): + topo = self._make_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + plan = td.get_heap_distribution_plan() + for rank, p in plan.items(): + ipc = set(p["ipc_peers"]) + fabric = set(p["fabric_peers"]) + rdma = set(p["rdma_peers"]) + assert not (ipc & fabric), f"Rank {rank}: ipc∩fabric" + assert not (ipc & rdma), f"Rank {rank}: ipc∩rdma" + assert not (fabric & rdma), f"Rank {rank}: fabric∩rdma" + assert rank not in (ipc | fabric | rdma), f"Rank {rank}: self in peers" + + # --- Topology summary --- + + def test_summary_contains_all_nodes(self): + topo = self._make_topology() + s = topo.summary() + assert "node-a" in s + assert "node-b" in s + assert "node-c" in s + assert "Fabric Domains" in s + + def test_ranks_for_fabric_domain(self): + topo = self._make_topology() + ranks = topo.get_ranks_for_fabric_domain("aabb:0") + assert ranks == [0, 1, 2, 3, 4, 5] + + def test_ranks_for_nonexistent_domain(self): + topo = self._make_topology() + assert topo.get_ranks_for_fabric_domain("nonexistent") == [] + + +class TestOversubscription: + """ + Tests for the oversubscription scenario: + 4 ranks sharing 2 physical GPUs on the same node. + """ + + def _make_oversubscribed_topology(self): + """ + node-x: ranks 0,1,2,3 + ranks 0,1 -> gpu_id=0, PCI=0000:41:00.0 + ranks 2,3 -> gpu_id=1, PCI=0000:42:00.0 + """ + gpus = {} + for r in range(4): + gid = r // 2 + gpus[r] = GPUInfo( + global_rank=r, + local_rank=r, + hostname="node-x", + gpu_id=gid, + pci_bus_id=f"0000:4{gid + 1}:00.0", + device_name="MI300X", + total_memory_mb=81920, + numa_node=0, + vendor="amd", + uuid=f"gpu-x-{gid}", + ) + + nodes = { + "node-x": NodeInfo( + hostname="node-x", + ranks=[0, 1, 2, 3], + gpu_ids=[0, 0, 1, 1], + unique_gpu_ids=[0, 1], + unique_pci_ids=["0000:41:00.0", "0000:42:00.0"], + num_gpus=2, # 2 physical GPUs, not 4 ranks + num_ranks=4, + link_types=[ + [IntraNodeLinkType.SELF, IntraNodeLinkType.NVLINK], + [IntraNodeLinkType.NVLINK, IntraNodeLinkType.SELF], + ], + p2p_access=[[True, True], [True, True]], + ), + } + + return TopologyMap( + world_size=4, + num_nodes=1, + gpu_info=gpus, + nodes=nodes, + fabric_domains={}, + ) + + def test_num_gpus_is_physical_count(self): + topo = self._make_oversubscribed_topology() + assert topo.nodes["node-x"].num_gpus == 2 + assert topo.nodes["node-x"].num_ranks == 4 + + def test_link_type_by_gpu_id(self): + """Topology lookup uses gpu_id (device index), not local_rank.""" + node = self._make_oversubscribed_topology().nodes["node-x"] + assert node.get_link_type(0, 1) == IntraNodeLinkType.NVLINK + assert node.get_link_type(0, 0) == IntraNodeLinkType.SELF + + def test_p2p_by_gpu_id(self): + node = self._make_oversubscribed_topology().nodes["node-x"] + assert node.can_p2p_access(0, 1) is True + + def test_all_ranks_are_node_peers(self): + topo = self._make_oversubscribed_topology() + assert topo.get_node_peers(0) == {1, 2, 3} + assert topo.get_node_peers(2) == {0, 1, 3} + + +class TestIsolationCollapse: + """ + Tests for the GPU isolation (CUDA_VISIBLE_DEVICES per-process) scenario. + + When SLURM/K8s isolates GPUs, every rank reports gpu_id=0, but their + PCI bus IDs differ. The node must NOT collapse to num_gpus=1. + """ + + def _make_isolated_topology(self): + """ + node-y: ranks 0,1 + rank 0 -> gpu_id=0 (isolated), PCI=0000:c1:00.0 + rank 1 -> gpu_id=0 (isolated), PCI=0000:c2:00.0 + """ + gpus = {} + for r in range(2): + gpus[r] = GPUInfo( + global_rank=r, + local_rank=r, + hostname="node-y", + gpu_id=0, # both report 0 due to isolation + pci_bus_id=f"0000:c{r + 1}:00.0", # but different physical GPUs + device_name="A100", + total_memory_mb=81920, + numa_node=0, + vendor="nvidia", + uuid=f"gpu-y-{r}", + ) + + nodes = { + "node-y": NodeInfo( + hostname="node-y", + ranks=[0, 1], + gpu_ids=[0, 0], + unique_gpu_ids=[0], # gpu_id dedup gives 1 + unique_pci_ids=["0000:c1:00.0", "0000:c2:00.0"], + num_gpus=2, # PCI dedup correctly gives 2 + num_ranks=2, + ), + } + + return TopologyMap( + world_size=2, + num_nodes=1, + gpu_info=gpus, + nodes=nodes, + fabric_domains={}, + ) + + def test_num_gpus_not_collapsed(self): + """ + Regression: with gpu_id dedup, num_gpus would be 1. + With PCI dedup, it's correctly 2. + """ + topo = self._make_isolated_topology() + assert topo.nodes["node-y"].num_gpus == 2 + + def test_both_ranks_are_node_peers(self): + topo = self._make_isolated_topology() + assert topo.get_node_peers(0) == {1} + assert topo.get_node_peers(1) == {0} + + +class TestNoFabricCluster: + """Tests for a cluster with NO fabric domains at all.""" + + def _make_no_fabric_topology(self): + gpus = {} + for r in range(4): + node = "node-a" if r < 2 else "node-b" + gpus[r] = GPUInfo( + global_rank=r, + local_rank=r % 2, + hostname=node, + gpu_id=r % 2, + pci_bus_id=f"0000:{r}0:00.0", + device_name="T4", + total_memory_mb=16384, + numa_node=0, + vendor="nvidia", + uuid=f"gpu-{r}", + ) + + nodes = { + "node-a": NodeInfo(hostname="node-a", ranks=[0, 1], num_gpus=2, num_ranks=2), + "node-b": NodeInfo(hostname="node-b", ranks=[2, 3], num_gpus=2, num_ranks=2), + } + + return TopologyMap( + world_size=4, + num_nodes=2, + gpu_info=gpus, + nodes=nodes, + fabric_domains={}, + ) + + def test_no_fabric_all_rdma(self): + topo = self._make_no_fabric_topology() + assert topo.get_interconnect_level(0, 2) == InterconnectLevel.INTER_NODE_RDMA + assert topo.get_interconnect_level(0, 1) == InterconnectLevel.INTRA_NODE + + def test_fabric_peers_empty(self): + topo = self._make_no_fabric_topology() + assert topo.get_fabric_domain_peers(0) == set() + + def test_comm_groups_no_fabric_is_empty(self): + """With no fabric, fabric groups should be empty — not mirrored from intra-node.""" + topo = self._make_no_fabric_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + groups = td.get_communication_groups() + assert groups[InterconnectLevel.INTRA_RACK_FABRIC] == [] + + def test_comm_groups_intra_node_still_correct(self): + """Intra-node groups are unaffected by the absence of fabric.""" + topo = self._make_no_fabric_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + groups = td.get_communication_groups() + intra = groups[InterconnectLevel.INTRA_NODE] + assert sorted(intra, key=lambda g: g[0]) == [[0, 1], [2, 3]] + + def test_comm_groups_rdma_still_covers_world(self): + """RDMA group contains all ranks even when no fabric exists.""" + topo = self._make_no_fabric_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + groups = td.get_communication_groups() + assert groups[InterconnectLevel.INTER_NODE_RDMA] == [[0, 1, 2, 3]] + + def test_heap_plan_no_fabric_peers(self): + """With no fabric, heap plan should have empty fabric_peers for all ranks.""" + topo = self._make_no_fabric_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + plan = td.get_heap_distribution_plan() + for rank, p in plan.items(): + assert p["fabric_peers"] == [], f"Rank {rank}: expected empty fabric_peers" + assert p["fabric_domain"] == "", f"Rank {rank}: expected empty fabric_domain" + + +class TestAllFabricCluster: + """Tests for a cluster where ALL nodes are in fabric domains.""" + + def _make_all_fabric_topology(self): + """ + node-a: ranks 0,1 (GPU0-1), fabric "aabb:0" + node-b: ranks 2,3 (GPU0-1), fabric "aabb:0" + """ + fabric = FabricInfo(cluster_uuid="aabb", clique_id=0) + gpus = {} + for r in range(2): + gpus[r] = GPUInfo( + global_rank=r, + local_rank=r, + hostname="node-a", + gpu_id=r, + pci_bus_id=f"0000:4{r}:00.0", + device_name="MI300X", + total_memory_mb=81920, + numa_node=0, + vendor="amd", + uuid=f"gpu-a-{r}", + fabric_info=fabric, + ) + for r in range(2, 4): + gpus[r] = GPUInfo( + global_rank=r, + local_rank=r - 2, + hostname="node-b", + gpu_id=r - 2, + pci_bus_id=f"0000:8{r - 2}:00.0", + device_name="MI300X", + total_memory_mb=81920, + numa_node=0, + vendor="amd", + uuid=f"gpu-b-{r - 2}", + fabric_info=fabric, + ) + + nodes = { + "node-a": NodeInfo( + hostname="node-a", + ranks=[0, 1], + num_gpus=2, + num_ranks=2, + fabric_domain_key="aabb:0", + fabric_domain_keys=["aabb:0"], + ), + "node-b": NodeInfo( + hostname="node-b", + ranks=[2, 3], + num_gpus=2, + num_ranks=2, + fabric_domain_key="aabb:0", + fabric_domain_keys=["aabb:0"], + ), + } + + return TopologyMap( + world_size=4, + num_nodes=2, + gpu_info=gpus, + nodes=nodes, + fabric_domains={"aabb:0": ["node-a", "node-b"]}, + ) + + def test_comm_groups_fabric_spans_nodes(self): + """Fabric group merges ranks from both nodes.""" + topo = self._make_all_fabric_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + groups = td.get_communication_groups() + fabric = groups[InterconnectLevel.INTRA_RACK_FABRIC] + assert fabric == [[0, 1, 2, 3]] + + def test_comm_groups_no_standalone_groups(self): + """When all nodes are in fabric, there should be no standalone groups.""" + topo = self._make_all_fabric_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + groups = td.get_communication_groups() + fabric = groups[InterconnectLevel.INTRA_RACK_FABRIC] + # Only one group — no standalone appendages + assert len(fabric) == 1 + + def test_comm_groups_intra_node_still_per_host(self): + """Intra-node groups stay per-host even when fabric spans nodes.""" + topo = self._make_all_fabric_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + groups = td.get_communication_groups() + intra = groups[InterconnectLevel.INTRA_NODE] + assert sorted(intra, key=lambda g: g[0]) == [[0, 1], [2, 3]] + + def test_heap_plan_fabric_peers_cross_node(self): + """Fabric peers should be cross-node ranks in the same domain.""" + topo = self._make_all_fabric_topology() + td = TopologyDiscovery.__new__(TopologyDiscovery) + td._topology = topo + plan = td.get_heap_distribution_plan() + # rank 0 on node-a: fabric peers = node-b ranks (cross-node, same domain) + assert plan[0]["fabric_peers"] == [2, 3] + assert plan[0]["ipc_peers"] == [1] + assert plan[0]["rdma_peers"] == [] + + def test_interconnect_cross_node_is_fabric(self): + topo = self._make_all_fabric_topology() + assert topo.get_interconnect_level(0, 2) == InterconnectLevel.INTRA_RACK_FABRIC + assert topo.get_interconnect_level(1, 3) == InterconnectLevel.INTRA_RACK_FABRIC + + +# --------------------------------------------------------------------------- +# Distributed tests — require real GPUs and torchrun +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not dist.is_initialized(), reason="No distributed process group") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA/ROCm GPUs") +class TestDistributed: + """Tests that run within a real distributed process group.""" + + def test_all_gather_strings(self): + """Test the string all-gather primitive.""" + rank = get_rank() + world_size = get_world_size() + local_str = f"hello_from_rank_{rank}" + results = _all_gather_strings(local_str, world_size) + assert len(results) == world_size + for r in range(world_size): + assert results[r] == f"hello_from_rank_{r}" + + def test_all_gather_strings_empty(self): + world_size = get_world_size() + results = _all_gather_strings("", world_size) + assert results == [""] * world_size + + def test_all_gather_strings_large_payload(self): + """Simulate large JSON payloads (a few KB each).""" + rank = get_rank() + world_size = get_world_size() + payload = json.dumps({"rank": rank, "data": "x" * 4096}) + results = _all_gather_strings(payload, world_size) + assert len(results) == world_size + for r in range(world_size): + parsed = json.loads(results[r]) + assert parsed["rank"] == r + assert len(parsed["data"]) == 4096 + + +@pytest.mark.skipif(not dist.is_initialized(), reason="No distributed process group") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA/ROCm GPUs") +class TestFullDiscovery: + """End-to-end topology discovery tests.""" + + def test_discover_returns_topology(self): + td = TopologyDiscovery() + topo = td.discover() + assert isinstance(topo, TopologyMap) + assert topo.world_size == get_world_size() + assert topo.num_nodes >= 1 + + def test_local_rank_is_unique_per_node(self): + td = TopologyDiscovery() + topo = td.discover() + for hostname, node in topo.nodes.items(): + local_ranks = [topo.gpu_info[r].local_rank for r in node.ranks] + assert len(set(local_ranks)) == len(local_ranks), f"Duplicate local_ranks on {hostname}: {local_ranks}" + + def test_own_rank_info_correct(self): + td = TopologyDiscovery() + topo = td.discover() + rank = get_rank() + info = topo.gpu_info[rank] + assert info.global_rank == rank + assert info.hostname == socket.gethostname() + assert info.vendor in ("amd", "nvidia") + assert info.total_memory_mb > 0 + assert info.pci_bus_id != "" + + def test_interconnect_symmetry(self): + """Interconnect level should be symmetric: level(a,b) == level(b,a).""" + td = TopologyDiscovery() + topo = td.discover() + ranks = sorted(topo.gpu_info.keys()) + for i, a in enumerate(ranks): + for b in ranks[i + 1 :]: + level_ab = topo.get_interconnect_level(a, b) + level_ba = topo.get_interconnect_level(b, a) + assert level_ab == level_ba, f"Asymmetric: level({a},{b})={level_ab} != level({b},{a})={level_ba}" + + def test_peer_partition_exhaustive(self): + """For every rank, peers must partition the world.""" + td = TopologyDiscovery() + topo = td.discover() + world = set(range(get_world_size())) + for rank in world: + node = topo.get_node_peers(rank) + fabric_only = topo.get_fabric_domain_peers(rank) - node + rdma = topo.get_rdma_peers(rank) + union = node | fabric_only | rdma | {rank} + assert union == world, f"Rank {rank}: missing {world - union}" + + +class TestLogicalToPhysicalGpuIndex: + """Tests for CUDA_VISIBLE_DEVICES / HIP_VISIBLE_DEVICES index translation.""" + + def test_no_env_var_returns_logical(self, monkeypatch): + monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) + assert _logical_to_physical_gpu_index(0, "nvidia") == 0 + assert _logical_to_physical_gpu_index(3, "nvidia") == 3 + + def test_nvidia_remapping(self, monkeypatch): + monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "2,3") + assert _logical_to_physical_gpu_index(0, "nvidia") == 2 + assert _logical_to_physical_gpu_index(1, "nvidia") == 3 + + def test_amd_hip_visible(self, monkeypatch): + monkeypatch.setenv("HIP_VISIBLE_DEVICES", "4,5,6") + assert _logical_to_physical_gpu_index(0, "amd") == 4 + assert _logical_to_physical_gpu_index(2, "amd") == 6 + + def test_amd_rocr_fallback(self, monkeypatch): + monkeypatch.delenv("HIP_VISIBLE_DEVICES", raising=False) + monkeypatch.setenv("ROCR_VISIBLE_DEVICES", "1,3") + assert _logical_to_physical_gpu_index(0, "amd") == 1 + assert _logical_to_physical_gpu_index(1, "amd") == 3 + + def test_hip_takes_priority_over_rocr(self, monkeypatch): + monkeypatch.setenv("HIP_VISIBLE_DEVICES", "7") + monkeypatch.setenv("ROCR_VISIBLE_DEVICES", "9") + assert _logical_to_physical_gpu_index(0, "amd") == 7 + + def test_logical_out_of_range_returns_logical(self, monkeypatch): + monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "2,3") + # logical index 5 is beyond the 2-entry list + assert _logical_to_physical_gpu_index(5, "nvidia") == 5 + + def test_uuid_style_entry_returns_logical(self, monkeypatch): + monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "GPU-abcdef12-3456-7890") + assert _logical_to_physical_gpu_index(0, "nvidia") == 0 + + def test_negative_index_passthrough(self): + assert _logical_to_physical_gpu_index(-1, "nvidia") == -1