Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/scripts/run_new_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ EXIT_CODE=0
fi
fi
echo \"Running: \$example_file with $NUM_RANKS ranks\"
torchrun --nproc_per_node=$NUM_RANKS --standalone \"\$example_file\"
torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:0 --nnodes=1 --nproc_per_node=$NUM_RANKS \"\$example_file\"
fi
done
" || { EXIT_CODE=$?; }
Expand Down
2 changes: 1 addition & 1 deletion .github/scripts/run_perf_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ echo "[PERF-BENCHMARK] Using GPUs: $GPU_DEVICES"

cd /iris_workspace
pip install -e .
torchrun --nproc_per_node=8 examples/${EXAMPLE_PATH}/benchmark.py \
torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:0 --nnodes=1 --nproc_per_node=8 examples/${EXAMPLE_PATH}/benchmark.py \
--benchmark \
--validate \
${BENCHMARK_ARGS} \
Expand Down
2 changes: 1 addition & 1 deletion .github/scripts/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ EXIT_CODE=0
for test_file in tests/$TEST_DIR/test_*.py; do
if [ -f \"\$test_file\" ]; then
echo \"Testing: \$test_file with $NUM_RANKS ranks (install: $INSTALL_METHOD)\"
torchrun --nproc_per_node=$NUM_RANKS --standalone tests/run_tests_distributed.py \"\$test_file\" -v --tb=short --durations=10
torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:0 --nnodes=1 --nproc_per_node=$NUM_RANKS tests/run_tests_distributed.py \"\$test_file\" -v --tb=short --durations=10
fi
done
" || { EXIT_CODE=$?; }
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/iris-external-validation-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
cd /iris_workspace
pip install git+https://github.com/${{ github.repository }}.git@${{ github.sha }}
wget -O test_iris_distributed.py https://gist.githubusercontent.com/mawad-amd/6375dc078e39e256828f379e03310ec7/raw/0827d023eaf8e9755b17cbe8ab06f2ce258e746a/test_iris_distributed.py
torchrun --nproc_per_node=2 test_iris_distributed.py
torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:0 --nnodes=1 --nproc_per_node=2 test_iris_distributed.py
"
echo "::endgroup::"

Expand Down Expand Up @@ -103,7 +103,7 @@ jobs:
cd /iris_workspace
pip install git+https://github.com/${{ github.repository }}.git@${{ github.sha }}
wget -O test_iris_gluon_distributed.py https://gist.githubusercontent.com/mawad-amd/2666dde8ebe2755eb0c4f2108709fcd5/raw/c5544943e2832c75252160bd9084600bf01a6b06/test_iris_gluon_distributed.py
torchrun --nproc_per_node=2 test_iris_gluon_distributed.py
torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:0 --nnodes=1 --nproc_per_node=2 test_iris_gluon_distributed.py
"
echo "::endgroup::"

Expand Down
9 changes: 4 additions & 5 deletions .github/workflows/iris-performance-regression-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@ jobs:
matrix:
# Performance baselines measured on AMD Instinct MI325X (8 GPUs)
include:
# Disabled https://github.com/ROCm/iris/issues/238
#- example_name: "GEMM All-Scatter WG Specialization"
# example_path: "10_gemm_all_scatter_wg_specialization"
# tflops_threshold: 1600 # Actual: ~2182 TFLOPs
# benchmark_args: "-m 16384 -n 16384 -k 16384 --BLK_M 128 --BLK_N 128 --BLK_K 64 --gsize_m 6 --gemm_sms 256"
- example_name: "GEMM All-Scatter WG Specialization"
example_path: "10_gemm_all_scatter_wg_specialization"
tflops_threshold: 1440 # Actual: ~1802 TFLOPs (80% regression threshold)
benchmark_args: "-m 16384 -n 16384 -k 16384 --BLK_M 128 --BLK_N 128 --BLK_K 64 --gsize_m 6 --gemm_sms 256"

- example_name: "GEMM All-Scatter"
example_path: "07_gemm_all_scatter"
Expand Down
52 changes: 39 additions & 13 deletions examples/10_gemm_all_scatter_wg_specialization/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
Expand Down Expand Up @@ -132,7 +134,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
total_blocks_N = triton.cdiv(args["n"], args["BLK_N"])
total_tiles = total_blocks_M * total_blocks_N

locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8)
locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)

bias = None

Expand All @@ -153,13 +155,18 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
# Allocate Timestamps
timestamps = Timestamps(num_tiles=total_tiles)

def preamble():
# Barrier 1: ensure all ranks finish previous iteration before clearing locks
shmem.barrier()
locks.zero_()
# Barrier 2: ensure all ranks see zeroed locks before any rank starts the kernel
shmem.barrier()

def run_experiment():
nonlocal local_C
nonlocal global_C
nonlocal kernel_timing

shmem.barrier()

if args["trace_tiles"]:
timestamps.reset()
shmem.barrier()
Expand Down Expand Up @@ -215,6 +222,16 @@ def run_experiment():
kernel_timing[k]["experiments"] = 0

if args["validate"]:
# Run a dedicated validation kernel to ensure all cross-GPU writes are fully
# propagated before checking results. The warmup above may leave some
# iris.put stores in-flight on the xGMI interconnect; the extra
# preamble + run + barrier cycle guarantees all ranks have flushed their
# GPU caches and that rank-0 sees every scattered tile before we call
# validate_gemm.
preamble()
run_experiment()
shmem.barrier()

shmem.info("Validating...")
matmul.set_debug(True)
# Validate global result
Expand All @@ -241,7 +258,7 @@ def run_experiment():
matmul.set_debug(False)
shmem.info("Benchmarking...")
perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3)
triton_ms = iris.do_bench(run_experiment, shmem.barrier)
triton_ms = iris.do_bench(run_experiment, shmem.barrier, preamble)
triton_tflops = perf(triton_ms)
algo_string = "all_scatter"
shmem.info(
Expand Down Expand Up @@ -275,15 +292,24 @@ def run_experiment():
def main():
args = parse_args()

num_ranks = args["num_ranks"]

init_url = "tcp://127.0.0.1:29500"
mp.spawn(
fn=_worker,
args=(num_ranks, init_url, args),
nprocs=num_ranks,
join=True,
)
# Check if running with torchrun (detected by environment variables)
if "RANK" in os.environ and "LOCAL_RANK" in os.environ:
# torchrun handles process spawning, so call _worker directly
print("Detected torchrun execution mode")
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
init_url = os.environ.get("MASTER_ADDR", "127.0.0.1") + ":" + os.environ.get("MASTER_PORT", "29500")
_worker(rank, world_size, f"tcp://{init_url}", args)
else:
# Use multiprocessing spawn for backward compatibility
num_ranks = args["num_ranks"]
init_url = "tcp://127.0.0.1:29500"
mp.spawn(
fn=_worker,
args=(num_ranks, init_url, args),
nprocs=num_ranks,
join=True,
)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def persistent_gemm_all_scatter_wg_specialization(
tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp)

tl.store(c_global + global_offset, c, mask=sub_mask, cache_modifier=".wt")
tl.debug_barrier()
tl.store(locks + tile_id, 1, cache_modifier=".wt")
tl.atomic_xchg(locks + tile_id, 1, sem="release", scope="gpu")

else: # pid >= GEMM_SMS
COMM_SMS = NUM_SMS - GEMM_SMS
Expand All @@ -163,8 +162,11 @@ def persistent_gemm_all_scatter_wg_specialization(
global_offset = rm[:, None] * stride_cm_global + (rn[None, :] + cur_rank * N) * stride_cn_global
# End: masks/offset calculations.

# Spin-wait: first check with a cheap volatile load, then acquire-CAS to
# ensure memory ordering once the lock is observed set.
while tl.load(locks + tile_id, cache_modifier=".cv", volatile=True) != 1:
pass
tl.atomic_cas(locks + tile_id, 1, 1, sem="acquire", scope="gpu")

for remote_rank in range(world_size):
if remote_rank != cur_rank:
Expand Down
19 changes: 18 additions & 1 deletion examples/25_ccl_all_gather/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch.distributed as dist

import iris
from iris.ccl import Config


def parse_args():
Expand All @@ -30,6 +31,12 @@ def parse_args():
parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size")
parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type")
parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference")
parser.add_argument("--block_size_m", type=int, default=32, help="Block size for M dimension tiling")
parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension tiling")
parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for all-gather kernel")
parser.add_argument("--num_stages", type=int, default=1, help="Number of stages")
parser.add_argument("--num_warps", type=int, default=4, help="Number of warps")
parser.add_argument("--waves_per_eu", type=int, default=0, help="Number of waves per EU")
return vars(parser.parse_args())


Expand All @@ -53,8 +60,18 @@ def main():
input_tensor.fill_(float(rank + 1))
output_tensor = ctx.zeros((world_size * M, N), dtype=dtype)

config_kwargs = {
"block_size_m": args["block_size_m"],
"block_size_n": args["block_size_n"],
"comm_sms": args["comm_sms"],
"num_stages": args["num_stages"],
"num_warps": args["num_warps"],
"waves_per_eu": args["waves_per_eu"],
}
config = Config(**config_kwargs)

ctx.barrier()
ctx.ccl.all_gather(output_tensor, input_tensor)
ctx.ccl.all_gather(output_tensor, input_tensor, config=config)
torch.cuda.synchronize()

if rank == 0:
Expand Down
19 changes: 18 additions & 1 deletion examples/26_ccl_all_to_all/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch.distributed as dist

import iris
from iris.ccl import Config


def parse_args():
Expand All @@ -28,6 +29,12 @@ def parse_args():
parser.add_argument("-m", type=int, default=512, help="Number of rows")
parser.add_argument("-n", type=int, default=128, help="Number of columns per rank slice")
parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size")
parser.add_argument("--block_size_m", type=int, default=32, help="Block size for M dimension tiling")
parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension tiling")
parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for all-to-all kernel")
parser.add_argument("--num_stages", type=int, default=1, help="Number of stages")
parser.add_argument("--num_warps", type=int, default=4, help="Number of warps")
parser.add_argument("--waves_per_eu", type=int, default=0, help="Number of waves per EU")
parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type")
parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference")
return vars(parser.parse_args())
Expand All @@ -54,8 +61,18 @@ def main():
input_tensor[:, target_rank * N : (target_rank + 1) * N] = float(rank * 10 + target_rank + 1)
output_tensor = ctx.zeros((M, N * world_size), dtype=dtype)

config_kwargs = {
"block_size_m": args["block_size_m"],
"block_size_n": args["block_size_n"],
"comm_sms": args["comm_sms"],
"num_stages": args["num_stages"],
"num_warps": args["num_warps"],
"waves_per_eu": args["waves_per_eu"],
}
config = Config(**config_kwargs)

ctx.barrier()
ctx.ccl.all_to_all(output_tensor, input_tensor)
ctx.ccl.all_to_all(output_tensor, input_tensor, config=config)
torch.cuda.synchronize()

if rank == 0:
Expand Down
114 changes: 114 additions & 0 deletions examples/27_ccl_reduce_scatter/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

"""
Example: iris.ccl.reduce_scatter

Each rank has input (M, N); each rank reduces its assigned tiles from all ranks
and stores the result only to its own output (same shape (M, N)).

Run with:
torchrun --nproc_per_node=<num_gpus> --standalone example.py [--validate]
"""

import argparse
import os

import torch
import torch.distributed as dist

import iris
from iris.ccl import Config


def parse_args():
parser = argparse.ArgumentParser(
description="CCL reduce-scatter example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("-m", type=int, default=1024, help="Number of rows")
parser.add_argument("-n", type=int, default=512, help="Number of columns")
parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size")
parser.add_argument("--block_size_m", type=int, default=32, help="Block size for M dimension tiling")
parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension tiling")
parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for reduce-scatter kernel")
parser.add_argument("--num_stages", type=int, default=1, help="Number of stages")
parser.add_argument("--num_warps", type=int, default=4, help="Number of warps")
parser.add_argument("--waves_per_eu", type=int, default=0, help="Number of waves per EU")
parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type")
parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference")
return vars(parser.parse_args())


def main():
args = parse_args()

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="gloo")

ctx = iris.iris(heap_size=args["heap_size"])
rank = ctx.get_rank()
world_size = ctx.get_num_ranks()

dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
dtype = dtype_map[args["datatype"]]
M, N = args["m"], args["n"]

# Each rank fills its input with (rank + 1)
input_tensor = ctx.zeros((M, N), dtype=dtype)
input_tensor.fill_(float(rank + 1))
output_tensor = ctx.zeros((M, N), dtype=dtype)

config_kwargs = {
"block_size_m": args["block_size_m"],
"block_size_n": args["block_size_n"],
"comm_sms": args["comm_sms"],
"num_stages": args["num_stages"],
"num_warps": args["num_warps"],
"waves_per_eu": args["waves_per_eu"],
"all_reduce_distribution": 1,
}
config = Config(**config_kwargs)

ctx.barrier()
ctx.ccl.reduce_scatter(output_tensor, input_tensor, config=config)
torch.cuda.synchronize()

if rank == 0:
ctx.info(f"reduce_scatter: world_size={world_size}, shape=({M},{N}), dtype={dtype}")

if args["validate"]:
# Reference: gather all inputs, sum, then each rank checks its assigned tiles
ref_list = [torch.empty(M, N, dtype=dtype, device=input_tensor.device) for _ in range(world_size)]
dist.all_gather(ref_list, input_tensor)
full_reduced = sum(ref_list).float()

block_size_m = args["block_size_m"]
block_size_n = args["block_size_n"]
num_pid_m = (M + block_size_m - 1) // block_size_m
num_pid_n = (N + block_size_n - 1) // block_size_n
total_tiles = num_pid_m * num_pid_n
tiles_per_rank = (total_tiles + world_size - 1) // world_size
start_tile = rank * tiles_per_rank

# Build mask of (i,j) belonging to this rank's tiles (block distribution)
pid_m = torch.arange(M, device=output_tensor.device) // block_size_m
pid_n = torch.arange(N, device=output_tensor.device) // block_size_n
tile_id = pid_m[:, None] * num_pid_n + pid_n[None, :]
mask = (tile_id >= start_tile) & (tile_id < start_tile + tiles_per_rank)

out_float = output_tensor.float()
expected_where = full_reduced[mask]
actual_where = out_float[mask]
assert torch.allclose(actual_where, expected_where, atol=0.6), f"Rank {rank}: output mismatch on assigned tiles"
if rank == 0:
ctx.info("Validation passed: output matches reference")

ctx.barrier()
dist.destroy_process_group()


if __name__ == "__main__":
main()
Loading