From 56ca569d01741d2ec9937c8e70620369a8ea8609 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 13 Mar 2026 16:04:31 +0000 Subject: [PATCH 1/3] Initial plan From 4492999649280ae829aee7671eba4a33890eed94 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 13 Mar 2026 16:20:43 +0000 Subject: [PATCH 2/3] Sync with main, enforce no cache modifiers on remote stores, update tests Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> --- .devcontainer/devcontainer.json | 2 +- .devcontainer/ensure-ssh-agent.sh | 22 +- .github/scripts/acquire_gpus.sh | 40 + .github/scripts/container_build.sh | 43 +- .github/scripts/container_exec.sh | 35 +- .github/scripts/examples_config.json | 5 + .github/scripts/gpu_allocator.sh | 227 +++ .github/scripts/release_gpus.sh | 28 + .github/scripts/run_new_examples.sh | 64 + .github/scripts/run_perf_benchmark.sh | 23 +- .github/scripts/run_tests.sh | 47 +- .github/workflows/copilot-setup-steps.yml | 40 + .../iris-external-validation-test.yml | 90 +- .../iris-performance-regression-test.yml | 53 +- .github/workflows/iris-tests.yml | 230 ++- .gitignore | 11 +- AGENTS.md | 100 + apptainer/iris.def | 14 +- benchmark/examples/benchmark_moe.py | 399 ++++ docker/Dockerfile | 13 +- docker/Dockerfile.ccl | 2 +- docs/conf.py | 8 +- docs/index.md | 1 + docs/reference/api-reference.md | 5 + docs/reference/gluon/ccl.md | 22 + docs/reference/gluon/class.md | 2 + docs/reference/gluon/overview.md | 1 + docs/reference/gluon/tensor-creation.md | 6 + docs/reference/talks-and-papers.md | 103 + docs/reference/triton/ccl.md | 28 + docs/reference/triton/class.md | 1 + docs/reference/triton/ops.md | 23 + docs/reference/triton/overview.md | 2 + docs/reference/triton/tensor-creation.md | 6 + docs/sphinx/_toc.yml | 3 + docs/sphinx/_toc.yml.in | 3 + .../message_passing_device_context.py | 188 ++ examples/07_gemm_all_scatter/benchmark.py | 30 +- .../07_gemm_all_scatter/gemm_all_scatter.py | 2 +- .../gemm_all_reduce_atomics.py | 2 +- .../gemm_one_shot_all_reduce.py | 2 +- .../gemm_all_scatter_wg_specialization.py | 2 +- .../benchmark.py | 36 +- .../gemm_all_scatter_producer_consumer.py | 2 +- .../benchmark.py | 36 +- .../gemm_all_scatter_bulk_synchronous.py | 2 +- .../gemm_all_reduce_ring_based.py | 2 +- .../gemm_all_scatter_bulk_synchronous.py | 2 +- .../gemm_one_shot_all_reduce_independent.py | 2 +- .../gemm_reduce_scatter.py | 2 +- .../23_gemm_all_scatter_tracing/benchmark.py | 306 +++ .../gemm_all_scatter.py | 166 ++ .../matmul_wrapper.py | 177 ++ examples/24_ccl_all_reduce/example.py | 77 + examples/25_ccl_all_gather/example.py | 78 + examples/26_ccl_all_to_all/example.py | 80 + examples/28_ops_matmul_all_reduce/example.py | 78 + examples/29_ops_all_gather_matmul/example.py | 85 + examples/30_ops_matmul_all_gather/example.py | 88 + examples/31_expert_sharded_moe/combine.py | 114 ++ examples/31_expert_sharded_moe/dispatch.py | 119 ++ examples/31_expert_sharded_moe/example_run.py | 188 ++ .../expert_assignment.py | 73 + .../fused_exp_matmul_ep_to_dp.py | 221 ++ .../31_expert_sharded_moe/grouped_matmul.py | 159 ++ examples/31_expert_sharded_moe/moe.py | 328 +++ .../31_expert_sharded_moe/ragged_metadata.py | 78 + examples/31_expert_sharded_moe/reduce.py | 109 + examples/31_expert_sharded_moe/topk.py | 121 ++ examples/31_message_passing/example.py | 217 ++ examples/common/utils.py | 16 +- iris/__init__.py | 48 +- iris/_distributed_helpers.py | 26 +- iris/allocators/__init__.py | 3 +- iris/allocators/base.py | 47 +- iris/allocators/torch_allocator.py | 110 +- iris/allocators/vmem_allocator.py | 323 +++ iris/ccl/all_gather.py | 157 +- iris/ccl/all_reduce.py | 14 +- iris/ccl/all_to_all.py | 2 + iris/ccl/config.py | 13 + iris/device_utils.py | 93 + iris/experimental/iris_gluon.py | 271 +-- iris/hip.py | 440 +++- iris/iris.py | 1789 +++++++++-------- iris/ops/all_gather_matmul.py | 163 +- iris/ops/matmul_all_gather.py | 149 +- iris/ops/matmul_all_reduce.py | 136 +- iris/ops/matmul_reduce_scatter.py | 86 +- iris/symmetric_heap.py | 244 ++- iris/tensor_creation.py | 881 ++++++++ iris/tensor_utils.py | 132 ++ iris/tracing/__init__.py | 12 + iris/tracing/core.py | 327 +++ iris/tracing/device.py | 185 ++ iris/tracing/events.py | 96 + iris/util.py | 59 + iris/x/__init__.py | 27 +- iris/x/all_gather.py | 13 +- iris/x/all_reduce.py | 33 +- iris/x/all_to_all.py | 5 +- iris/x/core.py | 130 +- iris/x/gather.py | 5 +- iris/x/reduce_scatter.py | 9 +- pyproject.toml | 2 + scripts/roccap_wrapper.py | 67 + tests/ccl/test_all_gather.py | 96 +- tests/ccl/test_all_reduce.py | 14 +- tests/ccl/test_all_to_all.py | 14 +- tests/examples/test_expert_sharded_moe.py | 106 + tests/ops/test_matmul_all_gather.py | 2 +- tests/run_tests_distributed.py | 135 +- tests/unittests/test_arange.py | 48 +- tests/unittests/test_device_context.py | 506 +++++ tests/unittests/test_dmabuf_apis.py | 60 +- .../test_dmabuf_controlled_va_import.py | 398 ++++ tests/unittests/test_dmabuf_vmem_import.py | 183 ++ tests/unittests/test_empty.py | 70 +- tests/unittests/test_full.py | 72 +- tests/unittests/test_hip_apis.py | 147 ++ tests/unittests/test_hip_vmem_primitives.py | 257 +++ tests/unittests/test_iris_helpers.py | 16 +- tests/unittests/test_is_symmetric.py | 38 + tests/unittests/test_linspace.py | 70 +- tests/unittests/test_ones.py | 64 +- tests/unittests/test_put_cache_modifiers.py | 88 +- tests/unittests/test_pytorch_dmabuf_export.py | 198 ++ .../test_pytorch_import_mechanism.py | 290 +++ tests/unittests/test_rand.py | 70 +- tests/unittests/test_randint.py | 70 +- tests/unittests/test_randn.py | 74 +- tests/unittests/test_rocr_behaviors.py | 344 ++++ tests/unittests/test_store_cache_modifiers.py | 39 +- tests/unittests/test_vmem_allocator.py | 252 +++ .../unittests/test_vmem_cumulative_access.py | 305 +++ .../test_vmem_imported_tensor_rma.py | 226 +++ .../test_vmem_minimal_export_with_ondemand.py | 401 ++++ .../unittests/test_vmem_multi_alloc_export.py | 198 ++ tests/unittests/test_vmem_offset_check.py | 76 + .../test_vmem_peer_dmabuf_exchange.py | 280 +++ tests/unittests/test_vmem_segmented_export.py | 192 ++ tests/unittests/test_zeros.py | 60 +- tests/unittests/test_zeros_like.py | 28 +- tests/x/test_all_gather.py | 24 +- tests/x/test_all_reduce.py | 38 +- tests/x/test_all_to_all.py | 10 +- tests/x/test_gather.py | 16 +- tests/x/test_reduce_scatter.py | 10 +- 148 files changed, 14090 insertions(+), 2482 deletions(-) mode change 100644 => 100755 .devcontainer/ensure-ssh-agent.sh create mode 100755 .github/scripts/acquire_gpus.sh create mode 100644 .github/scripts/examples_config.json create mode 100755 .github/scripts/gpu_allocator.sh create mode 100755 .github/scripts/release_gpus.sh create mode 100755 .github/scripts/run_new_examples.sh create mode 100644 .github/workflows/copilot-setup-steps.yml create mode 100644 AGENTS.md create mode 100644 benchmark/examples/benchmark_moe.py create mode 100644 docs/reference/gluon/ccl.md create mode 100644 docs/reference/talks-and-papers.md create mode 100644 docs/reference/triton/ccl.md create mode 100644 docs/reference/triton/ops.md create mode 100644 examples/06_message_passing/message_passing_device_context.py create mode 100644 examples/23_gemm_all_scatter_tracing/benchmark.py create mode 100644 examples/23_gemm_all_scatter_tracing/gemm_all_scatter.py create mode 100644 examples/23_gemm_all_scatter_tracing/matmul_wrapper.py create mode 100644 examples/24_ccl_all_reduce/example.py create mode 100644 examples/25_ccl_all_gather/example.py create mode 100644 examples/26_ccl_all_to_all/example.py create mode 100644 examples/28_ops_matmul_all_reduce/example.py create mode 100644 examples/29_ops_all_gather_matmul/example.py create mode 100644 examples/30_ops_matmul_all_gather/example.py create mode 100644 examples/31_expert_sharded_moe/combine.py create mode 100644 examples/31_expert_sharded_moe/dispatch.py create mode 100644 examples/31_expert_sharded_moe/example_run.py create mode 100644 examples/31_expert_sharded_moe/expert_assignment.py create mode 100644 examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py create mode 100644 examples/31_expert_sharded_moe/grouped_matmul.py create mode 100644 examples/31_expert_sharded_moe/moe.py create mode 100644 examples/31_expert_sharded_moe/ragged_metadata.py create mode 100644 examples/31_expert_sharded_moe/reduce.py create mode 100644 examples/31_expert_sharded_moe/topk.py create mode 100644 examples/31_message_passing/example.py create mode 100644 iris/allocators/vmem_allocator.py create mode 100644 iris/device_utils.py create mode 100644 iris/tensor_creation.py create mode 100644 iris/tensor_utils.py create mode 100644 iris/tracing/__init__.py create mode 100644 iris/tracing/core.py create mode 100644 iris/tracing/device.py create mode 100644 iris/tracing/events.py create mode 100644 scripts/roccap_wrapper.py create mode 100644 tests/examples/test_expert_sharded_moe.py create mode 100644 tests/unittests/test_device_context.py create mode 100644 tests/unittests/test_dmabuf_controlled_va_import.py create mode 100644 tests/unittests/test_dmabuf_vmem_import.py create mode 100644 tests/unittests/test_hip_apis.py create mode 100644 tests/unittests/test_hip_vmem_primitives.py create mode 100644 tests/unittests/test_is_symmetric.py create mode 100644 tests/unittests/test_pytorch_dmabuf_export.py create mode 100644 tests/unittests/test_pytorch_import_mechanism.py create mode 100644 tests/unittests/test_rocr_behaviors.py create mode 100644 tests/unittests/test_vmem_allocator.py create mode 100644 tests/unittests/test_vmem_cumulative_access.py create mode 100644 tests/unittests/test_vmem_imported_tensor_rma.py create mode 100644 tests/unittests/test_vmem_minimal_export_with_ondemand.py create mode 100644 tests/unittests/test_vmem_multi_alloc_export.py create mode 100644 tests/unittests/test_vmem_offset_check.py create mode 100644 tests/unittests/test_vmem_peer_dmabuf_exchange.py create mode 100644 tests/unittests/test_vmem_segmented_export.py diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index bb5adb9fd..6fc53f97f 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -9,7 +9,7 @@ // Creates a stable agent socket at ~/.ssh/ssh-agent.sock and optionally loads ~/.ssh/id_rsa. "initializeCommand": "bash -lc \"bash '${localWorkspaceFolder}/.devcontainer/ensure-ssh-agent.sh'\"", "runArgs": [ - "--name=${localEnv:USER}-iris-dev", + "--name=${localEnv:USER}-${localWorkspaceFolderBasename}-dev", "--network=host", "--device=/dev/kfd", "--device=/dev/dri", diff --git a/.devcontainer/ensure-ssh-agent.sh b/.devcontainer/ensure-ssh-agent.sh old mode 100644 new mode 100755 index 434af46e9..f9905a710 --- a/.devcontainer/ensure-ssh-agent.sh +++ b/.devcontainer/ensure-ssh-agent.sh @@ -14,13 +14,29 @@ SOCK="${HOME}/.ssh/ssh-agent.sock" mkdir -p "${HOME}/.ssh" +# Check if socket exists AND has keys loaded if [[ -S "${SOCK}" ]]; then - exit 0 + if SSH_AUTH_SOCK="${SOCK}" ssh-add -l >/dev/null 2>&1; then + # Agent is running and has keys loaded + exit 0 + fi + + # Check if agent is alive but just has no keys + if SSH_AUTH_SOCK="${SOCK}" ssh-add -l 2>&1 | grep -q "no identities"; then + # Agent is alive, just needs keys loaded - continue to key loading below + : + else + # Agent is dead or socket is stale, remove it + rm -f "${SOCK}" 2>/dev/null || true + fi fi -rm -f "${SOCK}" -ssh-agent -a "${SOCK}" -t 8h >/dev/null +# Start agent if socket doesn't exist +if [[ ! -S "${SOCK}" ]]; then + ssh-agent -a "${SOCK}" -t 8h >/dev/null || true +fi +# Load SSH key if it exists if [[ -f "${HOME}/.ssh/id_rsa" ]]; then SSH_AUTH_SOCK="${SOCK}" ssh-add "${HOME}/.ssh/id_rsa" >/dev/null 2>&1 || true fi diff --git a/.github/scripts/acquire_gpus.sh b/.github/scripts/acquire_gpus.sh new file mode 100755 index 000000000..8aaded388 --- /dev/null +++ b/.github/scripts/acquire_gpus.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Acquire GPUs for CI workflows - to be called as a workflow step +# Usage: acquire_gpus.sh +# +# Exports GPU_DEVICES environment variable to $GITHUB_ENV for use in subsequent steps + +set -e + +NUM_GPUS=$1 + +if [ -z "$NUM_GPUS" ]; then + echo "[ERROR] Missing required argument" + echo "Usage: $0 " + exit 1 +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +echo "[ACQUIRE-GPUS] Acquiring $NUM_GPUS GPU(s)" +source "$SCRIPT_DIR/gpu_allocator.sh" +acquire_gpus "$NUM_GPUS" + +echo "[ACQUIRE-GPUS] Allocated GPUs: $GPU_DEVICES" +echo "[ACQUIRE-GPUS] GPU allocation details:" +echo " GPU_DEVICES=$GPU_DEVICES" +echo " ALLOCATED_GPU_BITMAP=$ALLOCATED_GPU_BITMAP" + +# Export to GITHUB_ENV so subsequent steps can use these variables +if [ -n "$GITHUB_ENV" ]; then + { + echo "GPU_DEVICES=$GPU_DEVICES" + echo "ALLOCATED_GPU_BITMAP=$ALLOCATED_GPU_BITMAP" + } >> "$GITHUB_ENV" + echo "[ACQUIRE-GPUS] Exported variables to GITHUB_ENV" +else + echo "[ACQUIRE-GPUS] WARNING: GITHUB_ENV not set, variables not exported" +fi diff --git a/.github/scripts/container_build.sh b/.github/scripts/container_build.sh index e3cd241cd..5e8bda7bd 100755 --- a/.github/scripts/container_build.sh +++ b/.github/scripts/container_build.sh @@ -32,34 +32,34 @@ echo "✅ /dev/shm size OK (${shm_size_gb}GB)" if [ "$CONTAINER_RUNTIME" = "apptainer" ]; then echo "[INFO] Building with Apptainer..." - # Create persistent Apptainer directory - mkdir -p ~/apptainer - - # Define paths - IMAGE_PATH=~/apptainer/iris-dev.sif - DEF_FILE=apptainer/iris.def - CHECKSUM_FILE=~/apptainer/iris-dev.sif.checksum - # Verify def file exists + DEF_FILE=apptainer/iris.def if [ ! -f "$DEF_FILE" ]; then echo "[ERROR] Definition file $DEF_FILE not found" exit 1 fi - # Calculate checksum of the def file - NEW_CHECKSUM=$(sha256sum "$DEF_FILE" | awk '{print $1}') + # Calculate checksum of the def file to use as subdirectory name + DEF_CHECKSUM=$(sha256sum "$DEF_FILE" | awk '{print $1}') + + # Create persistent Apptainer directory with checksum subdirectory + mkdir -p "${HOME}/iris-apptainer-images/${DEF_CHECKSUM}" + + # Define paths + IMAGE_PATH="${HOME}/iris-apptainer-images/${DEF_CHECKSUM}/iris-dev.sif" + CHECKSUM_FILE="${HOME}/iris-apptainer-images/${DEF_CHECKSUM}/iris-dev.sif.checksum" # Check if image exists and has a valid checksum REBUILD_NEEDED=true if [ -f "$IMAGE_PATH" ] && [ -f "$CHECKSUM_FILE" ]; then OLD_CHECKSUM=$(head -n1 "$CHECKSUM_FILE" 2>/dev/null) # Validate checksum format (64 hex characters for SHA256) - if [[ "$OLD_CHECKSUM" =~ ^[a-f0-9]{64}$ ]] && [ "$OLD_CHECKSUM" = "$NEW_CHECKSUM" ]; then - echo "[INFO] Def file unchanged (checksum: $NEW_CHECKSUM)" + if [[ "$OLD_CHECKSUM" =~ ^[a-f0-9]{64}$ ]] && [ "$OLD_CHECKSUM" = "$DEF_CHECKSUM" ]; then + echo "[INFO] Def file unchanged (checksum: $DEF_CHECKSUM)" echo "[INFO] Skipping rebuild, using existing image at $IMAGE_PATH" REBUILD_NEEDED=false else - echo "[INFO] Def file changed (old: ${OLD_CHECKSUM:-}, new: $NEW_CHECKSUM)" + echo "[INFO] Def file changed (old: ${OLD_CHECKSUM:-}, new: $DEF_CHECKSUM)" echo "[INFO] Rebuilding Apptainer image..." fi else @@ -70,9 +70,9 @@ if [ "$CONTAINER_RUNTIME" = "apptainer" ]; then if [ "$REBUILD_NEEDED" = true ]; then if apptainer build --force "$IMAGE_PATH" "$DEF_FILE"; then # Store the checksum only if build succeeded - echo "$NEW_CHECKSUM" > "$CHECKSUM_FILE" + echo "$DEF_CHECKSUM" > "$CHECKSUM_FILE" echo "[INFO] Built image: $IMAGE_PATH" - echo "[INFO] Checksum saved: $NEW_CHECKSUM" + echo "[INFO] Checksum saved: $DEF_CHECKSUM" else echo "[ERROR] Apptainer build failed" exit 1 @@ -83,14 +83,19 @@ elif [ "$CONTAINER_RUNTIME" = "docker" ]; then echo "[INFO] Checking Docker images..." # Use GitHub variable if set, otherwise default to iris-dev IMAGE_NAME=${DOCKER_IMAGE_NAME:-"iris-dev"} - + # Check if the image exists if docker image inspect "$IMAGE_NAME" &> /dev/null; then echo "[INFO] Using existing Docker image: $IMAGE_NAME" else - echo "[WARNING] Docker image $IMAGE_NAME not found" - echo "[INFO] Please build it using: ./build_triton_image.sh" - echo "[INFO] Or pull it if available from registry" + echo "[INFO] Docker image $IMAGE_NAME not found, building..." + DOCKER_DIR="$(dirname "$(realpath "$0")")/../../docker" + if docker build -t "$IMAGE_NAME" "$DOCKER_DIR"; then + echo "[INFO] Built Docker image: $IMAGE_NAME" + else + echo "[ERROR] Docker build failed" + exit 1 + fi fi fi diff --git a/.github/scripts/container_exec.sh b/.github/scripts/container_exec.sh index 1c5313f2e..7fe05d2d7 100755 --- a/.github/scripts/container_exec.sh +++ b/.github/scripts/container_exec.sh @@ -51,13 +51,21 @@ if [ "$CONTAINER_RUNTIME" = "apptainer" ]; then # Find image if [ -n "$CUSTOM_IMAGE" ]; then IMAGE="$CUSTOM_IMAGE" - elif [ -f ~/apptainer/iris-dev.sif ]; then - IMAGE=~/apptainer/iris-dev.sif - elif [ -f apptainer/images/iris.sif ]; then - IMAGE="apptainer/images/iris.sif" else - echo "[ERROR] Apptainer image not found" >&2 - exit 1 + # Calculate checksum of def file to find the correct subdirectory + DEF_FILE=apptainer/iris.def + if [ ! -f "$DEF_FILE" ]; then + echo "[ERROR] Definition file $DEF_FILE not found" >&2 + exit 1 + fi + DEF_CHECKSUM=$(sha256sum "$DEF_FILE" | awk '{print $1}') + + if [ -f "${HOME}/iris-apptainer-images/${DEF_CHECKSUM}/iris-dev.sif" ]; then + IMAGE="${HOME}/iris-apptainer-images/${DEF_CHECKSUM}/iris-dev.sif" + else + echo "[ERROR] Apptainer image not found" >&2 + exit 1 + fi fi # Create temporary overlay in workspace with unique name based on PID and timestamp @@ -99,24 +107,11 @@ elif [ "$CONTAINER_RUNTIME" = "docker" ]; then fi # Build run command with proper GPU access - # Get video and render group IDs from host - VIDEO_GID=$(getent group video | cut -d: -f3) - RENDER_GID=$(getent group render | cut -d: -f3) - RUN_CMD="docker run --rm --network=host --device=/dev/kfd --device=/dev/dri" RUN_CMD="$RUN_CMD --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" RUN_CMD="$RUN_CMD -v ${PWD}:/iris_workspace -w /iris_workspace" RUN_CMD="$RUN_CMD --shm-size=16G --ulimit memlock=-1 --ulimit stack=67108864" - RUN_CMD="$RUN_CMD --user $(id -u):$(id -g)" - - # Add video and render groups for GPU access - if [ -n "$VIDEO_GID" ]; then - RUN_CMD="$RUN_CMD --group-add $VIDEO_GID" - fi - if [ -n "$RENDER_GID" ]; then - RUN_CMD="$RUN_CMD --group-add $RENDER_GID" - fi - + RUN_CMD="$RUN_CMD -e HOME=/iris_workspace" RUN_CMD="$RUN_CMD --entrypoint bash" diff --git a/.github/scripts/examples_config.json b/.github/scripts/examples_config.json new file mode 100644 index 000000000..1fa9ce4c4 --- /dev/null +++ b/.github/scripts/examples_config.json @@ -0,0 +1,5 @@ +{ + "31_message_passing": { + "required_ranks": 2 + } +} diff --git a/.github/scripts/gpu_allocator.sh b/.github/scripts/gpu_allocator.sh new file mode 100755 index 000000000..77c2e8aff --- /dev/null +++ b/.github/scripts/gpu_allocator.sh @@ -0,0 +1,227 @@ +#!/bin/bash +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Lightweight GPU allocator for CI jobs +# Provides isolation and efficient utilization for variable GPU requests +# +# Design: +# - Uses flock for atomic state management +# - Maintains shared state file with 8-bit bitmap (one bit per GPU) +# - Supports variable GPU requests (1, 2, 4, 8 GPUs) +# - Non-contiguous allocation: any available GPUs can be used +# - Out-of-order release safe: each GPU tracked independently +# - Throughput-oriented: first-available scheduling (non-FIFO) +# - Automatic cleanup on job exit +# +# Usage: +# source gpu_allocator.sh +# acquire_gpus # Blocks until GPUs available, sets GPU_DEVICES and ALLOCATED_GPU_BITMAP +# enable_gpu_cleanup_trap # Optional: enable automatic cleanup on EXIT +# # ... run your job with HIP_VISIBLE_DEVICES=$GPU_DEVICES ... +# release_gpus # Releases allocated GPUs back to pool + +# Note: Do not modify caller's shell options (e.g., set -e) when sourced. + +# Configuration +GPU_STATE_FILE="${GPU_STATE_FILE:-/tmp/iris_gpu_state}" +GPU_LOCK_FILE="${GPU_STATE_FILE}.lock" +MAX_GPUS="${MAX_GPUS:-8}" +RETRY_DELAY="${RETRY_DELAY:-60}" # 1 minute between checks +MAX_RETRIES="${MAX_RETRIES:-180}" # 3 hours total wait time (180 * 1 min) + +# Initialize GPU state file and validate its contents +# State format: 8-bit bitmap where bit N=1 means GPU N is allocated +init_gpu_state() { + # Use flock to ensure atomic initialization and validation + ( + flock -x 200 + + if [ ! -f "$GPU_STATE_FILE" ]; then + # Initialize with all GPUs free (bitmap = 0) + echo "0" > "$GPU_STATE_FILE" + echo "[GPU-ALLOC] Initialized GPU bitmap: 0 (all GPUs free)" >&2 + else + # Validate existing state file contents + local current_state + current_state=$(cat "$GPU_STATE_FILE" 2>/dev/null || echo "") + + # Ensure the state is a non-negative integer + if ! [[ "$current_state" =~ ^[0-9]+$ ]]; then + echo "0" > "$GPU_STATE_FILE" + echo "[GPU-ALLOC] Detected invalid GPU bitmap ('$current_state'); reset to 0" >&2 + # Ensure the bitmap is within valid range (0-255 for 8 GPUs) + elif [ "$current_state" -lt 0 ] || [ "$current_state" -gt 255 ]; then + echo "0" > "$GPU_STATE_FILE" + echo "[GPU-ALLOC] Detected out-of-range GPU bitmap ($current_state); reset to 0" >&2 + fi + fi + ) 200>"$GPU_LOCK_FILE" +} + +# Acquire N GPUs from the pool using bitmap allocation +# Sets GPU_DEVICES environment variable with comma-separated GPU IDs +# Sets ALLOCATED_GPU_BITMAP for cleanup (bitmap of allocated GPUs) +# Blocks until requested GPUs are available +acquire_gpus() { + local num_gpus="$1" + + # Validate input is provided and is numeric + if [ -z "$num_gpus" ]; then + echo "[GPU-ALLOC ERROR] GPU count not specified" >&2 + return 1 + fi + + # Check if numeric + if ! [[ "$num_gpus" =~ ^[0-9]+$ ]]; then + echo "[GPU-ALLOC ERROR] GPU count must be a number: $num_gpus" >&2 + return 1 + fi + + # Validate range + if [ "$num_gpus" -lt 1 ] || [ "$num_gpus" -gt "$MAX_GPUS" ]; then + echo "[GPU-ALLOC ERROR] Invalid GPU count: $num_gpus (must be 1-$MAX_GPUS)" >&2 + return 1 + fi + + # Initialize state if needed + init_gpu_state + + local attempt=0 + + echo "[GPU-ALLOC] Configuration: MAX_GPUS=$MAX_GPUS, MAX_RETRIES=$MAX_RETRIES, RETRY_DELAY=$RETRY_DELAY" >&2 + echo "[GPU-ALLOC] Requesting $num_gpus GPU(s)..." >&2 + + while [ "$attempt" -lt "$MAX_RETRIES" ]; do + # Try to allocate GPUs atomically using bitmap + local allocated_gpus="" + local allocated_bitmap=0 + local result_file + local lock_exit_code + result_file=$(mktemp) + + ( + flock -x 200 + + # Read current bitmap + local bitmap + bitmap=$(cat "$GPU_STATE_FILE") + + # Find N free GPUs (bits that are 0) + local found_gpus=() + local gpu_id + for gpu_id in $(seq 0 $((MAX_GPUS - 1))); do + # Check if bit gpu_id is 0 (GPU is free) + if [ $(( (bitmap >> gpu_id) & 1 )) -eq 0 ]; then + found_gpus+=("$gpu_id") + if [ "${#found_gpus[@]}" -eq "$num_gpus" ]; then + break + fi + fi + done + + # Check if we found enough GPUs + if [ "${#found_gpus[@]}" -eq "$num_gpus" ]; then + # Mark these GPUs as allocated in the bitmap + local new_bitmap=$bitmap + local allocated_mask=0 + for gpu_id in "${found_gpus[@]}"; do + new_bitmap=$(( new_bitmap | (1 << gpu_id) )) + allocated_mask=$(( allocated_mask | (1 << gpu_id) )) + done + + # Update state file with new bitmap + echo "$new_bitmap" > "$GPU_STATE_FILE" + + # Write results to file while holding the lock + # Format: "gpu_ids|allocated_mask" + local gpu_list + gpu_list=$(IFS=,; echo "${found_gpus[*]}") + echo "${gpu_list}|${allocated_mask}" > "$result_file" + + echo "[GPU-ALLOC] Allocated GPUs: $gpu_list (bitmap: $new_bitmap)" >&2 + exit 0 + else + # Not enough GPUs available + local available_count="${#found_gpus[@]}" + echo "[GPU-ALLOC] Not enough GPUs: need $num_gpus, only $available_count available (bitmap: $bitmap)" >&2 + exit 1 + fi + ) 200>"$GPU_LOCK_FILE" && lock_exit_code=0 || lock_exit_code=$? + + if [ "$lock_exit_code" -eq 0 ]; then + # Read the allocated GPU IDs and mask from the result file + local result_line + result_line=$(cat "$result_file") + rm -f "$result_file" + + allocated_gpus="${result_line%|*}" + allocated_bitmap="${result_line#*|}" + + # Export variables + GPU_DEVICES="$allocated_gpus" + ALLOCATED_GPU_BITMAP="$allocated_bitmap" + export GPU_DEVICES ALLOCATED_GPU_BITMAP + + echo "[GPU-ALLOC] Set GPU_DEVICES=$GPU_DEVICES" >&2 + return 0 + else + rm -f "$result_file" + fi + + # Sleep before retry + attempt=$((attempt + 1)) + if [ "$attempt" -lt "$MAX_RETRIES" ]; then + echo "[GPU-ALLOC] Retrying... (attempt $((attempt + 1))/$MAX_RETRIES)" >&2 + sleep "$RETRY_DELAY" + fi + done + + # If we got here, allocation failed + echo "[GPU-ALLOC ERROR] Failed to allocate $num_gpus GPU(s) after $attempt attempts (MAX_RETRIES=$MAX_RETRIES)" >&2 + return 1 +} + +# Release allocated GPUs back to the pool using bitmap +# Uses ALLOCATED_GPU_BITMAP environment variable +release_gpus() { + if [ -z "$ALLOCATED_GPU_BITMAP" ]; then + echo "[GPU-ALLOC] No GPUs to release" >&2 + return 0 + fi + + echo "[GPU-ALLOC] Releasing GPUs (bitmap mask: $ALLOCATED_GPU_BITMAP)" >&2 + + # Save the bitmap to release before entering subshell + local bitmap_to_release="$ALLOCATED_GPU_BITMAP" + + # Unset immediately to prevent double-release + unset GPU_DEVICES ALLOCATED_GPU_BITMAP + + ( + flock -x 200 + + # Read current bitmap + local bitmap + bitmap=$(cat "$GPU_STATE_FILE") + + # Clear the bits for the GPUs we're releasing (bitwise AND with inverse of mask) + local new_bitmap + new_bitmap=$(( bitmap & ~bitmap_to_release )) + + # Update state file + echo "$new_bitmap" > "$GPU_STATE_FILE" + + echo "[GPU-ALLOC] Released GPUs. New bitmap: $new_bitmap" >&2 + ) 200>"$GPU_LOCK_FILE" +} + +# Clean up function to ensure GPUs are released +cleanup_gpus() { + if [ -n "$ALLOCATED_GPU_BITMAP" ]; then + echo "[GPU-ALLOC] Cleanup: releasing GPUs on exit" >&2 + release_gpus + fi +} + + diff --git a/.github/scripts/release_gpus.sh b/.github/scripts/release_gpus.sh new file mode 100755 index 000000000..aab61a309 --- /dev/null +++ b/.github/scripts/release_gpus.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Release GPUs for CI workflows - to be called as a workflow step with if: always() +# Usage: release_gpus.sh +# +# Reads GPU allocation details from environment variables set by acquire_gpus.sh + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Check if we have GPU allocation details +if [ -z "$GPU_DEVICES" ] && [ -z "$ALLOCATED_GPU_BITMAP" ]; then + echo "[RELEASE-GPUS] No GPU allocation found, nothing to release" + exit 0 +fi + +echo "[RELEASE-GPUS] Releasing GPUs" +echo "[RELEASE-GPUS] GPU allocation details:" +echo " GPU_DEVICES=$GPU_DEVICES" +echo " ALLOCATED_GPU_BITMAP=$ALLOCATED_GPU_BITMAP" + +source "$SCRIPT_DIR/gpu_allocator.sh" +release_gpus + +echo "[RELEASE-GPUS] GPUs released successfully" diff --git a/.github/scripts/run_new_examples.sh b/.github/scripts/run_new_examples.sh new file mode 100755 index 000000000..076e7de46 --- /dev/null +++ b/.github/scripts/run_new_examples.sh @@ -0,0 +1,64 @@ +#!/bin/bash +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Run new example scripts (numbered 24+) directly with torchrun. +# Usage: run_new_examples.sh [install_method] +# num_ranks: number of GPU ranks (2, 4, or 8) +# install_method: "git", "editable", or "install" (default: "editable") + +set -e + +NUM_RANKS=${1:-2} +INSTALL_METHOD=${2:-"editable"} + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# GPU_DEVICES should be set by the workflow-level acquire_gpus.sh step +GPU_ARG="" +if [ -n "$GPU_DEVICES" ]; then + GPU_ARG="--gpus $GPU_DEVICES" +fi + +# Build install command based on method +if [ "$INSTALL_METHOD" = "git" ]; then + REPO=${GITHUB_REPOSITORY:-"ROCm/iris"} + SHA=${GITHUB_SHA:-"HEAD"} + INSTALL_CMD="pip install git+https://github.com/${REPO}.git@${SHA}" +elif [ "$INSTALL_METHOD" = "editable" ]; then + INSTALL_CMD="pip install -e ." +elif [ "$INSTALL_METHOD" = "install" ]; then + INSTALL_CMD="pip install ." +else + echo "[ERROR] Invalid install_method: $INSTALL_METHOD" + exit 1 +fi + +EXIT_CODE=0 +# shellcheck disable=SC2086 +"$SCRIPT_DIR/container_exec.sh" $GPU_ARG " + set -e + + echo \"Installing iris using method: $INSTALL_METHOD\" + $INSTALL_CMD + + # Run new examples (numbered 24 and above) + for example_file in examples/2[4-9]_*/example.py examples/3[0-9]_*/example.py; do + if [ -f \"\$example_file\" ]; then + # Check examples_config.json in CI scripts dir for rank requirements + example_name=\$(basename \$(dirname \"\$example_file\")) + config_file=\".github/scripts/examples_config.json\" + if [ -f \"\$config_file\" ]; then + required_ranks=\$(python3 -c \"import json; d=json.load(open('\$config_file')); print(d.get('\$example_name', {}).get('required_ranks',''))\" 2>/dev/null) + if [ -n \"\$required_ranks\" ] && [ \"\$required_ranks\" != \"$NUM_RANKS\" ]; then + echo \"Skipping: \$example_file (requires \$required_ranks ranks, got $NUM_RANKS)\" + continue + fi + fi + echo \"Running: \$example_file with $NUM_RANKS ranks\" + torchrun --nproc_per_node=$NUM_RANKS --standalone \"\$example_file\" + fi + done +" || { EXIT_CODE=$?; } + +exit $EXIT_CODE diff --git a/.github/scripts/run_perf_benchmark.sh b/.github/scripts/run_perf_benchmark.sh index 695537751..85e580a86 100755 --- a/.github/scripts/run_perf_benchmark.sh +++ b/.github/scripts/run_perf_benchmark.sh @@ -20,30 +20,19 @@ fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +# Use GPU_DEVICES from environment if set, otherwise default to all 8 GPUs +GPU_DEVICES=${GPU_DEVICES:-"0,1,2,3,4,5,6,7"} +echo "[PERF-BENCHMARK] Using GPUs: $GPU_DEVICES" + # Run benchmark in container -"$SCRIPT_DIR/container_exec.sh" --gpus "0,1,2,3,4,5,6,7" " +"$SCRIPT_DIR/container_exec.sh" --gpus "$GPU_DEVICES" " set -e - # Install tritonBLAS (required dependency) - echo \"Installing tritonBLAS...\" - if [ ! -d \"/tmp/tritonBLAS\" ]; then - cd /tmp && git clone https://github.com/ROCm/tritonBLAS.git 2>&1 | tail -3 - fi - if [ -d \"/tmp/tritonBLAS\" ]; then - cd /tmp/tritonBLAS - git checkout 47768c93acb7f89511d797964b84544c30ab81ad 2>&1 | tail -2 - pip install -e . 2>&1 | tail -3 - else - echo \"Warning: Could not clone tritonBLAS, trying pip install from git...\" - pip install git+https://github.com/ROCm/tritonBLAS.git@47768c93acb7f89511d797964b84544c30ab81ad 2>&1 | tail -3 - fi - cd /iris_workspace pip install -e . - python examples/${EXAMPLE_PATH}/benchmark.py \ + torchrun --nproc_per_node=8 examples/${EXAMPLE_PATH}/benchmark.py \ --benchmark \ --validate \ - -r 8 \ ${BENCHMARK_ARGS} \ --output_file perf_result.json " diff --git a/.github/scripts/run_tests.sh b/.github/scripts/run_tests.sh index 4abf4a717..bbcba6585 100755 --- a/.github/scripts/run_tests.sh +++ b/.github/scripts/run_tests.sh @@ -2,11 +2,11 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. # -# Run Iris tests in a container +# Run Iris tests in a container with automatic GPU allocation # Usage: run_tests.sh [gpu_devices] [install_method] # test_dir: subdirectory under tests/ (e.g., examples, unittests, ccl) # num_ranks: number of GPU ranks (1, 2, 4, or 8) -# gpu_devices: comma-separated GPU device IDs (optional) +# gpu_devices: comma-separated GPU device IDs (optional, if not provided will use allocator) # install_method: pip install method - "git", "editable", or "install" (optional, default: "editable") # - "git": pip install git+https://github.com/${{ github.repository }}.git@${{ github.sha }} # - "editable": pip install -e . @@ -16,7 +16,7 @@ set -e TEST_DIR=$1 NUM_RANKS=$2 -GPU_DEVICES=${3:-""} +GPU_DEVICES=${3:-${GPU_DEVICES:-""}} INSTALL_METHOD=${4:-"editable"} if [ -z "$TEST_DIR" ] || [ -z "$NUM_RANKS" ]; then @@ -43,7 +43,14 @@ fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -# Build GPU argument if provided +# GPU_DEVICES should be provided by workflow-level acquire_gpus.sh step +# or via command-line argument for backward compatibility +if [ -z "$GPU_DEVICES" ]; then + echo "[RUN-TESTS] WARNING: No GPUs allocated. GPU_DEVICES not set." + echo "[RUN-TESTS] Tests may fail if they require GPUs." +fi + +# Build GPU argument GPU_ARG="" if [ -n "$GPU_DEVICES" ]; then GPU_ARG="--gpus $GPU_DEVICES" @@ -63,32 +70,11 @@ elif [ "$INSTALL_METHOD" = "install" ]; then fi # Run tests in container +EXIT_CODE=0 +# shellcheck disable=SC2086 "$SCRIPT_DIR/container_exec.sh" $GPU_ARG " set -e - # Install tritonBLAS if not already installed (required for iris/ops) - echo \"Checking for tritonBLAS...\" - if ! python -c 'import tritonblas' 2>/dev/null; then - echo \"Installing tritonBLAS...\" - # Use workspace directory for tritonBLAS since /opt may not be writable - TRITONBLAS_DIR=\"./tritonblas_install\" - if [ ! -d \"\$TRITONBLAS_DIR\" ]; then - git clone https://github.com/ROCm/tritonBLAS.git \"\$TRITONBLAS_DIR\" - cd \"\$TRITONBLAS_DIR\" - git checkout 47768c93acb7f89511d797964b84544c30ab81ad - else - cd \"\$TRITONBLAS_DIR\" - git fetch - git checkout 47768c93acb7f89511d797964b84544c30ab81ad - fi - # Install with dependencies - pip install -e . - cd .. - echo \"tritonBLAS installed successfully\" - else - echo \"tritonBLAS already installed\" - fi - echo \"Installing iris using method: $INSTALL_METHOD\" $INSTALL_CMD @@ -96,7 +82,10 @@ fi for test_file in tests/$TEST_DIR/test_*.py; do if [ -f \"\$test_file\" ]; then echo \"Testing: \$test_file with $NUM_RANKS ranks (install: $INSTALL_METHOD)\" - python tests/run_tests_distributed.py --num_ranks $NUM_RANKS \"\$test_file\" -v --tb=short --durations=10 + torchrun --nproc_per_node=$NUM_RANKS --standalone tests/run_tests_distributed.py \"\$test_file\" -v --tb=short --durations=10 fi done -" \ No newline at end of file +" || { EXIT_CODE=$?; } + +# GPU cleanup is now handled by workflow-level release_gpus.sh step +exit $EXIT_CODE \ No newline at end of file diff --git a/.github/workflows/copilot-setup-steps.yml b/.github/workflows/copilot-setup-steps.yml new file mode 100644 index 000000000..82d383360 --- /dev/null +++ b/.github/workflows/copilot-setup-steps.yml @@ -0,0 +1,40 @@ +name: Copilot Setup Steps + +on: + workflow_dispatch: + issue_comment: + types: [created, edited] + +jobs: + copilot-setup-steps: + if: >- + github.event_name == 'workflow_dispatch' || + (github.event.issue.pull_request && contains(github.event.comment.body, '@copilot')) + runs-on: [self-hosted, copilot, apptainer, iris] + + permissions: + contents: read + pull-requests: read + + timeout-minutes: 600 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Create task venv for Copilot + run: | + python3 -m venv $GITHUB_WORKSPACE/.venv + source $GITHUB_WORKSPACE/.venv/bin/activate + python -m pip install --upgrade pip + + - name: Make venv default for subsequent steps + run: | + echo "$GITHUB_WORKSPACE/.venv/bin" >> $GITHUB_PATH + + - name: Verify ROCm and GPU visibility + run: | + echo "=== rocminfo ===" + rocminfo | head -50 || true + echo "=== rocm-smi ===" + rocm-smi || true diff --git a/.github/workflows/iris-external-validation-test.yml b/.github/workflows/iris-external-validation-test.yml index 4b1eb295a..1330e8d3c 100644 --- a/.github/workflows/iris-external-validation-test.yml +++ b/.github/workflows/iris-external-validation-test.yml @@ -15,8 +15,10 @@ env: DOCKER_IMAGE_NAME: ${{ vars.DOCKER_IMAGE_NAME || 'iris-dev-triton-aafec41' }} jobs: - build-container-image: - runs-on: [self-hosted, mi3xx] + external-validation-test: + name: External Validation Test + runs-on: [linux-mi325-8gpu-ossci-rad] + timeout-minutes: 180 steps: - name: Checkout repository @@ -35,21 +37,11 @@ jobs: - name: Build Iris container run: | - # Use the universal container build script bash .github/scripts/container_build.sh - external-validation-test: - name: External Validation Test - needs: build-container-image - runs-on: [self-hosted, mi3xx] - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Cleanup lingering ports before tests + - name: Acquire GPUs run: | - bash .github/scripts/cleanup_ports.sh + bash .github/scripts/acquire_gpus.sh 2 - name: Run External Validation Test run: | @@ -59,69 +51,65 @@ jobs: bash .github/scripts/container_exec.sh " set -e - # Install tritonBLAS (required dependency) - echo \"Installing tritonBLAS...\" - if [ ! -d \"/tmp/tritonBLAS\" ]; then - cd /tmp && git clone https://github.com/ROCm/tritonBLAS.git 2>&1 | tail -3 - fi - if [ -d \"/tmp/tritonBLAS\" ]; then - cd /tmp/tritonBLAS - git checkout 47768c93acb7f89511d797964b84544c30ab81ad 2>&1 | tail -2 - pip install -e . 2>&1 | tail -3 - else - echo \"Warning: Could not clone tritonBLAS, trying pip install from git...\" - pip install git+https://github.com/ROCm/tritonBLAS.git@47768c93acb7f89511d797964b84544c30ab81ad 2>&1 | tail -3 - fi - cd /iris_workspace pip install git+https://github.com/${{ github.repository }}.git@${{ github.sha }} - wget -O test_iris_distributed.py https://gist.githubusercontent.com/mawad-amd/6375dc078e39e256828f379e03310ec7/raw/a527c3192bee4615292769e340b1c73676f6945a/test_iris_distributed.py - python test_iris_distributed.py + wget -O test_iris_distributed.py https://gist.githubusercontent.com/mawad-amd/6375dc078e39e256828f379e03310ec7/raw/0827d023eaf8e9755b17cbe8ab06f2ce258e746a/test_iris_distributed.py + torchrun --nproc_per_node=2 test_iris_distributed.py " echo "::endgroup::" echo "✅ External validation test passed!" + - name: Release GPUs + if: always() + run: | + bash .github/scripts/release_gpus.sh + external-gluon-validation-test: name: External Gluon Validation Test - needs: build-container-image - runs-on: [self-hosted, mi3xx] + runs-on: [linux-mi325-8gpu-ossci-rad] steps: - name: Checkout repository uses: actions/checkout@v4 - - name: Cleanup lingering ports before tests + - name: Setup Apptainer (if not available) + run: | + if ! command -v apptainer &> /dev/null && ! command -v docker &> /dev/null; then + echo "Neither Apptainer nor Docker found, installing Apptainer..." + apt-get update && apt-get install -y software-properties-common + add-apt-repository -y ppa:apptainer/ppa + apt-get update && apt-get install -y apptainer + else + echo "Container runtime already available" + fi + + - name: Build Iris container + run: | + bash .github/scripts/container_build.sh + + - name: Acquire GPUs run: | - bash .github/scripts/cleanup_ports.sh + bash .github/scripts/acquire_gpus.sh 2 - name: Run External Gluon Validation Test run: | set -e echo "::group::Running external gluon validation test" - bash .github/scripts/container_exec.sh --gpus "0,1" " + bash .github/scripts/container_exec.sh --gpus "$GPU_DEVICES" " set -e - # Install tritonBLAS (required dependency) - echo \"Installing tritonBLAS...\" - if [ ! -d \"/tmp/tritonBLAS\" ]; then - cd /tmp && git clone https://github.com/ROCm/tritonBLAS.git 2>&1 | tail -3 - fi - if [ -d \"/tmp/tritonBLAS\" ]; then - cd /tmp/tritonBLAS - git checkout 47768c93acb7f89511d797964b84544c30ab81ad 2>&1 | tail -2 - pip install -e . 2>&1 | tail -3 - else - echo \"Warning: Could not clone tritonBLAS, trying pip install from git...\" - pip install git+https://github.com/ROCm/tritonBLAS.git@47768c93acb7f89511d797964b84544c30ab81ad 2>&1 | tail -3 - fi - cd /iris_workspace pip install git+https://github.com/${{ github.repository }}.git@${{ github.sha }} - wget -O test_iris_gluon_distributed.py https://gist.githubusercontent.com/mawad-amd/2666dde8ebe2755eb0c4f2108709fcd5/raw/aa567ef3185c37a80d25bc9724ae9589548261b4/test_iris_gluon_distributed.py - python test_iris_gluon_distributed.py + wget -O test_iris_gluon_distributed.py https://gist.githubusercontent.com/mawad-amd/2666dde8ebe2755eb0c4f2108709fcd5/raw/c5544943e2832c75252160bd9084600bf01a6b06/test_iris_gluon_distributed.py + torchrun --nproc_per_node=2 test_iris_gluon_distributed.py " echo "::endgroup::" echo "✅ External gluon validation test passed!" + + - name: Release GPUs + if: always() + run: | + bash .github/scripts/release_gpus.sh diff --git a/.github/workflows/iris-performance-regression-test.yml b/.github/workflows/iris-performance-regression-test.yml index f4853718e..ebde87df3 100644 --- a/.github/workflows/iris-performance-regression-test.yml +++ b/.github/workflows/iris-performance-regression-test.yml @@ -15,35 +15,10 @@ env: DOCKER_IMAGE_NAME: ${{ vars.DOCKER_IMAGE_NAME || 'iris-dev-triton-aafec41' }} jobs: - build-container-image: - runs-on: [self-hosted, mi3xx] - timeout-minutes: 20 - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Setup Apptainer (if not available) - run: | - if ! command -v apptainer &> /dev/null && ! command -v docker &> /dev/null; then - echo "Neither Apptainer nor Docker found, installing Apptainer..." - apt-get update && apt-get install -y software-properties-common - add-apt-repository -y ppa:apptainer/ppa - apt-get update && apt-get install -y apptainer - else - echo "Container runtime already available" - fi - - - name: Build Iris container - run: | - # Use the universal container build script - bash .github/scripts/container_build.sh - performance-test: name: ${{ matrix.example_name }} - needs: build-container-image - runs-on: [self-hosted, mi3xx] - timeout-minutes: 30 + runs-on: [linux-mi325-8gpu-ossci-rad] + timeout-minutes: 180 strategy: fail-fast: false matrix: @@ -74,9 +49,24 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - - name: Cleanup lingering ports before tests + - name: Setup Apptainer (if not available) + run: | + if ! command -v apptainer &> /dev/null && ! command -v docker &> /dev/null; then + echo "Neither Apptainer nor Docker found, installing Apptainer..." + apt-get update && apt-get install -y software-properties-common + add-apt-repository -y ppa:apptainer/ppa + apt-get update && apt-get install -y apptainer + else + echo "Container runtime already available" + fi + + - name: Build Iris container run: | - bash .github/scripts/cleanup_ports.sh + bash .github/scripts/container_build.sh + + - name: Acquire GPUs + run: | + bash .github/scripts/acquire_gpus.sh 8 - name: Run ${{ matrix.example_name }} Benchmark (8 ranks) run: | @@ -89,3 +79,8 @@ jobs: ${{ matrix.benchmark_args }} echo "::endgroup::" + - name: Release GPUs + if: always() + run: | + bash .github/scripts/release_gpus.sh + diff --git a/.github/workflows/iris-tests.yml b/.github/workflows/iris-tests.yml index fdfef7330..89aa16928 100644 --- a/.github/workflows/iris-tests.yml +++ b/.github/workflows/iris-tests.yml @@ -15,32 +15,10 @@ env: DOCKER_IMAGE_NAME: ${{ vars.DOCKER_IMAGE_NAME || 'iris-dev-triton-aafec41' }} jobs: - build-container-image: - runs-on: [self-hosted, mi3xx] - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Setup Apptainer (if not available) - run: | - if ! command -v apptainer &> /dev/null && ! command -v docker &> /dev/null; then - echo "Neither Apptainer nor Docker found, installing Apptainer..." - apt-get update && apt-get install -y software-properties-common - add-apt-repository -y ppa:apptainer/ppa - apt-get update && apt-get install -y apptainer - else - echo "Container runtime already available" - fi - - - name: Build Iris container - run: | - bash .github/scripts/container_build.sh - test-git: name: Test ${{ matrix.test_dir }} (${{ matrix.num_ranks }} ranks, git install) - needs: build-container-image - runs-on: [self-hosted, mi3xx] + runs-on: [linux-mi325-8gpu-ossci-rad] + timeout-minutes: 180 strategy: fail-fast: false matrix: @@ -48,72 +26,67 @@ jobs: # Test each subdirectory with each rank count using git install - test_dir: examples num_ranks: 1 - gpu_devices: "0,1" - test_dir: examples num_ranks: 2 - gpu_devices: "2,3" - test_dir: examples num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: examples num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" - test_dir: unittests num_ranks: 1 - gpu_devices: "0,1" - test_dir: unittests num_ranks: 2 - gpu_devices: "2,3" - test_dir: unittests num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: unittests num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" - test_dir: ccl num_ranks: 1 - gpu_devices: "0,1" - test_dir: ccl num_ranks: 2 - gpu_devices: "2,3" - test_dir: ccl num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: ccl num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" - test_dir: x num_ranks: 1 - gpu_devices: "0,1" - test_dir: x num_ranks: 2 - gpu_devices: "2,3" - test_dir: x num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: x num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" - test_dir: ops num_ranks: 1 - gpu_devices: "0,1" - test_dir: ops num_ranks: 2 - gpu_devices: "2,3" - test_dir: ops num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: ops num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" steps: - name: Checkout repository uses: actions/checkout@v4 - - name: Cleanup lingering ports before tests + - name: Setup Apptainer (if not available) run: | - bash .github/scripts/cleanup_ports.sh + if ! command -v apptainer &> /dev/null && ! command -v docker &> /dev/null; then + echo "Neither Apptainer nor Docker found, installing Apptainer..." + apt-get update && apt-get install -y software-properties-common + add-apt-repository -y ppa:apptainer/ppa + apt-get update && apt-get install -y apptainer + else + echo "Container runtime already available" + fi + + - name: Build Iris container + run: | + bash .github/scripts/container_build.sh + + - name: Acquire GPUs + run: | + bash .github/scripts/acquire_gpus.sh "${{ matrix.num_ranks }}" - name: Run ${{ matrix.test_dir }} tests with ${{ matrix.num_ranks }} ranks (git install) env: @@ -125,15 +98,21 @@ jobs: bash .github/scripts/run_tests.sh \ "${{ matrix.test_dir }}" \ "${{ matrix.num_ranks }}" \ - "${{ matrix.gpu_devices }}" \ + "" \ "git" echo "::endgroup::" echo "✅ ${{ matrix.test_dir }} tests with ${{ matrix.num_ranks }} ranks (git) passed!" + - name: Release GPUs + if: always() + run: | + bash .github/scripts/release_gpus.sh + test-editable: name: Test ${{ matrix.test_dir }} (${{ matrix.num_ranks }} ranks, editable install) - needs: [build-container-image, test-git] - runs-on: [self-hosted, mi3xx] + needs: [test-git] + runs-on: [linux-mi325-8gpu-ossci-rad] + timeout-minutes: 180 strategy: fail-fast: false matrix: @@ -141,72 +120,67 @@ jobs: # Test each subdirectory with each rank count using editable install - test_dir: examples num_ranks: 1 - gpu_devices: "0,1" - test_dir: examples num_ranks: 2 - gpu_devices: "2,3" - test_dir: examples num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: examples num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" - test_dir: unittests num_ranks: 1 - gpu_devices: "0,1" - test_dir: unittests num_ranks: 2 - gpu_devices: "2,3" - test_dir: unittests num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: unittests num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" - test_dir: ccl num_ranks: 1 - gpu_devices: "0,1" - test_dir: ccl num_ranks: 2 - gpu_devices: "2,3" - test_dir: ccl num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: ccl num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" - test_dir: x num_ranks: 1 - gpu_devices: "0,1" - test_dir: x num_ranks: 2 - gpu_devices: "2,3" - test_dir: x num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: x num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" - test_dir: ops num_ranks: 1 - gpu_devices: "0,1" - test_dir: ops num_ranks: 2 - gpu_devices: "2,3" - test_dir: ops num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: ops num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" steps: - name: Checkout repository uses: actions/checkout@v4 - - name: Cleanup lingering ports before tests + - name: Setup Apptainer (if not available) + run: | + if ! command -v apptainer &> /dev/null && ! command -v docker &> /dev/null; then + echo "Neither Apptainer nor Docker found, installing Apptainer..." + apt-get update && apt-get install -y software-properties-common + add-apt-repository -y ppa:apptainer/ppa + apt-get update && apt-get install -y apptainer + else + echo "Container runtime already available" + fi + + - name: Build Iris container + run: | + bash .github/scripts/container_build.sh + + - name: Acquire GPUs run: | - bash .github/scripts/cleanup_ports.sh + bash .github/scripts/acquire_gpus.sh "${{ matrix.num_ranks }}" - name: Run ${{ matrix.test_dir }} tests with ${{ matrix.num_ranks }} ranks (editable install) run: | @@ -215,15 +189,20 @@ jobs: bash .github/scripts/run_tests.sh \ "${{ matrix.test_dir }}" \ "${{ matrix.num_ranks }}" \ - "${{ matrix.gpu_devices }}" \ + "" \ "editable" echo "::endgroup::" echo "✅ ${{ matrix.test_dir }} tests with ${{ matrix.num_ranks }} ranks (editable) passed!" + - name: Release GPUs + if: always() + run: | + bash .github/scripts/release_gpus.sh + test-install: name: Test ${{ matrix.test_dir }} (${{ matrix.num_ranks }} ranks, pip install) - needs: [build-container-image, test-editable] - runs-on: [self-hosted, mi3xx] + needs: [test-editable] + runs-on: [linux-mi325-8gpu-ossci-rad] strategy: fail-fast: false matrix: @@ -231,72 +210,67 @@ jobs: # Test each subdirectory with each rank count using pip install - test_dir: examples num_ranks: 1 - gpu_devices: "0,1" - test_dir: examples num_ranks: 2 - gpu_devices: "2,3" - test_dir: examples num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: examples num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" - test_dir: unittests num_ranks: 1 - gpu_devices: "0,1" - test_dir: unittests num_ranks: 2 - gpu_devices: "2,3" - test_dir: unittests num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: unittests num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" - test_dir: ccl num_ranks: 1 - gpu_devices: "0,1" - test_dir: ccl num_ranks: 2 - gpu_devices: "2,3" - test_dir: ccl num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: ccl num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" - test_dir: x num_ranks: 1 - gpu_devices: "0,1" - test_dir: x num_ranks: 2 - gpu_devices: "2,3" - test_dir: x num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: x num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" - test_dir: ops num_ranks: 1 - gpu_devices: "0,1" - test_dir: ops num_ranks: 2 - gpu_devices: "2,3" - test_dir: ops num_ranks: 4 - gpu_devices: "4,5,6,7" - test_dir: ops num_ranks: 8 - gpu_devices: "0,1,2,3,4,5,6,7" steps: - name: Checkout repository uses: actions/checkout@v4 - - name: Cleanup lingering ports before tests + - name: Setup Apptainer (if not available) + run: | + if ! command -v apptainer &> /dev/null && ! command -v docker &> /dev/null; then + echo "Neither Apptainer nor Docker found, installing Apptainer..." + apt-get update && apt-get install -y software-properties-common + add-apt-repository -y ppa:apptainer/ppa + apt-get update && apt-get install -y apptainer + else + echo "Container runtime already available" + fi + + - name: Build Iris container run: | - bash .github/scripts/cleanup_ports.sh + bash .github/scripts/container_build.sh + + - name: Acquire GPUs + run: | + bash .github/scripts/acquire_gpus.sh "${{ matrix.num_ranks }}" - name: Run ${{ matrix.test_dir }} tests with ${{ matrix.num_ranks }} ranks (pip install) run: | @@ -305,8 +279,70 @@ jobs: bash .github/scripts/run_tests.sh \ "${{ matrix.test_dir }}" \ "${{ matrix.num_ranks }}" \ - "${{ matrix.gpu_devices }}" \ + "" \ "install" echo "::endgroup::" echo "✅ ${{ matrix.test_dir }} tests with ${{ matrix.num_ranks }} ranks (install) passed!" + - name: Release GPUs + if: always() + run: | + bash .github/scripts/release_gpus.sh + + test-new-examples: + name: New examples (${{ matrix.num_ranks }} ranks, ${{ matrix.install_method }}) + runs-on: [linux-mi325-8gpu-ossci-rad] + timeout-minutes: 180 + permissions: + contents: read + strategy: + fail-fast: false + matrix: + include: + - num_ranks: 2 + install_method: editable + - num_ranks: 4 + install_method: editable + - num_ranks: 8 + install_method: editable + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Apptainer (if not available) + run: | + if ! command -v apptainer &> /dev/null && ! command -v docker &> /dev/null; then + echo "Neither Apptainer nor Docker found, installing Apptainer..." + apt-get update && apt-get install -y software-properties-common + add-apt-repository -y ppa:apptainer/ppa + apt-get update && apt-get install -y apptainer + else + echo "Container runtime already available" + fi + + - name: Build Iris container + run: | + bash .github/scripts/container_build.sh + + - name: Acquire GPUs + run: | + bash .github/scripts/acquire_gpus.sh "${{ matrix.num_ranks }}" + + - name: Run new examples with ${{ matrix.num_ranks }} ranks (${{ matrix.install_method }}) + env: + GITHUB_REPOSITORY: ${{ github.repository }} + GITHUB_SHA: ${{ github.sha }} + run: | + set -e + echo "::group::Running new examples with ${{ matrix.num_ranks }} ranks (install: ${{ matrix.install_method }})" + bash .github/scripts/run_new_examples.sh \ + "${{ matrix.num_ranks }}" \ + "${{ matrix.install_method }}" + echo "::endgroup::" + echo "✅ New examples with ${{ matrix.num_ranks }} ranks passed!" + + - name: Release GPUs + if: always() + run: | + bash .github/scripts/release_gpus.sh diff --git a/.gitignore b/.gitignore index 57d842401..d8f9754f7 100644 --- a/.gitignore +++ b/.gitignore @@ -48,4 +48,13 @@ __pycache__/ *.pyzw *.pyzwz -!.devcontainer/devcontainer.json \ No newline at end of file +!.devcontainer/devcontainer.json +!.github/scripts/examples_config.json + + +resources/ +gpucore.* +logs/ +*.cap +hsakmt_counters.csv +core \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..31ee7095f --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,100 @@ +# AGENTS.md + +## Project Overview + +Iris is a Triton-based framework for Remote Memory Access (RMA) operations on AMD GPUs. It provides SHMEM-like APIs within Triton for Multi-GPU programming with: + +- Clean abstractions with a full symmetric heap implementation +- Pythonic PyTorch-like host APIs for tensor operations +- Triton-style device APIs for load, store, and atomic operations +- Minimal dependencies (Triton, PyTorch, HIP runtime) +- Comprehensive examples showing communication/computation overlap + +**Supported GPUs**: MI300X, MI350X & MI355X (other ROCm-compatible AMD GPUs may work) + +## Dev Environment Setup + +Install Iris in development mode: + +```bash +pip install -e ".[dev]" +``` + +### Accessing Triton Source Code + +> **Important**: Always read the Triton source code before attempting any Triton-related task. Do not guess at APIs, behavior, or error causes — read the source directly. The source code will show you working examples, explain error messages, and reveal workarounds. + +First check whether Triton is already installed: + +```bash +pip show triton +``` + +The `Location` field in the output shows where the package is installed. Browse the source at that path. If Triton is not found, clone it in shallow mode: + +```bash +git clone --depth 1 https://github.com/triton-lang/triton +``` + +## Code Style + +- Use `ruff` for linting and formatting (configured in `pyproject.toml`). +- Line length: 120 characters. +- Double quotes for strings. +- Run before every commit: + +```bash +ruff check . --fix +ruff format . +``` + +## Testing Instructions + +Tests require at least **2 AMD GPUs**. Use `torchrun` via the helper script: + +```bash +# Run all unit tests (2 ranks) +python tests/run_tests_distributed.py tests/unittests/ --num_ranks 2 -v + +# Run all example tests (2 ranks) +python tests/run_tests_distributed.py tests/examples/ --num_ranks 2 -v + +# Run a single test file +python tests/run_tests_distributed.py tests/unittests/test_load_triton.py --num_ranks 2 -v +``` + +> **Environment note**: The test runner sets `HSA_NO_SCRATCH_RECLAIM=1` automatically, which is required for RCCL on ROCm. + +## Repository Structure + +``` +iris/ +├── iris/ # Main Python package +│ ├── ops/ # RMA operation kernels (load, store, atomics) +│ ├── ccl/ # Collective communication primitives +│ ├── experimental/ # Gluon-based experimental APIs +│ └── allocators/ # Symmetric heap allocators +├── csrc/ # C++/HIP source code +├── examples/ # Ready-to-run algorithm examples +├── tests/ +│ ├── unittests/ # Per-operation unit tests +│ ├── examples/ # End-to-end example tests +│ └── run_tests_distributed.py # torchrun test launcher +├── docs/ # Sphinx documentation +├── docker/ # Docker build/run scripts +└── pyproject.toml # Project metadata and tool config +``` + +## PR Guidelines + +- Create a feature branch: `git checkout -b $USER/feature-name` +- Run linting and tests before opening a PR: + +```bash +ruff check . --fix && ruff format . +python tests/run_tests_distributed.py tests/unittests/ --num_ranks 2 -v +``` + +- Add or update tests for any code you change. +- Update documentation under `docs/` for user-visible behavior changes. +- Fill in the PR description with a clear summary of what changed and why. diff --git a/apptainer/iris.def b/apptainer/iris.def index a5f3c3088..7a1f39849 100644 --- a/apptainer/iris.def +++ b/apptainer/iris.def @@ -8,7 +8,6 @@ From: rocm/pytorch:rocm7.1_ubuntu24.04_py3.13_pytorch_release_2.9.1 /bin/bash -c " # Set environment variables export TRITON_PATH=/opt/triton - export TRITONBLAS_PATH=/opt/tritonBLAS export ROCM_PATH=/opt/rocm export LD_LIBRARY_PATH=\$ROCM_PATH/lib:\$LD_LIBRARY_PATH export PATH=\"\$ROCM_PATH/bin:\$PATH\" @@ -16,7 +15,7 @@ From: rocm/pytorch:rocm7.1_ubuntu24.04_py3.13_pytorch_release_2.9.1 # Install system packages apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install -y \ - git wget ninja-build cmake python3-pip python3-dev build-essential jq && \ + git wget ninja-build cmake python3-pip python3-dev build-essential jq libdwarf-dev && \ rm -rf /var/lib/apt/lists/* # Create groups if they don't exist @@ -34,20 +33,15 @@ From: rocm/pytorch:rocm7.1_ubuntu24.04_py3.13_pytorch_release_2.9.1 git checkout bcbcabdd0cff6539c7168299075992b2a23ff38e pip3 install -e . - # Clone and install tritonBLAS - cd /opt - git clone https://github.com/ROCm/tritonBLAS.git - cd tritonBLAS - git checkout 47768c93acb7f89511d797964b84544c30ab81ad - pip3 install -e . + # Make the venv writable by all + chmod -R 777 /opt/venv " %environment # Define environment variables export TRITON_PATH=/opt/triton - export TRITONBLAS_PATH=/opt/tritonBLAS export ROCM_PATH=/opt/rocm - export PYTHONPATH=$TRITON_PATH:$TRITONBLAS_PATH + export PYTHONPATH=$TRITON_PATH export LD_LIBRARY_PATH=$ROCM_PATH/lib:$LD_LIBRARY_PATH export PATH="$ROCM_PATH/bin:$PATH" export OMPI_MCA_mtl="^ofi" diff --git a/benchmark/examples/benchmark_moe.py b/benchmark/examples/benchmark_moe.py new file mode 100644 index 000000000..46ab4a3ee --- /dev/null +++ b/benchmark/examples/benchmark_moe.py @@ -0,0 +1,399 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark script for the expert-sharded MoE example. + +This script follows the style and sweep intent of Triton's bench_mlp.py +for GPT-OSS-like sizes, adapted to the current Iris MoE example: + - examples/31_expert_sharded_moe + +It benchmarks distributed MoE forward: + mixture_of_expt_epsharded(...) + +Optional: + - Validate against single-device reference (nosharded) output + - Benchmark single-device reference latency on rank 0 for comparison + +Run: + HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 benchmark/examples/benchmark_moe.py \ + --num_ranks 8 --benchmark --output_file moe_gpt_oss.json +""" + +import argparse +import functools +import json +import os +import sys +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import iris + + +def _load_example_modules(): + project_root = Path(__file__).resolve() + while not (project_root / "tests").is_dir() or not (project_root / "examples").is_dir(): + if project_root == project_root.parent: + raise FileNotFoundError("Could not find project root") + project_root = project_root.parent + + example_dir = project_root / "examples" / "31_expert_sharded_moe" + sys.path.insert(0, str(example_dir)) + + from expert_assignment import make_expt_assignment, make_expt_dict_uniform + from moe import MoeFusionConfig, mixture_of_expt_epsharded, mixture_of_expt_nosharded + + return ( + make_expt_assignment, + make_expt_dict_uniform, + MoeFusionConfig, + mixture_of_expt_epsharded, + mixture_of_expt_nosharded, + ) + + +( + make_expt_assignment, + make_expt_dict_uniform, + MoeFusionConfig, + mixture_of_expt_epsharded, + mixture_of_expt_nosharded, +) = _load_example_modules() + + +def gpt_oss_batch_per_expert_sweep() -> list[int]: + # Matches Triton bench_mlp.py: + # batch_ranges = [(2**(2+k), 2**(3+k), min(2**k, 32)) for k in range(8)] + # batch_sizes = list(chain(*[range(*r) for r in batch_ranges])) + out: list[int] = [] + for k in range(8): + start = 2 ** (2 + k) + end = 2 ** (3 + k) + step = min(2**k, 32) + out.extend(list(range(start, end, step))) + return out + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark expert-sharded MoE with GPT-OSS-style sweep", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--num_ranks", type=int, default=8, help="Number of ranks (GPUs)") + parser.add_argument("--init_port", type=int, default=29531, help="TCP port for torch.distributed init") + + # GPT-OSS-like defaults from Triton bench_mlp.py end-goal. + parser.add_argument("--d_model", type=int, default=5760, help="Model hidden dimension") + parser.add_argument("--n_expts_tot", type=int, default=128, help="Total experts") + parser.add_argument("--n_expts_act", type=int, default=4, help="Top-k experts per token") + + parser.add_argument( + "--datatype", + type=str, + default="bf16", + choices=["fp16", "bf16", "fp32"], + help="Activation/weight dtype", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size in bytes") + + parser.add_argument( + "--batch_per_expt", + type=int, + nargs="*", + default=None, + help="Optional explicit batch_per_expert values; default uses GPT-OSS sweep", + ) + + parser.add_argument("--benchmark", action="store_true", help="Run timing benchmark") + parser.add_argument( + "--validate", action="store_true", help="Validate distributed output vs single-device reference" + ) + parser.add_argument( + "--compare_single_gpu", + action="store_true", + help="Also benchmark single-device reference path on rank 0 for latency comparison", + ) + + parser.add_argument("--warmup", type=int, default=25, help="Warmup iterations for do_bench") + parser.add_argument("--repeat", type=int, default=100, help="Benchmark iterations for do_bench") + parser.add_argument("--breakdown", action="store_true", help="Print per-stage timing breakdown (rank 0)") + + parser.add_argument("--output_dir", type=str, default="benchmark/results/moe", help="Output directory") + parser.add_argument("--output_file", type=str, default="benchmark_moe.json", help="Output JSON filename") + parser.add_argument( + "--fusion_mode", + type=str, + default="unfused", + choices=[ + "unfused", + "fused_grouped_matmul_convert_ep_to_dp", + "fused_convert_dp_to_ep_grouped_matmul", + "fused_convert_dp_to_ep_grouped_matmul__grouped_matmul_convert_ep_to_dp", + ], + help="MoE fusion mode selector", + ) + return parser.parse_args() + + +def _dtype_from_str(s: str) -> torch.dtype: + return {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}[s] + + +def _make_heap_resetter(allocator, offset): + """Return a callable that resets the bump allocator to *offset*.""" + + def _reset(): + allocator.heap_offset = offset + + return _reset + + +def _run_dist_once( + x_dp_local, + l_dp_local, + w_ep_local, + b_ep_local, + expt_assignment, + n_expts_act, + shmem, + fusion_config, +): + return mixture_of_expt_epsharded( + x_dp_local, + l_dp_local, + w_ep_local, + b_ep_local, + expt_assignment, + n_expts_act, + shmem, + fusion_config=fusion_config, + ) + + +def _worker(rank: int, world_size: int, init_url: str, args): + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=rank, + device_id=torch.device(f"cuda:{rank}"), + ) + torch.cuda.set_device(rank) + shmem = iris.iris(args.heap_size) + + try: + ws = shmem.get_num_ranks() + device = torch.device(f"cuda:{rank}") + dtype = _dtype_from_str(args.datatype) + fusion_config = MoeFusionConfig.from_mode_name(args.fusion_mode) + + if args.n_expts_tot % ws != 0: + raise ValueError(f"n_expts_tot ({args.n_expts_tot}) must be divisible by world_size ({ws})") + + if args.batch_per_expt: + sweep = args.batch_per_expt + else: + sweep = gpt_oss_batch_per_expert_sweep() + + if rank == 0: + os.makedirs(args.output_dir, exist_ok=True) + + results: list[dict] = [] + sweep_heap_base = shmem.heap.allocator.heap_offset + + for bpe in sweep: + shmem.heap.allocator.heap_offset = sweep_heap_base + n_tokens = bpe * args.n_expts_tot // args.n_expts_act + if n_tokens % ws != 0: + if rank == 0: + print(f"Skipping bpe={bpe}: n_tokens={n_tokens} not divisible by world_size={ws}") + continue + + n_tokens_local = n_tokens // ws + + torch.manual_seed(0) + x_global = torch.randn(n_tokens, args.d_model, device=device, dtype=dtype) + l_global = torch.rand(n_tokens, args.n_expts_tot, device=device, dtype=torch.float32) + w_global = torch.randn(args.n_expts_tot, args.d_model, args.d_model, device=device, dtype=dtype) + b_global = torch.randn(args.n_expts_tot, args.d_model, device=device, dtype=torch.float32) + + dist.broadcast(x_global, src=0) + dist.broadcast(l_global, src=0) + dist.broadcast(w_global, src=0) + dist.broadcast(b_global, src=0) + + expt_dict = make_expt_dict_uniform(ws, args.n_expts_tot) + expt_assignment = make_expt_assignment(ws, args.n_expts_tot, expt_dict, device) + + first = rank * n_tokens_local + last = first + n_tokens_local + x_dp_local = x_global[first:last].contiguous() + l_dp_local = l_global[first:last].contiguous() + w_ep_local = w_global[expt_assignment.expt_boolmask[rank]].contiguous() + b_ep_local = b_global[expt_assignment.expt_boolmask[rank]].contiguous() + + run_dist = functools.partial( + _run_dist_once, + x_dp_local, + l_dp_local, + w_ep_local, + b_ep_local, + expt_assignment, + args.n_expts_act, + shmem, + fusion_config, + ) + + if args.validate or args.compare_single_gpu: + y_ref = mixture_of_expt_nosharded(x_global, l_global, w_global, b_global, args.n_expts_act) + + # Warmup one run for graph/kernels. + z_dp_local = run_dist() + y_tri = torch.empty((n_tokens, args.d_model), dtype=dtype, device=device) + dist.all_gather_into_tensor(y_tri, z_dp_local.contiguous()) + + if args.breakdown: + N_BREAKDOWN_ITERS = 10 + stage_ms = {} + for _ in range(N_BREAKDOWN_ITERS): + shmem.heap.allocator.heap_offset = sweep_heap_base + td = [] if rank == 0 else None + mixture_of_expt_epsharded( + x_dp_local, + l_dp_local, + w_ep_local, + b_ep_local, + expt_assignment, + args.n_expts_act, + shmem, + fusion_config=fusion_config, + timing_dict=td, + ) + if rank == 0: + for j in range(1, len(td)): + key = td[j][0] + ms = td[j - 1][1].elapsed_time(td[j][1]) + stage_ms.setdefault(key, []).append(ms) + if rank == 0: + total_avg = sum(sum(v) / len(v) for v in stage_ms.values()) + parts = [] + for k, v in stage_ms.items(): + avg = sum(v) / len(v) + pct = 100 * avg / total_avg if total_avg > 0 else 0 + parts.append("{}={:.2f}ms ({:.1f}%)".format(k, avg, pct)) + print(" [breakdown bpe={} total={:.2f}ms] ".format(bpe, total_avg) + " ".join(parts)) + + result = { + "world_size": ws, + "batch_per_expt": bpe, + "n_tokens": n_tokens, + "d_model": args.d_model, + "n_expts_tot": args.n_expts_tot, + "n_expts_act": args.n_expts_act, + "dtype": args.datatype, + "fusion_mode": fusion_config.mode_name(), + } + + if args.validate: + diff = (y_ref.float() - y_tri.float()).abs() + result["validate_max_diff"] = float(diff.max().item()) + result["validate_mean_diff"] = float(diff.mean().item()) + result["validate_pass"] = bool(torch.allclose(y_ref, y_tri, atol=1e-2, rtol=1e-2)) + + if args.benchmark: + heap_snapshot = shmem.heap.allocator.heap_offset + reset_heap = _make_heap_resetter(shmem.heap.allocator, heap_snapshot) + saved_refresh = shmem.heap.refresh_peer_access + shmem.heap.refresh_peer_access = lambda: None + dist_ms = iris.do_bench( + run_dist, + barrier_fn=shmem.barrier, + preamble_fn=reset_heap, + n_warmup=args.warmup, + n_repeat=args.repeat, + return_mode="mean", + ) + shmem.heap.refresh_peer_access = saved_refresh + reset_heap() + result["dist_ms"] = float(dist_ms) + + if args.compare_single_gpu: + if rank == 0: + + def run_ref(): + return mixture_of_expt_nosharded(x_global, l_global, w_global, b_global, args.n_expts_act) + + ref_ms = iris.do_bench( + run_ref, + barrier_fn=torch.cuda.synchronize, + n_warmup=args.warmup, + n_repeat=args.repeat, + return_mode="mean", + ) + result["single_gpu_ref_ms"] = float(ref_ms) + if args.benchmark and ref_ms > 0: + result["speedup_vs_single_gpu"] = float(ref_ms / dist_ms) + + # keep all ranks aligned before next config + shmem.barrier() + + if rank == 0: + print( + f"[bpe={bpe:4d}] n_tokens={n_tokens:6d}" + + (f" dist={result.get('dist_ms', 0.0):8.3f} ms" if args.benchmark else "") + + ( + f" ref={result.get('single_gpu_ref_ms', 0.0):8.3f} ms" + if args.compare_single_gpu and "single_gpu_ref_ms" in result + else "" + ) + + (f" max_diff={result.get('validate_max_diff', 0.0):.4f}" if args.validate else "") + ) + results.append(result) + + shmem.barrier() + + if rank == 0: + out_path = Path(args.output_dir) / args.output_file + payload = { + "sweep": "gpt_oss_batch_per_expt" if not args.batch_per_expt else "custom", + "results": results, + } + with open(out_path, "w") as f: + json.dump(payload, f, indent=2) + print(f"Saved benchmark results: {out_path}") + + finally: + try: + shmem.barrier() + except Exception: + pass + del shmem + import gc + + gc.collect() + dist.destroy_process_group() + + +def main(): + args = parse_args() + if not args.benchmark and not args.validate and not args.compare_single_gpu: + print("No mode selected. Use at least one of: --benchmark, --validate, --compare_single_gpu") + sys.exit(1) + + init_url = f"tcp://127.0.0.1:{args.init_port}" + mp.spawn( + fn=_worker, + args=(args.num_ranks, init_url, args), + nprocs=args.num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/docker/Dockerfile b/docker/Dockerfile index a0f97d1c5..c01e86a8e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -22,12 +22,14 @@ ENV OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \ # Install system packages RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install -y \ - git wget ninja-build cmake python3-pip python3-dev build-essential && \ + git wget ninja-build cmake python3-pip python3-dev build-essential libdwarf-dev && \ rm -rf /var/lib/apt/lists/* RUN groupadd -r video 2>/dev/null || true && \ groupadd -r render 2>/dev/null || true +# Make the venv writable by all +RUN chmod -R 777 /opt/venv # Install Python packages with pip RUN pip3 install --upgrade pip && \ pip3 install wheel jupyter @@ -39,13 +41,6 @@ RUN git checkout bcbcabdd0cff6539c7168299075992b2a23ff38e RUN pip3 install -e . ENV PYTHONPATH=$TRITON_PATH -# Install tritonBLAS -WORKDIR /opt -RUN git clone https://github.com/ROCm/tritonBLAS.git && \ - cd tritonBLAS && \ - git checkout 47768c93acb7f89511d797964b84544c30ab81ad && \ - pip3 install -e . - # Set up workspace WORKDIR /workspace @@ -60,4 +55,4 @@ RUN echo '#!/bin/bash' > /entrypoint.sh && \ chmod +x /entrypoint.sh # Set the entrypoint -ENTRYPOINT ["/bin/bash", "-c", "source /entrypoint.sh && exec bash"] \ No newline at end of file +ENTRYPOINT ["/bin/bash", "-c", "source /entrypoint.sh && exec bash"] diff --git a/docker/Dockerfile.ccl b/docker/Dockerfile.ccl index 4db2a3be9..8271c31dd 100644 --- a/docker/Dockerfile.ccl +++ b/docker/Dockerfile.ccl @@ -21,7 +21,7 @@ ENV OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \ # Install system packages RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install -y \ - git wget ninja-build cmake python3-pip python3-dev build-essential && \ + git wget ninja-build cmake python3-pip python3-dev build-essential libdwarf-dev && \ rm -rf /var/lib/apt/lists/* # Install Python packages with pip diff --git a/docs/conf.py b/docs/conf.py index d341b6264..c58e3dd66 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -95,6 +95,7 @@ "numpy", "iris._distributed_helpers", "iris.hip", + "tritonblas", ] # Custom mocks that preserve docstrings for Triton Gluon @@ -104,7 +105,10 @@ class PreserveDocstringMock: """Mock decorator that preserves docstrings and function attributes.""" - def __call__(self, func): + def __call__(self, func=None, **kwargs): + # Handle both @decorator and @decorator() usage patterns + if func is None or not callable(func): + return lambda f: f # Return the original function unchanged to preserve docstrings return func @@ -120,6 +124,7 @@ def __call__(self, func): class TritonMock: jit = PreserveDocstringMock() language = triton_language_mock + constexpr_function = PreserveDocstringMock() sys.modules["triton"] = TritonMock() @@ -128,6 +133,7 @@ class TritonMock: # Mock gluon with docstring-preserving jit class GluonMock: jit = PreserveDocstringMock() + constexpr_function = PreserveDocstringMock() sys.modules["triton.experimental"] = MagicMock() diff --git a/docs/index.md b/docs/index.md index 2130fdb16..58a4206ee 100644 --- a/docs/index.md +++ b/docs/index.md @@ -209,6 +209,7 @@ For other setup methods, see the [Installation Guide](getting-started/installati - **[Taxonomy](conceptual/taxonomy.md)**: Multi-GPU programming patterns ### 📖 **Reference** +- **[Talks and Papers](reference/talks-and-papers.md)**: Publications, conference talks, and videos - **[API Reference](reference/api-reference.md)**: Structured API documentation - **[Triton APIs](reference/triton/overview.md)**: Standard Iris APIs with Triton - **[Gluon APIs (Experimental)](reference/gluon/overview.md)**: Cleaner API with Gluon decorators diff --git a/docs/reference/api-reference.md b/docs/reference/api-reference.md index fa235f782..ab14487a5 100644 --- a/docs/reference/api-reference.md +++ b/docs/reference/api-reference.md @@ -5,6 +5,8 @@ Explore Iris APIs. The reference is broken down into focused sections to mirror - The `Iris` class itself (constructor and helper utilities) - Tensor-like creation methods on the `Iris` context - Triton device-side functions for remote memory ops and atomics +- Collective communication operations (CCL) +- Fused GEMM + CCL operations - Experimental Gluon APIs (using `@aggregate` and `@gluon.jit`) Use the links below to navigate: @@ -13,8 +15,11 @@ Use the links below to navigate: - [Iris Class](triton/class.md) - [Tensor Creation](triton/tensor-creation.md) - [Device Functions](triton/device-functions.md) + - [Collective Communication (CCL)](triton/ccl.md) + - [Fused GEMM + CCL Operations](triton/ops.md) - [Gluon (Experimental)](gluon/overview.md) - [Iris Class](gluon/class.md) - [Tensor Creation](gluon/tensor-creation.md) - [Device Functions](gluon/device-functions.md) + - [Collective Communication (CCL)](gluon/ccl.md) diff --git a/docs/reference/gluon/ccl.md b/docs/reference/gluon/ccl.md new file mode 100644 index 000000000..cb260c834 --- /dev/null +++ b/docs/reference/gluon/ccl.md @@ -0,0 +1,22 @@ +# Collective Communication Operations + +```{warning} +The Gluon API is **experimental** and may undergo breaking changes in future releases. +``` + +Collective communication operations accessible via the `ccl` attribute on the `IrisGluon` instance (e.g. `ctx.ccl.all_to_all(...)`). + +## all_to_all +```{eval-rst} +.. automethod:: iris.experimental.iris_gluon.IrisGluon.CCL.all_to_all +``` + +## all_gather +```{eval-rst} +.. automethod:: iris.experimental.iris_gluon.IrisGluon.CCL.all_gather +``` + +## reduce_scatter +```{eval-rst} +.. automethod:: iris.experimental.iris_gluon.IrisGluon.CCL.reduce_scatter +``` diff --git a/docs/reference/gluon/class.md b/docs/reference/gluon/class.md index bac809036..337f45aad 100644 --- a/docs/reference/gluon/class.md +++ b/docs/reference/gluon/class.md @@ -50,3 +50,5 @@ Broadcast data from a source rank to all ranks. This method automatically detect .. automethod:: iris.experimental.iris_gluon.IrisGluon.broadcast ``` + + diff --git a/docs/reference/gluon/overview.md b/docs/reference/gluon/overview.md index b7fb9295d..8736f04e6 100644 --- a/docs/reference/gluon/overview.md +++ b/docs/reference/gluon/overview.md @@ -52,6 +52,7 @@ Explore the API by section: - [Iris Class](class.md) - [Tensor Creation](tensor-creation.md) - [Device Functions](device-functions.md) +- [Collective Communication (CCL)](ccl.md) ## Complete Example: Producer-Consumer Pattern diff --git a/docs/reference/gluon/tensor-creation.md b/docs/reference/gluon/tensor-creation.md index d1be3a8b3..6d3d7825a 100644 --- a/docs/reference/gluon/tensor-creation.md +++ b/docs/reference/gluon/tensor-creation.md @@ -13,3 +13,9 @@ APIs on `IrisGluon` that create and initialize tensors on the Iris symmetric hea .. automethod:: iris.experimental.iris_gluon.IrisGluon.full ``` +## Symmetric Heap Utilities + +```{eval-rst} +.. automethod:: iris.experimental.iris_gluon.IrisGluon.is_symmetric +``` + diff --git a/docs/reference/talks-and-papers.md b/docs/reference/talks-and-papers.md new file mode 100644 index 000000000..ad707f4dc --- /dev/null +++ b/docs/reference/talks-and-papers.md @@ -0,0 +1,103 @@ + + +# Talks and Papers + +This page collects publications, conference talks, and videos related to Iris. + +## Papers + +### Iris: First-Class Multi-GPU Programming Experience in Triton + +> Muhammad Awad, Muhammad Osama, Brandon Potter — *arXiv, November 2025* + +Introduces the Iris framework and its SHMEM-like Remote Memory Access (RMA) APIs for multi-GPU programming inside Triton kernels, demonstrating programmability and competitive performance on AMD MI300X GPUs. + +- 📄 [arXiv:2511.12500](https://arxiv.org/abs/2511.12500) +- 🔖 DOI: [10.48550/arXiv.2511.12500](https://doi.org/10.48550/arXiv.2511.12500) + +**BibTeX** + +```bibtex +@misc{Awad:2025:IFM, + author = {Muhammad Awad and Muhammad Osama and Brandon Potter}, + title = {Iris: First-Class Multi-{GPU} Programming Experience in {Triton}}, + year = {2025}, + archivePrefix = {arXiv}, + eprint = {2511.12500}, + primaryClass = {cs.DC}, + doi = {10.48550/arXiv.2511.12500} +} +``` + +--- + +### Eliminating Multi-GPU Performance Taxes: A Systems Approach to Efficient Distributed LLMs + +> Octavian Alexandru Trifan, Karthik Sangaiah, Muhammad Awad, Muhammad Osama, Sumanth Gudaparthi, Alexandru Nicolau, Alexander Veidenbaum, Ganesh Dasika — *arXiv, November 2025* + +Presents a systems-level approach for reducing communication overhead in distributed large language model inference, leveraging Iris for fine-grained GPU-to-GPU data movement. + +- 📄 [arXiv:2511.02168](https://arxiv.org/abs/2511.02168) +- 🔖 DOI: [10.48550/arXiv.2511.02168](https://doi.org/10.48550/arXiv.2511.02168) + +**BibTeX** + +```bibtex +@misc{Trifan:2025:EMT, + author = {Octavian Alexandru Trifan and Karthik Sangaiah and Muhammad Awad and Muhammad Osama and Sumanth Gudaparthi and Alexandru Nicolau and Alexander Veidenbaum and Ganesh Dasika}, + title = {Eliminating Multi-{GPU} Performance Taxes: A Systems Approach to Efficient Distributed {LLMs}}, + year = {2025}, + archivePrefix = {arXiv}, + eprint = {2511.02168}, + primaryClass = {cs.DC}, + doi = {10.48550/arXiv.2511.02168} +} +``` + +--- + +## Software Citation + +If you use the Iris software directly, please also cite the software release: + +```bibtex +@software{Awad:2025:IFM:Software, + author = {Muhammad Awad and Muhammad Osama and Brandon Potter}, + title = {Iris: First-Class Multi-{GPU} Programming Experience in {Triton}}, + year = 2025, + month = oct, + doi = {10.5281/zenodo.17382307}, + url = {https://github.com/ROCm/iris} +} +``` + +--- + +## Talks and Videos + +### Iris at GPU Mode — September 2025 + +Iris was presented at the GPU Mode meetup, covering the design of the RMA API, the symmetric heap, and performance results on multi-GPU workloads. + +- 🎬 [Watch on YouTube](https://www.youtube.com/watch?v=i6Y2EelEC04) +- 📊 [Slides (PDF)](https://github.com/ROCm/iris/blob/main/docs/slides/Awad-Osama-Potter%20-%20Iris%20Multi-GPU%20Programming%20Made%20Easier%20(GPU%20Mode).pdf) + +--- + +### Iris All-Scatter Taxonomy — August 2025 + +A deep-dive video on the taxonomy of multi-GPU programming patterns, with a focus on All-Scatter and GEMM + communication overlap. + +- 🎬 [Watch on YouTube](https://youtu.be/fYMdPe9UpHE) +- 📖 [Taxonomy Documentation](../conceptual/taxonomy.md) + +--- + +### Iris Presented in Chinese — September 2025 + +Iris was presented in Chinese for participants of the AMD Distributed Inference Kernel Contest. + +- 🎬 [Watch on YouTube](https://youtu.be/wW14w1QNrY8) diff --git a/docs/reference/triton/ccl.md b/docs/reference/triton/ccl.md new file mode 100644 index 000000000..f888381c9 --- /dev/null +++ b/docs/reference/triton/ccl.md @@ -0,0 +1,28 @@ +# Collective Communication Operations + +Collective communication operations accessible via the `ccl` attribute on the `Iris` instance (e.g. `ctx.ccl.all_reduce(...)`). + +## all_to_all +```{eval-rst} +.. automethod:: iris.iris.Iris.CCL.all_to_all +``` + +## all_gather +```{eval-rst} +.. automethod:: iris.iris.Iris.CCL.all_gather +``` + +## all_reduce_preamble +```{eval-rst} +.. automethod:: iris.iris.Iris.CCL.all_reduce_preamble +``` + +## all_reduce +```{eval-rst} +.. automethod:: iris.iris.Iris.CCL.all_reduce +``` + +## reduce_scatter +```{eval-rst} +.. automethod:: iris.iris.Iris.CCL.reduce_scatter +``` diff --git a/docs/reference/triton/class.md b/docs/reference/triton/class.md index 84d2215f8..bdae5a5e4 100644 --- a/docs/reference/triton/class.md +++ b/docs/reference/triton/class.md @@ -12,6 +12,7 @@ Prefer using the convenience factory over calling the constructor directly: ```{eval-rst} .. automethod:: iris.iris.Iris.get_heap_bases +.. automethod:: iris.iris.Iris.get_device_context .. automethod:: iris.iris.Iris.barrier .. automethod:: iris.iris.Iris.get_device .. automethod:: iris.iris.Iris.get_cu_count diff --git a/docs/reference/triton/ops.md b/docs/reference/triton/ops.md new file mode 100644 index 000000000..93be2b31d --- /dev/null +++ b/docs/reference/triton/ops.md @@ -0,0 +1,23 @@ +# Fused GEMM + CCL Operations + +Fused matrix multiplication and collective communication operations accessible via the `ops` property on the `Iris` instance (e.g. `ctx.ops.matmul_all_reduce(...)`). + +## matmul_all_reduce +```{eval-rst} +.. automethod:: iris.ops.OpsNamespace.matmul_all_reduce +``` + +## all_gather_matmul +```{eval-rst} +.. automethod:: iris.ops.OpsNamespace.all_gather_matmul +``` + +## matmul_all_gather +```{eval-rst} +.. automethod:: iris.ops.OpsNamespace.matmul_all_gather +``` + +## matmul_reduce_scatter +```{eval-rst} +.. automethod:: iris.ops.OpsNamespace.matmul_reduce_scatter +``` diff --git a/docs/reference/triton/overview.md b/docs/reference/triton/overview.md index fb3728597..a166dc408 100644 --- a/docs/reference/triton/overview.md +++ b/docs/reference/triton/overview.md @@ -35,4 +35,6 @@ Explore the API by section: - [Iris Class](class.md) - [Tensor Creation](tensor-creation.md) - [Device Functions](device-functions.md) +- [Collective Communication (CCL)](ccl.md) +- [Fused GEMM + CCL Operations](ops.md) diff --git a/docs/reference/triton/tensor-creation.md b/docs/reference/triton/tensor-creation.md index 0fe3c3a52..40a999f7e 100644 --- a/docs/reference/triton/tensor-creation.md +++ b/docs/reference/triton/tensor-creation.md @@ -8,6 +8,7 @@ APIs on `Iris` that create and initialize tensors on the Iris symmetric heap. .. automethod:: iris.iris.Iris.ones .. automethod:: iris.iris.Iris.full .. automethod:: iris.iris.Iris.empty +.. automethod:: iris.iris.Iris.rand .. automethod:: iris.iris.Iris.randn .. automethod:: iris.iris.Iris.uniform .. automethod:: iris.iris.Iris.randint @@ -15,4 +16,9 @@ APIs on `Iris` that create and initialize tensors on the Iris symmetric heap. .. automethod:: iris.iris.Iris.arange ``` +## Symmetric Heap Utilities +```{eval-rst} +.. automethod:: iris.iris.Iris.as_symmetric +.. automethod:: iris.iris.Iris.is_symmetric +``` diff --git a/docs/sphinx/_toc.yml b/docs/sphinx/_toc.yml index a6a839994..758f8174b 100644 --- a/docs/sphinx/_toc.yml +++ b/docs/sphinx/_toc.yml @@ -12,6 +12,7 @@ subtrees: - file: conceptual/taxonomy.md - caption: Reference entries: + - file: reference/talks-and-papers.md - file: reference/api-reference.md entries: - file: reference/triton/overview.md @@ -19,6 +20,8 @@ subtrees: - file: reference/triton/class.md - file: reference/triton/tensor-creation.md - file: reference/triton/device-functions.md + - file: reference/triton/ccl.md + - file: reference/triton/ops.md - file: reference/gluon/overview.md entries: - file: reference/gluon/class.md diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index a6a839994..758f8174b 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -12,6 +12,7 @@ subtrees: - file: conceptual/taxonomy.md - caption: Reference entries: + - file: reference/talks-and-papers.md - file: reference/api-reference.md entries: - file: reference/triton/overview.md @@ -19,6 +20,8 @@ subtrees: - file: reference/triton/class.md - file: reference/triton/tensor-creation.md - file: reference/triton/device-functions.md + - file: reference/triton/ccl.md + - file: reference/triton/ops.md - file: reference/gluon/overview.md entries: - file: reference/gluon/class.md diff --git a/examples/06_message_passing/message_passing_device_context.py b/examples/06_message_passing/message_passing_device_context.py new file mode 100644 index 000000000..0f7ccb236 --- /dev/null +++ b/examples/06_message_passing/message_passing_device_context.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Message Passing with DeviceContext API + +This example demonstrates the DeviceContext API - an object-oriented interface +for Iris operations that follows the gluon pattern. + +""" + +import argparse +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +import triton.language as tl + +import iris +from iris import DeviceContext + + +@triton.jit +def device_context_producer_kernel( + context_tensor, # Encoded context from iris.get_device_context() + source_buffer, + target_buffer, + flag, + buffer_size, + rank: tl.constexpr, + world_size: tl.constexpr, + consumer_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + Producer kernel using DeviceContext API. + + Note how we don't need to pass heap_bases - it's encapsulated in DeviceContext. + """ + # Initialize device context from encoded tensor + ctx = DeviceContext.initialize(context_tensor, rank, world_size) + + pid = tl.program_id(0) + + # Compute start index of this block + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Guard for out-of-bounds accesses + mask = offsets < buffer_size + + # Load from local buffer (no translation needed, so we just use tl.load) + values = tl.load(source_buffer + offsets, mask=mask) + + # Store to remote buffer using DeviceContext (much cleaner API!) + ctx.store(target_buffer + offsets, values, to_rank=consumer_rank, mask=mask) + + # Signal completion with atomic CAS + ctx.atomic_cas(flag + pid, 0, 1, to_rank=consumer_rank, sem="release", scope="sys") + + +@triton.jit +def device_context_consumer_kernel( + context_tensor, + buffer, + flag, + buffer_size, + rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Consumer kernel using DeviceContext API.""" + # Initialize device context from encoded tensor + ctx = DeviceContext.initialize(context_tensor, rank, world_size) + + pid = tl.program_id(0) + + # Spin-wait on flag + while ctx.atomic_cas(flag + pid, 1, 1, to_rank=rank, sem="acquire", scope="sys") != 1: + pass + + # Process the data (just read and verify it exists) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < buffer_size + + # Load the received data + data = tl.load(buffer + offsets, mask=mask) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + # Initialize Iris + ctx = iris.iris(heap_size=args["heap_size"]) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + # Get device context tensor for kernels + context_tensor = ctx.get_device_context() + + # Allocate buffers + buffer_size = args["buffer_size"] + block_size = args["block_size"] + source_buffer = ctx.zeros(buffer_size, dtype=torch.float32) + target_buffer = ctx.zeros(buffer_size, dtype=torch.float32) + num_blocks = (buffer_size + block_size - 1) // block_size + flag = ctx.zeros(num_blocks, dtype=torch.int32) + + # Initialize source buffer with data + source_buffer.copy_(torch.arange(buffer_size, dtype=torch.float32)) + + # Determine producer/consumer + producer_rank = 0 + consumer_rank = 1 if world_size > 1 else 0 + + ctx.barrier() + + if rank == producer_rank: + ctx.info(f"Producer: Sending {buffer_size} elements to rank {consumer_rank}") + + # Launch producer kernel with DeviceContext + device_context_producer_kernel[(num_blocks,)]( + context_tensor, + source_buffer, + target_buffer, + flag, + buffer_size, + rank, + world_size, + consumer_rank, + block_size, + ) + + ctx.info("Producer: Data sent successfully using DeviceContext API") + + if rank == consumer_rank: + ctx.info(f"Consumer: Waiting for data from rank {producer_rank}") + + # Launch consumer kernel with DeviceContext + device_context_consumer_kernel[(num_blocks,)]( + context_tensor, + target_buffer, + flag, + buffer_size, + rank, + world_size, + block_size, + ) + + # Verify the data + expected = torch.arange(buffer_size, dtype=torch.float32, device=target_buffer.device) + if torch.allclose(target_buffer, expected): + ctx.info("Consumer: Data received and verified successfully using DeviceContext API!") + else: + ctx.error("Consumer: Data verification failed!") + + ctx.barrier() + dist.destroy_process_group() + + +def main(): + parser = argparse.ArgumentParser(description="DeviceContext Message Passing Example") + parser.add_argument("--buffer_size", type=int, default=1024, help="Buffer size") + parser.add_argument("--block_size", type=int, default=256, help="Block size") + parser.add_argument("--heap_size", type=int, default=1 << 30, help="Iris heap size (default: 1GB)") + parser.add_argument("--num_ranks", type=int, default=2, help="Number of ranks/processes") + args = vars(parser.parse_args()) + + world_size = args["num_ranks"] + init_url = "tcp://127.0.0.1:23456" + + print(f"Spawning {world_size} processes for DeviceContext example...") + mp.spawn(_worker, args=(world_size, init_url, args), nprocs=world_size, join=True) + print("DeviceContext example completed!") + + +if __name__ == "__main__": + main() diff --git a/examples/07_gemm_all_scatter/benchmark.py b/examples/07_gemm_all_scatter/benchmark.py index 24694f6af..994c10cad 100755 --- a/examples/07_gemm_all_scatter/benchmark.py +++ b/examples/07_gemm_all_scatter/benchmark.py @@ -3,6 +3,7 @@ # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. import argparse +import os import random import torch @@ -261,18 +262,27 @@ def run_experiment(): def main(): + print("Starting GEMM all_scatter benchmark...") args = parse_args() - # Use command line argument if provided, otherwise use num_ranks parameter - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) + # Check if running with torchrun (detected by environment variables) + if "RANK" in os.environ and "LOCAL_RANK" in os.environ: + # torchrun handles process spawning, so call _worker directly + print("Detected torchrun execution mode") + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + init_url = os.environ.get("MASTER_ADDR", "127.0.0.1") + ":" + os.environ.get("MASTER_PORT", "29500") + _worker(rank, world_size, f"tcp://{init_url}", args) + else: + # Use multiprocessing spawn for backward compatibility + num_ranks = args["num_ranks"] + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) if __name__ == "__main__": diff --git a/examples/07_gemm_all_scatter/gemm_all_scatter.py b/examples/07_gemm_all_scatter/gemm_all_scatter.py index b9cfcea91..937835d6f 100644 --- a/examples/07_gemm_all_scatter/gemm_all_scatter.py +++ b/examples/07_gemm_all_scatter/gemm_all_scatter.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from iris.device_utils import read_realtime import iris diff --git a/examples/08_gemm_all_reduce_atomics/gemm_all_reduce_atomics.py b/examples/08_gemm_all_reduce_atomics/gemm_all_reduce_atomics.py index a35356eef..b430abc60 100644 --- a/examples/08_gemm_all_reduce_atomics/gemm_all_reduce_atomics.py +++ b/examples/08_gemm_all_reduce_atomics/gemm_all_reduce_atomics.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from iris.device_utils import read_realtime import iris diff --git a/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py b/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py index aa8ead4a2..34ee3c9ef 100644 --- a/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py +++ b/examples/09_gemm_one_shot_all_reduce/gemm_one_shot_all_reduce.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from iris.device_utils import read_realtime import iris diff --git a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py index 7857e546e..4d9c28255 100644 --- a/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py +++ b/examples/10_gemm_all_scatter_wg_specialization/gemm_all_scatter_wg_specialization.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from iris.device_utils import read_realtime import iris diff --git a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py index 94b2c9af0..75819f3fa 100755 --- a/examples/11_gemm_all_scatter_producer_consumer/benchmark.py +++ b/examples/11_gemm_all_scatter_producer_consumer/benchmark.py @@ -2,13 +2,15 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +import argparse +import math +import os +import random + import torch import torch.distributed as dist import torch.multiprocessing as mp import triton -import random -import argparse -import math from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm @@ -304,17 +306,27 @@ def run_experiment(): def main(): + print("Starting GEMM all_scatter producer-consumer benchmark...") args = parse_args() - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) + # Check if running with torchrun (detected by environment variables) + if "RANK" in os.environ and "LOCAL_RANK" in os.environ: + # torchrun handles process spawning, so call _worker directly + print("Detected torchrun execution mode") + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + init_url = os.environ.get("MASTER_ADDR", "127.0.0.1") + ":" + os.environ.get("MASTER_PORT", "29500") + _worker(rank, world_size, f"tcp://{init_url}", args) + else: + # Use multiprocessing spawn for backward compatibility + num_ranks = args["num_ranks"] + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) if __name__ == "__main__": diff --git a/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py b/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py index 32d68a88b..f564137da 100644 --- a/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py +++ b/examples/11_gemm_all_scatter_producer_consumer/gemm_all_scatter_producer_consumer.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from iris.device_utils import read_realtime import iris diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py index 2c271bb74..4b5457d88 100755 --- a/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py @@ -2,13 +2,15 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +import argparse +import math +import os +import random + import torch import torch.distributed as dist import torch.multiprocessing as mp import triton -import random -import argparse -import math from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set from examples.common.validation import validate_gemm @@ -300,17 +302,27 @@ def run_experiment(): def main(): + print("Starting GEMM all_scatter bulk synchronous benchmark...") args = parse_args() - num_ranks = args["num_ranks"] - - init_url = "tcp://127.0.0.1:29500" - mp.spawn( - fn=_worker, - args=(num_ranks, init_url, args), - nprocs=num_ranks, - join=True, - ) + # Check if running with torchrun (detected by environment variables) + if "RANK" in os.environ and "LOCAL_RANK" in os.environ: + # torchrun handles process spawning, so call _worker directly + print("Detected torchrun execution mode") + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + init_url = os.environ.get("MASTER_ADDR", "127.0.0.1") + ":" + os.environ.get("MASTER_PORT", "29500") + _worker(rank, world_size, f"tcp://{init_url}", args) + else: + # Use multiprocessing spawn for backward compatibility + num_ranks = args["num_ranks"] + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) if __name__ == "__main__": diff --git a/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py b/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py index e63447a07..6bfddd3c5 100644 --- a/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py +++ b/examples/12_gemm_all_scatter_bulk_synchronous/gemm_all_scatter_bulk_synchronous.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from iris.device_utils import read_realtime import iris diff --git a/examples/15_gemm_all_reduce_ring_based/gemm_all_reduce_ring_based.py b/examples/15_gemm_all_reduce_ring_based/gemm_all_reduce_ring_based.py index d72dac188..255ab6a87 100644 --- a/examples/15_gemm_all_reduce_ring_based/gemm_all_reduce_ring_based.py +++ b/examples/15_gemm_all_reduce_ring_based/gemm_all_reduce_ring_based.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from iris.device_utils import read_realtime import iris diff --git a/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py b/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py index 924a12280..d4323e5bf 100644 --- a/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py +++ b/examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from iris.device_utils import read_realtime import iris diff --git a/examples/21_gemm_one_shot_all_reduce_independent/gemm_one_shot_all_reduce_independent.py b/examples/21_gemm_one_shot_all_reduce_independent/gemm_one_shot_all_reduce_independent.py index 7d0f4e751..2a41a8830 100644 --- a/examples/21_gemm_one_shot_all_reduce_independent/gemm_one_shot_all_reduce_independent.py +++ b/examples/21_gemm_one_shot_all_reduce_independent/gemm_one_shot_all_reduce_independent.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from iris.device_utils import read_realtime import iris diff --git a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py index 60312b6b8..1ef3e3d7c 100644 --- a/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py +++ b/examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from examples.common.utils import read_realtime +from iris.device_utils import read_realtime import iris diff --git a/examples/23_gemm_all_scatter_tracing/benchmark.py b/examples/23_gemm_all_scatter_tracing/benchmark.py new file mode 100644 index 000000000..dd4dd716d --- /dev/null +++ b/examples/23_gemm_all_scatter_tracing/benchmark.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import argparse +import random + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import triton +from matmul_wrapper import matmul + +import iris +from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set +from examples.common.validation import validate_gemm + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Parse matrix dimensions and configuration.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=8192, help="Number of rows in matrix A") + parser.add_argument("-n", type=int, default=4608, help="Number of columns in matrix B") + parser.add_argument("-k", type=int, default=36864, help="Common dimension between matrices A and B") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-t", "--trace_tiles", action="store_true", help="Enable tile-tracing mode") + parser.add_argument("--trace", action="store_true", help="Enable device tracing") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of computation", + ) + parser.add_argument( + "--output_file", + type=str, + default="log.json", + help="Output file", + ) + parser.add_argument("--BLK_M", type=int, default=256, help="Block size M") + parser.add_argument("--BLK_N", type=int, default=64, help="Block size N") + parser.add_argument("--BLK_K", type=int, default=64, help="Block size K") + parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter") + parser.add_argument("--num_stages", type=int, default=2, help="Number of stages") + parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size") + parser.add_argument( + "--gemm_sms", + type=int, + default=None, + help="Number of SMs for persistent GEMM algorithm (default: auto-detected)", + ) + parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes") + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + # Main benchmark logic + ctx = iris.iris(args["heap_size"]) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + # Enable device tracing if requested + if args["trace"]: + ctx.tracing.enable(max_events=1_000_000) + ctx.info("Device tracing enabled") + + # Get device context + context_tensor = ctx.get_device_context() + + # Set default SM values if not provided + if args["gemm_sms"] is None: + # For all_scatter: use total CU count + cu_count = torch.cuda.get_device_properties(rank).multi_processor_count + args["gemm_sms"] = cu_count + + # GEMM + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + assert args["n"] % world_size == 0, f"N ({args['n']}) must be divisible by world size ({world_size})." + assert args["k"] % world_size == 0, f"K ({args['k']}) must be divisible by world size ({world_size})." + + A = ctx.randn(args["m"], args["k"], device="cuda", dtype=datatype) + B = ctx.randn(args["n"], args["k"], device="cuda", dtype=datatype).T + + args["M"] = args["m"] + args["N"] = args["n"] + args["K"] = args["k"] + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + json_writer.add_field("enable_tracing", args["trace"]) + + # Splitting + args["n"] = args["n"] // world_size + local_B = B[:, rank * args["n"] : (rank + 1) * args["n"]].clone() + local_A = A + + for key, value in args.items(): + json_writer.add_field(key, value) + + global_C = ctx.zeros((args["M"], args["N"]), device="cuda", dtype=A.dtype) + local_C = ctx.zeros((args["m"], args["n"]), device="cuda", dtype=A.dtype) + + total_blocks_M = triton.cdiv(args["m"], args["BLK_M"]) + total_blocks_N = triton.cdiv(args["n"], args["BLK_N"]) + total_tiles = total_blocks_M * total_blocks_N + + bias = None + + gemm_stream = torch.cuda.Stream() + + json_writer.add_field("gemm_sms", args["gemm_sms"]) + + kernel_timing = { + "gemm": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + # Allocate Timestamps + timestamps = Timestamps(num_tiles=total_tiles) + + def run_experiment(enable_tracing=False): + nonlocal local_C + nonlocal global_C + nonlocal kernel_timing + + ctx.barrier() + + if args["trace_tiles"]: + timestamps.reset() + ctx.barrier() + + torch.cuda.nvtx.range_push("GEMM + Communication") + torch.cuda.nvtx.range_push("GEMM") + with torch.cuda.stream(gemm_stream): + kernel_timing["gemm"]["start_event"].record() + local_C = matmul.apply( + local_A, + local_B, + local_C, + global_C, + bias, + rank, + world_size, + args["gemm_sms"], + args["BLK_M"], + args["BLK_N"], + args["BLK_K"], + args["gsize_m"], + args["num_stages"], + context_tensor, + "gfx942", + enable_tracing, + args["trace_tiles"], + timestamps.mm_begin_timestamp, + timestamps.mm_end_timestamp, + ) + kernel_timing["gemm"]["end_event"].record() + kernel_timing["gemm"]["experiments"] += 1 + + torch.cuda.nvtx.range_pop() + ctx.barrier() + + for k in ["gemm"]: + ms = kernel_timing[k]["start_event"].elapsed_time(kernel_timing[k]["end_event"]) + kernel_timing[k]["ms"] += ms + + torch.cuda.nvtx.range_pop() + + # Synchronize across all GPUs + ctx.barrier() + + # Warmup + num_warmup_iters = 10 + for i in range(num_warmup_iters): + run_experiment(enable_tracing=False) + ctx.barrier() + + # If tracing enabled, reset and run one clean iteration + if args["trace"]: + ctx.tracing.reset() + ctx.barrier() + run_experiment(enable_tracing=True) + ctx.barrier() + ctx.info(f"Captured clean trace after {num_warmup_iters} warmup iterations") + + for k in ["gemm"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + if args["validate"]: + ctx.info("Validating...") + matmul.set_debug(True) + # Validate global result + success = validate_gemm(A, B, global_C, ctx) + passed_str = "passed" if success else "failed" + ctx.info(f"Final C validation {passed_str}.") + + # Wait for all to finish validation + ctx.barrier() + ctx.info("Validating local C...") + + json_writer.add_field("success", success) + + if not is_triton_interpret_set(): + gemm_registers = matmul.get_matmul_registers() + gemm_spills = matmul.get_matmul_spills() + + json_writer.add_field("gemm_registers", gemm_registers) + json_writer.add_field("gemm_spills", gemm_spills) + + ctx.info("Validation completed") + + if args["benchmark"]: + matmul.set_debug(False) + ctx.info("Benchmarking...") + perf = lambda ms: 2 * args["M"] * args["N"] * args["K"] * 1e-12 / (ms * 1e-3) + triton_ms = iris.do_bench(lambda: run_experiment(enable_tracing=args["trace"]), ctx.barrier) + triton_tflops = perf(triton_ms) + algo_string = "all_scatter" + tracing_str = " (with tracing)" if args["trace"] else "" + ctx.info( + f"tile matmul + {algo_string}{tracing_str} (total_tiles={total_tiles}): {triton_ms:.3f} ms {triton_tflops:.3f} tflops" + ) + + json_writer.add_field("tflops", triton_tflops) + json_writer.add_field("total_ms", triton_ms) + + for k in ["gemm"]: + json_writer.add_field(k + "_ms", kernel_timing[k]["ms"] / kernel_timing[k]["experiments"]) + json_writer.add_field(k + "_experiments", kernel_timing[k]["experiments"]) + + # Wait for all to finish benchmarking + ctx.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + if args["trace_tiles"] and rank == 0: + gpu_freq = iris.hip.get_wall_clock_rate(rank) * 1e-3 + algo_string = "all_scatter" + filename = f"gemm_tiles_{algo_string}_trace_rank{rank}.json" + timestamps.to_json(filename, gpu_freq) + + # Export device traces if enabled + if args["trace"]: + ctx.barrier() # Ensure all kernels finished + # Export per-rank and merged trace + ctx.tracing.export("device_trace.json", merge=True) + + ctx.barrier() + + dist.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + + # Use command line argument if provided, otherwise use num_ranks parameter + num_ranks = args["num_ranks"] + + init_url = "tcp://127.0.0.1:29500" + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/23_gemm_all_scatter_tracing/gemm_all_scatter.py b/examples/23_gemm_all_scatter_tracing/gemm_all_scatter.py new file mode 100644 index 000000000..89be23c5f --- /dev/null +++ b/examples/23_gemm_all_scatter_tracing/gemm_all_scatter.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl + +from iris import DeviceContext, TraceEvent +from iris.device_utils import read_realtime + + +@triton.jit() +def persistent_gemm_all_scatter( + A, + B, + C, + c_global, + bias_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_cm_global, + stride_cn_global, + stride_bias, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + BIAS: tl.constexpr, + EVEN_K: tl.constexpr, + context_tensor: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + TRACING: tl.constexpr = False, + COLLECT_TIMESTAMPS: tl.constexpr = False, + mm_begin_timestamp_ptr: tl.tensor = None, + mm_end_timestamp_ptr: tl.tensor = None, +): + # Initialize DeviceContext with tracing + ctx = DeviceContext.initialize(context_tensor, cur_rank, world_size, tracing=TRACING) + + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + + for tile_id in range(pid, total_tiles, NUM_SMS): + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + rk = tl.arange(0, BLOCK_SIZE_K) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + loop_k = tl.cdiv(K, BLOCK_SIZE_K) + if not EVEN_K: + loop_k -= 1 + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, loop_k): + a = tl.load(tl.multiple_of(A_BASE, (1, 16))) + b = tl.load(tl.multiple_of(B_BASE, (16, 1))) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + if not EVEN_K: + k = loop_k + rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + A_BASE = tl.multiple_of(A_BASE, (1, 16)) + B_BASE = tl.multiple_of(B_BASE, (16, 1)) + a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0) + b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0) + acc += tl.dot(a, b) + + # Accumulator registers with C results + c = acc.to(C.type.element_ty) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + + # Add compiler hints + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + # Define the C-mask (BLOCK_SIZE_M, 1) x (1, BLOCK_SIZE_N) + sub_mask = (rm[:, None] < M) & (rn[None, :] < N) + + # Calculate the "global" offset of C based on the rank. + # Note how the N-dimension is being multiplied by current rank. + # This is because each rank is computing a portion of the N-dimension + # locally and then scattering it to all other ranks to complete + # the global N-dimension. + global_offset = rm[:, None] * stride_cm_global + (rn[None, :] + cur_rank * N) * stride_cn_global + + # Timestamp for GEMM before store + if COLLECT_TIMESTAMPS: + timestamp = read_realtime() + tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp) + + # Store local result first (needed for put operations) + C_ptr = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.store(C_ptr, c, mask=sub_mask) + + # Store data to the global result using DeviceContext + for remote_rank in range(world_size): + if remote_rank == cur_rank: + # For the current rank, we can use store + tl.store(c_global + global_offset, c, mask=sub_mask) + else: + # Record duration event around remote store (compiles away if tracing=False) + # Pass 2D pointer tensor; record_event_start takes min as representative address + # op_index is automatically tracked internally (0, 1, 2, ...) + # payload_size is automatically calculated from mask + handle = ctx.tracing.record_event_start( + event_id=TraceEvent().put, + target_rank=remote_rank, + address=c_global + global_offset, + pid_m=pid_m, + pid_n=pid_n, + mask=sub_mask, + ) + + # Use DeviceContext.put for remote stores + # Put from local C to remote c_global + ctx.put(C_ptr, c_global + global_offset, to_rank=remote_rank, mask=sub_mask) + + # End duration event + ctx.tracing.record_event_end(handle) diff --git a/examples/23_gemm_all_scatter_tracing/matmul_wrapper.py b/examples/23_gemm_all_scatter_tracing/matmul_wrapper.py new file mode 100644 index 000000000..2d5587499 --- /dev/null +++ b/examples/23_gemm_all_scatter_tracing/matmul_wrapper.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch + +from gemm_all_scatter import persistent_gemm_all_scatter +from examples.common.utils import is_triton_interpret_set +import iris + +gemm_kernel = persistent_gemm_all_scatter + + +class matmul(torch.autograd.Function): + _debug = False + _registers = None + _spills = None + _asm = None + + _num_xcds = iris.hip.get_num_xcc() + + @staticmethod + def set_debug(debug: bool): + matmul._debug = debug + + @staticmethod + def get_matmul_registers(): + if matmul._debug: + return matmul._registers + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @staticmethod + def get_matmul_spills(): + if matmul._debug: + return matmul._spills + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @staticmethod + def get_matmul_asm(): + if matmul._debug: + return matmul._asm + else: + raise RuntimeError("Debug mode is not enabled. Call set_debug(True) first.") + + @staticmethod + def _call( + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + c_global: torch.Tensor, + bias: torch.Tensor, + rank: int, + world_size: int, + num_sms: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, + num_stages: int, + context_tensor: torch.Tensor = None, + arch: str = "gfx942", + TRACING: bool = False, + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + num_xcds = matmul._num_xcds + + # TODO: Use arch-specific values. + num_warps = 8 + waves_per_eu = 0 + mfma = 16 + kpack = 1 + + even_k = K % BLK_K == 0 + use_bias = False + + # compute grid (work to do per SM on the first wave) + stride_bias = bias.stride(0) if use_bias else 0 + kk = gemm_kernel[(num_sms,)]( + a, + b, + c, + c_global, + bias, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + c_global.stride(0), + c_global.stride(1), + stride_bias, + BLOCK_SIZE_M=BLK_M, + BLOCK_SIZE_N=BLK_N, + BLOCK_SIZE_K=BLK_K, + GROUP_SIZE_M=gsize_m, + NUM_SMS=num_sms, + NUM_XCDS=num_xcds, + BIAS=use_bias, + EVEN_K=even_k, + num_stages=num_stages, + num_warps=num_warps, + waves_per_eu=waves_per_eu, + matrix_instr_nonkdim=mfma, + kpack=kpack, + context_tensor=context_tensor, + cur_rank=rank, + world_size=world_size, + TRACING=TRACING, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp_ptr=mm_begin_timestamp, + mm_end_timestamp_ptr=mm_end_timestamp, + ) + + if matmul._debug and not is_triton_interpret_set(): + matmul._registers = kk.n_regs + matmul._spills = kk.n_spills + matmul._asm = kk.asm["amdgcn"] + + return c + + @staticmethod + def forward( + ctx, + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + c_global: torch.Tensor, + bias: torch.Tensor, + rank: int, + world_size: int, + num_sms: int, + BLK_M: int, + BLK_N: int, + BLK_K: int, + gsize_m: int, + num_stages: int, + context_tensor: torch.Tensor = None, + arch: str = "gfx942", + TRACING: bool = False, + COLLECT_TIMESTAMPS: bool = False, + mm_begin_timestamp: torch.Tensor = None, + mm_end_timestamp: torch.Tensor = None, + ): + matmul._call( + a=a, + b=b, + c=c, + c_global=c_global, + bias=bias, + rank=rank, + world_size=world_size, + num_sms=num_sms, + BLK_M=BLK_M, + BLK_N=BLK_N, + BLK_K=BLK_K, + gsize_m=gsize_m, + num_stages=num_stages, + context_tensor=context_tensor, + arch=arch, + TRACING=TRACING, + COLLECT_TIMESTAMPS=COLLECT_TIMESTAMPS, + mm_begin_timestamp=mm_begin_timestamp, + mm_end_timestamp=mm_end_timestamp, + ) + return c diff --git a/examples/24_ccl_all_reduce/example.py b/examples/24_ccl_all_reduce/example.py new file mode 100644 index 000000000..b0ba322ef --- /dev/null +++ b/examples/24_ccl_all_reduce/example.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Example: iris.ccl.all_reduce + +Each rank contributes its local tensor; the result on every rank is the element-wise sum. + +Run with: + torchrun --nproc_per_node= --standalone example.py [--validate] +""" + +import argparse +import os + +import torch +import torch.distributed as dist + +import iris + + +def parse_args(): + parser = argparse.ArgumentParser( + description="CCL all-reduce example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=1024, help="Number of rows") + parser.add_argument("-n", type=int, default=512, help="Number of columns") + parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") + parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") + parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference") + return vars(parser.parse_args()) + + +def main(): + args = parse_args() + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="gloo") + + ctx = iris.iris(heap_size=args["heap_size"]) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + dtype = dtype_map[args["datatype"]] + M, N = args["m"], args["n"] + + # Each rank fills its input with (rank + 1) + input_tensor = ctx.zeros((M, N), dtype=dtype) + input_tensor.fill_(float(rank + 1)) + output_tensor = ctx.zeros((M, N), dtype=dtype) + + ctx.barrier() + ctx.ccl.all_reduce(output_tensor, input_tensor) + torch.cuda.synchronize() + + if rank == 0: + ctx.info(f"all_reduce: world_size={world_size}, shape=({M},{N}), dtype={dtype}") + + if args["validate"]: + # Expected: sum of (r+1) for r in 0..world_size-1 + expected = float(world_size * (world_size + 1) // 2) + assert torch.allclose(output_tensor, torch.full_like(output_tensor, expected), atol=0.5), ( + f"Rank {rank}: mismatch. Got {output_tensor[0, 0].item():.1f}, expected {expected:.1f}" + ) + if rank == 0: + ctx.info(f"Validation passed: output[0,0] = {output_tensor[0, 0].item():.1f}") + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/25_ccl_all_gather/example.py b/examples/25_ccl_all_gather/example.py new file mode 100644 index 000000000..2c15f59e1 --- /dev/null +++ b/examples/25_ccl_all_gather/example.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Example: iris.ccl.all_gather + +Each rank contributes an (M, N) tensor; every rank receives the concatenated (world_size*M, N) result. + +Run with: + torchrun --nproc_per_node= --standalone example.py [--validate] +""" + +import argparse +import os + +import torch +import torch.distributed as dist + +import iris + + +def parse_args(): + parser = argparse.ArgumentParser( + description="CCL all-gather example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=512, help="Number of rows per rank") + parser.add_argument("-n", type=int, default=256, help="Number of columns") + parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") + parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") + parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference") + return vars(parser.parse_args()) + + +def main(): + args = parse_args() + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="gloo") + + ctx = iris.iris(heap_size=args["heap_size"]) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + dtype = dtype_map[args["datatype"]] + M, N = args["m"], args["n"] + + # Each rank fills its input with (rank + 1) + input_tensor = ctx.zeros((M, N), dtype=dtype) + input_tensor.fill_(float(rank + 1)) + output_tensor = ctx.zeros((world_size * M, N), dtype=dtype) + + ctx.barrier() + ctx.ccl.all_gather(output_tensor, input_tensor) + torch.cuda.synchronize() + + if rank == 0: + ctx.info(f"all_gather: world_size={world_size}, input=({M},{N}), output=({world_size * M},{N}), dtype={dtype}") + + if args["validate"]: + for r in range(world_size): + expected = float(r + 1) + chunk = output_tensor[r * M : (r + 1) * M] + assert torch.allclose(chunk, torch.full_like(chunk, expected), atol=0.5), ( + f"Rank {rank}: chunk {r} mismatch. Got {chunk[0, 0].item():.1f}, expected {expected:.1f}" + ) + if rank == 0: + ctx.info(f"Validation passed: output[0,0] = {output_tensor[0, 0].item():.1f}") + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/26_ccl_all_to_all/example.py b/examples/26_ccl_all_to_all/example.py new file mode 100644 index 000000000..d24fbd909 --- /dev/null +++ b/examples/26_ccl_all_to_all/example.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Example: iris.ccl.all_to_all + +Input and output are both (M, N*world_size): input[:, r*N:(r+1)*N] is sent to rank r. + +Run with: + torchrun --nproc_per_node= --standalone example.py [--validate] +""" + +import argparse +import os + +import torch +import torch.distributed as dist + +import iris + + +def parse_args(): + parser = argparse.ArgumentParser( + description="CCL all-to-all example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=512, help="Number of rows") + parser.add_argument("-n", type=int, default=128, help="Number of columns per rank slice") + parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") + parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") + parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference") + return vars(parser.parse_args()) + + +def main(): + args = parse_args() + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="gloo") + + ctx = iris.iris(heap_size=args["heap_size"]) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + dtype = dtype_map[args["datatype"]] + M, N = args["m"], args["n"] + + # input[:, r*N:(r+1)*N] is the slice sent to rank r; fill with unique values + input_tensor = ctx.zeros((M, N * world_size), dtype=dtype) + for target_rank in range(world_size): + input_tensor[:, target_rank * N : (target_rank + 1) * N] = float(rank * 10 + target_rank + 1) + output_tensor = ctx.zeros((M, N * world_size), dtype=dtype) + + ctx.barrier() + ctx.ccl.all_to_all(output_tensor, input_tensor) + torch.cuda.synchronize() + + if rank == 0: + ctx.info(f"all_to_all: world_size={world_size}, shape=({M},{N * world_size}), dtype={dtype}") + + if args["validate"]: + for src_rank in range(world_size): + expected = float(src_rank * 10 + rank + 1) + chunk = output_tensor[:, src_rank * N : (src_rank + 1) * N] + assert torch.allclose(chunk, torch.full_like(chunk, expected), atol=0.5), ( + f"Rank {rank}: chunk from rank {src_rank} mismatch. " + f"Got {chunk[0, 0].item():.1f}, expected {expected:.1f}" + ) + if rank == 0: + ctx.info(f"Validation passed: output[0,0] = {output_tensor[0, 0].item():.1f}") + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/28_ops_matmul_all_reduce/example.py b/examples/28_ops_matmul_all_reduce/example.py new file mode 100644 index 000000000..f66bf3c90 --- /dev/null +++ b/examples/28_ops_matmul_all_reduce/example.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Example: iris.ops.matmul_all_reduce + +Fused GEMM + all-reduce: output = all_reduce(A @ B). + +Run with: + torchrun --nproc_per_node= --standalone example.py [--validate] +""" + +import argparse +import os + +import torch +import torch.distributed as dist + +import iris + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Fused matmul + all-reduce example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=512, help="Rows of A") + parser.add_argument("-n", type=int, default=128, help="Columns of B") + parser.add_argument("-k", type=int, default=256, help="Inner dimension") + parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") + parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") + parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference") + return vars(parser.parse_args()) + + +def main(): + args = parse_args() + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="gloo") + + ctx = iris.iris(heap_size=args["heap_size"]) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + dtype = dtype_map[args["datatype"]] + M, K, N = args["m"], args["k"], args["n"] + + torch.manual_seed(42) + A = ctx.randn((M, K), dtype=dtype) + B = ctx.randn((K, N), dtype=dtype) + output = ctx.zeros((M, N), dtype=dtype) + + ctx.barrier() + ctx.ops.matmul_all_reduce(output, A, B) + torch.cuda.synchronize() + + if rank == 0: + ctx.info(f"matmul_all_reduce: world_size={world_size}, A=({M},{K}), B=({K},{N}), dtype={dtype}") + + if args["validate"]: + # Each rank computes the same GEMM; all-reduce sums world_size copies + ref = torch.matmul(A.clone().float(), B.clone().float()).to(dtype) * world_size + assert torch.allclose(output.float(), ref.float(), atol=1.0, rtol=0.05), ( + f"Rank {rank}: mismatch. Max diff: {(output.float() - ref.float()).abs().max().item():.4f}" + ) + if rank == 0: + ctx.info(f"Validation passed: output[0,0] = {output[0, 0].item():.4f}") + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/29_ops_all_gather_matmul/example.py b/examples/29_ops_all_gather_matmul/example.py new file mode 100644 index 000000000..9c8ac031d --- /dev/null +++ b/examples/29_ops_all_gather_matmul/example.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Example: iris.ops.all_gather_matmul + +Fused all-gather + GEMM: output = all_gather(A_sharded) @ B. +A is column-sharded across ranks; each rank holds A[:, k_start:k_end]. + +Run with: + torchrun --nproc_per_node= --standalone example.py [--validate] + Example: + torchrun --nproc_per_node=4 --standalone example.py -m 4096 -n 128 +""" + +import argparse +import os + +import torch +import torch.distributed as dist + +import iris + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Fused all-gather + matmul example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=512, help="Rows of A") + parser.add_argument("-n", type=int, default=256, help="Columns of B") + parser.add_argument("--k_local", type=int, default=128, help="Columns of A per rank (K_local)") + parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") + parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") + parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference") + return vars(parser.parse_args()) + + +def main(): + args = parse_args() + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="gloo") + + ctx = iris.iris(heap_size=args["heap_size"]) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + dtype = dtype_map[args["datatype"]] + M, K_local, N = args["m"], args["k_local"], args["n"] + K = K_local * world_size + + torch.manual_seed(42 + rank) + A_sharded = ctx.randn((M, K_local), dtype=dtype) + torch.manual_seed(0) + B = ctx.randn((K, N), dtype=dtype) + output = ctx.zeros((M, N), dtype=dtype) + + ctx.barrier() + ctx.ops.all_gather_matmul(output, A_sharded, B) + torch.cuda.synchronize() + + if rank == 0: + ctx.info(f"all_gather_matmul: world_size={world_size}, A_sharded=({M},{K_local}), B=({K},{N}), dtype={dtype}") + + if args["validate"]: + A_shards = [torch.zeros(M, K_local, dtype=dtype, device=A_sharded.device) for _ in range(world_size)] + dist.all_gather(A_shards, A_sharded) + A_full = torch.cat(A_shards, dim=1) + ref = torch.matmul(A_full.float(), B.clone().float()).to(dtype) + assert torch.allclose(output.float(), ref.float(), atol=1.0, rtol=0.05), ( + f"Rank {rank}: mismatch. Max diff: {(output.float() - ref.float()).abs().max().item():.4f}" + ) + if rank == 0: + ctx.info(f"Validation passed: output[0,0] = {output[0, 0].item():.4f}") + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/30_ops_matmul_all_gather/example.py b/examples/30_ops_matmul_all_gather/example.py new file mode 100644 index 000000000..3b70f5dd1 --- /dev/null +++ b/examples/30_ops_matmul_all_gather/example.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Example: iris.ops.matmul_all_gather + +Fused GEMM + all-gather along M: output = all_gather(A_local @ B). +A is row-sharded across ranks; every rank gets the full (M, N) output. + +Run with: + torchrun --nproc_per_node= --standalone example.py [--validate] +""" + +import argparse +import os + +import torch +import torch.distributed as dist + +import iris + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Fused matmul + all-gather example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=4096, help="Total rows (must be divisible by world_size)") + parser.add_argument("-n", type=int, default=128, help="Columns of B") + parser.add_argument("-k", type=int, default=256, help="Inner dimension") + parser.add_argument("--heap_size", type=int, default=1 << 31, help="Iris heap size") + parser.add_argument("--datatype", type=str, default="fp16", choices=["fp16", "fp32", "bf16"], help="Data type") + parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference") + return vars(parser.parse_args()) + + +def main(): + args = parse_args() + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="gloo") + + ctx = iris.iris(heap_size=args["heap_size"]) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + dtype_map = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + dtype = dtype_map[args["datatype"]] + M, K, N = args["m"], args["k"], args["n"] + + if M % world_size != 0: + raise ValueError( + f"M ({M}) must be divisible by world_size ({world_size}). Please adjust -m to be a multiple of {world_size}." + ) + M_local = M // world_size + + torch.manual_seed(42 + rank) + A_local = ctx.randn((M_local, K), dtype=dtype) + torch.manual_seed(0) + B = ctx.randn((K, N), dtype=dtype) + output = ctx.zeros((M, N), dtype=dtype) + + ctx.barrier() + ctx.ops.matmul_all_gather(output, A_local, B) + torch.cuda.synchronize() + + if rank == 0: + ctx.info(f"matmul_all_gather: world_size={world_size}, A_local=({M_local},{K}), B=({K},{N}), dtype={dtype}") + + if args["validate"]: + C_local = torch.matmul(A_local.float(), B.clone().float()).to(dtype) + C_shards = [torch.zeros(M_local, N, dtype=dtype, device=C_local.device) for _ in range(world_size)] + dist.all_gather(C_shards, C_local) + ref = torch.cat(C_shards, dim=0) + assert torch.allclose(output.float(), ref.float(), atol=1.0, rtol=0.05), ( + f"Rank {rank}: mismatch. Max diff: {(output.float() - ref.float()).abs().max().item():.4f}" + ) + if rank == 0: + ctx.info(f"Validation passed: output[0,0] = {output[0, 0].item():.4f}") + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/31_expert_sharded_moe/combine.py b/examples/31_expert_sharded_moe/combine.py new file mode 100644 index 000000000..5f6ec11b4 --- /dev/null +++ b/examples/31_expert_sharded_moe/combine.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +""" +EP-to-DP result combine via iris symmetric heap. + +Closely follows triton_kernels/distributed.py _convert_ep_to_dp: + https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/distributed.py + +Each rank iterates over its expert-sorted output rows. For every row the +kernel looks up the global flat index via col_sorted_indx, determines +which rank owns the originating token, and writes the result into that +rank's per-rank destination buffer using iris.store. + +Destination buffer shape per rank: (n_slots_per_rank, d_model) where +n_slots_per_rank = n_tokens_global // world_size. +""" + +import triton +import triton.language as tl +import iris + + +@triton.jit +def _convert_ep_to_dp( + dst_ptr, + dst_stride_m, + src_ptr, + src_stride_m, + src_shape_n, + expt_filter_ptr, + expt_filter_stride_m, + expt_indx_ptr, + dst_row_indx_ptr, + n_slots_per_rank, + heap_bases, + BLOCK: tl.constexpr, + SRC_RANK: tl.constexpr, + N_RANKS: tl.constexpr, +): + pid_m = tl.program_id(0) + + dst_indx_global = tl.load(dst_row_indx_ptr + pid_m) + if dst_indx_global < 0: + return + + dst_rank = dst_indx_global // n_slots_per_rank + + dst_expt_indx = tl.load(expt_indx_ptr + dst_indx_global).to(tl.int32) + expt_filter_ptr_local = expt_filter_ptr + SRC_RANK * expt_filter_stride_m + has_dst_expt = (tl.load(expt_filter_ptr_local + dst_expt_indx // 32) >> (dst_expt_indx % 32)) & 1 + if not has_dst_expt.to(tl.int1): + return + + dst_indx_local = dst_indx_global - dst_rank * n_slots_per_rank + + offs_n = tl.arange(0, BLOCK) + offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK) + for start_n in range(0, src_shape_n, BLOCK): + mask_n = start_n + offs_n < src_shape_n + src = tl.load( + src_ptr + pid_m * src_stride_m + start_n + offs_n, + mask=mask_n, + other=0.0, + ) + dst_off = dst_indx_local * dst_stride_m + start_n + offs_n + for r in tl.static_range(N_RANKS): + if dst_rank == r: + iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n, hint=16) + + +def convert_ep_to_dp(src, expt_assignment, expt_indx, topk_indx, shmem): + """Scatter expert results back to token-owning ranks. + + Matches the upstream convert_ep_to_dp interface. + + Args: + src: (n_total_slots, d_model) expert-sorted matmul output. + expt_assignment: ExptAssignment with bitmask. + expt_indx: (n_tokens_global * n_expts_act,) flat expert ids. + topk_indx: (n_total_slots,) col_sorted_indx (combine order). + shmem: iris.Iris instance. + + Returns: + dst_local: (n_slots_per_rank, d_model) this rank's combine buffer. + """ + expt_bitmask = expt_assignment.expt_bitmask + n_tokens_global, d_model = src.shape + n_tokens_local = n_tokens_global // shmem.get_num_ranks() + + dst_local = shmem.zeros((n_tokens_local, d_model), dtype=src.dtype) + shmem.barrier() + + BLOCK = min(triton.next_power_of_2(d_model), 512) + grid = (n_tokens_global,) + + _convert_ep_to_dp[grid]( + dst_local, + dst_local.stride(0), + src, + src.stride(0), + src.shape[1], + expt_bitmask, + expt_bitmask.stride(0), + expt_indx, + topk_indx, + n_tokens_local, + shmem.get_heap_bases(), + BLOCK=BLOCK, + SRC_RANK=shmem.get_rank(), + N_RANKS=shmem.get_num_ranks(), + ) + + shmem.barrier() + return dst_local diff --git a/examples/31_expert_sharded_moe/dispatch.py b/examples/31_expert_sharded_moe/dispatch.py new file mode 100644 index 000000000..55c491c1f --- /dev/null +++ b/examples/31_expert_sharded_moe/dispatch.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +""" +DP-to-EP token dispatch via iris symmetric heap. + +Closely follows triton_kernels/distributed.py _convert_dp_to_ep: + https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/distributed.py + +One Triton program per local token. For each of its k expert activations, +the kernel determines which rank owns the expert using a bitmask lookup and +scatters the token's activation row into that rank's destination buffer +via iris.store. +""" + +import triton +import triton.language as tl +import iris + + +@triton.jit +def _convert_dp_to_ep( + dst_ptr, + dst_stride_m, + src_ptr, + src_stride_m, + src_shape_n, + expt_filter_ptr, + expt_filter_stride_m, + expt_indx_ptr, + expt_indx_stride_m, + dst_row_indx_ptr, + dst_row_indx_stride_m, + n_tokens_local, + heap_bases, + SRC_RANK: tl.constexpr, + N_EXPT_ACT: tl.constexpr, + N_RANKS: tl.constexpr, + BLOCK: tl.constexpr, +): + pid_m = tl.program_id(0) + off_m_global = pid_m + n_tokens_local * SRC_RANK + off_m_local = pid_m + + offs_n = tl.arange(0, BLOCK) + offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK), BLOCK) + + for act in tl.static_range(N_EXPT_ACT): + dst_row = tl.load(dst_row_indx_ptr + off_m_global * dst_row_indx_stride_m + act) + if dst_row >= 0: + expt_id = tl.load(expt_indx_ptr + off_m_global * expt_indx_stride_m + act) + + dst_rank = 0 + for r in tl.static_range(N_RANKS): + word = expt_id // 32 + bit = expt_id % 32 + filt = tl.load(expt_filter_ptr + r * expt_filter_stride_m + word) + if (filt >> bit) & 1: + dst_rank = r + + for start_n in range(0, src_shape_n, BLOCK): + mask_n = start_n + offs_n < src_shape_n + src = tl.load( + src_ptr + off_m_local * src_stride_m + start_n + offs_n, + mask=mask_n, + other=0.0, + ) + dst_off = dst_row * dst_stride_m + start_n + offs_n + for r in tl.static_range(N_RANKS): + if dst_rank == r: + iris.store(dst_ptr + dst_off, src, SRC_RANK, r, heap_bases, mask=mask_n, hint=16) + + +def convert_dp_to_ep(src, expt_assignment, expt_indx, gate_indx, shmem): + """Dispatch local tokens to expert-owning ranks. + + Args: + src: (n_tokens_local, d_model) local activations. + expt_assignment: ExptAssignment with bitmask. + expt_indx: (n_tokens_global, n_expts_act) int16/int32 expert ids. + gate_indx: (n_tokens_global * n_expts_act,) row_sorted_indx (dispatch order). + shmem: iris.Iris instance. + + Returns: + dst_local: (n_tokens_global * n_expts_act, d_model) dispatch buffer + on this rank's iris heap. + """ + expt_bitmask = expt_assignment.expt_bitmask + device = src.device + n_tokens_local, d_model = src.shape + n_tokens_global, n_expt_act = expt_indx.shape + + dst_local = shmem.zeros((n_tokens_global * n_expt_act, d_model), dtype=src.dtype) + shmem.barrier() + + BLOCK = min(triton.next_power_of_2(d_model), 512) + grid = (n_tokens_local,) + + _convert_dp_to_ep[grid]( + dst_local, + dst_local.stride(0), + src, + src.stride(0), + src.shape[1], + expt_bitmask, + expt_bitmask.stride(0), + expt_indx, + expt_indx.stride(0), + gate_indx, + n_expt_act, + n_tokens_local, + shmem.get_heap_bases(), + SRC_RANK=shmem.get_rank(), + N_EXPT_ACT=n_expt_act, + N_RANKS=shmem.get_num_ranks(), + BLOCK=BLOCK, + ) + + shmem.barrier() + return dst_local diff --git a/examples/31_expert_sharded_moe/example_run.py b/examples/31_expert_sharded_moe/example_run.py new file mode 100644 index 000000000..714897a05 --- /dev/null +++ b/examples/31_expert_sharded_moe/example_run.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +""" +Expert-sharded distributed MoE example using Iris. + +Validates against a single-device reference implementation. + +Usage: + HIP_VISIBLE_DEVICES=0,1 python example_run.py + HIP_VISIBLE_DEVICES=0,1 python example_run.py --num_ranks 2 --n_tokens 128 --d_model 64 +""" + +import os +import sys +import argparse + +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +import iris + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from expert_assignment import make_expt_dict_uniform, make_expt_assignment +from moe import mixture_of_expt_nosharded, mixture_of_expt_epsharded + + +def parse_args(): + parser = argparse.ArgumentParser(description="Expert-sharded MoE example with Iris") + parser.add_argument("--num_ranks", type=int, default=2) + parser.add_argument("--n_tokens", type=int, default=128) + parser.add_argument("--d_model", type=int, default=64) + parser.add_argument("--n_expts_tot", type=int, default=8) + parser.add_argument("--n_expts_act", type=int, default=2) + parser.add_argument("--atol", type=float, default=1e-2) + parser.add_argument("--rtol", type=float, default=1e-2) + return parser.parse_args() + + +def run_worker(rank, world_size, init_url, args): + dist.init_process_group( + backend="nccl", + init_method=init_url, + world_size=world_size, + rank=rank, + device_id=torch.device(f"cuda:{rank}"), + ) + torch.cuda.set_device(rank) + shmem = iris.iris() + try: + _run_moe_example(rank, world_size, shmem, args) + finally: + del shmem + import gc + + gc.collect() + dist.destroy_process_group() + + +def _run_moe_example(rank, world_size, shmem, args): + n_tokens = args.n_tokens + d_model = args.d_model + n_expts_tot = args.n_expts_tot + n_expts_act = args.n_expts_act + n_tokens_local = n_tokens // world_size + device = torch.device(f"cuda:{rank}") + + if rank == 0: + print("=" * 60) + print("Expert-Sharded MoE Example (Iris)") + print("=" * 60) + print(f" ranks: {world_size}") + print(f" n_tokens: {n_tokens} ({n_tokens_local} per rank)") + print(f" d_model: {d_model}") + print(f" n_expts_tot: {n_expts_tot}") + print(f" n_expts_act: {n_expts_act}") + print(f" expts/rank: {n_expts_tot // world_size}") + print() + + torch.manual_seed(0) + x_global = torch.randn(n_tokens, d_model, device=device, dtype=torch.bfloat16) + l_global = torch.rand(n_tokens, n_expts_tot, device=device, dtype=torch.float32) + w_global = torch.randn(n_expts_tot, d_model, d_model, device=device, dtype=torch.bfloat16) + b_global = torch.randn(n_expts_tot, d_model, device=device, dtype=torch.float32) + + dist.broadcast(x_global, src=0) + dist.broadcast(l_global, src=0) + dist.broadcast(w_global, src=0) + dist.broadcast(b_global, src=0) + + n_shards = world_size + expt_dict = make_expt_dict_uniform(n_shards, n_expts_tot) + expt_assignment = make_expt_assignment(n_shards, n_expts_tot, expt_dict, device) + + if rank == 0: + print("Expert assignment:") + for s, expts in expt_dict.items(): + print(f" rank {s}: experts {expts}") + print() + + if rank == 0: + print("Computing reference (non-sharded) MoE...") + y_global_ref = mixture_of_expt_nosharded( + x_global, + l_global, + w_global, + b_global, + n_expts_act, + ) + + first = rank * n_tokens_local + last = first + n_tokens_local + x_dp_local = x_global[first:last].contiguous() + l_dp_local = l_global[first:last].contiguous() + w_ep_local = w_global[expt_assignment.expt_boolmask[rank]].contiguous() + b_ep_local = b_global[expt_assignment.expt_boolmask[rank]].contiguous() + + shmem.barrier() + + if rank == 0: + print("Running expert-sharded MoE pipeline...") + + z_dp_local = mixture_of_expt_epsharded( + x_dp_local, + l_dp_local, + w_ep_local, + b_ep_local, + expt_assignment, + n_expts_act, + shmem, + ) + + torch.cuda.synchronize() + shmem.barrier() + dist.barrier() + + y_global_tri = torch.empty_like(y_global_ref) + dist.all_gather_into_tensor(y_global_tri, z_dp_local.contiguous()) + + if rank == 0: + print() + print("--- Validation ---") + print(f" Reference output shape: {y_global_ref.shape}") + print(f" Sharded output shape: {y_global_tri.shape}") + print(f" Reference first row[:5]: {y_global_ref[0, :5]}") + print(f" Sharded first row[:5]: {y_global_tri[0, :5]}") + print() + + diff = (y_global_ref.float() - y_global_tri.float()).abs() + print(f" max diff = {diff.max().item():.6f}") + print(f" mean diff = {diff.mean().item():.6f}") + print() + + try: + torch.testing.assert_close( + y_global_ref, + y_global_tri, + atol=args.atol, + rtol=args.rtol, + ) + print("PASSED: sharded MoE matches reference") + except AssertionError as e: + print(f"FAILED: {str(e)[:500]}") + + print("=" * 60) + + +def main(): + args = parse_args() + assert args.n_tokens % args.num_ranks == 0, ( + f"n_tokens ({args.n_tokens}) must be divisible by num_ranks ({args.num_ranks})" + ) + assert args.n_expts_tot % args.num_ranks == 0, ( + f"n_expts_tot ({args.n_expts_tot}) must be divisible by num_ranks ({args.num_ranks})" + ) + + init_url = "tcp://127.0.0.1:29504" + mp.spawn( + fn=run_worker, + args=(args.num_ranks, init_url, args), + nprocs=args.num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/31_expert_sharded_moe/expert_assignment.py b/examples/31_expert_sharded_moe/expert_assignment.py new file mode 100644 index 000000000..0d951ce97 --- /dev/null +++ b/examples/31_expert_sharded_moe/expert_assignment.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Expert-to-rank assignment for expert-parallel MoE. + +Ported from triton_kernels/distributed.py: + https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/distributed.py +""" + +import torch +from dataclasses import dataclass + + +@dataclass +class ExptAssignment: + # (n_shards, ceil(n_expts_tot / 32)) -- packed int32 bitmask + # (expt_bitmask[i, j//32] >> j%32) & 1 == 1 iff expert j is owned by shard i + expt_bitmask: torch.Tensor + # (n_shards, n_expts_tot) -- boolean mask + expt_boolmask: torch.Tensor + # (n_shards, n_expts_tot) -- local expert id or -1 + expt_map: torch.Tensor + n_expts_per_shard: list[int] + + +def make_expt_dict_uniform(n_shards: int, n_expts_tot: int) -> dict[int, list[int]]: + """Contiguous assignment: shard i owns experts [i*E_per_shard, (i+1)*E_per_shard).""" + assert n_expts_tot % n_shards == 0, "n_expts_tot must be divisible by n_shards" + e_per_shard = n_expts_tot // n_shards + return {i: list(range(i * e_per_shard, (i + 1) * e_per_shard)) for i in range(n_shards)} + + +def make_expt_assignment( + n_shards: int, + n_expts_tot: int, + expt_dict: dict[int, list[int]], + device, +) -> ExptAssignment: + """Build bitmask, boolmask, and local-id map from an expert ownership dict.""" + words = (n_expts_tot + 31) // 32 + expt_bitmask = torch.zeros((n_shards, words), dtype=torch.int32) + expt_boolmask = torch.zeros((n_shards, n_expts_tot), dtype=torch.bool) + counts = {e: 0 for e in range(n_expts_tot)} + + for shard, experts in expt_dict.items(): + if not (0 <= shard < n_shards): + raise ValueError(f"shard {shard} out of range [0, {n_shards})") + if len(experts) == 0: + raise ValueError(f"shard {shard} has no experts") + for e in experts: + counts[e] += 1 + if not (0 <= e < n_expts_tot): + raise ValueError(f"expert id {e} out of range [0, {n_expts_tot})") + word = e >> 5 + bit = e & 31 + expt_bitmask[shard, word] |= 1 << bit + expt_boolmask[shard, e] = True + + if not all(c == 1 for c in counts.values()): + raise ValueError("each expert must be owned by exactly one shard") + + expt_bitmask = expt_bitmask.to(device) + expt_boolmask = expt_boolmask.to(device) + + expt_map = torch.full((n_shards, n_expts_tot), -1, dtype=torch.int32) + for shard, experts in expt_dict.items(): + for local_id, global_id in enumerate(sorted(experts)): + expt_map[shard, global_id] = local_id + expt_map = expt_map.to(device) + + n_expts_per_shard = [len(expt_dict[s]) for s in range(n_shards)] + return ExptAssignment(expt_bitmask, expt_boolmask, expt_map, n_expts_per_shard) diff --git a/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py b/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py new file mode 100644 index 000000000..ac163d1aa --- /dev/null +++ b/examples/31_expert_sharded_moe/fused_exp_matmul_ep_to_dp.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +""" +Fused expert matmul + EP->DP combine. + +This module fuses: + grouped_matmul(y_ep_local, w_ep_local, b_ep_local, ...) + + convert_ep_to_dp(...) + +into a single Triton kernel that: + 1) computes a tiled GEMM (BLOCK_M x BLOCK_N via tl.dot) for each expert + 2) scatters the output tile to token-owning ranks via per-rank 2D iris.store + +Grid: (n_n_tiles * n_local_experts,) -- same tiling as grouped_matmul. +Each program loops over M-tiles for one (expert, N-tile) pair, computes +the tile with tl.dot, then does per-rank masked 2D stores. +""" + +import torch +import triton +import triton.language as tl +import iris + +from ragged_metadata import RaggedTensorMetadata + + +@triton.jit +def _fused_exp_matmul_ep_to_dp_kernel( + dst_ptr, + dst_stride_m, + x_ptr, + x_stride_m, + x_stride_k, + w_ptr, + w_stride_e, + w_stride_k, + w_stride_n, + b_ptr, + b_stride_e, + b_stride_n, + slice_offs_ptr, + slice_sizes_ptr, + expt_filter_ptr, + expt_filter_stride_m, + expt_indx_ptr, + topk_indx_ptr, + n_local_experts, + n_slots_per_rank, + K, + N, + heap_bases, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + HAS_BIAS: tl.constexpr, + SRC_RANK: tl.constexpr, + N_RANKS: tl.constexpr, +): + pid = tl.program_id(0) + n_n_tiles = tl.cdiv(N, BLOCK_N) + + local_expert_id = pid // n_n_tiles + pid_n = pid % n_n_tiles + + if local_expert_id >= n_local_experts: + return + + local_expert_id_64 = local_expert_id.to(tl.int64) + slice_off = tl.load(slice_offs_ptr + local_expert_id_64).to(tl.int64) + slice_size = tl.load(slice_sizes_ptr + local_expert_id_64) + if slice_size == 0: + return + + n_m_tiles = tl.cdiv(slice_size, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_N), BLOCK_N) + mask_n = offs_n < N + + for pid_m in range(0, n_m_tiles): + offs_m_local = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = slice_off + offs_m_local + mask_m = offs_m_local < slice_size + + # Pre-load scatter metadata for this M-tile. + dst_indx_globals = tl.load(topk_indx_ptr + offs_m, mask=mask_m, other=-1) + valid_dst = mask_m & (dst_indx_globals >= 0) + + safe_dst_indx = tl.where(valid_dst, dst_indx_globals, tl.zeros_like(dst_indx_globals)) + dst_expt_indxs = tl.load(expt_indx_ptr + safe_dst_indx, mask=valid_dst, other=0).to(tl.int32) + + expt_filter_ptr_local = expt_filter_ptr + SRC_RANK * expt_filter_stride_m + has_dst_expts = ( + (tl.load(expt_filter_ptr_local + dst_expt_indxs // 32, mask=valid_dst, other=0) >> (dst_expt_indxs % 32)) + & 1 + ).to(tl.int1) + + row_valid = valid_dst & has_dst_expts + dst_ranks = dst_indx_globals // n_slots_per_rank + dst_indx_locals = dst_indx_globals - dst_ranks * n_slots_per_rank + dst_indx_locals = tl.where(row_valid, dst_indx_locals, tl.zeros_like(dst_indx_locals)) + + # Tiled GEMM. + acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + for start_k in range(0, K, BLOCK_K): + offs_k = start_k + tl.arange(0, BLOCK_K) + mask_k = offs_k < K + + x_ptrs = x_ptr + offs_m[:, None] * x_stride_m + offs_k[None, :] * x_stride_k + x = tl.load(x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0) + + w_ptrs = ( + w_ptr + local_expert_id_64 * w_stride_e + offs_k[:, None] * w_stride_k + offs_n[None, :] * w_stride_n + ) + w = tl.load(w_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0.0) + + acc += tl.dot(x, w) + + if HAS_BIAS: + b_ptrs = b_ptr + local_expert_id_64 * b_stride_e + offs_n * b_stride_n + bias = tl.load(b_ptrs, mask=mask_n, other=0.0) + acc += bias[None, :] + + out = acc.to(dst_ptr.dtype.element_ty) + + # Per-rank 2D masked scatter. + dst_ptrs_2d = dst_ptr + dst_indx_locals[:, None] * dst_stride_m + offs_n[None, :] + for r in tl.static_range(N_RANKS): + rank_mask = row_valid & (dst_ranks == r) + store_mask = rank_mask[:, None] & mask_n[None, :] + if r == SRC_RANK: + tl.store(dst_ptrs_2d, out, mask=store_mask) + else: + iris.store(dst_ptrs_2d, out, SRC_RANK, r, heap_bases, mask=store_mask, hint=(1, 16)) + + +def fused_exp_matmul_ep_to_dp( + x_ep_local: torch.Tensor, + w_ep_local: torch.Tensor, + b_ep_local: torch.Tensor | None, + expt_assignment, + expt_map_local: torch.Tensor, + expt_indx_flat: torch.Tensor, + combine_indx: torch.Tensor, + shmem, + ragged_metadata: RaggedTensorMetadata | None = None, +) -> torch.Tensor: + """Compute expert matmul and scatter to DP-local output in one kernel. + + Uses tiled GEMM (tl.dot) with per-rank 2D masked scatter -- same + compute throughput as grouped_matmul but fused with the EP->DP combine. + + Args: + x_ep_local: (n_total_slots, d_model) dispatched activations. + w_ep_local: (n_local_experts, d_model, d_model) local expert weights. + b_ep_local: (n_local_experts, d_model) local expert biases or None. + expt_assignment: ExptAssignment with bitmask for ownership check. + expt_map_local: (n_expts_tot,) global expert -> local expert id or -1. + expt_indx_flat: (n_total_slots,) flat global expert ids by token-slot. + combine_indx: (n_total_slots,) col_sorted_indx. + shmem: iris.Iris instance. + ragged_metadata: local-expert-view ragged metadata (slice_offs, slice_sizes). + + Returns: + (n_slots_per_rank, d_model) DP-local combined output. + """ + expt_bitmask = expt_assignment.expt_bitmask + n_total_slots, d_model = x_ep_local.shape + n_local_experts = w_ep_local.shape[0] + n_slots_per_rank = n_total_slots // shmem.get_num_ranks() + K = d_model + N = d_model + + dst_local = shmem.zeros((n_slots_per_rank, d_model), dtype=x_ep_local.dtype) + shmem.barrier() + + BLOCK_M = 128 + BLOCK_N = min(triton.next_power_of_2(N), 128) + BLOCK_K = min(triton.next_power_of_2(K), 64) + + n_n_tiles = triton.cdiv(N, BLOCK_N) + grid = (n_n_tiles * n_local_experts,) + + _fused_exp_matmul_ep_to_dp_kernel[grid]( + dst_local, + dst_local.stride(0), + x_ep_local, + x_ep_local.stride(0), + x_ep_local.stride(1), + w_ep_local, + w_ep_local.stride(0), + w_ep_local.stride(1), + w_ep_local.stride(2), + b_ep_local if b_ep_local is not None else x_ep_local, + b_ep_local.stride(0) if b_ep_local is not None else 0, + b_ep_local.stride(1) if b_ep_local is not None else 0, + ragged_metadata.slice_offs, + ragged_metadata.slice_sizes, + expt_bitmask, + expt_bitmask.stride(0), + expt_indx_flat, + combine_indx, + n_local_experts, + n_slots_per_rank, + K, + N, + shmem.get_heap_bases(), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + HAS_BIAS=(b_ep_local is not None), + SRC_RANK=shmem.get_rank(), + N_RANKS=shmem.get_num_ranks(), + num_warps=8, + num_stages=2, + matrix_instr_nonkdim=16, + kpack=1, + ) + + shmem.barrier() + return dst_local diff --git a/examples/31_expert_sharded_moe/grouped_matmul.py b/examples/31_expert_sharded_moe/grouped_matmul.py new file mode 100644 index 000000000..fbc4c096a --- /dev/null +++ b/examples/31_expert_sharded_moe/grouped_matmul.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Simplified grouped/ragged GEMM for expert-parallel MoE. + +Ported / simplified from triton_kernels matmul: + https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/matmul_details/_matmul.py + +Non-persistent, non-TMA tiled GEMM that handles variable-length expert +batches described by ragged metadata (slice_offs, slice_sizes). + + Y[offs[e]:offs[e+1], :] = X[offs[e]:offs[e+1], :] @ W[e, :, :] + bias[e, :] +""" + +import torch +import triton +import triton.language as tl +from ragged_metadata import RaggedTensorMetadata + + +@triton.jit +def _grouped_matmul_kernel( + X_ptr, + stride_x_m, + stride_x_k, + W_ptr, + stride_w_e, + stride_w_k, + stride_w_n, + B_ptr, + stride_b_e, + stride_b_n, + Y_ptr, + stride_y_m, + stride_y_n, + SliceOffs_ptr, + SliceSizes_ptr, + n_experts, + K, + N, + HAS_BIAS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """Tiled GEMM over ragged expert batches. + + Grid: (n_n_tiles * n_experts,) + Each program handles one (expert, n_tile) pair and loops over M tiles. + """ + pid = tl.program_id(0) + n_n_tiles = tl.cdiv(N, BLOCK_N) + + expert_id = pid // n_n_tiles + pid_n = pid % n_n_tiles + + if expert_id >= n_experts: + return + + # int64 to prevent pointer-offset overflow when n_experts * K * N > 2^31 + expert_id = expert_id.to(tl.int64) + slice_off = tl.load(SliceOffs_ptr + expert_id).to(tl.int64) + slice_size = tl.load(SliceSizes_ptr + expert_id) + if slice_size == 0: + return + + n_m_tiles = tl.cdiv(slice_size, BLOCK_M) + + for pid_m in range(0, n_m_tiles): + offs_m = slice_off + pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) < slice_size + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_N), BLOCK_N) + mask_n = offs_n < N + + acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + for start_k in range(0, K, BLOCK_K): + offs_k = start_k + tl.arange(0, BLOCK_K) + mask_k = offs_k < K + + x_ptrs = X_ptr + offs_m[:, None] * stride_x_m + offs_k[None, :] * stride_x_k + x = tl.load(x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0) + + w_ptrs = W_ptr + expert_id * stride_w_e + offs_k[:, None] * stride_w_k + offs_n[None, :] * stride_w_n + w = tl.load(w_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0.0) + + acc += tl.dot(x, w) + + if HAS_BIAS: + b_ptrs = B_ptr + expert_id * stride_b_e + offs_n * stride_b_n + bias = tl.load(b_ptrs, mask=mask_n, other=0.0) + acc += bias[None, :] + + y_ptrs = Y_ptr + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n + tl.store(y_ptrs, acc.to(Y_ptr.dtype.element_ty), mask=mask_m[:, None] & mask_n[None, :]) + + +def grouped_matmul( + x: torch.Tensor, + w: torch.Tensor, + bias: torch.Tensor | None, + ragged_metadata: RaggedTensorMetadata, +) -> torch.Tensor: + """Ragged grouped GEMM: one matmul per expert slice. + + Args: + x: (total_tokens, K) activations in expert-sorted order. + w: (n_experts, K, N) weight matrices. + bias: (n_experts, N) bias vectors, or None. + ragged_metadata: which rows of x belong to which expert. + + Returns: + y: (total_tokens, N) output in the same ragged layout as x. + """ + total_tokens, K = x.shape + n_experts, _, N = w.shape + device = x.device + + y = torch.zeros((total_tokens, N), dtype=x.dtype, device=device) + + BLOCK_M = 128 + BLOCK_N = min(triton.next_power_of_2(N), 128) + BLOCK_K = min(triton.next_power_of_2(K), 64) + + n_n_tiles = triton.cdiv(N, BLOCK_N) + grid = (n_n_tiles * n_experts,) + + _grouped_matmul_kernel[grid]( + x, + x.stride(0), + x.stride(1), + w, + w.stride(0), + w.stride(1), + w.stride(2), + bias if bias is not None else x, + bias.stride(0) if bias is not None else 0, + bias.stride(1) if bias is not None else 0, + y, + y.stride(0), + y.stride(1), + ragged_metadata.slice_offs, + ragged_metadata.slice_sizes, + n_experts, + K, + N, + HAS_BIAS=(bias is not None), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + num_warps=8, + num_stages=2, + matrix_instr_nonkdim=16, + kpack=1, + ) + return y diff --git a/examples/31_expert_sharded_moe/moe.py b/examples/31_expert_sharded_moe/moe.py new file mode 100644 index 000000000..8a9124293 --- /dev/null +++ b/examples/31_expert_sharded_moe/moe.py @@ -0,0 +1,328 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +""" +Expert-sharded MoE forward pass -- reference and distributed. + +Closely follows the test flow in triton_kernels/tests/test_distributed.py: + https://github.com/triton-lang/triton/blob/main/python/triton_kernels/tests/test_distributed.py + +Implements: + mixture_of_expt_nosharded -- single-device reference (uses PyTorch) + mixture_of_expt_epsharded -- expert-parallel 8-step pipeline using iris +""" + +from dataclasses import dataclass + +import torch +import triton +import triton.language as tl +import iris + +from ragged_metadata import make_ragged_tensor_metadata, remap_ragged_tensor_metadata +from topk import topk, _make_bitmatrix_metadata +from dispatch import convert_dp_to_ep +from combine import convert_ep_to_dp +from grouped_matmul import grouped_matmul +from fused_exp_matmul_ep_to_dp import fused_exp_matmul_ep_to_dp +from reduce import reduce + + +# --------------------------------------------------------------------------- +# Iris all-gather helper (push model) +# --------------------------------------------------------------------------- + + +@triton.jit +def _allgather_push_kernel( + src_ptr, + dst_ptr, + dst_offset, + src_numel, + heap_bases, + CUR_RANK: tl.constexpr, + N_RANKS: tl.constexpr, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + offs = tl.max_contiguous(tl.multiple_of(offs, BLOCK), BLOCK) + mask = offs < src_numel + data = tl.load(src_ptr + offs, mask=mask) + for r in tl.static_range(N_RANKS): + dst = dst_ptr + dst_offset + offs + iris.store(dst, data, CUR_RANK, r, heap_bases, mask=mask, hint=16) + + +def _allgather_iris(local_tensor, shmem): + """All-gather a 2-D tensor via iris push: each rank writes its chunk + to every rank's shared buffer at the correct offset. + + Sub-32-bit dtypes (e.g. int16) are promoted to int32 for the push + because iris.store can silently mishandle narrow element types when + the heap offset is not aligned to the natural store width. + """ + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + orig_dtype = local_tensor.dtype + need_promote = orig_dtype.itemsize < 4 + work_dtype = torch.int32 if need_promote else orig_dtype + + src = local_tensor.contiguous() + if need_promote: + src = src.to(work_dtype) + + n_local = src.shape[0] + rest = list(src.shape[1:]) + global_shape = [n_local * world_size] + rest + buf = shmem.zeros(global_shape, dtype=work_dtype) + # Match the other communication wrappers: ensure every rank has + # allocated its destination heap buffer before remote stores begin. + shmem.barrier() + heap_bases = shmem.get_heap_bases() + + src_flat = src.view(-1) + numel = src_flat.numel() + elem_offset = rank * numel + + BLOCK = min(triton.next_power_of_2(numel), 1024) + grid = (triton.cdiv(numel, BLOCK),) + _allgather_push_kernel[grid]( + src_flat, + buf.view(-1), + elem_offset, + numel, + heap_bases, + CUR_RANK=rank, + N_RANKS=world_size, + BLOCK=BLOCK, + ) + shmem.barrier() + if need_promote: + return buf.to(orig_dtype) + return buf + + +# --------------------------------------------------------------------------- +# Reference: single-device MoE (no sharding) +# --------------------------------------------------------------------------- + + +def mixture_of_expt_nosharded(x_global, l_global, w_global, b_global, n_expts_act): + """Reference MoE on a single device using our own kernels. + + Follows the upstream routing -> matmul -> reduce flow exactly: + 1. topk routing + 2. build bitmatrix & ragged metadata + 3. gather tokens into expert-sorted order (using col_sorted_indx / k) + 4. grouped matmul + 5. scatter results back to (n_tokens*k, d_model) + 6. reduce(dim=1) with validity mask + """ + n_tokens, d_model = x_global.shape + n_expts_tot = l_global.shape[1] + device = x_global.device + + topk_result = topk(l_global, n_expts_act, apply_softmax=True) + active_indx = topk_result.indx + mask_metadata = _make_bitmatrix_metadata(active_indx.to(torch.int32), n_expts_tot) + + dispatch_indx = mask_metadata.row_sorted_indx + combine_indx = mask_metadata.col_sorted_indx + expt_sizes = mask_metadata.col_sum + + n_active = int(expt_sizes.sum().item()) + ragged_meta = make_ragged_tensor_metadata(expt_sizes, n_active) + + gather_idx = torch.div(combine_indx[:n_active], n_expts_act, rounding_mode="trunc") + + x_sorted = torch.zeros(n_active, d_model, dtype=x_global.dtype, device=device) + valid_gather = gather_idx >= 0 + x_sorted[valid_gather] = x_global[gather_idx[valid_gather].long()] + + y_sorted = grouped_matmul(x_sorted, w_global, b_global, ragged_meta) + + y_flat = torch.zeros(n_tokens * n_expts_act, d_model, dtype=x_global.dtype, device=device) + for i in range(n_active): + dst = combine_indx[i].item() + if dst >= 0: + y_flat[dst] = y_sorted[i] + + y_mask = (dispatch_indx != -1).view(n_tokens, n_expts_act, 1) + y_3d = y_flat.view(n_tokens, n_expts_act, d_model) + y_mask = y_mask.expand_as(y_3d).contiguous() + y_global, _ = reduce(y_3d, dim=1, mask=y_mask) + return y_global + + +# --------------------------------------------------------------------------- +# Distributed: expert-parallel MoE using iris +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class MoeFusionConfig: + """Fusion mode selector for expert-sharded MoE pipeline.""" + + fuse_convert_dp_to_ep_grouped_matmul: bool = False + fuse_grouped_matmul_convert_ep_to_dp: bool = False + + def mode_name(self) -> str: + parts: list[str] = [] + if self.fuse_convert_dp_to_ep_grouped_matmul: + parts.append("convert_dp_to_ep_grouped_matmul") + if self.fuse_grouped_matmul_convert_ep_to_dp: + parts.append("grouped_matmul_convert_ep_to_dp") + if not parts: + return "unfused" + return "fused_" + "__".join(parts) + + @staticmethod + def from_mode_name(name: str) -> "MoeFusionConfig": + if name == "unfused": + return MoeFusionConfig() + if name == "fused_grouped_matmul_convert_ep_to_dp": + return MoeFusionConfig(fuse_grouped_matmul_convert_ep_to_dp=True) + if name == "fused_convert_dp_to_ep_grouped_matmul": + return MoeFusionConfig(fuse_convert_dp_to_ep_grouped_matmul=True) + if name == "fused_convert_dp_to_ep_grouped_matmul__grouped_matmul_convert_ep_to_dp": + return MoeFusionConfig( + fuse_convert_dp_to_ep_grouped_matmul=True, + fuse_grouped_matmul_convert_ep_to_dp=True, + ) + raise ValueError(f"Unknown fusion mode name: {name}") + + +def mixture_of_expt_epsharded( + x_dp_local, + l_dp_local, + w_ep_local, + b_ep_local, + expt_assignment, + n_expts_act, + shmem, + fusion_config: MoeFusionConfig | None = None, + timing_dict: dict | None = None, +): + """Expert-parallel MoE forward using iris symmetric heap. + + Args: + x_dp_local: (n_tokens_local, d_model) local token activations. + l_dp_local: (n_tokens_local, n_expts_tot) local logits. + w_ep_local: (n_expts_local, d_model, d_model) local expert weights. + b_ep_local: (n_expts_local, d_model) local expert biases. + expt_assignment: ExptAssignment mapping experts to ranks. + n_expts_act: k (experts per token). + shmem: iris.Iris instance. + + Returns: + (n_tokens_local, d_model) output for this rank's tokens. + """ + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + n_tokens_local, d_model = x_dp_local.shape + n_tokens_global = n_tokens_local * world_size + n_expts_tot = l_dp_local.shape[1] + device = x_dp_local.device + + def _tick(label): + """Record a cuda event for timing breakdown. timing_dict is a list of (label, event) pairs.""" + if timing_dict is not None: + torch.cuda.synchronize() + ev = torch.cuda.Event(enable_timing=True) + ev.record() + timing_dict.append((label, ev)) + + _tick("start") + + # ------------------------------------------------------------------ + # Step 1: Top-k routing (local) + all-gather via iris + # ------------------------------------------------------------------ + local_topk = topk(l_dp_local, n_expts_act, apply_softmax=True) + _tick("topk") + + vals_global = _allgather_iris(local_topk.vals, shmem) + # Keep routing indices in int32 after gather. We observed rank-dependent + # corruption when converting gathered index buffers back to int16. + indx_global = _allgather_iris(local_topk.indx.contiguous().to(torch.int32), shmem) + _tick("allgather") + + # ------------------------------------------------------------------ + # Step 2: Extract routing metadata from global topk + # ------------------------------------------------------------------ + active_indx = indx_global + mask_metadata = _make_bitmatrix_metadata(active_indx.to(torch.int32), n_expts_tot) + + expt_sizes = mask_metadata.col_sum + dispatch_indx = mask_metadata.row_sorted_indx + combine_indx = mask_metadata.col_sorted_indx + + # ------------------------------------------------------------------ + # Step 3: Build ragged tensor metadata + # ------------------------------------------------------------------ + n_active = int(expt_sizes.sum().item()) + x_global_metadata = make_ragged_tensor_metadata(expt_sizes, n_active) + _tick("metadata") + + # ------------------------------------------------------------------ + # Step 4: DP -> EP dispatch (all-to-all via iris.store) + # ------------------------------------------------------------------ + y_ep_local = convert_dp_to_ep( + x_dp_local, + expt_assignment, + active_indx, + dispatch_indx, + shmem, + ) + _tick("dispatch") + + # ------------------------------------------------------------------ + # Step 5: Remap ragged metadata to local expert view + # ------------------------------------------------------------------ + expt_map = expt_assignment.expt_map[rank, :].contiguous() + y_ep_local_metadata = remap_ragged_tensor_metadata(x_global_metadata, expt_map) + + fusion_config = fusion_config or MoeFusionConfig() + if fusion_config.fuse_convert_dp_to_ep_grouped_matmul: + raise NotImplementedError("Fusion mode convert_dp_to_ep_grouped_matmul is not implemented yet.") + + # ------------------------------------------------------------------ + # grouped_matmul + convert_ep_to_dp (select fused/unfused variant) + # ------------------------------------------------------------------ + flat_expt_indx = active_indx.to(torch.int32).reshape(-1) + if fusion_config.fuse_grouped_matmul_convert_ep_to_dp: + y_dp_local = fused_exp_matmul_ep_to_dp( + y_ep_local, + w_ep_local, + b_ep_local, + expt_assignment, + expt_map, + flat_expt_indx, + combine_indx, + shmem, + ragged_metadata=y_ep_local_metadata, + ) + _tick("fused_matmul_scatter") + else: + y_ep_local = grouped_matmul(y_ep_local, w_ep_local, b_ep_local, y_ep_local_metadata) + _tick("matmul") + y_dp_local = convert_ep_to_dp( + y_ep_local, + expt_assignment, + flat_expt_indx, + combine_indx, + shmem, + ) + _tick("combine") + + # ------------------------------------------------------------------ + # Step 8: Reduce (unweighted sum, masked) + # ------------------------------------------------------------------ + y_dp_local = y_dp_local.view(-1, n_expts_act, d_model) + y_mask = (dispatch_indx != -1).view(n_tokens_global, n_expts_act, 1) + local_mask = y_mask[rank * n_tokens_local : (rank + 1) * n_tokens_local] + local_mask = local_mask.expand_as(y_dp_local).contiguous() + z_dp_local, _ = reduce(y_dp_local, dim=1, mask=local_mask) + _tick("reduce") + + torch.cuda.synchronize() + return z_dp_local diff --git a/examples/31_expert_sharded_moe/ragged_metadata.py b/examples/31_expert_sharded_moe/ragged_metadata.py new file mode 100644 index 000000000..3b837b6ed --- /dev/null +++ b/examples/31_expert_sharded_moe/ragged_metadata.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Ragged tensor metadata for grouped expert computation. + +Simplified port of triton_kernels/tensor_details/ragged_tensor.py: + https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/ragged_tensor.py + +Only the fields needed by the simplified grouped matmul are retained: +slice_sizes, slice_offs, and n_slices. +""" + +import torch +from dataclasses import dataclass + + +@dataclass +class RaggedTensorMetadata: + """Lightweight ragged tensor descriptor. + + Example with 4 experts receiving [3, 0, 5, 2] tokens: + slice_sizes = [3, 0, 5, 2] + slice_offs = [0, 3, 3, 8, 10] + """ + + slice_sizes: torch.Tensor # (n_slices,) int32 + slice_offs: torch.Tensor # (n_slices + 1,) int32 + + @property + def n_slices(self) -> int: + return self.slice_sizes.shape[0] + + +def make_ragged_tensor_metadata( + slice_sizes: torch.Tensor, + n_total_rows: int, +) -> RaggedTensorMetadata: + """Build ragged metadata from per-expert token counts. + + Args: + slice_sizes: (n_experts,) int32 tensor of token counts per expert. + n_total_rows: total number of active token-expert slots (for validation). + """ + assert slice_sizes.ndim == 1 + slice_sizes = slice_sizes.to(torch.int32) + offs = torch.zeros(slice_sizes.shape[0] + 1, dtype=torch.int32, device=slice_sizes.device) + offs[1:] = torch.cumsum(slice_sizes, dim=0) + return RaggedTensorMetadata(slice_sizes, offs) + + +def remap_ragged_tensor_metadata( + metadata: RaggedTensorMetadata, + expt_map: torch.Tensor, +) -> RaggedTensorMetadata: + """Remap global expert metadata to a local expert view. + + expt_map: (n_expts_tot,) int32 where expt_map[global_id] is the local id + on this rank, or -1 if the expert is not on this rank. + + Returns metadata containing only the experts owned by this rank, with + ORIGINAL global offsets preserved so the grouped matmul addresses the + correct positions in the globally-indexed dispatch buffer. + """ + valid = expt_map != -1 + local_ids = expt_map[valid] + n_local = int(local_ids.max().item()) + 1 if local_ids.numel() > 0 else 0 + device = metadata.slice_sizes.device + local_sizes = torch.zeros(n_local, dtype=torch.int32, device=device) + local_offs = torch.zeros(n_local + 1, dtype=torch.int32, device=device) + for g in range(expt_map.shape[0]): + lid = expt_map[g].item() + if lid >= 0: + local_sizes[lid] = metadata.slice_sizes[g] + local_offs[lid] = metadata.slice_offs[g] + if n_local > 0: + local_offs[n_local] = local_offs[n_local - 1] + local_sizes[n_local - 1] + return RaggedTensorMetadata(local_sizes, local_offs) diff --git a/examples/31_expert_sharded_moe/reduce.py b/examples/31_expert_sharded_moe/reduce.py new file mode 100644 index 000000000..6203908a7 --- /dev/null +++ b/examples/31_expert_sharded_moe/reduce.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +""" +Expert reduce for MoE. + +Matches triton_kernels/reduce.py semantics: + https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/reduce.py + + z[t, :] = sum_{a where mask[t,a,:]!=0} y[t, a, :] + +Plain (unweighted) sum over the k expert outputs per token, gated only +by a boolean validity mask. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _reduce_kernel( + Y_ptr, + stride_y_t, + stride_y_a, + stride_y_d, + Z_ptr, + stride_z_t, + stride_z_d, + Mask_ptr, + n_tokens, + d_model, + N_EXPTS_ACT: tl.constexpr, + BLOCK_D: tl.constexpr, + HAS_MASK: tl.constexpr, +): + pid_t = tl.program_id(0) + pid_d = tl.program_id(1) + if pid_t >= n_tokens: + return + + offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask_d = offs_d < d_model + + acc = tl.zeros([BLOCK_D], dtype=tl.float32) + for act in range(N_EXPTS_ACT): + if HAS_MASK: + m = tl.load( + Mask_ptr + pid_t * N_EXPTS_ACT * d_model + act * d_model + offs_d, + mask=mask_d, + other=0, + ).to(tl.int1) + y = tl.load( + Y_ptr + pid_t * stride_y_t + act * stride_y_a + offs_d * stride_y_d, + mask=mask_d, + other=0.0, + ).to(tl.float32) + if HAS_MASK: + y = tl.where(m, y, 0.0) + acc += y + + tl.store( + Z_ptr + pid_t * stride_z_t + offs_d * stride_z_d, + acc.to(Z_ptr.dtype.element_ty), + mask=mask_d, + ) + + +def reduce( + y: torch.Tensor, + dim: int = 1, + mask: torch.Tensor | None = None, +) -> tuple[torch.Tensor, None]: + """Sum-reduce over *dim* with optional boolean mask. + + Matches the upstream ``reduce(y, dim=1, mask=mask)`` signature. + + Args: + y: (n_tokens, k, d_model) expert outputs. + dim: reduction dimension (must be 1). + mask: (n_tokens, k, d_model) bool/int mask; zero = skip. + + Returns: + (z, None) where z has shape (n_tokens, d_model). + """ + assert dim == 1 and y.ndim == 3 + n_tokens, k, d_model = y.shape + device = y.device + + z = torch.zeros((n_tokens, d_model), dtype=y.dtype, device=device) + + BLOCK_D = min(triton.next_power_of_2(d_model), 512) + grid = (n_tokens, triton.cdiv(d_model, BLOCK_D)) + + _reduce_kernel[grid]( + y, + y.stride(0), + y.stride(1), + y.stride(2), + z, + z.stride(0), + z.stride(1), + mask if mask is not None else y, + n_tokens, + d_model, + N_EXPTS_ACT=k, + BLOCK_D=BLOCK_D, + HAS_MASK=(mask is not None), + ) + return z, None diff --git a/examples/31_expert_sharded_moe/topk.py b/examples/31_expert_sharded_moe/topk.py new file mode 100644 index 000000000..8477cdd54 --- /dev/null +++ b/examples/31_expert_sharded_moe/topk.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Top-k expert routing for MoE. + +Ported / simplified from triton_kernels/topk.py and bitmatrix.py: + https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/topk.py + https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/bitmatrix.py + +Provides: + - PyTorch-based top-k + softmax (matches topk_torch reference) + - Host-side BitmatrixMetadata (col_sum, row_sorted_indx, col_sorted_indx) + - A convenience ``topk`` function +""" + +import torch +from dataclasses import dataclass + + +@dataclass +class BitmatrixMetadata: + """Routing indices derived from the top-k selection. + + col_sum: (n_expts,) histogram: tokens per expert + row_sorted_indx: (n_tokens * k,) flat token-expert slots grouped by expert (dispatch order) + col_sorted_indx: (n_tokens * k,) inverse permutation (combine order) + """ + + col_sum: torch.Tensor + row_sorted_indx: torch.Tensor + col_sorted_indx: torch.Tensor + + +@dataclass +class TopkResult: + vals: torch.Tensor # (n_tokens, k) softmax gating weights + indx: torch.Tensor # (n_tokens, k) expert indices (int16) + mask_metadata: BitmatrixMetadata + + +# --------------------------------------------------------------------------- +# Host-side bitmatrix metadata construction (torch reference) +# --------------------------------------------------------------------------- + + +def _make_bitmatrix_metadata(indx: torch.Tensor, n_expts: int) -> BitmatrixMetadata: + """Build dispatch/combine indices from the (n_tokens, k) expert-index tensor. + + Follows triton_kernels/tensor_details/bitmatrix.py (optimised convention): + col_sorted_indx[expert_sorted_pos] = original flat index + row_sorted_indx[original_flat_idx] = expert_sorted_pos + + Handles -1 (invalid) entries correctly. + """ + device = indx.device + flat_indx = indx.reshape(-1).to(torch.int32) + n_elements = flat_indx.numel() + + valid = flat_indx >= 0 + n_valid = valid.sum().item() + + col_sum = torch.histc( + flat_indx[valid].float(), + bins=n_expts, + min=0, + max=n_expts - 1, + ).to(torch.int32) + + col_sorted_indx = torch.full((n_elements,), -1, dtype=torch.int32, device=device) + row_sorted_indx = torch.full((n_elements,), -1, dtype=torch.int32, device=device) + + sort_keys = flat_indx.clone().long() + sort_keys[~valid] = n_expts + sorted_order = torch.argsort(sort_keys, stable=True).to(torch.int32) + + col_sorted_indx[:n_valid] = sorted_order[:n_valid] + expert_positions = torch.arange(n_valid, device=device, dtype=torch.int32) + row_sorted_indx.scatter_(0, sorted_order[:n_valid].long(), expert_positions) + + return BitmatrixMetadata( + col_sum=col_sum, + col_sorted_indx=col_sorted_indx, + row_sorted_indx=row_sorted_indx, + ) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def topk( + x: torch.Tensor, + k: int, + apply_softmax: bool = True, +) -> TopkResult: + """Compute top-k routing over expert logits. + + Uses PyTorch ops (matches upstream topk_torch reference). + + Args: + x: (n_tokens, n_expts) float32 logit tensor. + k: number of experts to activate per token. + apply_softmax: whether to softmax the selected values. + + Returns: + TopkResult with vals, indx, and mask_metadata. + """ + n_tokens, n_expts = x.shape + + vals, indx = torch.topk(x.float(), k, dim=1, sorted=True) + + if apply_softmax: + vals = torch.softmax(vals, dim=-1).to(x.dtype) + else: + vals = vals.to(x.dtype) + indx = indx.to(torch.int16) + + mask_metadata = _make_bitmatrix_metadata(indx.to(torch.int32), n_expts) + return TopkResult(vals=vals, indx=indx, mask_metadata=mask_metadata) diff --git a/examples/31_message_passing/example.py b/examples/31_message_passing/example.py new file mode 100644 index 000000000..9b4c233e0 --- /dev/null +++ b/examples/31_message_passing/example.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Example: producer-consumer message passing with Iris load/store. + +Producer rank writes to consumer's buffer; consumer spin-waits on flags then reads. +Requires exactly 2 ranks. + +Run with: + torchrun --nproc_per_node=2 --standalone example.py [--validate] +""" + +import argparse +import os +import random + +import torch +import torch.distributed as dist +import triton +import triton.language as tl + +import iris + + +@triton.jit +def producer_kernel( + source_buffer, # tl.tensor: pointer to source data + target_buffer, # tl.tensor: pointer to target data + flag, # tl.tensor: pointer to flags + buffer_size, # int32: total number of elements + producer_rank: tl.constexpr, + consumer_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers +): + pid = tl.program_id(0) + + # Compute start index of this block + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Guard for out-of-bounds accesses + mask = offsets < buffer_size + + # Load chunk from source buffer + values = iris.load(source_buffer + offsets, producer_rank, producer_rank, heap_bases_ptr, mask=mask) + + # Store chunk to target buffer + iris.store( + target_buffer + offsets, + values, + producer_rank, + consumer_rank, + heap_bases_ptr, + mask=mask, + ) + + # Set flag to signal completion + iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys") + + +@triton.jit +def consumer_kernel( + buffer, # tl.tensor: pointer to shared buffer (read from target_rank) + flag, # tl.tensor: sync flag per block + buffer_size, # int32: total number of elements + consumer_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers +): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < buffer_size + + # Spin-wait until writer sets flag[pid] = 1 + done = 0 + while done == 0: + done = iris.atomic_cas( + flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys" + ) + + # Read from the target buffer (written by producer) + values = iris.load(buffer + offsets, consumer_rank, consumer_rank, heap_bases_ptr, mask=mask) + + # Do something with values... + values = values * 2 + + # Store chunk to target buffer + iris.store( + buffer + offsets, + values, + consumer_rank, + consumer_rank, + heap_bases_ptr, + mask=mask, + ) + + # Optionally reset the flag for next iteration + tl.store(flag + pid, 0) + + +torch.manual_seed(123) +random.seed(123) + + +def torch_dtype_from_str(datatype: str) -> torch.dtype: + dtype_map = { + "fp16": torch.float16, + "fp32": torch.float32, + "int8": torch.int8, + "bf16": torch.bfloat16, + } + try: + return dtype_map[datatype] + except KeyError: + raise ValueError(f"Unknown datatype: {datatype}") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Message passing producer-consumer example (2 ranks).", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "-t", + "--datatype", + type=str, + default="fp32", + choices=["fp16", "fp32", "int8", "bf16"], + help="Datatype of computation", + ) + parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer size") + parser.add_argument("-b", "--block_size", type=int, default=512, help="Block size") + parser.add_argument("--heap_size", type=int, default=1 << 16, help="Iris heap size") + parser.add_argument("-v", "--validate", action="store_true", help="Validate output against reference") + return vars(parser.parse_args()) + + +def main(): + args = parse_args() + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl") + + ctx = iris.iris(heap_size=args["heap_size"]) + cur_rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + if world_size != 2: + raise ValueError("This example requires exactly two processes. Use: torchrun --nproc_per_node=2 ...") + + dtype = torch_dtype_from_str(args["datatype"]) + producer_rank = 0 + consumer_rank = 1 + + # Allocate source and destination buffers on the symmetric heap + source_buffer = ctx.zeros(args["buffer_size"], device="cuda", dtype=dtype) + if dtype.is_floating_point: + destination_buffer = ctx.randn(args["buffer_size"], device="cuda", dtype=dtype) + else: + ii = torch.iinfo(dtype) + destination_buffer = ctx.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype) + + n_elements = source_buffer.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + num_blocks = triton.cdiv(n_elements, args["block_size"]) + flags = ctx.zeros((num_blocks,), device="cuda", dtype=torch.int32) + + heap_bases = ctx.get_heap_bases() + + if cur_rank == producer_rank: + ctx.info(f"Rank {cur_rank} is sending data to rank {consumer_rank}.") + producer_kernel[grid]( + source_buffer, + destination_buffer, + flags, + n_elements, + producer_rank, + consumer_rank, + args["block_size"], + heap_bases, + ) + else: + ctx.info(f"Rank {cur_rank} is receiving data from rank {producer_rank}.") + consumer_kernel[grid]( + destination_buffer, + flags, + n_elements, + consumer_rank, + args["block_size"], + heap_bases, + ) + + ctx.barrier() + ctx.info(f"Rank {cur_rank} has finished sending/receiving data.") + + if args["validate"]: + ctx.info("Validating output...") + if cur_rank == consumer_rank: + expected = source_buffer * 2 + if not torch.allclose(destination_buffer, expected, atol=1): + max_diff = (destination_buffer - expected).abs().max().item() + ctx.error(f"Validation failed. Max absolute difference: {max_diff}") + else: + ctx.info("Validation successful.") + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/common/utils.py b/examples/common/utils.py index f9ebba8d7..d01fa7214 100644 --- a/examples/common/utils.py +++ b/examples/common/utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. -import triton import triton.language as tl import torch @@ -154,16 +153,5 @@ def is_triton_interpret_set(): return "TRITON_INTERPRET" in os.environ -@triton.jit -def read_realtime(): - tmp = tl.inline_asm_elementwise( - asm="""s_waitcnt vmcnt(0) - s_memrealtime $0 - s_waitcnt lgkmcnt(0)""", - constraints=("=s"), - args=[], - dtype=tl.int64, - is_pure=False, - pack=1, - ) - return tmp +# Re-export device utility functions from iris module +# These are kept here for backward compatibility with existing examples diff --git a/iris/__init__.py b/iris/__init__.py index 7345d0fea..02f78d428 100644 --- a/iris/__init__.py +++ b/iris/__init__.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. """ Iris: Multi-GPU Communication and Memory Management Framework @@ -41,11 +41,11 @@ >>> ctx.load(buffer, 1) """ -# __init__.py - from .iris import ( Iris, iris, + DeviceContext, + TraceEvent, load, store, copy, @@ -63,17 +63,19 @@ from .util import ( do_bench, + get_device_id_for_rank, + is_simulation_env, ) -from . import hip +from .tensor_utils import ( + CUDAArrayInterface, + tensor_from_ptr, +) -# Import experimental features (optional, for users who want experimental APIs) +from . import hip from . import experimental - -# Import ops module (fused GEMM+CCL operations) from . import ops - -# Import logging functionality +from . import tensor_creation from .logging import ( set_logger_level, logger, @@ -83,11 +85,12 @@ ERROR, ) -# Launcher functionality is now user code - see examples and documentation - __all__ = [ "Iris", "iris", + "get_device_id_for_rank", + "DeviceContext", + "TraceEvent", "load", "store", "copy", @@ -102,9 +105,12 @@ "atomic_min", "atomic_max", "do_bench", + "CUDAArrayInterface", + "tensor_from_ptr", "hip", - "experimental", # Experimental features including iris_gluon - "ops", # Fused GEMM+CCL operations + "experimental", + "ops", + "tensor_creation", "set_logger_level", "logger", "DEBUG", @@ -112,3 +118,19 @@ "WARNING", "ERROR", ] + +# Patch torch.cuda.set_device to automatically handle device ID wrapping in simulation mode +# Only patch if in simulation mode +if is_simulation_env(): + import torch + + _original_set_device = torch.cuda.set_device + + def _patched_set_device(device): + """Patched version of torch.cuda.set_device that wraps device IDs in simulation mode.""" + num_devices = torch.cuda.device_count() + if num_devices > 0 and isinstance(device, int) and device >= num_devices: + device = device % num_devices + return _original_set_device(device) + + torch.cuda.set_device = _patched_set_device diff --git a/iris/_distributed_helpers.py b/iris/_distributed_helpers.py index a2319062e..9b222375e 100644 --- a/iris/_distributed_helpers.py +++ b/iris/_distributed_helpers.py @@ -57,20 +57,28 @@ def distributed_allgather(data): # Fast path: tensor all_gather if dtype is NCCL-supported or backend != nccl data_tensor = torch.from_numpy(data) - use_tensor_collective = backend != "nccl" or _nccl_dtype_supported(data_tensor) + # Gloo doesn't support uint64, so use object collective for uint64 with gloo + # For int64 with gloo, we can use tensor collective (gloo supports int64) + use_tensor_collective = (backend != "nccl" or _nccl_dtype_supported(data_tensor)) and not ( + backend == "gloo" and data_tensor.dtype == torch.uint64 + ) if use_tensor_collective: data_tensor = data_tensor.to(device) gathered_tensors = [torch.empty_like(data_tensor) for _ in range(world_size)] dist.all_gather(gathered_tensors, data_tensor) - return torch.stack(gathered_tensors, dim=0).to("cpu").numpy() - - # Fallback for NCCL-unsupported dtypes (e.g., uint64/bool/etc.) - obj_list = [None for _ in range(world_size)] - # Use object collective (works across backends) - dist.all_gather_object(obj_list, data) - # Ensure uniform shapes and stack - return np.stack(obj_list, axis=0) + stacked = torch.stack(gathered_tensors, dim=0) + cpu_tensor = stacked.to("cpu") + result = cpu_tensor.numpy() + return result + else: + # Fallback for NCCL-unsupported dtypes or gloo with uint64 (e.g., uint64/bool/etc.) + obj_list = [None for _ in range(world_size)] + # Use object collective (works across backends) + dist.all_gather_object(obj_list, data) + # Ensure uniform shapes and stack + result = np.stack(obj_list, axis=0) + return result def distributed_allgather_multidim(data): diff --git a/iris/allocators/__init__.py b/iris/allocators/__init__.py index 8e824c57a..460c53d39 100644 --- a/iris/allocators/__init__.py +++ b/iris/allocators/__init__.py @@ -7,5 +7,6 @@ from .base import BaseAllocator from .torch_allocator import TorchAllocator +from .vmem_allocator import VMemAllocator -__all__ = ["BaseAllocator", "TorchAllocator"] +__all__ = ["BaseAllocator", "TorchAllocator", "VMemAllocator"] diff --git a/iris/allocators/base.py b/iris/allocators/base.py index ed87418cd..4cfefabbc 100644 --- a/iris/allocators/base.py +++ b/iris/allocators/base.py @@ -6,7 +6,6 @@ """ from abc import ABC, abstractmethod -from typing import Optional, Dict, Any import torch @@ -14,8 +13,8 @@ class BaseAllocator(ABC): """ Abstract base class for Iris memory allocators. - Allocators manage GPU memory for the symmetric heap and handle - inter-process memory sharing. + Allocators manage GPU memory allocation for a single device. + Inter-process coordination is handled by SymmetricHeap. """ def __init__(self, heap_size: int, device_id: int, cur_rank: int, num_ranks: int): @@ -34,6 +33,17 @@ def __init__(self, heap_size: int, device_id: int, cur_rank: int, num_ranks: int self.num_ranks = num_ranks self.heap_offset = 0 + def get_minimum_allocation_size(self) -> int: + """ + Minimum size in bytes for a single allocation. + Callers must request at least this many bytes (or the allocator will bump); + the allocator uses this for tracking actual size for deallocation. + + Returns: + Minimum allocation size in bytes (default 0). + """ + return 0 + @abstractmethod def get_base_address(self) -> int: """ @@ -59,27 +69,6 @@ def allocate(self, num_elements: int, dtype: torch.dtype, alignment: int = 1024) """ pass - @abstractmethod - def get_shareable_handle(self) -> Any: - """ - Get a shareable handle for inter-process communication. - - Returns: - Shareable handle (implementation-specific: FD, IPC handle, etc.) - """ - pass - - @abstractmethod - def establish_peer_access(self, all_bases: Dict[int, int], connections: Optional[Dict] = None): - """ - Establish access to peer memory for symmetric addressing. - - Args: - all_bases: Dictionary mapping rank -> base address - connections: Optional peer connections for handle exchange - """ - pass - @abstractmethod def get_device(self) -> torch.device: """ @@ -90,16 +79,6 @@ def get_device(self) -> torch.device: """ pass - @abstractmethod - def get_heap_bases(self) -> torch.Tensor: - """ - Get heap base addresses for all ranks as a tensor. - - Returns: - Tensor of shape (num_ranks,) with base addresses - """ - pass - @abstractmethod def owns_tensor(self, tensor: torch.Tensor) -> bool: """ diff --git a/iris/allocators/torch_allocator.py b/iris/allocators/torch_allocator.py index 301e68554..3bc428922 100644 --- a/iris/allocators/torch_allocator.py +++ b/iris/allocators/torch_allocator.py @@ -15,8 +15,9 @@ import struct from .base import BaseAllocator -from iris.hip import export_dmabuf_handle, import_dmabuf_handle +from iris.hip import export_dmabuf_handle, import_dmabuf_handle, destroy_external_memory from iris.fd_passing import send_fd, recv_fd, managed_fd +from iris.util import is_simulation_env class TorchAllocator(BaseAllocator): @@ -40,8 +41,34 @@ def __init__(self, heap_size: int, device_id: int, cur_rank: int, num_ranks: int super().__init__(heap_size, device_id, cur_rank, num_ranks) self.device = f"cuda:{device_id}" - self.memory_pool = torch.empty(heap_size, device=self.device, dtype=torch.int8) - self.heap_bases_array = None # Will be set in establish_peer_access + if is_simulation_env(): + import json + + # In simulation, each rank allocates n distinct buffers; memory_pool is a shallow view of the ith. + self.rank_bools = [torch.empty(heap_size, device=self.device, dtype=torch.int8) for _ in range(num_ranks)] + self.memory_pool = self.rank_bools[cur_rank] + + heap_views = [self.rank_bools[r].data_ptr() for r in range(num_ranks)] + out_path = f"iris_rank_{cur_rank}_allocator_views.json" + with open(out_path, "w") as f: + json.dump( + { + "rank": cur_rank, + "num_ranks": num_ranks, + "heap_views": [hex(b) for b in heap_views], + }, + f, + indent=2, + ) + else: + self.rank_bools = None + self.memory_pool = torch.empty(heap_size, device=self.device, dtype=torch.int8) + + self._peer_ext_mem_handles: Dict[int, object] = {} + + def get_minimum_allocation_size(self) -> int: + """Minimum allocation size in bytes (PyTorch allows 0-size views).""" + return 0 def get_base_address(self) -> int: """Get the base address of the memory pool.""" @@ -93,65 +120,83 @@ def establish_peer_access(self, all_bases: Dict[int, int], connections: Optional all_bases: Dictionary mapping rank -> base address connections: Optional peer connections for handle exchange """ - # Use the original heap bases (no remapping for TorchAllocator) heap_bases_array = np.zeros(self.num_ranks, dtype=np.uint64) if connections is not None: - # Get shareable handle for our memory pool + for handle in self._peer_ext_mem_handles.values(): + try: + destroy_external_memory(handle) + except Exception: + pass + self._peer_ext_mem_handles.clear() + my_fd, my_base, my_size = self.get_shareable_handle() heap_base = self.get_base_address() - - # Pack metadata: (base_ptr, base_size, heap_ptr) as three 64-bit unsigned ints my_metadata = struct.pack("QQQ", my_base, my_size, heap_base) - # Use context manager for automatic cleanup with managed_fd(my_fd): - # Exchange handles with all peers for peer, sock in connections.items(): if peer == self.cur_rank: continue - # To avoid deadlock, higher rank sends first - # Send FD along with metadata (base_ptr, base_size, heap_ptr) + # Higher rank sends first to avoid deadlock if self.cur_rank > peer: send_fd(sock, my_fd, payload=my_metadata) - peer_handle, peer_metadata = recv_fd(sock, payload_size=24) # 3 * 8 bytes + peer_handle, peer_metadata = recv_fd(sock, payload_size=24) else: - peer_handle, peer_metadata = recv_fd(sock, payload_size=24) # 3 * 8 bytes + peer_handle, peer_metadata = recv_fd(sock, payload_size=24) send_fd(sock, my_fd, payload=my_metadata) - # Unpack peer's metadata peer_base, peer_size, peer_heap = struct.unpack("QQQ", peer_metadata) - # Use context manager for peer handle and import the DMA-BUF with managed_fd(peer_handle): - # Import peer's memory via DMA-BUF with proper offset correction - # peer_heap is where their heap starts (what they want us to use) - # peer_base is the base of their allocation buffer - # peer_size is the size of their allocation buffer - mapped_addr = import_dmabuf_handle( - peer_handle, - peer_size, # Import the full base allocation - peer_heap, # Original heap pointer (for offset calculation) - peer_base, # Base of allocation (for offset calculation) - ) - heap_bases_array[peer] = mapped_addr - - # Set our own base + mapped_ptr, ext_mem_handle = import_dmabuf_handle(peer_handle, peer_size, peer_heap, peer_base) + heap_bases_array[peer] = mapped_ptr + self._peer_ext_mem_handles[peer] = ext_mem_handle + heap_bases_array[self.cur_rank] = all_bases[self.cur_rank] else: - # Single rank, just set our own base heap_bases_array[self.cur_rank] = all_bases[self.cur_rank] self.heap_bases_array = heap_bases_array + def close(self): + """Release peer external memory handles.""" + for handle in self._peer_ext_mem_handles.values(): + try: + destroy_external_memory(handle) + except Exception: + pass + self._peer_ext_mem_handles.clear() + def get_device(self) -> torch.device: """Get the torch device.""" return self.memory_pool.device - def get_heap_bases(self) -> torch.Tensor: - """Get heap base addresses as a tensor.""" - return torch.from_numpy(self.heap_bases_array).to(device=self.device, dtype=torch.uint64) + def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: + """ + Place an external tensor's data on the symmetric heap by copying. + + Unlike the VMem allocator, this does not share memory with the external + tensor: it allocates on the heap and copies. Subsequent changes to the + external tensor are not visible in the returned tensor. + + Args: + external_tensor: External PyTorch tensor to copy from (must be CUDA, contiguous) + + Returns: + New tensor on the symmetric heap with the same data and shape. + """ + if not external_tensor.is_cuda: + raise RuntimeError("Can only import CUDA tensors") + if not external_tensor.is_contiguous(): + raise RuntimeError("Only contiguous tensors can be imported; call .contiguous() before as_symmetric()") + num_elements = external_tensor.numel() + dtype = external_tensor.dtype + shape = external_tensor.shape + heap_tensor = self.allocate(num_elements, dtype) + heap_tensor = heap_tensor.reshape(shape).copy_(external_tensor) + return heap_tensor def owns_tensor(self, tensor: torch.Tensor) -> bool: """ @@ -163,7 +208,6 @@ def owns_tensor(self, tensor: torch.Tensor) -> bool: Returns: True if tensor is within the heap, False otherwise """ - # Special case for empty tensors - they might not have a valid data_ptr if tensor.numel() == 0: return True diff --git a/iris/allocators/vmem_allocator.py b/iris/allocators/vmem_allocator.py new file mode 100644 index 000000000..e5427edff --- /dev/null +++ b/iris/allocators/vmem_allocator.py @@ -0,0 +1,323 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +VMem-based allocator using HIP's virtual memory management APIs. + +This allocator provides fine-grained control over virtual and physical memory, +enabling features like memory oversubscription and on-demand paging. +""" + +import torch +import os +from typing import Dict +from threading import Lock + +from .base import BaseAllocator +from ..hip import ( + get_allocation_granularity, + get_address_range, + export_dmabuf_handle, + mem_import_from_shareable_handle, + mem_create, + mem_address_reserve, + mem_map, + mem_unmap, + mem_address_free, + mem_release, + mem_set_access, + hipMemAccessDesc, + hipMemLocationTypeDevice, + hipMemAccessFlagsProtReadWrite, +) + + +class VMemAllocator(BaseAllocator): + """ + Virtual Memory allocator using HIP's VMem APIs. + + Features: + - Reserve large virtual address (VA) space upfront + - Map physical memory on demand + - Support memory oversubscription + - Fine-grained control over allocations + + Args: + heap_size: Total size of the heap in bytes + device: PyTorch device (e.g., "cuda:0") + rank: Current rank ID + world_size: Total number of ranks + va_multiplier: VA space multiplier (reserve more VA than physical) + """ + + def __init__( + self, + heap_size: int, + device_id: int, + rank: int, + world_size: int, + va_multiplier: float = 1.0, + ): + super().__init__(heap_size, device_id, rank, world_size) + self.va_multiplier = va_multiplier + self.device = torch.device(f"cuda:{device_id}") + self.lock = Lock() + self.granularity = get_allocation_granularity(self.device_id) + self.aligned_heap_size = (heap_size + self.granularity - 1) & ~(self.granularity - 1) + self.va_size = self.aligned_heap_size + self.base_va = mem_address_reserve(self.va_size, self.granularity, 0) + + self.minimal_size = min(2 << 20, self.aligned_heap_size // 2) + if self.minimal_size < self.granularity: + self.minimal_size = self.granularity + + self.minimal_handle = mem_create(self.minimal_size, self.device_id) + mem_map(self.base_va, self.minimal_size, 0, self.minimal_handle) + + # ROCm: mem_set_access must be called cumulatively from base_va (see rocm-systems#2667) + self.access_descs = [] + for peer_device_id in range(world_size): + desc = hipMemAccessDesc() + desc.location.type = hipMemLocationTypeDevice + desc.location.id = peer_device_id + desc.flags = hipMemAccessFlagsProtReadWrite + self.access_descs.append(desc) + + self.cumulative_mapped_size = self.minimal_size + mem_set_access(self.base_va, self.cumulative_mapped_size, self.access_descs) + + self.allocations: Dict[int, tuple] = {} + self.allocation_order = [] + self._track_allocation(0, self.minimal_size, False, self.minimal_handle, self.base_va) + self.current_offset = self.minimal_size + + self.world_size = world_size + + def get_base_address(self) -> int: + """Get the base address of the heap.""" + return self.base_va + + def _track_allocation(self, offset: int, size: int, is_imported: bool, handle, va: int): + """Track a new allocation for cleanup and segmented export.""" + self.allocations[offset] = (size, is_imported, handle, va) + self.allocation_order.append((offset, size)) + + def get_allocation_segments(self): + """ + Get list of allocation segments for segmented DMA-BUF export. + + Returns: + List of (offset, size, va) tuples for each allocation in order. + Each tuple describes one physically-backed segment that needs + to be exported/imported separately. + """ + segments = [] + for offset, size in self.allocation_order: + va = self.base_va + offset + segments.append((offset, size, va)) + return segments + + def get_minimum_allocation_size(self) -> int: + """Minimum allocation size in bytes (one granule; hipMemCreate(0) is invalid).""" + return self.granularity + + def allocate(self, num_elements: int, dtype: torch.dtype, alignment: int = 1024) -> torch.Tensor: + """ + Allocate memory from the VMem heap. + + Args: + num_elements: Number of elements to allocate + dtype: PyTorch data type + alignment: Alignment requirement in bytes + + Returns: + PyTorch tensor wrapping the allocated memory + + Raises: + RuntimeError: If allocation fails or heap is full + """ + with self.lock: + element_size = torch.tensor([], dtype=dtype).element_size() + size_bytes = num_elements * element_size + actual_size_bytes = max(size_bytes, self.get_minimum_allocation_size()) + aligned_size = (actual_size_bytes + self.granularity - 1) & ~(self.granularity - 1) + aligned_offset = (self.current_offset + alignment - 1) & ~(alignment - 1) + + if aligned_offset + aligned_size > self.aligned_heap_size: + raise RuntimeError( + f"Out of VMem address space for allocation: " + f"need {aligned_size} bytes at offset {aligned_offset}, " + f"but heap size is {self.aligned_heap_size}. " + f"Current offset: {self.current_offset}, " + f"available: {self.aligned_heap_size - aligned_offset} bytes" + ) + + target_va = self.base_va + aligned_offset + handle = mem_create(aligned_size, self.device_id) + mem_map(target_va, aligned_size, 0, handle) + + new_cumulative_size = aligned_offset + aligned_size + if new_cumulative_size > self.cumulative_mapped_size: + self.cumulative_mapped_size = new_cumulative_size + mem_set_access(self.base_va, self.cumulative_mapped_size, self.access_descs) + + self._track_allocation(aligned_offset, aligned_size, False, handle, target_va) + self.current_offset = aligned_offset + aligned_size + + interface_size = (aligned_size // element_size) * element_size + + class CUDAArrayInterface: + def __init__(self, ptr, size_bytes, device): + self.ptr = ptr + self.size_bytes = size_bytes + self.device = device + + @property + def __cuda_array_interface__(self): + return { + "shape": (self.size_bytes,), + "typestr": "|u1", + "data": (self.ptr, False), + "version": 3, + } + + cuda_array = CUDAArrayInterface(target_va, interface_size, self.device) + tensor_bytes = torch.as_tensor(cuda_array, device=self.device) + full = tensor_bytes.view(dtype) + if num_elements == 0: + tensor = full.narrow(0, 1, 0) + else: + tensor = full.narrow(0, 0, num_elements) + return tensor + + def get_device(self) -> torch.device: + """ + Get the PyTorch device for this allocator. + + Returns: + PyTorch device object + """ + return self.device + + def owns_tensor(self, tensor: torch.Tensor) -> bool: + """ + Check if a tensor's memory belongs to this allocator's heap. + + Args: + tensor: Tensor to check + + Returns: + True if tensor is within this allocator's heap, False otherwise + """ + if not tensor.is_cuda: + return False + if tensor.numel() == 0: + return True + + ptr = tensor.data_ptr() + return self.base_va <= ptr < self.base_va + self.aligned_heap_size + + def import_external_tensor(self, external_tensor: torch.Tensor) -> torch.Tensor: + """ + Import an external PyTorch tensor into the symmetric heap (as_symmetric). + + This creates a view into the symmetric heap that shares physical memory + with the external tensor, handling PyTorch caching allocator offsets. + + Args: + external_tensor: External PyTorch tensor to import + + Returns: + New tensor view in symmetric heap that shares memory with external tensor + + Raises: + RuntimeError: If import fails or tensor is not contiguous + """ + + with self.lock: + if not external_tensor.is_cuda: + raise RuntimeError("Can only import CUDA tensors") + if not external_tensor.is_contiguous(): + raise RuntimeError("Only contiguous tensors can be imported; call .contiguous() before as_symmetric()") + + external_ptr = external_tensor.data_ptr() + alloc_base, alloc_size = get_address_range(external_ptr) + offset_in_alloc = external_ptr - alloc_base + aligned_size = (alloc_size + self.granularity - 1) & ~(self.granularity - 1) + aligned_offset = (self.current_offset + self.granularity - 1) & ~(self.granularity - 1) + + if aligned_offset + aligned_size > self.aligned_heap_size: + raise RuntimeError( + f"Out of VMem address space for import: " + f"need {aligned_size} bytes at offset {aligned_offset}, " + f"but heap size is {self.aligned_heap_size}. " + f"Current offset: {self.current_offset}, " + f"available: {self.aligned_heap_size - aligned_offset} bytes" + ) + + dmabuf_fd, export_base, export_size = export_dmabuf_handle(alloc_base, alloc_size) + aligned_export_size = (export_size + self.granularity - 1) & ~(self.granularity - 1) + target_va = self.base_va + aligned_offset + imported_handle = mem_import_from_shareable_handle(dmabuf_fd) + os.close(dmabuf_fd) + + mem_map(target_va, aligned_export_size, 0, imported_handle) + + new_cumulative_size = aligned_offset + aligned_export_size + if new_cumulative_size > self.cumulative_mapped_size: + self.cumulative_mapped_size = new_cumulative_size + mem_set_access(self.base_va, self.cumulative_mapped_size, self.access_descs) + + tensor_va = target_va + offset_in_alloc + self._track_allocation(aligned_offset, aligned_export_size, True, imported_handle, target_va) + self.current_offset = aligned_offset + aligned_export_size + + tensor_size = external_tensor.numel() * external_tensor.element_size() + + class CUDAArrayInterface: + def __init__(self, ptr, size_bytes, device): + self.ptr = ptr + self.size_bytes = size_bytes + self.device = device + + @property + def __cuda_array_interface__(self): + return { + "shape": (self.size_bytes,), + "typestr": "|u1", + "data": (self.ptr, False), + "version": 3, + } + + cuda_array = CUDAArrayInterface(tensor_va, tensor_size, self.device) + tensor_bytes = torch.as_tensor(cuda_array, device=self.device) + imported_tensor = tensor_bytes.view(external_tensor.dtype).reshape(external_tensor.shape) + + return imported_tensor + + def close(self): + """Explicitly release VMem resources.""" + if hasattr(self, "_closed") and self._closed: + return + + with self.lock: + for offset, alloc_info in self.allocations.items(): + if len(alloc_info) == 4: + size, is_imported, handle, va = alloc_info + + if handle is not None: + aligned_size = (size + self.granularity - 1) & ~(self.granularity - 1) + mem_unmap(va, aligned_size) + mem_release(handle) + + self.allocations.clear() + + if hasattr(self, "base_va") and self.base_va: + mem_address_free(self.base_va, self.va_size) + self.base_va = 0 + + self._closed = True + + def __del__(self): + """Cleanup VMem resources on deletion.""" + self.close() diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index dde64dce9..190c96072 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -137,9 +137,148 @@ def persistent_all_gather( target_rank, heap_bases, mask=combined_mask, + hint=(1, BLOCK_SIZE_N), ) +@triton.jit() +def persistent_all_gather_partitioned( + input_ptr, + output_ptr, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases: tl.tensor, + group_rank: tl.constexpr, + iris_rank: tl.constexpr, + world_size: tl.constexpr, + rank_start: tl.constexpr, + rank_stride: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + COMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + CHUNK_SIZE: tl.constexpr, +): + """ + Persistent all-gather kernel with rank-partitioned work distribution. + + Each PID is assigned to work on a specific destination rank, and multiple PIDs + partition the tiles for that rank. This avoids the inner loop over world_size. + + Work distribution: + - PIDs are partitioned across destination ranks + - PIDs_per_rank = COMM_SMS // world_size + - Each group of PIDs handles all tiles for one destination rank + - Within each rank group, PIDs partition the tiles + + Args: + input_ptr: Pointer to input tensor (local rank's data to send) of shape (M, N) + output_ptr: Pointer to output tensor (will receive from all ranks) of shape (world_size * M, N) + M: Number of rows per rank (output will be world_size * M rows) + N: Number of columns + stride_in_m, stride_in_n: Strides for input tensor + stride_out_m, stride_out_n: Strides for output tensor + heap_bases: Heap base pointers for all ranks + group_rank: Rank within the ProcessGroup (0 to group_size-1), used for tile assignment and comparisons + iris_rank: Rank in the iris context, used for iris RMA operations (heap_bases indexing) + world_size: Total number of ranks in the group + BLOCK_SIZE_M, BLOCK_SIZE_N: Block sizes for tiling + GROUP_SIZE_M: Group size for M dimension tiling + COMM_SMS: Number of SMs for communication (must be divisible by world_size) + NUM_XCDS: Number of XCDs + CHUNK_SIZE: Chunk size for chiplet transform + """ + pid = tl.program_id(0) + + # Partition PIDs across destination ranks + pids_per_rank = COMM_SMS // world_size + dest_rank_idx = pid // pids_per_rank # Which destination rank this PID works on (0 to world_size-1) + pid_in_rank_group = pid % pids_per_rank # Which PID within the rank group + + # Compute the actual target rank + target_rank = rank_start + dest_rank_idx * rank_stride + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + tl.assume(total_tiles > 0) + + # Iterate over tiles with this PID's offset and stride within the rank group + for tile_id in range(pid_in_rank_group, total_tiles, pids_per_rank): + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(tile_id >= 0) + tl.assume(stride_in_m >= 0) + tl.assume(stride_in_n >= 0) + tl.assume(stride_out_m >= 0) + tl.assume(stride_out_n >= 0) + + # Compute local row and column indices for input tensor + rm_base = pid_m * BLOCK_SIZE_M + rn_base = pid_n * BLOCK_SIZE_N + rm_input = rm_base + tl.arange(0, BLOCK_SIZE_M) + rn = rn_base + tl.arange(0, BLOCK_SIZE_N) + rm_input = tl.max_contiguous(tl.multiple_of(rm_input, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + # Mask for local input bounds + input_mask = (rm_input[:, None] < M) & (rn[None, :] < N) + + # Compute input offset and load local shard data once + input_base_m = rm_input[:, None] * stride_in_m + input_base_n = rn[None, :] * stride_in_n + input_offset = input_base_m + input_base_n + input_ptr_source = input_ptr + input_offset + input_ptr_source = tl.multiple_of(input_ptr_source, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + # Load local input data once for this tile + data = tl.load(input_ptr_source, mask=input_mask, other=0.0) + + # Compute global output row indices: offset by group_rank * M + rm_output = rm_input + group_rank * M + + # Output mask: only write where input was valid + output_mask = (rm_output[:, None] < (group_rank + 1) * M) & (rn[None, :] < N) + + # Combine masks: must be valid in both input and output + combined_mask = input_mask & output_mask + + # Compute output offset + output_base_m = rm_output[:, None] * stride_out_m + output_base_n = rn[None, :] * stride_out_n + output_offset = output_base_m + output_base_n + output_ptr_target = output_ptr + output_offset + output_ptr_target = tl.multiple_of(output_ptr_target, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + # Send to the assigned destination rank + if dest_rank_idx == group_rank: + # Local destination: use direct store + tl.store(output_ptr_target, data, mask=combined_mask, cache_modifier=".wt") + else: + # Remote destination: use iris.store to send data to remote destination + iris.store( + output_ptr_target, + data, + iris_rank, + target_rank, + heap_bases, + mask=combined_mask, + hint=(1, BLOCK_SIZE_N), + ) + + def all_gather( output_tensor, input_tensor, @@ -169,6 +308,7 @@ def all_gather( Default: False. config: Config instance with kernel parameters (default: None). If None, uses default Config values. + Set config.all_gather_variant to choose variant: "persistent" or "partitioned" """ # Use provided config or create default one if config is None: @@ -187,6 +327,13 @@ def all_gather( # rank_global: global rank in iris context - passed as iris_rank to kernel for RMA operations rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, shmem) + # Validate COMM_SMS divisibility for partitioned variant + if config.all_gather_variant == "partitioned" and config.comm_sms % world_size != 0: + raise ValueError( + f"For all_gather_variant='partitioned', COMM_SMS ({config.comm_sms}) must be divisible by world_size ({world_size}). " + f"Please adjust config.comm_sms to be a multiple of {world_size}." + ) + M, N = input_tensor.shape[:2] expected_output_shape = (world_size * M, N) @@ -201,7 +348,15 @@ def all_gather( heap_bases = shmem.get_heap_bases() - persistent_all_gather[(config.comm_sms,)]( + # Dispatch to the appropriate kernel based on variant + if config.all_gather_variant == "persistent": + kernel_fn = persistent_all_gather + elif config.all_gather_variant == "partitioned": + kernel_fn = persistent_all_gather_partitioned + else: + raise ValueError(f"Unknown all_gather_variant: {config.all_gather_variant}") + + kernel_fn[(config.comm_sms,)]( input_tensor, output_tensor, M, diff --git a/iris/ccl/all_reduce.py b/iris/ccl/all_reduce.py index a0d445215..8503907a5 100644 --- a/iris/ccl/all_reduce.py +++ b/iris/ccl/all_reduce.py @@ -334,6 +334,7 @@ def persistent_all_reduce_spinlock( dest_rank, heap_bases, mask=mask, + hint=(1, BLOCK_SIZE_N), ) # Release lock for this tile at dest_rank @@ -539,6 +540,7 @@ def persistent_all_reduce_ring( next_rank, heap_bases, mask=mask, + hint=(1, BLOCK_SIZE_N), ) tl.debug_barrier() iris.atomic_xchg( @@ -668,7 +670,7 @@ def persistent_all_reduce_two_shot( remote_rank_idx = (start_rank_idx + i) % world_size remote_rank = rank_start + remote_rank_idx * rank_stride if remote_rank_idx != group_rank: - iris.store(out_ptr, reduced, iris_rank, remote_rank, heap_bases) + iris.store(out_ptr, reduced, iris_rank, remote_rank, heap_bases, hint=(1, BLOCK_SIZE_N)) # Slow path: MASKED (only boundary tiles land here) # This path handles tiles at tensor boundaries where not all elements are valid. @@ -691,7 +693,15 @@ def persistent_all_reduce_two_shot( remote_rank_idx = (start_rank_idx + i) % world_size remote_rank = rank_start + remote_rank_idx * rank_stride if remote_rank_idx != group_rank: - iris.store(out_ptr, reduced, iris_rank, remote_rank, heap_bases, mask=mask) + iris.store( + out_ptr, + reduced, + iris_rank, + remote_rank, + heap_bases, + mask=mask, + hint=(1, BLOCK_SIZE_N), + ) def all_reduce( diff --git a/iris/ccl/all_to_all.py b/iris/ccl/all_to_all.py index 9010ef066..9ff16a1bd 100644 --- a/iris/ccl/all_to_all.py +++ b/iris/ccl/all_to_all.py @@ -144,6 +144,7 @@ def persistent_all_to_all( iris_rank, target_rank, heap_bases, + hint=(1, BLOCK_SIZE_N), ) # Slow path: MASKED (only boundary tiles land here) @@ -183,6 +184,7 @@ def persistent_all_to_all( target_rank, heap_bases, mask=mask, + hint=(1, BLOCK_SIZE_N), ) diff --git a/iris/ccl/config.py b/iris/ccl/config.py index ce29f4c27..bb84c2ea2 100644 --- a/iris/ccl/config.py +++ b/iris/ccl/config.py @@ -31,6 +31,10 @@ class Config: num_xcds: Number of XCCs. If None, auto-detected from system (default: None) use_gluon: If True, use Gluon-based implementation (default: False) Gluon provides better control over warp-level traffic shaping + all_gather_variant: Variant for all-gather operation (default: "persistent") + Options: "persistent", "partitioned" + - "persistent": Each PID handles multiple tiles and sends to all ranks + - "partitioned": PIDs partitioned across ranks, eliminates inner loop all_reduce_variant: Variant for all-reduce operation (default: "atomic") Options: "atomic", "ring", "two_shot", "one_shot", "spinlock" all_reduce_distribution: Distribution for two-shot all-reduce (default: 0) @@ -57,6 +61,10 @@ class Config: >>> # All-reduce with ring variant >>> config = Config(all_reduce_variant="ring") >>> shmem.ccl.all_reduce(output_tensor, input_tensor, config=config) + + >>> # All-gather with partitioned variant + >>> config = Config(all_gather_variant="partitioned") + >>> shmem.ccl.all_gather(output_tensor, input_tensor, config=config) """ block_size_m: int = 32 @@ -66,6 +74,7 @@ class Config: num_xcds: int | None = None chunk_size: int | None = None use_gluon: bool = False + all_gather_variant: str = "persistent" all_reduce_variant: str = "two_shot" all_reduce_distribution: int = 1 all_reduce_num_rings: int = 1 @@ -91,6 +100,10 @@ def __post_init__(self): raise ValueError(f"comm_sms must be positive, got {self.comm_sms}") if self.num_xcds <= 0: raise ValueError(f"num_xcds must be positive, got {self.num_xcds}") + if self.all_gather_variant not in ["persistent", "partitioned"]: + raise ValueError( + f"all_gather_variant must be one of: 'persistent', 'partitioned', got {self.all_gather_variant}" + ) if self.all_reduce_variant not in ["atomic", "ring", "two_shot", "one_shot", "spinlock"]: raise ValueError( f"all_reduce_variant must be one of: 'atomic', 'ring', 'two_shot', 'one_shot', 'spinlock', got {self.all_reduce_variant}" diff --git a/iris/device_utils.py b/iris/device_utils.py new file mode 100644 index 000000000..1e328ebcf --- /dev/null +++ b/iris/device_utils.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Device-side utility functions for Iris. + +This module provides low-level device intrinsics for accessing hardware +information and timing within Triton kernels. +""" + +import triton +import triton.language as tl + + +@triton.jit +def read_realtime(): + """ + Read GPU wall clock timestamp from s_memrealtime. + + Returns a 64-bit timestamp from a constant 100MHz clock (not affected + by power modes or core clock frequency changes). + + Returns: + int64: Current timestamp in cycles (100MHz constant clock) + """ + tmp = tl.inline_asm_elementwise( + asm="""s_waitcnt vmcnt(0) + s_memrealtime $0 + s_waitcnt lgkmcnt(0)""", + constraints=("=s"), + args=[], + dtype=tl.int64, + is_pure=False, + pack=1, + ) + return tmp + + +@triton.jit +def get_xcc_id(): + """ + Get XCC (GPU chiplet) ID. + + Returns: + int32: XCC ID for the current execution + """ + xcc_id = tl.inline_asm_elementwise( + asm="s_getreg_b32 $0, hwreg(HW_REG_XCC_ID, 0, 16)", + constraints=("=s"), + args=[], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + return xcc_id + + +@triton.jit +def get_cu_id(): + """ + Get Compute Unit ID. + + Returns: + int32: CU ID for the current execution + """ + cu_id = tl.inline_asm_elementwise( + asm="s_getreg_b32 $0, hwreg(HW_REG_HW_ID, 8, 4)", + constraints=("=s"), + args=[], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + return cu_id + + +@triton.jit +def get_se_id(): + """ + Get Shader Engine ID. + + Returns: + int32: SE ID for the current execution + """ + se_id = tl.inline_asm_elementwise( + asm="s_getreg_b32 $0, hwreg(HW_REG_HW_ID, 13, 3)", + constraints=("=s"), + args=[], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + return se_id diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index f2dc36080..8aead7c41 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -52,13 +52,15 @@ ) from iris.symmetric_heap import SymmetricHeap import numpy as np -import math import torch import logging # Import logging functionality from the separate logging module from ..logging import logger +# Import shared tensor-creation helpers +from .. import tensor_creation + @aggregate class IrisDeviceCtx: @@ -801,54 +803,6 @@ def broadcast(self, data, src_rank=0): else: return distributed_broadcast_scalar(data, src_rank) - def __allocate(self, num_elements, dtype): - """Internal method to allocate memory from the symmetric heap.""" - self.debug(f"allocate: num_elements = {num_elements}, dtype = {dtype}") - return self.heap.allocate(num_elements, dtype) - - def __parse_size(self, size): - """Parse size parameter and calculate number of elements.""" - # Handle nested tuples/lists by flattening them recursively - while len(size) == 1 and isinstance(size[0], (tuple, list)): - size = size[0] - num_elements = math.prod(size) - return size, num_elements - - def __throw_if_invalid_device(self, device): - """Check if the requested device is compatible with this Iris instance.""" - if not self.__is_valid_device(device): - raise ValueError( - f"Requested device {device} does not match Iris device {self.get_device()}. " - f"All Iris tensors must be on the same device as the Iris symmetric heap." - ) - - def __is_valid_device(self, device) -> bool: - """Check if the requested device is compatible with this Iris instance.""" - if device is None: - return True # None means use default device - - # Convert device strings to torch.device objects for proper comparison - requested_device = torch.device(device) if isinstance(device, str) else device - iris_device = self.get_device() - - # Check if both are CUDA devices - if requested_device.type == "cuda" and iris_device.type == "cuda": - # Check if index matches or if requested is "cuda" (any index) - if requested_device.index is None: - return True - else: - return requested_device.index == iris_device.index - - # For non-CUDA devices, always return False - return False - - def __apply_layout(self, tensor, layout): - """Apply the requested layout to the tensor.""" - if layout == torch.strided: - return tensor - else: - raise ValueError(f"Unsupported layout: {layout}") - def zeros( self, *size, @@ -871,37 +825,16 @@ def zeros( Returns: torch.Tensor: Zero-initialized tensor on the symmetric heap """ - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # Allocate memory from symmetric heap - tensor = self.__allocate(num_elements, dtype) - - # Zero-initialize - tensor.zero_() - - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor + return tensor_creation.zeros( + self.heap, + self.get_device(), + size, + out=out, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) def ones( self, @@ -937,44 +870,16 @@ def ones( >>> print(tensor.shape) # torch.Size([2, 3]) >>> print(tensor[0]) # tensor([1., 1., 1.], device='cuda:0') """ - self.debug(f"ones: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") - - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Fill with ones - out.fill_(1) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Fill with ones - tensor.fill_(1) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor + return tensor_creation.ones( + self.heap, + self.get_device(), + size, + out=out, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) def full( self, @@ -1012,54 +917,18 @@ def full( >>> print(tensor.shape) # torch.Size([2, 3]) >>> print(tensor[0]) # tensor([3.1400, 3.1400, 3.1400], device='cuda:0') """ - self.debug( - f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + return tensor_creation.full( + self.heap, + self.get_device(), + size, + fill_value, + out=out, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, ) - # Infer dtype from fill_value if not provided - if dtype is None: - if isinstance(fill_value, (int, float)): - if isinstance(fill_value, float): - dtype = torch.get_default_dtype() - else: - dtype = torch.int64 - else: - # For other types (like tensors), use their dtype - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Fill with the specified value - out.fill_(fill_value) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Fill with the specified value - tensor.fill_(fill_value) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - def zeros_like( self, input, @@ -1094,55 +963,43 @@ def zeros_like( >>> zeros_tensor = ctx.zeros_like(input_tensor) >>> print(zeros_tensor.shape) # torch.Size([2, 3]) """ - self.debug( - f"zeros_like: input_shape = {input.shape}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + return tensor_creation.zeros_like( + self.heap, + self.get_device(), + input, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + memory_format=memory_format, ) - # Use input's properties as defaults if not specified - if dtype is None: - dtype = input.dtype - if layout is None: - layout = input.layout - if device is None: - device = input.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Get the size from input tensor - size = input.size() - num_elements = input.numel() - - # Allocate new tensor with the same size - new_tensor = self.__allocate(num_elements, dtype) - new_tensor.zero_() - - # Reshape to match input size - new_tensor = new_tensor.reshape(size) - - # Apply the requested layout - new_tensor = self.__apply_layout(new_tensor, layout) - - # Set requires_grad if specified - if requires_grad: - new_tensor.requires_grad_() - - return new_tensor + def is_symmetric(self, tensor: torch.Tensor) -> bool: + """ + Check if a tensor is allocated on the symmetric heap. - def __throw_if_invalid_output_tensor(self, out, num_elements, dtype): - """Check if the output tensor is valid.""" - if out.numel() != num_elements: - raise RuntimeError(f"The output tensor has {out.numel()} elements, but {num_elements} are required") + This method checks whether a tensor resides in the symmetric heap, making it + accessible for RMA operations across ranks. Use this to validate tensors before + performing distributed operations. - if out.dtype != dtype: - raise RuntimeError(f"The output tensor has dtype {out.dtype}, but {dtype} is required") + Args: + tensor (torch.Tensor): PyTorch tensor to check - if not self.__on_symmetric_heap(out): - raise RuntimeError("The output tensor is not on the symmetric heap") + Returns: + bool: True if tensor is on the symmetric heap, False otherwise - def __on_symmetric_heap(self, tensor): - """Check if tensor is allocated on the symmetric heap.""" - return self.heap.on_symmetric_heap(tensor) + Example: + >>> import iris.experimental.iris_gluon as iris_gl + >>> ctx = iris_gl.iris(heap_size=2**30) + >>> # Create a symmetric tensor + >>> symmetric_tensor = ctx.zeros(1000, dtype=torch.float32) + >>> ctx.is_symmetric(symmetric_tensor) # True + >>> + >>> # Create an external tensor (not on symmetric heap) + >>> external_tensor = torch.zeros(1000, dtype=torch.float32, device='cuda') + >>> ctx.is_symmetric(external_tensor) # False + """ + return self.heap.is_symmetric(tensor) def iris(heap_size=1 << 30): diff --git a/iris/hip.py b/iris/hip.py index 96842a583..e6dc598d8 100644 --- a/iris/hip.py +++ b/iris/hip.py @@ -283,10 +283,6 @@ def export_dmabuf_handle(ptr, size): ptr_int = ptr if isinstance(ptr, int) else ptr.value ptr_arg = ctypes.c_void_p(ptr_int) - # First, get the base address and size of the allocation containing this pointer - # This is needed because hipMemGetHandleForAddressRange exports the entire - # allocation buffer (e.g., PyTorch's caching allocator buffer), not just the - # specific memory range requested. base_ptr = ctypes.c_void_p() base_size = ctypes.c_size_t() @@ -303,7 +299,6 @@ def export_dmabuf_handle(ptr, size): fd = ctypes.c_int(-1) - # Configure function signature to avoid truncation gpu_runtime.hipMemGetHandleForAddressRange.restype = ctypes.c_int gpu_runtime.hipMemGetHandleForAddressRange.argtypes = [ ctypes.POINTER(ctypes.c_int), # handle (DMA-BUF fd) @@ -313,9 +308,6 @@ def export_dmabuf_handle(ptr, size): ctypes.c_ulonglong, # flags ] - # hipMemRangeHandleTypeDmaBufFd = 1 - # Note: We pass the original ptr and size, but ROCm will export the entire - # base allocation buffer. The fd will refer to the base buffer. err = gpu_runtime.hipMemGetHandleForAddressRange(ctypes.byref(fd), ptr_arg, size, 1, 0) if err != 0: @@ -344,9 +336,11 @@ def import_dmabuf_handle(fd, size, original_ptr=None, base_ptr=None): and mapped_base is returned directly. Returns: - Mapped GPU address (integer). If original_ptr and base_ptr are provided, - returns the offset-corrected address to match the original pointer's position - within the mapped buffer. + tuple: (mapped_ptr, ext_mem_handle) where: + - mapped_ptr: GPU address (integer). If original_ptr and base_ptr are provided, + returns the offset-corrected address. + - ext_mem_handle: External memory handle that must be destroyed with + destroy_external_memory() when done. Raises: RuntimeError: If import fails or backend doesn't support it @@ -418,18 +412,428 @@ class hipExternalMemoryBufferDesc(ctypes.Structure): mapped_base = dev_ptr.value - # If original_ptr and base_ptr are provided, calculate the offset and return - # the correctly positioned pointer in the mapped address space if original_ptr is not None and base_ptr is not None: - # Normalize to integers to support both raw ints and ctypes pointers original_ptr_int = original_ptr if isinstance(original_ptr, int) else original_ptr.value base_ptr_int = base_ptr if isinstance(base_ptr, int) else base_ptr.value - - # Calculate and validate offset offset = original_ptr_int - base_ptr_int if offset < 0: raise ValueError(f"Invalid offset: original_ptr ({original_ptr_int}) < base_ptr ({base_ptr_int})") - return mapped_base + offset + return (mapped_base + offset, ext_mem) + + return (mapped_base, ext_mem) + + +def destroy_external_memory(ext_mem_handle): + """ + Destroy an external memory handle created by hipImportExternalMemory. + + Args: + ext_mem_handle: The external memory handle (hipExternalMemory_t) to destroy + + Raises: + RuntimeError: If destroy fails + """ + if not _is_amd_backend: + raise RuntimeError("External memory only supported on AMD/HIP backend") + + # hipExternalMemory_t is an opaque handle (pointer) + hipExternalMemory_t = ctypes.c_void_p + + gpu_runtime.hipDestroyExternalMemory.argtypes = [hipExternalMemory_t] + gpu_runtime.hipDestroyExternalMemory.restype = ctypes.c_int + + err = gpu_runtime.hipDestroyExternalMemory(ext_mem_handle) + if err != 0: + gpu_try(err) + + +def get_address_range(ptr): + """ + Query the base allocation and size for a given device pointer. + + Args: + ptr: Device pointer (integer or ctypes pointer) + + Returns: + tuple: (base_ptr, size) - base address and size of the allocation + + Raises: + RuntimeError: If query fails + """ + ptr_int = ptr if isinstance(ptr, int) else ptr.value + + base_ptr = ctypes.c_void_p() + size = ctypes.c_size_t() + gpu_runtime.hipMemGetAddressRange.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), # void** pbase + ctypes.POINTER(ctypes.c_size_t), # size_t* psize + ctypes.c_void_p, # void* dptr + ] + gpu_runtime.hipMemGetAddressRange.restype = ctypes.c_int + + gpu_try(gpu_runtime.hipMemGetAddressRange(ctypes.byref(base_ptr), ctypes.byref(size), ctypes.c_void_p(ptr_int))) + + return base_ptr.value, size.value + + +# ============================================================================ +# HIP Virtual Memory (VMem) Management APIs +# ============================================================================ + +# Constants for VMem APIs +hipMemAllocationTypePinned = 0x1 +hipMemHandleTypePosixFileDescriptor = 0x1 +hipMemLocationTypeDevice = 0x1 +hipMemAllocationGranularityRecommended = 0x1 +hipMemAccessFlagsProtReadWrite = 0x3 + +# Type alias for VMem handle (pointer type) +hipMemGenericAllocationHandle_t = ctypes.c_void_p + + +class hipMemLocation(ctypes.Structure): + """Structure describing a memory location (device).""" + + _fields_ = [ + ("type", ctypes.c_int), # hipMemLocationType + ("id", ctypes.c_int), # Device ID + ] + + +class hipMemAllocationProp(ctypes.Structure): + """Properties for memory allocation.""" + + class _allocFlags(ctypes.Structure): + _fields_ = [ + ("smc", ctypes.c_ubyte), + ("l2", ctypes.c_ubyte), + ] + + _fields_ = [ + ("type", ctypes.c_int), # hipMemAllocationType + ("requestedHandleType", ctypes.c_int), # hipMemHandleType + ("location", hipMemLocation), # Memory location + ("win32Handle", ctypes.c_void_p), # Windows handle (unused on Linux) + ("allocFlags", _allocFlags), # Allocation flags + ] + + +class hipMemAccessDesc(ctypes.Structure): + """Memory access descriptor for setting access permissions.""" + + _fields_ = [ + ("location", hipMemLocation), # Device location + ("flags", ctypes.c_int), # Access flags + ] + + +def get_allocation_granularity(device_id): + """ + Get the allocation granularity for VMem allocations on a device. + + Args: + device_id: Device ID + + Returns: + Allocation granularity in bytes + + Raises: + RuntimeError: If query fails or backend doesn't support VMem + """ + if not _is_amd_backend: + raise RuntimeError("VMem only supported on AMD/HIP backend") + + prop = hipMemAllocationProp() + prop.type = hipMemAllocationTypePinned + prop.location.type = hipMemLocationTypeDevice + prop.location.id = device_id + prop.requestedHandleType = hipMemHandleTypePosixFileDescriptor + + granularity = ctypes.c_size_t() + + gpu_try( + gpu_runtime.hipMemGetAllocationGranularity( + ctypes.byref(granularity), + ctypes.byref(prop), + hipMemAllocationGranularityRecommended, + ) + ) + + return granularity.value + + +def mem_create(size, device_id): + """ + Create a physical memory allocation. + + Args: + size: Size in bytes (should be aligned to granularity) + device_id: Device ID + + Returns: + hipMemGenericAllocationHandle_t handle + + Raises: + RuntimeError: If creation fails or backend doesn't support VMem + """ + if not _is_amd_backend: + raise RuntimeError("VMem only supported on AMD/HIP backend") + + prop = hipMemAllocationProp() + prop.type = hipMemAllocationTypePinned + prop.location.type = hipMemLocationTypeDevice + prop.location.id = device_id + prop.requestedHandleType = hipMemHandleTypePosixFileDescriptor + + handle = hipMemGenericAllocationHandle_t() + + # Set argument types explicitly to avoid 32/64-bit issues + gpu_runtime.hipMemCreate.argtypes = [ + ctypes.POINTER(hipMemGenericAllocationHandle_t), # handle + ctypes.c_size_t, # size (64-bit!) + ctypes.POINTER(hipMemAllocationProp), # prop + ctypes.c_ulonglong, # flags + ] + gpu_runtime.hipMemCreate.restype = ctypes.c_int + + gpu_try(gpu_runtime.hipMemCreate(ctypes.byref(handle), size, ctypes.byref(prop), 0)) + + return handle.value + + +def mem_export_to_shareable_handle(handle): + """ + Export a VMem handle as a shareable file descriptor. + + Args: + handle: hipMemGenericAllocationHandle_t + + Returns: + File descriptor (integer) + + Raises: + RuntimeError: If export fails or backend doesn't support VMem + """ + if not _is_amd_backend: + raise RuntimeError("VMem only supported on AMD/HIP backend") + + fd = ctypes.c_int(-1) + + # Set argument types + gpu_runtime.hipMemExportToShareableHandle.argtypes = [ + ctypes.c_void_p, # void* shareableHandle (pointer to fd) + hipMemGenericAllocationHandle_t, # hipMemGenericAllocationHandle_t handle + ctypes.c_int, # hipMemAllocationHandleType handleType + ctypes.c_ulonglong, # unsigned long long flags + ] + gpu_runtime.hipMemExportToShareableHandle.restype = ctypes.c_int + + gpu_try(gpu_runtime.hipMemExportToShareableHandle(ctypes.byref(fd), handle, hipMemHandleTypePosixFileDescriptor, 0)) + + return fd.value + + +def mem_import_from_shareable_handle(fd): + """ + Import a VMem handle from a shareable file descriptor. + + Args: + fd: File descriptor (integer) + + Returns: + hipMemGenericAllocationHandle_t handle + + Raises: + RuntimeError: If import fails or backend doesn't support VMem + """ + if not _is_amd_backend: + raise RuntimeError("VMem only supported on AMD/HIP backend") + + handle = hipMemGenericAllocationHandle_t() + + # Set argument types + gpu_runtime.hipMemImportFromShareableHandle.argtypes = [ + ctypes.POINTER(hipMemGenericAllocationHandle_t), + ctypes.c_void_p, # void* - cast the fd integer to void* + ctypes.c_int, # hipMemAllocationHandleType + ] + gpu_runtime.hipMemImportFromShareableHandle.restype = ctypes.c_int + + # Cast the integer fd to void* (like the C++ tests do) + gpu_try( + gpu_runtime.hipMemImportFromShareableHandle( + ctypes.byref(handle), ctypes.c_void_p(fd), hipMemHandleTypePosixFileDescriptor + ) + ) + + return handle.value + + +def mem_address_reserve(size, alignment=0, addr=0, flags=0): + """ + Reserve a virtual address range. + + Args: + size: Size in bytes + alignment: Alignment requirement (0 for default) + addr: Requested address (0 for automatic) + flags: Flags + + Returns: + Reserved virtual address (integer) + + Raises: + RuntimeError: If reservation fails or backend doesn't support VMem + """ + if not _is_amd_backend: + raise RuntimeError("VMem only supported on AMD/HIP backend") + + ptr = ctypes.c_void_p() + + # Set argument types explicitly + gpu_runtime.hipMemAddressReserve.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), # void** ptr + ctypes.c_size_t, # size_t size + ctypes.c_size_t, # size_t alignment + ctypes.c_void_p, # void* addr + ctypes.c_ulonglong, # unsigned long long flags + ] + gpu_runtime.hipMemAddressReserve.restype = ctypes.c_int + + gpu_try(gpu_runtime.hipMemAddressReserve(ctypes.byref(ptr), size, alignment, ctypes.c_void_p(addr), flags)) + + return ptr.value + + +def mem_map(ptr, size, offset, handle, flags=0): + """ + Map physical memory to virtual address range. + + Args: + ptr: Virtual address (integer) + size: Size in bytes + offset: Offset within physical allocation + handle: hipMemGenericAllocationHandle_t + flags: Flags + + Raises: + RuntimeError: If mapping fails or backend doesn't support VMem + """ + if not _is_amd_backend: + raise RuntimeError("VMem only supported on AMD/HIP backend") + + # Set argument types + gpu_runtime.hipMemMap.argtypes = [ + ctypes.c_void_p, # void* ptr + ctypes.c_size_t, # size_t size + ctypes.c_size_t, # size_t offset + hipMemGenericAllocationHandle_t, # hipMemGenericAllocationHandle_t handle + ctypes.c_ulonglong, # unsigned long long flags + ] + gpu_runtime.hipMemMap.restype = ctypes.c_int + + gpu_try(gpu_runtime.hipMemMap(ctypes.c_void_p(ptr), size, offset, handle, flags)) + + +def mem_unmap(ptr, size): + """ + Unmap virtual address range. + + Args: + ptr: Virtual address (integer) + size: Size in bytes + + Raises: + RuntimeError: If unmapping fails or backend doesn't support VMem + """ + if not _is_amd_backend: + raise RuntimeError("VMem only supported on AMD/HIP backend") + + # Set argument types explicitly + gpu_runtime.hipMemUnmap.argtypes = [ + ctypes.c_void_p, # void* ptr + ctypes.c_size_t, # size_t size + ] + gpu_runtime.hipMemUnmap.restype = ctypes.c_int + + gpu_try(gpu_runtime.hipMemUnmap(ctypes.c_void_p(ptr), size)) + + +def mem_address_free(ptr, size): + """ + Free a reserved virtual address range. + + Args: + ptr: Virtual address (integer) + size: Size in bytes + + Raises: + RuntimeError: If freeing fails or backend doesn't support VMem + """ + if not _is_amd_backend: + raise RuntimeError("VMem only supported on AMD/HIP backend") + + # Set argument types explicitly + gpu_runtime.hipMemAddressFree.argtypes = [ + ctypes.c_void_p, # void* ptr + ctypes.c_size_t, # size_t size + ] + gpu_runtime.hipMemAddressFree.restype = ctypes.c_int + + gpu_try(gpu_runtime.hipMemAddressFree(ctypes.c_void_p(ptr), size)) + + +def mem_release(handle): + """ + Release a physical memory allocation handle. + + Args: + handle: hipMemGenericAllocationHandle_t + + Raises: + RuntimeError: If release fails or backend doesn't support VMem + """ + if not _is_amd_backend: + raise RuntimeError("VMem only supported on AMD/HIP backend") + + # Set argument types + gpu_runtime.hipMemRelease.argtypes = [hipMemGenericAllocationHandle_t] + gpu_runtime.hipMemRelease.restype = ctypes.c_int + + gpu_try(gpu_runtime.hipMemRelease(handle)) + + +def mem_set_access(ptr, size, desc_or_list): + """ + Set access permissions for a virtual address range. + + Args: + ptr: Virtual address (integer) + size: Size in bytes + desc_or_list: hipMemAccessDesc or list of hipMemAccessDesc for multi-device access + + Raises: + RuntimeError: If setting access fails or backend doesn't support VMem + """ + if not _is_amd_backend: + raise RuntimeError("VMem only supported on AMD/HIP backend") + + # Support both single descriptor and list of descriptors + if isinstance(desc_or_list, list): + desc_array = (hipMemAccessDesc * len(desc_or_list))(*desc_or_list) + count = len(desc_or_list) + else: + desc_array = (hipMemAccessDesc * 1)(desc_or_list) + count = 1 + + # Set argument types + gpu_runtime.hipMemSetAccess.argtypes = [ + ctypes.c_void_p, # void* ptr + ctypes.c_size_t, # size_t size + ctypes.POINTER(hipMemAccessDesc), # const hipMemAccessDesc* desc + ctypes.c_size_t, # size_t count + ] + gpu_runtime.hipMemSetAccess.restype = ctypes.c_int - return mapped_base + gpu_try(gpu_runtime.hipMemSetAccess(ctypes.c_void_p(ptr), size, desc_array, count)) diff --git a/iris/iris.py b/iris/iris.py index 239fd0703..43b89d0bc 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -15,15 +15,32 @@ - Memory allocation and deallocation utilities - Built-in logging with rank information - PyTorch distributed integration for distributed computing +- DeviceContext: Object-oriented API for device-side operations (gluon-style) -Example: +Example (Traditional Functional API): >>> import iris >>> ctx = iris.iris(heap_size=2**30) # 1GB heap >>> tensor = ctx.zeros(1024, 1024, dtype=torch.float32) + >>> + >>> @triton.jit + >>> def kernel(buffer, heap_bases, rank, world_size): + >>> data = iris.load(buffer, rank, remote_rank, heap_bases) + +Example (Object-Oriented DeviceContext API): + >>> import iris + >>> from iris import DeviceContext + >>> ctx = iris.iris(heap_size=2**30) + >>> context_tensor = ctx.get_device_context() + >>> + >>> @triton.jit + >>> def kernel(context_tensor, rank: tl.constexpr, world_size: tl.constexpr): + >>> device_ctx = DeviceContext.initialize(context_tensor, rank, world_size) + >>> data = device_ctx.load(buffer, from_rank=remote_rank) """ import triton import triton.language as tl +from triton.language.core import _aggregate as aggregate from iris._distributed_helpers import ( init_distributed, @@ -38,13 +55,19 @@ ) from iris.symmetric_heap import SymmetricHeap import numpy as np -import math import torch import logging # Import logging functionality from the separate logging module from .logging import logger +# Import tracing functionality +from .tracing import Tracing, TraceEvent, DeviceTracing # noqa: F401 re-export for iris.TraceEvent + +# Import shared tensor-creation helpers +from . import tensor_creation +from .util import is_simulation_env + class Iris: """ @@ -55,14 +78,21 @@ class Iris: Args: heap_size (int): Size of the symmetric heap in bytes. Default: 1GB (2^30) + allocator_type (str): Type of allocator to use. Options: "torch" (default), "vmem" Example: - >>> ctx = iris.iris(heap_size=2**31) # 2GB heap + >>> ctx = iris.iris(heap_size=2**31) # 2GB heap with torch allocator >>> print(f"Rank {ctx.cur_rank} of {ctx.num_ranks}") # Rank 0 of 1 >>> tensor = ctx.zeros(1000, 1000, dtype=torch.float32) + + >>> # Use VMem allocator for memory oversubscription + >>> ctx = iris.iris(heap_size=2**31, allocator_type="vmem") """ - def __init__(self, heap_size=1 << 30): + def __init__(self, heap_size=1 << 30, allocator_type="torch"): + if is_simulation_env(): + allocator_type = "torch" + # Initialize distributed environment comm, cur_rank, num_ranks = init_distributed() num_gpus = count_devices() @@ -76,13 +106,26 @@ def __init__(self, heap_size=1 << 30): self.gpu_id = gpu_id self.heap_size = heap_size - # Initialize symmetric heap - self.heap = SymmetricHeap(heap_size, gpu_id, cur_rank, num_ranks) + # Initialize symmetric heap with specified allocator + self.heap = SymmetricHeap(heap_size, gpu_id, cur_rank, num_ranks, allocator_type) self.device = f"cuda:{gpu_id}" self.heap_bases = self.heap.get_heap_bases() - for i in range(num_ranks): - self.debug(f"GPU {i}: Heap base {hex(int(self.heap_bases[i].item()))}") + if is_simulation_env(): + import json + + heap_bases_list = [int(self.heap_bases[r].item()) for r in range(self.num_ranks)] + out_path = f"iris_rank_{self.cur_rank}_heap_bases.json" + with open(out_path, "w") as f: + json.dump( + { + "rank": self.cur_rank, + "num_ranks": self.num_ranks, + "heap_bases": [hex(b) for b in heap_bases_list], + }, + f, + indent=2, + ) distributed_barrier() @@ -92,6 +135,18 @@ def __init__(self, heap_size=1 << 30): # Lazy initialization for ops interface self._ops = None + # Initialize tracing + self.tracing = Tracing(self) + + def __del__(self): + """Cleanup resources on deletion.""" + try: + if hasattr(self, "heap") and hasattr(self.heap, "allocator"): + if hasattr(self.heap.allocator, "close"): + self.heap.allocator.close() + except Exception: + pass # Best effort cleanup in destructor (GC context) + def _log_with_rank(self, level, message): """Helper method to log with rank information injected into the record.""" if logger.isEnabledFor(level): @@ -265,18 +320,6 @@ def broadcast(self, value, source_rank=0): else: return distributed_broadcast_scalar(value, source_rank) - def __allocate(self, num_elements, dtype): - """Allocate memory using the symmetric heap.""" - self.debug(f"allocate: num_elements = {num_elements}, dtype = {dtype}") - return self.heap.allocate(num_elements, dtype) - - def __parse_size(self, size): - # Handle nested tuples/lists by flattening them recursively - while len(size) == 1 and isinstance(size[0], (tuple, list)): - size = size[0] - num_elements = math.prod(size) - return size, num_elements - def zeros_like( self, input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format ): @@ -304,44 +347,17 @@ def zeros_like( >>> zeros_tensor = ctx.zeros_like(input_tensor) >>> print(zeros_tensor.shape) # torch.Size([2, 3]) """ - self.debug( - f"zeros_like: input_shape = {input.shape}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + return tensor_creation.zeros_like( + self.heap, + self.get_device(), + input, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + memory_format=memory_format, ) - # Use input's properties as defaults if not specified - if dtype is None: - dtype = input.dtype - if layout is None: - layout = input.layout - if device is None: - device = input.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Get the size from input tensor - size = input.size() - num_elements = input.numel() - - # Allocate new tensor with the same size - new_tensor = self.__allocate(num_elements, dtype) - new_tensor.zero_() - - # Reshape to match input size - new_tensor = new_tensor.reshape(size) - - # Apply the requested memory format - new_tensor = self.__apply_memory_format(new_tensor, size, memory_format, input) - - # Apply the requested layout - new_tensor = self.__apply_layout(new_tensor, layout) - - # Set requires_grad if specified - if requires_grad: - new_tensor.requires_grad_() - - return new_tensor - def arange( self, start=0, end=None, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False ): @@ -380,57 +396,22 @@ def arange( >>> tensor = ctx.arange(0, 10, 2) # [0, 2, 4, 6, 8] >>> print(tensor.shape) # torch.Size([5]) """ - self.debug(f"arange: start = {start}, end = {end}, step = {step}, dtype = {dtype}, device = {device}") - # Handle the case where only one argument is provided (end) if end is None: end = start start = 0 - - # Validate inputs - if step == 0: - raise ValueError("step must be non-zero") - - # Validate step direction consistency - if step > 0 and start >= end: - raise ValueError(f"Invalid range: start >= end with positive step (start={start}, end={end}, step={step})") - elif step < 0 and start <= end: - raise ValueError(f"Invalid range: start <= end with negative step (start={start}, end={end}, step={step})") - - # Calculate the number of elements - num_elements = math.ceil((end - start) / step) - - # Infer dtype if not provided - if dtype is None: - if any(isinstance(x, float) for x in [start, end, step]): - dtype = torch.get_default_dtype() - else: - dtype = torch.int64 - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - tensor = out - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - - target_device = tensor.device - arange_tensor = torch.arange(start, end, step, dtype=dtype, device=target_device) - - tensor[:] = arange_tensor - - tensor = self.__apply_layout(tensor, layout) - - if requires_grad: - tensor.requires_grad_() - - return tensor + return tensor_creation.arange( + self.heap, + self.get_device(), + start, + end, + step, + out=out, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) def zeros(self, *size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): """ @@ -458,44 +439,16 @@ def zeros(self, *size, out=None, dtype=None, layout=torch.strided, device=None, >>> print(tensor.shape) # torch.Size([2, 3]) >>> print(tensor[0]) # tensor([0., 0., 0.], device='cuda:0') """ - self.debug(f"zeros: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") - - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Fill with zeros - out.zero_() - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Fill with zeros - tensor.zero_() - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor + return tensor_creation.zeros( + self.heap, + self.get_device(), + size, + out=out, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) def randn( self, @@ -554,49 +507,18 @@ def randn( >>> print(tensor.shape) # torch.Size([2, 3]) >>> print(tensor[0]) # tensor([ 0.3982, -0.0059, -0.4365], device='cuda:0') """ - self.debug( - f"randn: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" + return tensor_creation.randn( + self.heap, + self.get_device(), + size, + generator=generator, + out=out, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, ) - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Generate random data and copy to out tensor - random_data = torch.randn(num_elements, generator=generator, dtype=dtype, device=device, layout=layout) - out.copy_(random_data) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Generate random data and copy to tensor - random_data = torch.randn(num_elements, generator=generator, dtype=dtype, device=device, layout=layout) - tensor.copy_(random_data) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - def ones(self, *size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): """ Returns a tensor filled with the scalar value 1, with the shape defined by the variable argument size. @@ -623,44 +545,82 @@ def ones(self, *size, out=None, dtype=None, layout=torch.strided, device=None, r >>> print(tensor.shape) # torch.Size([2, 3]) >>> print(tensor[0]) # tensor([1., 1., 1.], device='cuda:0') """ - self.debug(f"ones: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + return tensor_creation.ones( + self.heap, + self.get_device(), + size, + out=out, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() + def as_symmetric(self, external_tensor: torch.Tensor) -> torch.Tensor: + """ + Import an external PyTorch tensor into the symmetric heap. - # Use current device if none specified - if device is None: - device = self.device + This creates a new tensor in the symmetric heap that shares physical memory + with the external tensor. Any modifications to either tensor will be visible + in both. This is useful for importing pre-allocated tensors (e.g., model weights) + into the symmetric heap for RMA operations. - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) + Note: This feature requires `allocator_type='vmem'`. - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) + Args: + external_tensor (torch.Tensor): External PyTorch tensor to import. + Must be a CUDA tensor. - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Fill with ones - out.fill_(1) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Fill with ones - tensor.fill_(1) - # Reshape to the desired size - tensor = tensor.reshape(size) + Returns: + torch.Tensor: New tensor in symmetric heap sharing memory with external tensor - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) + Raises: + RuntimeError: If allocator doesn't support imports or import fails - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() + Example: + >>> ctx = iris.iris(allocator_type='vmem') + >>> # Create an external tensor + >>> external = torch.randn(1000, 1000, device='cuda') + >>> # Import it into symmetric heap + >>> symmetric = ctx.as_symmetric(external) + >>> # Verify they share memory + >>> external[0, 0] = 999.0 + >>> assert symmetric[0, 0].item() == 999.0 + >>> # Now you can use symmetric in RMA operations + >>> ctx.put(symmetric, peer_rank, remote_buffer) + """ + return self.heap.as_symmetric(external_tensor) - return tensor + def is_symmetric(self, tensor: torch.Tensor) -> bool: + """ + Check if a tensor is allocated on the symmetric heap. + + This method checks whether a tensor resides in the symmetric heap, making it + accessible for RMA operations across ranks. Use this to validate tensors before + performing distributed operations. + + Args: + tensor (torch.Tensor): PyTorch tensor to check + + Returns: + bool: True if tensor is on the symmetric heap, False otherwise + + Example: + >>> ctx = iris.iris(heap_size=2**30) + >>> # Create a symmetric tensor + >>> symmetric_tensor = ctx.zeros(1000, dtype=torch.float32) + >>> ctx.is_symmetric(symmetric_tensor) # True + >>> + >>> # Create an external tensor (not on symmetric heap) + >>> external_tensor = torch.zeros(1000, dtype=torch.float32, device='cuda') + >>> ctx.is_symmetric(external_tensor) # False + >>> + >>> # Import external tensor (only with vmem allocator) + >>> ctx_vmem = iris.iris(allocator_type='vmem') + >>> imported = ctx_vmem.as_symmetric(external_tensor) + >>> ctx_vmem.is_symmetric(imported) # True + """ + return self.heap.is_symmetric(tensor) def full(self, size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): """ @@ -688,54 +648,18 @@ def full(self, size, fill_value, *, out=None, dtype=None, layout=torch.strided, >>> print(tensor.shape) # torch.Size([2, 3]) >>> print(tensor[0]) # tensor([3.1400, 3.1400, 3.1400], device='cuda:0') """ - self.debug( - f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + return tensor_creation.full( + self.heap, + self.get_device(), + size, + fill_value, + out=out, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, ) - # Infer dtype from fill_value if not provided - if dtype is None: - if isinstance(fill_value, (int, float)): - if isinstance(fill_value, float): - dtype = torch.get_default_dtype() - else: - dtype = torch.int64 - else: - # For other types (like tensors), use their dtype - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Fill with the specified value - out.fill_(fill_value) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Fill with the specified value - tensor.fill_(fill_value) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - def uniform(self, size, low=0.0, high=1.0, dtype=torch.float): """ Returns a tensor filled with random numbers from a uniform distribution, allocated on the Iris symmetric heap. @@ -755,11 +679,7 @@ def uniform(self, size, low=0.0, high=1.0, dtype=torch.float): >>> print(tensor.shape) # torch.Size([2, 3]) >>> print(tensor[0]) # tensor([0.1234, 0.5678, 0.9012], device='cuda:0') """ - self.debug(f"uniform: size = {size}, low = {low}, high = {high}, dtype = {dtype}") - size, num_elements = self.__parse_size(size) - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - tensor.uniform_(low, high) - return tensor.reshape(size) + return tensor_creation.uniform(self.heap, self.get_device(), size, low, high, dtype) def empty( self, @@ -805,46 +725,18 @@ def empty( >>> tensor = ctx.empty(2, 3) >>> print(tensor.shape) # torch.Size([2, 3]) """ - self.debug( - f"empty: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" + return tensor_creation.empty( + self.heap, + self.get_device(), + size, + out=out, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + memory_format=memory_format, ) - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested memory format - tensor = self.__apply_memory_format(tensor, size, memory_format) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - def randint( self, *args, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False ): @@ -875,64 +767,27 @@ def randint( >>> print(tensor.shape) # torch.Size([2, 3]) >>> print(tensor[0]) # tensor([7, 2, 9], device='cuda:0') """ - self.debug(f"randint: args = {args}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") - # Parse arguments to determine low, high, and size - # PyTorch randint signatures: - # randint(high, size) - where high is the upper bound and size is the shape - # randint(low, high, size) - where low and high are bounds, size is the shape if len(args) == 2: - # randint(high, size) high, size = args low = 0 elif len(args) == 3: - # randint(low, high, size) low, high, size = args else: raise ValueError(f"randint expects 2 or 3 positional arguments, got {len(args)}") - - # Use default dtype if None is provided - if dtype is None: - dtype = torch.int64 - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Generate random integers using PyTorch's randint - # Use specified device or fall back to current device - target_device = device if device is not None else self.device - - # Handle generator parameter - if generator is not None: - torch.randint(low, high, size, generator=generator, out=tensor, dtype=dtype, device=target_device) - else: - torch.randint(low, high, size, out=tensor, dtype=dtype, device=target_device) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor + return tensor_creation.randint( + self.heap, + self.get_device(), + low, + high, + size, + generator=generator, + out=out, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + ) def linspace(self, start, end, steps, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): """ @@ -961,74 +816,19 @@ def linspace(self, start, end, steps, out=None, dtype=None, layout=torch.strided >>> tensor = ctx.linspace(0, 10, 5) # [0, 2.5, 5, 7.5, 10] >>> print(tensor) # tensor([ 0.0000, 2.5000, 5.0000, 7.5000, 10.0000], device='cuda:0') """ - self.debug( - f"linspace: start = {start}, end = {end}, steps = {steps}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + return tensor_creation.linspace( + self.heap, + self.get_device(), + start, + end, + steps, + out=out, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, ) - # Use global default dtype if None is provided - if dtype is None: - # Check if start or end are complex numbers - start_is_complex = isinstance(start, complex) or (hasattr(start, "dtype") and torch.is_complex(start)) - end_is_complex = isinstance(end, complex) or (hasattr(end, "dtype") and torch.is_complex(end)) - - if start_is_complex or end_is_complex: - # Infer complex dtype based on default dtype - dtype = torch.complex64 if torch.get_default_dtype() == torch.float32 else torch.complex128 - else: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse steps and extract the integer value - if isinstance(steps, (tuple, list)): - if len(steps) == 1: - # Single-element tuple/list like (5,) or [5] - steps_int = steps[0] - # Handle nested tuples like ((5,),) - if isinstance(steps_int, (tuple, list)): - steps_int = steps_int[0] - else: - # Multi-element tuple/list - use __parse_size for compatibility - size, num_elements = self.__parse_size(steps) - steps_int = num_elements - else: - # steps is a single integer - steps_int = steps - - # Ensure steps_int is an integer - steps_int = int(steps_int) - size = (steps_int,) - num_elements = steps_int - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Generate linspace using PyTorch's linspace - # Use specified device or fall back to current device - target_device = device if device is not None else self.device - torch.linspace(start, end, steps_int, out=tensor, dtype=dtype, device=target_device) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - def rand( self, *size, @@ -1068,52 +868,18 @@ def rand( >>> print(tensor.shape) # torch.Size([2, 3]) >>> print(tensor[0]) # tensor([0.1234, 0.5678, 0.9012], device='cuda:0') """ - self.debug( - f"rand: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" + return tensor_creation.rand( + self.heap, + self.get_device(), + size, + generator=generator, + out=out, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, ) - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Generate random numbers using PyTorch's rand - # Use specified device (already validated and set above) - - # Handle generator parameter - if generator is not None: - torch.rand(size, generator=generator, out=tensor, dtype=dtype, device=device) - else: - torch.rand(size, out=tensor, dtype=dtype, device=device) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - def __deallocate(self, pointer): pass @@ -1133,6 +899,69 @@ def get_heap_bases(self): """ return self.heap_bases + def get_device_context(self): + """ + Get the device context tensor for DeviceContext initialization. + + Returns a tensor encoding: [cur_rank, world_size, heap_base_0, heap_base_1, ...] + If tracing is enabled, also includes: [trace_enabled, max_events, trace_counter_ptr, trace_buffer_ptrs...] + + This opaque format allows future extension without breaking the API. + + Returns: + torch.Tensor: Encoded context data as int64 tensor on device + + Example: + >>> import iris + >>> from iris import DeviceContext + >>> import triton + >>> import triton.language as tl + >>> + >>> ctx = iris.iris() + >>> context_tensor = shmem.get_device_context() + >>> + >>> @triton.jit + >>> def my_kernel(context_tensor, rank: tl.constexpr, world_size: tl.constexpr, ...): + >>> ctx = DeviceContext.initialize(context_tensor, rank, world_size) + >>> data = ctx.load(buffer, from_rank=1) + """ + # Convert heap_bases to a list for concatenation + heap_bases_list = self.heap_bases.tolist() + + # Create context tensor: [cur_rank, world_size, heap_base_0, heap_base_1, ...] + context_data = [self.cur_rank, self.num_ranks] + heap_bases_list + + # Add tracing info if enabled + if self.tracing.enabled: + # Explicit buffer ordering (must match DeviceContext.initialize extraction order) + trace_buffer_ptrs = [ + self.tracing.trace_buffers["event_id"].data_ptr(), + self.tracing.trace_buffers["pid"].data_ptr(), + self.tracing.trace_buffers["pid_m"].data_ptr(), + self.tracing.trace_buffers["pid_n"].data_ptr(), + self.tracing.trace_buffers["cur_rank"].data_ptr(), + self.tracing.trace_buffers["target_rank"].data_ptr(), + self.tracing.trace_buffers["xcc_id"].data_ptr(), + self.tracing.trace_buffers["cu_id"].data_ptr(), + self.tracing.trace_buffers["timestamp"].data_ptr(), + self.tracing.trace_buffers["address"].data_ptr(), + self.tracing.trace_buffers["duration_cycles"].data_ptr(), + self.tracing.trace_buffers["op_index"].data_ptr(), + self.tracing.trace_buffers["payload_size"].data_ptr(), + ] + context_data += [ + 1, # trace_enabled = 1 (true) + self.tracing.max_events, + self.tracing.trace_counter.data_ptr(), + self.tracing.op_index_counter.data_ptr(), + ] + trace_buffer_ptrs + else: + context_data += [0] # trace_enabled = 0 (false) + + context_tensor = torch.tensor(context_data, dtype=torch.int64, device=self.device) + + return context_tensor + def barrier(self, stream=None, group=None): """ Synchronize ranks within the specified group and their CUDA devices. @@ -1188,6 +1017,24 @@ def get_cu_count(self): """ return get_cu_count(self.gpu_id) + def get_device_id(self): + """ + Get the device ID used by this Iris instance. + + In simulation mode, this may differ from the local rank if multiple + ranks share a single GPU. This is the device ID that was set during + Iris initialization. + + Returns: + int: The GPU device ID used by this Iris instance. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> device_id = ctx.get_device_id() + >>> print(f"Using GPU {device_id}") # Using GPU 0 + """ + return self.gpu_id + def get_rank(self): """ Get this process's rank id in the distributed communicator. @@ -1216,282 +1063,14 @@ def get_num_ranks(self): """ return self.num_ranks - def __throw_if_invalid_output_tensor(self, tensor: torch.Tensor, num_elements: int, dtype: torch.dtype): - if not self.__tensor_on_device(tensor): - raise RuntimeError( - f"The output tensor is not on the same device as the Iris instance. The Iris instance is on device {self.device} but the output tensor is on device {tensor.device}" - ) - if not self.__on_symmetric_heap(tensor): - raise RuntimeError( - f"The output tensor is not on the symmetric heap. The Iris instance is on heap base {self.heap_bases[self.cur_rank]} but the output tensor is on heap base {tensor.data_ptr()}" - ) - if tensor.numel() != num_elements: - raise RuntimeError(f"The output tensor has {tensor.numel()} elements, but {num_elements} are required") - if tensor.dtype != dtype: - raise RuntimeError(f"The output tensor has dtype {tensor.dtype}, but {dtype} is required") - - def __throw_if_invalid_device(self, device): - """ - Throw a RuntimeError if the requested device is not compatible with this Iris instance. - - Args: - device: The requested device (can be string, torch.device, or None) - - Raises: - RuntimeError: If the device is not compatible - """ - if not self.__is_valid_device(device): - raise RuntimeError( - f"Device mismatch: requested device {device} but Iris instance is on device {self.device}. " - f"Iris only supports tensors on its own device." - ) - - def __apply_memory_format( - self, tensor: torch.Tensor, size: tuple, memory_format: torch.memory_format, input_tensor: torch.Tensor = None - ): - """ - Apply the requested memory format to a tensor by setting appropriate strides. - This keeps the tensor on the symmetric heap while changing how PyTorch interprets the memory layout. - - Args: - tensor: The tensor to modify - size: The tensor's size/dimensions - memory_format: The desired memory format - input_tensor: The original input tensor (needed for preserve_format detection) - """ - if memory_format == torch.contiguous_format: - # Default format, no changes needed - return tensor - elif memory_format == torch.channels_last and len(size) == 4: - # For channels_last format: preserve shape (N, C, H, W) but change strides - # channels_last strides: [C*H*W, 1, C*W, C] for shape (N, C, H, W) - N, C, H, W = size[0], size[1], size[2], size[3] - # Keep the original shape (N, C, H, W) but use channels_last strides - tensor = self.__create_tensor_with_strides(tensor, size, (C * H * W, 1, C * W, C)) - return tensor - elif memory_format == torch.channels_last_3d and len(size) == 5: - # For channels_last_3d format: preserve shape (N, C, D, H, W) but change strides - # channels_last_3d strides: [C*D*H*W, 1, C*D*W, C*W, C] for shape (N, C, D, H, W) - N, C, D, H, W = size[0], size[1], size[2], size[3], size[4] - # Keep the original shape (N, C, D, H, W) but use channels_last_3d strides - tensor = self.__create_tensor_with_strides(tensor, size, (C * D * H * W, 1, C * D * W, C * W, C)) - return tensor - elif memory_format == torch.preserve_format: - # For preserve_format, we need to detect the input tensor's memory format - # and apply the same format to the output - if input_tensor is not None: - # Check the actual memory format of the input tensor - if len(size) == 4: - # Check if input tensor is in channels_last format by examining strides - # channels_last format has strides[1] == 1 (channels dimension is contiguous) - input_strides = input_tensor.stride() - if len(input_strides) == 4 and input_strides[1] == 1: - # Input is in channels_last format, preserve it - # Use the input tensor's actual shape, not the size parameter - input_shape = input_tensor.shape - if len(input_shape) == 4: - # Input is already in channels_last format (N, H, W, C) - new_size = input_shape - # Use the input tensor's strides directly - tensor = self.__create_tensor_with_strides(tensor, new_size, input_strides) - return tensor - elif len(size) == 5: - # Check if input tensor is in channels_last_3d format - input_strides = input_tensor.stride() - if len(input_strides) == 5 and input_strides[1] == 1: - # Input is in channels_last_3d format, preserve it - # Use the input tensor's actual shape, not the size parameter - input_shape = input_tensor.shape - if len(input_shape) == 5: - # Input is already in channels_last_3d format (N, D, H, W, C) - new_size = input_shape - # Use the input tensor's strides directly - tensor = self.__create_tensor_with_strides(tensor, new_size, input_strides) - return tensor - # If no special format detected or no input tensor provided, use contiguous format - return tensor - else: - # Unsupported format or dimension combination - self.debug( - f"Warning: Memory format {memory_format} not supported for {len(size)}D tensor, using contiguous format" - ) - # For unsupported formats, return the tensor as-is (contiguous) - return tensor - - def __create_tensor_with_strides(self, original_tensor: torch.Tensor, size: tuple, strides: tuple) -> torch.Tensor: - """ - Create a new tensor with the specified strides while keeping the data on the symmetric heap. - - Args: - original_tensor: The original tensor (source of data and heap allocation) - size: The tensor's size/dimensions - strides: The desired strides for the new memory format - - Returns: - A new tensor with the specified strides, data copied from original, on the same heap - """ - - # First, create a temporary tensor with the correct strides using PyTorch - temp_tensor = torch.empty_strided(size, strides, dtype=original_tensor.dtype, device=original_tensor.device) - - # Handle different cases based on whether size changes and what the strides indicate - if size != original_tensor.shape: - # Size is different - this might be a format change that requires permutation - # Check if this is a channels_last format by comparing strides - if len(size) == 4: - # For channels_last: expected strides are [H*W*C, 1, W*C, C] for shape (N, H, W, C) - N, H, W, C = size[0], size[1], size[2], size[3] - expected_strides = (H * W * C, 1, W * C, C) - if strides == expected_strides: - permuted = original_tensor.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) - else: - # If the size differs for other reasons, do not permute; just reshape if possible - try: - permuted = original_tensor.reshape(size) - except Exception: - raise ValueError( - "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." - ) - elif len(size) == 5: - # For channels_last_3d: expected strides are [D*H*W*C, 1, H*W*C, W*C, C] for shape (N, D, H, W, C) - N, D, H, W, C = size[0], size[1], size[2], size[3], size[4] - expected_strides = (D * H * W * C, 1, H * W * C, W * C, C) - if strides == expected_strides: - permuted = original_tensor.permute(0, 2, 3, 4, 1) # (N, C, D, H, W) -> (N, D, H, W, C) - else: - # If the size differs for other reasons, do not permute; just reshape if possible - try: - permuted = original_tensor.reshape(size) - except Exception: - raise ValueError( - "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." - ) - else: - # For other dimensions, just try to reshape - try: - permuted = original_tensor.reshape(size) - except Exception: - raise ValueError( - "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." - ) - else: - # Size is the same - this is a stride-only change (like channels_last with preserved shape) - # We need to reorder the data to match the new stride pattern - if len(size) == 4: - # Check if this is channels_last format with preserved shape - N, C, H, W = size[0], size[1], size[2], size[3] - expected_strides = (C * H * W, 1, C * W, C) - if strides == expected_strides: - permuted = original_tensor - else: - permuted = original_tensor - elif len(size) == 5: - # Check if this is channels_last_3d format with preserved shape - N, C, D, H, W = size[0], size[1], size[2], size[3], size[4] - expected_strides = (C * D * H * W, 1, C * D * W, C * W, C) - if strides == expected_strides: - permuted = original_tensor - else: - permuted = original_tensor - else: - permuted = original_tensor - - # Copy the permuted data to the temporary tensor - temp_tensor.copy_(permuted) - - # Now allocate a new tensor on our symmetric heap - num_elements = math.prod(size) - heap_tensor = self.__allocate(num_elements, original_tensor.dtype) - - # Reshape to the desired size - heap_tensor = heap_tensor.reshape(size) - - # Copy the data from the temporary tensor to our heap tensor - heap_tensor.copy_(temp_tensor) - - # Clean up the temporary tensor - del temp_tensor - - # Now we need to create a view with the correct strides - # We can't use as_strided directly on our heap tensor, but we can - # create a new tensor with the right strides and copy the data again - final_tensor = torch.as_strided(heap_tensor, size, strides) - - return final_tensor - - def __apply_layout(self, tensor: torch.Tensor, layout: torch.layout) -> torch.Tensor: - """ - Apply the requested layout to a tensor. - - Args: - tensor: The tensor to modify - layout: The desired layout - - Returns: - Tensor with the requested layout - """ - - if layout == torch.strided: - # Strided layout is the default - no changes needed - return tensor - else: - # Only support strided layout for now - raise ValueError(f"Layout {layout} not supported. Only torch.strided is currently supported.") - - def __tensor_on_device(self, tensor: torch.Tensor): - # Get the Iris device from memory_pool.device - iris_device = self.get_device() - tensor_device = tensor.device - - # For CUDA devices, check if they're compatible - if tensor_device.type == "cuda" and iris_device.type == "cuda": - if iris_device.index is None: - return True - return tensor_device.index == iris_device.index - - # For non-CUDA devices, they must be exactly equal - return tensor_device == iris_device - - def __on_symmetric_heap(self, tensor: torch.Tensor): - """Check if a tensor is allocated on the symmetric heap.""" - return self.heap.on_symmetric_heap(tensor) - - def __is_valid_device(self, device) -> bool: - """ - Check if the requested device is compatible with this Iris instance. - - Args: - device: The requested device (can be string, torch.device, or None) - - Returns: - bool: True if the device is compatible, False otherwise - """ - if device is None: - return True # None means use default device - - # Convert device strings to torch.device objects for proper comparison - requested_device = torch.device(device) if isinstance(device, str) else device - iris_device = self.get_device() - - # Check if both are CUDA devices - if requested_device.type == "cuda" and iris_device.type == "cuda": - # Check if index matches or if requested is "cuda" (any index) - if requested_device.index is None: - return True - else: - return requested_device.index == iris_device.index - - # For non-CUDA devices, always return False - return False - class CCL: """ Collective Communication Library (CCL) interface for Iris. Provides collective operations that can be called as methods on the Iris instance. Example usage: - >>> shmem = iris.iris() - >>> shmem.ccl.all_to_all(output_tensor, input_tensor) + >>> ctx = iris.iris() + >>> ctx.ccl.all_to_all(output_tensor, input_tensor) """ def __init__(self, iris_instance): @@ -1522,16 +1101,16 @@ def all_to_all(self, output_tensor, input_tensor, group=None, async_op=False, co If None, uses default Config values. Example: - >>> shmem = iris.iris() - >>> shmem.ccl.all_to_all(output_tensor, input_tensor) + >>> ctx = iris.iris() + >>> ctx.ccl.all_to_all(output_tensor, input_tensor) >>> # Custom configuration >>> from iris.ccl import Config >>> config = Config(block_size_m=128, block_size_n=32) - >>> shmem.ccl.all_to_all(output_tensor, input_tensor, config=config) + >>> ctx.ccl.all_to_all(output_tensor, input_tensor, config=config) >>> # Async operation (no barrier) - >>> shmem.ccl.all_to_all(output_tensor, input_tensor, async_op=True) + >>> ctx.ccl.all_to_all(output_tensor, input_tensor, async_op=True) """ from iris.ccl.all_to_all import all_to_all as _all_to_all @@ -1556,17 +1135,17 @@ def all_gather(self, output_tensor, input_tensor, group=None, async_op=False, co If None, uses default Config values. Example: - >>> shmem = iris.iris() + >>> ctx = iris.iris() >>> # Input: (M, N), Output: (world_size * M, N) - >>> shmem.ccl.all_gather(output_tensor, input_tensor) + >>> ctx.ccl.all_gather(output_tensor, input_tensor) >>> # Custom configuration >>> from iris.ccl import Config >>> config = Config(block_size_m=128, block_size_n=32) - >>> shmem.ccl.all_gather(output_tensor, input_tensor, config=config) + >>> ctx.ccl.all_gather(output_tensor, input_tensor, config=config) >>> # Async operation (no barrier) - >>> shmem.ccl.all_gather(output_tensor, input_tensor, async_op=True) + >>> ctx.ccl.all_gather(output_tensor, input_tensor, async_op=True) """ from iris.ccl.all_gather import all_gather as _all_gather @@ -1620,20 +1199,20 @@ def all_reduce( reuse internal buffers across invocations. Example: - >>> shmem = iris.iris() - >>> shmem.ccl.all_reduce(output_tensor, input_tensor) + >>> ctx = iris.iris() + >>> ctx.ccl.all_reduce(output_tensor, input_tensor) >>> # Custom configuration with ring variant >>> from iris.ccl import Config >>> config = Config(all_reduce_variant="ring") - >>> shmem.ccl.all_reduce(output_tensor, input_tensor, config=config) + >>> ctx.ccl.all_reduce(output_tensor, input_tensor, config=config) >>> # Two-shot variant with block distribution >>> config = Config(all_reduce_variant="two_shot", all_reduce_distribution=1) - >>> shmem.ccl.all_reduce(output_tensor, input_tensor, config=config) + >>> ctx.ccl.all_reduce(output_tensor, input_tensor, config=config) >>> # Async operation (no barrier) - >>> shmem.ccl.all_reduce(output_tensor, input_tensor, async_op=True) + >>> ctx.ccl.all_reduce(output_tensor, input_tensor, async_op=True) """ from iris.ccl.all_reduce import all_reduce as _all_reduce from iris.ccl import ReduceOp @@ -1675,13 +1254,13 @@ def reduce_scatter(self, output_tensor, input_tensor, op=None, group=None, async Only supports reduce_scatter_variant="two_shot". Example: - >>> shmem = iris.iris() - >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) + >>> ctx = iris.iris() + >>> ctx.ccl.reduce_scatter(output_tensor, input_tensor) >>> # Custom configuration >>> from iris.ccl import Config >>> config = Config(reduce_scatter_variant="two_shot", all_reduce_distribution=1) - >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config) + >>> ctx.ccl.reduce_scatter(output_tensor, input_tensor, config=config) """ from iris.ccl.reduce_scatter import reduce_scatter as _reduce_scatter from iris.ccl import ReduceOp @@ -1696,36 +1275,591 @@ def reduce_scatter(self, output_tensor, input_tensor, op=None, group=None, async @triton.jit -def __translate(ptr, from_rank, to_rank, heap_bases): +def __translate(ptr, from_rank, to_rank, heap_bases, hint: tl.constexpr = None): from_base = tl.load(heap_bases + from_rank) to_base = tl.load(heap_bases + to_rank) - # convert to int to compute difference ptr_int = tl.cast(ptr, tl.uint64) - # Find the offset from from_rank heap offset = ptr_int - from_base - # Byte cast for byte offset addition to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) - # Find the offset into the to_rank heap translated_ptr_byte = to_base_byte + offset - # Cast to_base back to pointer type translated_ptr = tl.cast(translated_ptr_byte, ptr.dtype) + if hint is not None: + translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, hint), hint) + return translated_ptr + - # Optimization to vectorize the load/store - # We can't do this in general because we don't know the shape of the tensor or block sizes - # ptr = tl.max_contiguous(tl.multiple_of(ptr, (16, 16)), (16, 32)) +@aggregate +class DeviceContext: + """ + Device-side context that encapsulates rank and heap_bases for ergonomic Iris operations. + + This aggregate provides an object-oriented interface for Iris device operations, + eliminating the need to pass heap_bases to every function call. + + Usage: + import iris + from iris import DeviceContext + + # Host-side: Get encoded context tensor + shmem = iris.iris() + context_tensor = shmem.get_device_context() + + @triton.jit + def my_kernel(context_tensor, rank: tl.constexpr, world_size: tl.constexpr, ...): + # Initialize device context from encoded tensor + ctx = DeviceContext.initialize(context_tensor, rank, world_size) + + # Use object-oriented API + data = ctx.load(buffer + offsets, from_rank=1, mask=mask) + ctx.store(buffer + offsets, data, to_rank=1, mask=mask) + old_val = ctx.atomic_add(counter, 1, to_rank=1) + + Attributes: + rank: Current rank (constexpr) + world_size: Total number of ranks (constexpr) + heap_bases: Heap base pointers for all ranks (tensor) + trace_enabled: Whether tracing is enabled (constexpr) + max_trace_events: Maximum number of trace events (constexpr) + trace_counter: Pointer to atomic event counter (tensor) + trace_buf_pid: Pointer to pid buffer (tensor) + trace_buf_pid_m: Pointer to pid_m buffer (tensor) + trace_buf_pid_n: Pointer to pid_n buffer (tensor) + trace_buf_cur_rank: Pointer to cur_rank buffer (tensor) + trace_buf_target_rank: Pointer to target_rank buffer (tensor) + trace_buf_xcc_id: Pointer to xcc_id buffer (tensor) + trace_buf_cu_id: Pointer to cu_id buffer (tensor) + trace_buf_timestamp: Pointer to timestamp buffer (tensor) + trace_buf_address: Pointer to address buffer (tensor) + """ - # 0 You can use this if your block sizes are multiples of 32. - # Largest vectorized load instruction is dwordx4 (128-bits) - # translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) - # translated_ptr = tl.max_contiguous(translated_ptr, (1, 32)) + rank: tl.constexpr + world_size: tl.constexpr + heap_bases: tl.tensor + tracing: DeviceTracing - # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) - # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) - return translated_ptr + @triton.constexpr_function + def __init__(self, rank, world_size, heap_bases, tracing): + """ + Internal constructor - use DeviceContext.initialize() instead. + + Args: + rank: Current rank (constexpr) + world_size: Total number of ranks (constexpr) + heap_bases: Heap base pointers for all ranks (tensor) + tracing: DeviceTracing instance + """ + self.rank = tl.constexpr(rank) + self.world_size = tl.constexpr(world_size) + self.heap_bases = heap_bases + self.tracing = tracing + + @staticmethod + @triton.jit + def initialize(context_tensor, rank, world_size, tracing: tl.constexpr = False): + """ + Initialize DeviceContext from the encoded context tensor. + + The context tensor has the format: + - [cur_rank, num_ranks, heap_base_0, ..., heap_base_N, trace_info...] + - If tracing=True: extracts trace buffer pointers from context_tensor + + Args: + context_tensor: Pointer to encoded context data (from Iris.get_device_context()) + rank: Current rank (must be constexpr in kernel signature) + world_size: Total number of ranks (must be constexpr in kernel signature) + tracing: Enable event tracing (constexpr, default: False) + + Returns: + DeviceContext: Initialized device context + + Example: + >>> import iris + >>> from iris import DeviceContext + >>> + >>> ctx = iris.iris() + >>> ctx.tracing.enable(max_events=1_000_000) + >>> context_tensor = ctx.get_device_context() + >>> + >>> @triton.jit + >>> def kernel(context_tensor, rank: tl.constexpr, world_size: tl.constexpr, ...): + >>> # Without tracing + >>> ctx = DeviceContext.initialize(context_tensor, rank, world_size) + >>> + >>> # With tracing + >>> ctx = DeviceContext.initialize(context_tensor, rank, world_size, tracing=True) + >>> mask = tl.full([64], True, dtype=tl.int1) # Example mask + >>> ctx.tracing.record_event_start(event_id=TraceEvent().put, target_rank=1, address=ptr, pid_m=0, pid_n=0, mask=mask) + """ + # Extract heap bases (from index 2 onwards) + heap_bases = context_tensor + 2 # Offset pointer to start at heap bases + + if tracing: + # Extract tracing info (starts after heap_bases) + trace_info_idx = 2 + world_size + 1 # Skip: cur_rank, num_ranks, heap_bases, trace_enabled flag + max_events = tl.load(context_tensor + trace_info_idx + 0) + trace_counter_ptr = tl.load(context_tensor + trace_info_idx + 1) + op_index_counter_ptr = tl.load(context_tensor + trace_info_idx + 2) + + # Cast counter pointers to pointer type + trace_counter = tl.cast(trace_counter_ptr, tl.pointer_type(tl.int32)) + op_index_counter = tl.cast(op_index_counter_ptr, tl.pointer_type(tl.int32)) + + # Extract trace buffer pointers (13 buffers) + base_idx = trace_info_idx + 3 # Updated: +3 because we now have op_index_counter + trace_buf_event_id = tl.cast(tl.load(context_tensor + base_idx + 0), tl.pointer_type(tl.int32)) + trace_buf_pid = tl.cast(tl.load(context_tensor + base_idx + 1), tl.pointer_type(tl.int32)) + trace_buf_pid_m = tl.cast(tl.load(context_tensor + base_idx + 2), tl.pointer_type(tl.int32)) + trace_buf_pid_n = tl.cast(tl.load(context_tensor + base_idx + 3), tl.pointer_type(tl.int32)) + trace_buf_cur_rank = tl.cast(tl.load(context_tensor + base_idx + 4), tl.pointer_type(tl.int32)) + trace_buf_target_rank = tl.cast(tl.load(context_tensor + base_idx + 5), tl.pointer_type(tl.int32)) + trace_buf_xcc_id = tl.cast(tl.load(context_tensor + base_idx + 6), tl.pointer_type(tl.int32)) + trace_buf_cu_id = tl.cast(tl.load(context_tensor + base_idx + 7), tl.pointer_type(tl.int32)) + trace_buf_timestamp = tl.cast(tl.load(context_tensor + base_idx + 8), tl.pointer_type(tl.int64)) + trace_buf_address = tl.cast(tl.load(context_tensor + base_idx + 9), tl.pointer_type(tl.int64)) + trace_buf_duration_cycles = tl.cast(tl.load(context_tensor + base_idx + 10), tl.pointer_type(tl.int64)) + trace_buf_op_index = tl.cast(tl.load(context_tensor + base_idx + 11), tl.pointer_type(tl.int32)) + trace_buf_payload_size = tl.cast(tl.load(context_tensor + base_idx + 12), tl.pointer_type(tl.int32)) + + # Create DeviceTracing instance + device_tracing = DeviceTracing( + enabled=tracing, + rank=rank, + max_events=max_events, + counter=trace_counter, + op_index_counter=op_index_counter, + buf_event_id=trace_buf_event_id, + buf_pid=trace_buf_pid, + buf_pid_m=trace_buf_pid_m, + buf_pid_n=trace_buf_pid_n, + buf_cur_rank=trace_buf_cur_rank, + buf_target_rank=trace_buf_target_rank, + buf_xcc_id=trace_buf_xcc_id, + buf_cu_id=trace_buf_cu_id, + buf_timestamp=trace_buf_timestamp, + buf_address=trace_buf_address, + buf_duration_cycles=trace_buf_duration_cycles, + buf_op_index=trace_buf_op_index, + buf_payload_size=trace_buf_payload_size, + ) + + return DeviceContext(rank, world_size, heap_bases, device_tracing) + else: + # When tracing disabled, use dummy pointers (never dereferenced; we return early in record_*) + dummy_ptr_i32 = tl.cast(context_tensor, tl.pointer_type(tl.int32)) + dummy_ptr_i64 = tl.cast(context_tensor, tl.pointer_type(tl.int64)) + max_events_zero = tl.full((), 0, dtype=tl.int32) + device_tracing = DeviceTracing( + enabled=False, + rank=rank, + max_events=max_events_zero, + counter=dummy_ptr_i32, + op_index_counter=dummy_ptr_i32, + buf_event_id=dummy_ptr_i32, + buf_pid=dummy_ptr_i32, + buf_pid_m=dummy_ptr_i32, + buf_pid_n=dummy_ptr_i32, + buf_cur_rank=dummy_ptr_i32, + buf_target_rank=dummy_ptr_i32, + buf_xcc_id=dummy_ptr_i32, + buf_cu_id=dummy_ptr_i32, + buf_timestamp=dummy_ptr_i64, + buf_address=dummy_ptr_i64, + buf_duration_cycles=dummy_ptr_i64, + buf_op_index=dummy_ptr_i32, + buf_payload_size=dummy_ptr_i32, + ) + + return DeviceContext(rank, world_size, heap_bases, device_tracing) + + @triton.jit + def _translate(self, ptr, from_rank, to_rank, hint: tl.constexpr = None): + """Internal pointer translation between rank address spaces.""" + return __translate(ptr, from_rank, to_rank, self.heap_bases, hint) + + @triton.jit + def load(self, pointer, from_rank, mask=None, hint: tl.constexpr = None): + """ + Loads a value from the specified rank's memory location. + + This method performs a memory read operation by translating the pointer + from the current rank's address space to the `from_rank`'s address space and loading + data from the target memory location. If the current rank and `from_rank` are the same, + this performs a local load operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `from_rank`'s address space. + from_rank (int): The rank ID from which to read the data. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint for the translated pointer. Defaults to None. + + Returns: + Block: The loaded value from the target memory location. + + Example: + >>> data = ctx.load(buffer + offsets, from_rank=1, mask=mask) + """ + translated_ptr = self._translate(pointer, self.rank, from_rank, hint) + result = tl.load(translated_ptr, mask=mask) + return result + + @triton.jit + def store(self, pointer, value, to_rank, mask=None, hint: tl.constexpr = None): + """ + Writes data to the specified rank's memory location. + + This method performs a memory write operation by translating the pointer + from the current rank's address space to the `to_rank`'s address space and storing + the provided data to the target memory location. If the current rank and `to_rank` are the same, + this performs a local store operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. + value (Block): The tensor of elements to be stored. + to_rank (int): The rank ID to which the data will be written. + mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. + + Returns: + None + + Example: + >>> ctx.store(buffer + offsets, values, to_rank=1, mask=mask) + """ + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) + tl.store(translated_ptr, value, mask=mask) + + @triton.jit + def get(self, from_ptr, to_ptr, from_rank, mask=None, hint: tl.constexpr = None): + """ + Copies data from the specified rank's memory into current rank's local memory. + + This method performs a remote load operation by translating `from_ptr` from the current + rank's address space to the `from_rank`'s address space, loading the data, and storing + it to `to_ptr` in the current rank's local memory. If the current rank and `from_rank` + are the same, this performs a local copy operation. + + Args: + from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that references memory in `from_rank`. + to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer to local memory in current rank where the data will be written. + from_rank (int): The rank ID from which to read the data. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + + Returns: + None + + Example: + >>> ctx.get(remote_ptr + offsets, local_ptr + offsets, from_rank=1, mask=mask) + """ + translated_from_ptr = self._translate(from_ptr, self.rank, from_rank, hint) + data = tl.load(translated_from_ptr, mask=mask) + tl.store(to_ptr, data, mask=mask) + + @triton.jit + def put(self, from_ptr, to_ptr, to_rank, mask=None, hint: tl.constexpr = None): + """ + Copies data from current rank's local memory to the specified rank's memory. + + This method performs a remote store operation by loading data from `from_ptr` in the + current rank's local memory, translating `to_ptr` from the current rank's address space + to the `to_rank`'s address space, and storing the data to the target memory location. + If the current rank and `to_rank` are the same, this performs a local copy operation. + + Args: + from_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer to local memory in current rank from which to read data. + to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that references memory in `to_rank`. + to_rank (int): The rank ID to which the data will be written. + mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + + Returns: + None + + Example: + >>> ctx.put(local_ptr + offsets, remote_ptr + offsets, to_rank=1, mask=mask) + """ + translated_to_ptr = self._translate(to_ptr, self.rank, to_rank, hint) + data = tl.load(from_ptr, mask=mask) + tl.store(translated_to_ptr, data, mask=mask) + + @triton.jit + def copy(self, src_ptr, dst_ptr, from_rank, to_rank, mask=None, hint: tl.constexpr = None): + """ + Copies data from one rank's memory to another rank's memory. + + This method performs a data transfer by translating `src_ptr` from the current rank's + address space to the `from_rank`'s address space, performing a masked load from the + translated source, translating `dst_ptr` to the `to_rank`'s address space, and storing + the loaded data to the target memory location. If `from_rank` and `to_rank` are the same, + this performs a local copy operation. It is undefined behaviour if the current rank is + neither `from_rank` nor `to_rank`. + + Args: + src_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that references `from_rank`'s local memory. + dst_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that references `to_rank`'s local memory. + from_rank (int): The rank ID that owns `src_ptr` (source rank). + to_rank (int): The rank ID that will receive the data (destination rank). + mask (Block of triton.int1, optional): If mask[idx] is false, do not load from src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. + + Returns: + None + + Example: + >>> ctx.copy(src_ptr + offsets, dst_ptr + offsets, from_rank=1, to_rank=0, mask=mask) + """ + cur_base = tl.load(self.heap_bases + self.rank) + from_base = tl.load(self.heap_bases + from_rank) + to_base = tl.load(self.heap_bases + to_rank) + + src_ptr_int = tl.cast(src_ptr, tl.uint64) + src_offset = src_ptr_int - cur_base + + dst_ptr_int = tl.cast(dst_ptr, tl.uint64) + dst_offset = dst_ptr_int - cur_base + + from_base_byte = tl.cast(from_base, tl.pointer_type(tl.int8)) + to_base_byte = tl.cast(to_base, tl.pointer_type(tl.int8)) + + translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) + translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) + + if hint is not None: + translated_src = tl.max_contiguous(tl.multiple_of(translated_src, hint), hint) + translated_dst = tl.max_contiguous(tl.multiple_of(translated_dst, hint), hint) + + data = tl.load(translated_src, mask=mask) + tl.store(translated_dst, data, mask=mask) + + @triton.jit + def atomic_add(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): + """ + Performs an atomic add at the specified rank's memory location. + + This method performs an atomic addition operation by translating the pointer + from the current rank's address space to the `to_rank`'s address space and atomically + adding the provided data to the `to_rank` memory location. If the current rank and + `to_rank` are the same, this performs a local atomic addition operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the current rank's address space that will be translated to the `to_rank`'s address space. + val (Block of dtype=pointer.dtype.element_ty): The values with which to perform the atomic operation. + to_rank (int): The rank ID to which the atomic operation will be performed. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. + scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + + Example: + >>> old_val = ctx.atomic_add(counter, 1, to_rank=1) + """ + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) + return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + @triton.jit + def atomic_sub(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): + """ + Atomically subtracts data from the specified rank's memory location. + + This method performs an atomic subtraction operation by translating the pointer + from the current rank's address space to the `to_rank`'s address space and atomically + subtracting the provided data from the `to_rank` memory location. If the current rank + and `to_rank` are the same, this performs a local atomic subtraction operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that will be translated to the `to_rank`'s address space. + val (Block): The tensor of elements to be subtracted atomically. + to_rank (int): The rank ID to which the atomic operation will be performed. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel", and "relaxed". Defaults to "acq_rel". + scope (str, optional): Defines the scope of threads that observe the synchronizing effect. Acceptable values are "gpu" (default), "cta", or "sys". The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + """ + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) + return tl.atomic_sub(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + @triton.jit + def atomic_cas(self, pointer, cmp, val, to_rank, sem=None, scope=None, hint: tl.constexpr = None): + """ + Performs an atomic compare-and-swap at the specified rank's memory location. + + This method performs an atomic compare-and-swap operation by translating the pointer + from the current rank's address space to the `to_rank`'s address space and atomically + comparing the value at the memory location with `cmp`. If they match, it replaces the + value with `val`. If the current rank and `to_rank` are the same, this performs a local + atomic CAS operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory location in the current rank's address space that will be translated to the `to_rank`'s address space. + cmp (Block): The expected value to compare against. + val (Block): The new value to store if comparison succeeds. + to_rank (int): The rank ID to which the atomic operation will be performed. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel", and "relaxed". Defaults to "acq_rel". + scope (str, optional): Defines the scope of threads that observe the synchronizing effect. Acceptable values are "gpu" (default), "cta", or "sys". The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + """ + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) + return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) + + @triton.jit + def atomic_xchg(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): + """ + Performs an atomic exchange at the specified rank's memory location. + + This method performs an atomic exchange operation by translating the pointer + from the current rank's address space to the `to_rank`'s address space and atomically + swapping the value at the memory location with `val`. If the current rank and `to_rank` + are the same, this performs a local atomic exchange operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the current rank's address space that will be translated to the `to_rank`'s address space. + val (Block): The new values to store. + to_rank (int): The rank ID to which the atomic operation will be performed. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel", and "relaxed". Defaults to "acq_rel". + scope (str, optional): Defines the scope of threads that observe the synchronizing effect. Acceptable values are "gpu" (default), "cta", or "sys". The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + """ + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) + return tl.atomic_xchg(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + @triton.jit + def atomic_xor(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): + """ + Performs an atomic XOR at the specified rank's memory location. + + This method performs an atomic bitwise XOR operation by translating the pointer + from the current rank's address space to the `to_rank`'s address space and atomically + XOR'ing the value at the memory location with `val`. If the current rank and `to_rank` + are the same, this performs a local atomic XOR operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the current rank's address space that will be translated to the `to_rank`'s address space. + val (Block): The values to XOR with. + to_rank (int): The rank ID to which the atomic operation will be performed. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel", and "relaxed". Defaults to "acq_rel". + scope (str, optional): Defines the scope of threads that observe the synchronizing effect. Acceptable values are "gpu" (default), "cta", or "sys". The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + """ + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) + return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + @triton.jit + def atomic_and(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): + """ + Performs an atomic AND at the specified rank's memory location. + + This method performs an atomic bitwise AND operation by translating the pointer + from the current rank's address space to the `to_rank`'s address space and atomically + AND'ing the value at the memory location with `val`. If the current rank and `to_rank` + are the same, this performs a local atomic AND operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the current rank's address space that will be translated to the `to_rank`'s address space. + val (Block): The values to AND with. + to_rank (int): The rank ID to which the atomic operation will be performed. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel", and "relaxed". Defaults to "acq_rel". + scope (str, optional): Defines the scope of threads that observe the synchronizing effect. Acceptable values are "gpu" (default), "cta", or "sys". The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + """ + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) + return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + @triton.jit + def atomic_or(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): + """ + Performs an atomic OR at the specified rank's memory location. + + This method performs an atomic bitwise OR operation by translating the pointer + from the current rank's address space to the `to_rank`'s address space and atomically + OR'ing the value at the memory location with `val`. If the current rank and `to_rank` + are the same, this performs a local atomic OR operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the current rank's address space that will be translated to the `to_rank`'s address space. + val (Block): The values to OR with. + to_rank (int): The rank ID to which the atomic operation will be performed. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel", and "relaxed". Defaults to "acq_rel". + scope (str, optional): Defines the scope of threads that observe the synchronizing effect. Acceptable values are "gpu" (default), "cta", or "sys". The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + """ + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) + return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + @triton.jit + def atomic_min(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): + """ + Performs an atomic minimum at the specified rank's memory location. + + This method performs an atomic minimum operation by translating the pointer + from the current rank's address space to the `to_rank`'s address space and atomically + updating the memory location to the minimum of its current value and `val`. If the + current rank and `to_rank` are the same, this performs a local atomic min operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the current rank's address space that will be translated to the `to_rank`'s address space. + val (Block): The values to compare with. + to_rank (int): The rank ID to which the atomic operation will be performed. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel", and "relaxed". Defaults to "acq_rel". + scope (str, optional): Defines the scope of threads that observe the synchronizing effect. Acceptable values are "gpu" (default), "cta", or "sys". The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + """ + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) + return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope) + + @triton.jit + def atomic_max(self, pointer, val, to_rank, mask=None, sem=None, scope=None, hint: tl.constexpr = None): + """ + Performs an atomic maximum at the specified rank's memory location. + + This method performs an atomic maximum operation by translating the pointer + from the current rank's address space to the `to_rank`'s address space and atomically + updating the memory location to the maximum of its current value and `val`. If the + current rank and `to_rank` are the same, this performs a local atomic max operation. + + Args: + pointer (triton.PointerType, or block of dtype=triton.PointerType): The memory locations in the current rank's address space that will be translated to the `to_rank`'s address space. + val (Block): The values to compare with. + to_rank (int): The rank ID to which the atomic operation will be performed. + mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. + sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel", and "relaxed". Defaults to "acq_rel". + scope (str, optional): Defines the scope of threads that observe the synchronizing effect. Acceptable values are "gpu" (default), "cta", or "sys". The default value is "gpu". + + Returns: + Block: The data stored at pointer before the atomic operation. + """ + translated_ptr = self._translate(pointer, self.rank, to_rank, hint) + return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def load(pointer, to_rank, from_rank, heap_bases, mask=None, other=None, cache_modifier=None, volatile=False): +def load( + pointer, + to_rank, + from_rank, + heap_bases, + mask=None, + other=None, + cache_modifier=None, + volatile=False, + hint: tl.constexpr = None, +): """ Loads a value from the specified rank's memory location. @@ -1757,6 +1891,7 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None, other=None, cache_m volatile (bool, optional): If True, disables compiler optimizations that could reorder or eliminate the load. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: Block: The loaded value from the target memory location. @@ -1770,13 +1905,22 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None, other=None, cache_m >>> data = iris.load(ptr, cur_rank, remote_rank, heap_bases) >>> return data """ - translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases) + translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases, hint) result = tl.load(translated_ptr, mask=mask, other=other, cache_modifier=cache_modifier, volatile=volatile) return result @triton.jit -def store(pointer, value, from_rank, to_rank, heap_bases, mask=None, cache_modifier=None): +def store( + pointer, + value, + from_rank, + to_rank, + heap_bases, + mask=None, + hint: tl.constexpr = None, + cache_modifier=None, +): """ Writes data to the specified rank's memory location. @@ -1797,7 +1941,8 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None, cache_modif to_rank (int): The rank ID to which the data will be written. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. - cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). + cache_modifier (str, optional): Controls cache behavior of the store. Only effective for local stores (when `from_rank == to_rank`). Supported values are: - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. @@ -1817,8 +1962,11 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None, cache_modif >>> value = 42 >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) - tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) + if from_rank == to_rank: + tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier) + else: + tl.store(translated_ptr, value, mask=mask) @triton.jit @@ -1833,6 +1981,7 @@ def copy( other=None, load_cache_modifier=None, store_cache_modifier=None, + hint: tl.constexpr = None, ): """ Copies data from the specified rank's memory into the destination rank's memory. @@ -1851,19 +2000,19 @@ def copy( heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. - load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. - store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + store_cache_modifier (str, optional): Controls cache behavior of the store. Only effective for local stores (when `to_rank == cur_rank`). Supported values are: - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointers. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: None @@ -1893,8 +2042,15 @@ def copy( translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) + if hint is not None: + translated_src = tl.max_contiguous(tl.multiple_of(translated_src, hint), hint) + translated_dst = tl.max_contiguous(tl.multiple_of(translated_dst, hint), hint) + data = tl.load(translated_src, mask=mask, other=other, cache_modifier=load_cache_modifier) - tl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier) + if to_rank == cur_rank: + tl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier) + else: + tl.store(translated_dst, data, mask=mask) @triton.jit @@ -1908,6 +2064,7 @@ def get( other=None, load_cache_modifier=None, store_cache_modifier=None, + hint: tl.constexpr = None, ): """ Copies data from the specified rank's memory to the current rank's local memory. @@ -1925,19 +2082,19 @@ def get( heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. - load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. - store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + store_cache_modifier (str, optional): Controls cache behavior of the store. The store is always to local memory (`to_ptr`), so this is always applied. Supported values are: - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: None @@ -1949,7 +2106,7 @@ def get( >>> to_rank = 0 >>> iris.get(remote_ptr, local_ptr, from_rank, to_rank, heap_bases) """ - translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases) + translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases, hint) data = tl.load(translated_from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) @@ -1967,6 +2124,7 @@ def put( other=None, load_cache_modifier=None, store_cache_modifier=None, + hint: tl.constexpr = None, ): """ Copies data from the current rank's local memory to the specified rank's memory. @@ -1990,12 +2148,13 @@ def put( - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. - store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + store_cache_modifier (str, optional): Controls cache behavior of the store. Only effective for local stores (when `from_rank == to_rank`). Supported values are: - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Use a scalar for 1-D (e.g. 16) or a tuple for N-D (e.g. (1, 16)). Defaults to None (no hint). Returns: None @@ -2007,15 +2166,20 @@ def put( >>> to_rank = 1 >>> iris.put(local_ptr, remote_ptr, from_rank, to_rank, heap_bases) """ - translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) + translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases, hint) data = tl.load(from_ptr, mask=mask, other=other, cache_modifier=load_cache_modifier) - tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) + if from_rank == to_rank: + tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) + else: + tl.store(translated_to_ptr, data, mask=mask) @triton.jit -def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_add( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic add at the specified rank's memory location. @@ -2033,6 +2197,7 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2046,12 +2211,14 @@ def atomic_add(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> increment = 5 >>> old_val = iris.atomic_add(ptr, increment, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_add(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_sub( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Atomically subtracts data from the specified rank's memory location. @@ -2069,6 +2236,7 @@ def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". Defaults to "acq_rel". scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). Defaults to "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The value at the memory location before the atomic subtraction. @@ -2082,12 +2250,12 @@ def atomic_sub(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> decrement = 3 >>> old_val = iris.atomic_sub(ptr, decrement, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_sub(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None): +def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scope=None, hint: tl.constexpr = None): """ Atomically compares and exchanges the specified rank's memory location. @@ -2105,6 +2273,7 @@ def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scop heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". Defaults to "acq_rel". scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). Defaults to "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The value contained at the memory location before the atomic operation attempt. @@ -2119,12 +2288,14 @@ def atomic_cas(pointer, cmp, val, from_rank, to_rank, heap_bases, sem=None, scop >>> new_val = 42 >>> old_val = iris.atomic_cas(ptr, expected, new_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_cas(translated_ptr, cmp, val, sem=sem, scope=scope) @triton.jit -def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_xchg( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic exchange at the specified rank's memory location. @@ -2142,6 +2313,7 @@ def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=Non mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2155,12 +2327,14 @@ def atomic_xchg(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=Non >>> new_value = 99 >>> old_val = iris.atomic_xchg(ptr, new_value, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_xchg(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_xor( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic xor at the specified rank's memory location. @@ -2178,6 +2352,7 @@ def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2191,12 +2366,14 @@ def atomic_xor(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> mask_val = 0xFF >>> old_val = iris.atomic_xor(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_xor(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_and( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic and at the specified rank's memory location. @@ -2214,6 +2391,7 @@ def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2227,12 +2405,12 @@ def atomic_and(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> mask_val = 0x0F >>> old_val = iris.atomic_and(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_and(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None): """ Performs an atomic or at the specified rank's memory location. @@ -2250,6 +2428,7 @@ def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2263,12 +2442,14 @@ def atomic_or(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, >>> mask_val = 0xF0 >>> old_val = iris.atomic_or(ptr, mask_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_or(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_min( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic min at the specified rank's memory location. @@ -2286,6 +2467,7 @@ def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2299,12 +2481,14 @@ def atomic_min(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> new_val = 10 >>> old_val = iris.atomic_min(ptr, new_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_min(translated_ptr, val, mask=mask, sem=sem, scope=scope) @triton.jit -def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None): +def atomic_max( + pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None, scope=None, hint: tl.constexpr = None +): """ Performs an atomic max at the specified rank's memory location. @@ -2322,6 +2506,7 @@ def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None mask (Block of triton.int1, optional): If mask[idx] is false, do not perform the atomic operation at address pointer[idx]. Defaults to None. sem (str, optional): Specifies the memory semantics for the operation. Acceptable values are "acquire", "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, the function defaults to using "acq_rel" semantics. scope (str, optional): Defines the scope of threads that observe the synchronizing effect of the atomic operation. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None (no hint). Returns: Block: The data stored at pointer before the atomic operation. @@ -2335,23 +2520,29 @@ def atomic_max(pointer, val, from_rank, to_rank, heap_bases, mask=None, sem=None >>> new_val = 100 >>> old_val = iris.atomic_max(ptr, new_val, cur_rank, remote_rank, heap_bases) """ - translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) + translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases, hint) return tl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope) -def iris(heap_size=1 << 30): +def iris(heap_size=1 << 30, allocator_type="torch"): """ Create and return an Iris instance with the specified heap size. Args: heap_size (int): Size of the heap in bytes. Defaults to 1GB. + allocator_type (str): Type of allocator to use. Options: "torch" (default), "vmem". + Can be overridden with IRIS_ALLOCATOR environment variable. Returns: Iris: An initialized Iris instance. Example: >>> import iris - >>> iris_ctx = iris.iris(2**30) # 1GB heap + >>> iris_ctx = iris.iris(2**30) # 1GB heap with default (torch) allocator + >>> tensor = iris_ctx.zeros(1024, 1024) + + >>> # Use VMem allocator + >>> iris_ctx = iris.iris(2**30, allocator_type="vmem") >>> tensor = iris_ctx.zeros(1024, 1024) """ - return Iris(heap_size) + return Iris(heap_size, allocator_type) diff --git a/iris/ops/all_gather_matmul.py b/iris/ops/all_gather_matmul.py index 49bca5cf6..5d700206c 100644 --- a/iris/ops/all_gather_matmul.py +++ b/iris/ops/all_gather_matmul.py @@ -16,8 +16,7 @@ import iris import iris.x -from tritonblas.kernels.stages.algorithms.binary import add_vector -from tritonblas.kernels.stages.algorithms.unary import convert_dtype +from tritonblas.kernels.stages import GemmContext, ScheduleContext from .config import FusedConfig from .workspace import FusedWorkspace @@ -29,18 +28,18 @@ def _fused_all_gather_matmul_kernel( B, C, bias_ptr, - M: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - K_local: tl.constexpr, - stride_am: tl.constexpr, - stride_ak: tl.constexpr, - stride_bk: tl.constexpr, - stride_bn: tl.constexpr, - stride_cm: tl.constexpr, - stride_cn: tl.constexpr, - stride_bias: tl.constexpr, - heap_bases: tl.tensor, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bias, + context_tensor: tl.tensor, cur_rank: tl.constexpr, world_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, @@ -49,75 +48,71 @@ def _fused_all_gather_matmul_kernel( GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr, NUM_XCDS: tl.constexpr, + NUM_K_BLOCKS_LOCAL: tl.constexpr, BIAS: tl.constexpr, EVEN_K: tl.constexpr, ALLOW_TF32: tl.constexpr, ): """Fused all-gather + GEMM kernel using pull pattern.""" - pid = tl.program_id(0) - - # Handle multi-XCD devices - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) - - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - total_tiles = num_pid_m * num_pid_n - - tl.assume(stride_am > 0) - tl.assume(stride_ak > 0) - tl.assume(stride_bk > 0) - tl.assume(stride_bn > 0) - tl.assume(stride_cm > 0) - tl.assume(stride_cn > 0) - - acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 - - # Persistent loop over output tiles - for tile_id in range(pid, total_tiles, NUM_SMS): - # Compute tile coordinates with swizzling - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - # Compute row and column indices - rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) - rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) - - # Initialize accumulator - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + # ═══════════════════════════════════════════════════════════════════════ + # Create tritonblas context and scheduler for GEMM configuration + # ═══════════════════════════════════════════════════════════════════════ + gemm_ctx = GemmContext( + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + num_sms=NUM_SMS, + num_xcds=NUM_XCDS, + group_size_m=GROUP_SIZE_M, + even_k=EVEN_K, + allow_tf32=ALLOW_TF32, + ) + sched = ScheduleContext(M, N, K, gemm_ctx) + + # Persistent loop over output tiles using scheduler + start, total, stride = sched.persistent_tile_range() + for tile_id in range(start, total, stride): + # Get tile coordinates with swizzling from scheduler + out_tile = sched.get_tile_from_idx(tile_id) + pid_m = out_tile.pid_m + pid_n = out_tile.pid_n + + # Initialize accumulator using GemmContext + acc = gemm_ctx.init_accumulator() # Create DeviceContext and TensorView for gather operations - ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases) - src_view = iris.x.TensorView(A_sharded, M, K_local, stride_am, stride_ak) + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) + src_view = iris.x.make_tensor_view(A_sharded, M, K_local, stride_am, stride_ak) + + # Precompute B column offsets for this output tile (constant across K iterations) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + rn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_SIZE_N), BLOCK_SIZE_N) # Loop over all ranks to pull and accumulate + # Note: K = world_size * K_local, so we iterate over each rank's K_local contribution for source_rank_id in range(world_size): - loop_k_local = tl.cdiv(K_local, BLOCK_SIZE_K) - if not EVEN_K: - loop_k_local -= 1 + # Use pre-computed loop bound (constexpr for static unrolling) + loop_k_local = NUM_K_BLOCKS_LOCAL if EVEN_K else NUM_K_BLOCKS_LOCAL - 1 # Loop over K dimension for this rank's shard for k_block_idx in range(0, loop_k_local): k_offset = k_block_idx * BLOCK_SIZE_K # Create tile view for this K block - tile_k = k_offset // BLOCK_SIZE_K + # Promote tile_k to tensor (TileView expects tl.tensor for pid_n) + tile_k = pid_m * 0 + k_offset // BLOCK_SIZE_K k_tile = iris.x.TileView(pid_m, tile_k, BLOCK_SIZE_M, BLOCK_SIZE_K) # Pull A tile from source_rank_id using gather primitive a = iris.x.gather(k_tile, src_view, source_rank_id, ctx) - # Load B tile - rk_local = k_offset + tl.arange(0, BLOCK_SIZE_K) - rk_global = (source_rank_id * K_local) + rk_local - B_ptr = B + rk_global[:, None] * stride_bk + rn[None, :] * stride_bn - b = tl.load(tl.multiple_of(B_ptr, (16, 1))) + # Load B tile using direct pointer arithmetic + # Compute global K row index for B matrix + global_k_offset = source_rank_id * K_local + k_block_idx * BLOCK_SIZE_K + rk = global_k_offset + tl.arange(0, BLOCK_SIZE_K) + rk = tl.max_contiguous(tl.multiple_of(rk % K, BLOCK_SIZE_K), BLOCK_SIZE_K) + B_ptrs = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(B_ptrs) # Accumulate if ALLOW_TF32: @@ -128,40 +123,38 @@ def _fused_all_gather_matmul_kernel( # Handle remaining K elements if not evenly divisible if not EVEN_K: k_offset = loop_k_local * BLOCK_SIZE_K - tile_k = k_offset // BLOCK_SIZE_K + # Promote tile_k to tensor (TileView expects tl.tensor for pid_n) + tile_k = pid_m * 0 + k_offset // BLOCK_SIZE_K k_tile = iris.x.TileView(pid_m, tile_k, BLOCK_SIZE_M, BLOCK_SIZE_K) # Pull A tile from source_rank_id using gather primitive a = iris.x.gather(k_tile, src_view, source_rank_id, ctx) - rk_local = k_offset + tl.arange(0, BLOCK_SIZE_K) - rk_global = (source_rank_id * K_local) + rk_local - rk_global_mask = rk_global < K - B_ptr = B + rk_global[:, None] * stride_bk + rn[None, :] * stride_bn - b = tl.load(tl.multiple_of(B_ptr, (16, 1)), mask=rk_global_mask[:, None], other=0.0) + # Load B tile with boundary handling + global_k_offset = source_rank_id * K_local + loop_k_local * BLOCK_SIZE_K + rk = global_k_offset + tl.arange(0, BLOCK_SIZE_K) + B_ptrs = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + b_mask = (rk[:, None] < K) & (rn[None, :] < N) + b = tl.load(B_ptrs, mask=b_mask, other=0.0) if ALLOW_TF32: acc = tl.dot(a, b, acc, allow_tf32=True) else: acc += tl.dot(a, b, allow_tf32=False) - # Add bias if provided using tritonBLAS + # Add bias if provided if BIAS: + rm, _ = out_tile.indices() bias_vector = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0) - acc = add_vector(acc, bias_vector, QUANTIZED=False) - - # Convert to output dtype using tritonBLAS - c = convert_dtype(acc, C.type.element_ty) - - # Store result (manual for now, tritonBLAS store has issues with our indices) - C_ptr = ( - C - + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] * stride_cm - + (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] * stride_cn - ) - mask = ((pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))[:, None] < M) & ( - (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))[None, :] < N - ) + acc = acc + bias_vector[:, None] + + # Convert to output dtype + c = acc.to(C.type.element_ty) + + # Store result using tritonblas Tile + rm, rn = out_tile.indices() + C_ptr = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + mask = (rm[:, None] < M) & (rn[None, :] < N) tl.store(C_ptr, c, mask=mask) @@ -250,6 +243,7 @@ def all_gather_matmul( num_sms = props.multi_processor_count even_k = K_local % config.block_size_k == 0 + num_k_blocks_local = (K_local + config.block_size_k - 1) // config.block_size_k # Launch single fused kernel grid = (num_sms,) @@ -269,7 +263,7 @@ def all_gather_matmul( stride_cm, stride_cn, stride_bias, - shmem.heap_bases, + shmem.get_device_context(), rank, world_size, config.block_size_m, @@ -278,6 +272,7 @@ def all_gather_matmul( config.group_size_m, num_sms, config.num_xcds, + num_k_blocks_local, use_bias, even_k, config.allow_tf32, diff --git a/iris/ops/matmul_all_gather.py b/iris/ops/matmul_all_gather.py index 4e2e36e86..ad42ac041 100644 --- a/iris/ops/matmul_all_gather.py +++ b/iris/ops/matmul_all_gather.py @@ -17,8 +17,7 @@ import iris import iris.x -from tritonblas.kernels.stages.algorithms.binary import add_vector -from tritonblas.kernels.stages.algorithms.unary import convert_dtype +from tritonblas.kernels.stages import GemmContext, ScheduleContext, make_tensor_view from .config import FusedConfig from .workspace import FusedWorkspace @@ -30,18 +29,18 @@ def _fused_matmul_all_gather_kernel( B, # (K, N) - replicated across ranks C_gathered, # (M, N) - gathered output (M = M_local * world_size) bias_ptr, - M_local: tl.constexpr, - M: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - stride_am: tl.constexpr, - stride_ak: tl.constexpr, - stride_bk: tl.constexpr, - stride_bn: tl.constexpr, - stride_cm_gathered: tl.constexpr, - stride_cn_gathered: tl.constexpr, - stride_bias: tl.constexpr, - heap_bases: tl.tensor, + M_local, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm_gathered, + stride_cn_gathered, + stride_bias, + context_tensor: tl.tensor, cur_rank: tl.constexpr, world_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, @@ -60,97 +59,45 @@ def _fused_matmul_all_gather_kernel( Computes local GEMM tile and immediately scatters to all ranks. No intermediate buffer needed - direct from registers to remote memory. """ - pid = tl.program_id(0) - - # Handle multi-XCD devices - if NUM_XCDS != 1: - pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS) - - num_pid_m = tl.cdiv(M_local, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - total_tiles = num_pid_m * num_pid_n - - tl.assume(stride_am > 0) - tl.assume(stride_ak > 0) - tl.assume(stride_bk > 0) - tl.assume(stride_bn > 0) - tl.assume(stride_cm_gathered > 0) - tl.assume(stride_cn_gathered > 0) - - acc_dtype = tl.int32 if C_gathered.type.element_ty == tl.int8 else tl.float32 - - # Persistent loop over local tiles - for tile_id in range(pid, total_tiles, NUM_SMS): - # Compute tile coordinates with swizzling - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - # Compute row and column indices for local tile - rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M_local - rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) - rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) - - # Initialize accumulator - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - - # Compute number of K tiles - loop_k = tl.cdiv(K, BLOCK_SIZE_K) - if not EVEN_K: - loop_k -= 1 - - # GEMM loop over K dimension - for k_tile_idx in range(0, loop_k): - k_offset = k_tile_idx * BLOCK_SIZE_K - rk = k_offset + tl.arange(0, BLOCK_SIZE_K) - - # Load A tile - A_ptr = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - a = tl.load(tl.multiple_of(A_ptr, (1, 16))) - - # Load B tile - B_ptr = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - b = tl.load(tl.multiple_of(B_ptr, (16, 1))) - - # Accumulate - if ALLOW_TF32: - acc = tl.dot(a, b, acc, allow_tf32=True) - else: - acc += tl.dot(a, b, allow_tf32=False) - - # Handle remaining K elements if not evenly divisible - if not EVEN_K: - k_offset = loop_k * BLOCK_SIZE_K - rk = k_offset + tl.arange(0, BLOCK_SIZE_K) - rk_mask = rk < K - - A_ptr = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - a = tl.load(tl.multiple_of(A_ptr, (1, 16)), mask=rk_mask[None, :], other=0.0) - - B_ptr = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - b = tl.load(tl.multiple_of(B_ptr, (16, 1)), mask=rk_mask[:, None], other=0.0) - - if ALLOW_TF32: - acc = tl.dot(a, b, acc, allow_tf32=True) - else: - acc += tl.dot(a, b, allow_tf32=False) - - # Add bias if provided using tritonBLAS + # ═══════════════════════════════════════════════════════════════════════ + # Create tritonblas views, context, and scheduler for GEMM + # ═══════════════════════════════════════════════════════════════════════ + tensorA = make_tensor_view(A, M_local, K, stride_am, stride_ak) + tensorB = make_tensor_view(B, K, N, stride_bk, stride_bn) + gemm_ctx = GemmContext( + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + num_sms=NUM_SMS, + num_xcds=NUM_XCDS, + group_size_m=GROUP_SIZE_M, + even_k=EVEN_K, + allow_tf32=ALLOW_TF32, + ) + sched = ScheduleContext(M_local, N, K, gemm_ctx) + + # Persistent loop over local tiles using scheduler + start, total, stride = sched.persistent_tile_range() + for tile_id in range(start, total, stride): + # Get tile coordinates with swizzling from scheduler + out_tile = sched.get_tile_from_idx(tile_id) + + # GEMM using tritonblas stages + acc = gemm_ctx.reduce_axis(tensorA, tensorB, out_tile) + + # Add bias if provided if BIAS: + rm, _ = out_tile.indices() bias_vector = tl.load(bias_ptr + rm * stride_bias, mask=rm < M_local, other=0.0) - acc = add_vector(acc, bias_vector, QUANTIZED=False) + acc = acc + bias_vector[:, None] - # Convert to output dtype using tritonBLAS - c = convert_dtype(acc, C_gathered.type.element_ty) + # Convert to output dtype + c = acc.to(C_gathered.type.element_ty) # Create DeviceContext and destination TensorView for all-gather - ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases) - dst_view = iris.x.TensorView(C_gathered, M, N, stride_cm_gathered, stride_cn_gathered) - tile_obj = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) + dst_view = iris.x.make_tensor_view(C_gathered, M, N, stride_cm_gathered, stride_cn_gathered) + tile_obj = iris.x.Tile(out_tile.pid_m, out_tile.pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) # Scatter this tile to all ranks using iris.x.all_gather # dim=0 means scatter along M dimension (rows) @@ -288,7 +235,7 @@ def matmul_all_gather( stride_cm_gathered, stride_cn_gathered, stride_bias, - shmem.heap_bases, + shmem.get_device_context(), rank, world_size, config.block_size_m, diff --git a/iris/ops/matmul_all_reduce.py b/iris/ops/matmul_all_reduce.py index 056ebd91f..73bea92c2 100644 --- a/iris/ops/matmul_all_reduce.py +++ b/iris/ops/matmul_all_reduce.py @@ -13,6 +13,8 @@ import triton import triton.language as tl +from tritonblas.kernels.stages import GemmContext, make_tensor_view, Tile + from .config import FusedConfig from .workspace import FusedWorkspace import iris @@ -26,21 +28,22 @@ def _fused_matmul_all_reduce_kernel( C, aux_buffer, locks, - M: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - stride_am: tl.constexpr, - stride_ak: tl.constexpr, - stride_bk: tl.constexpr, - stride_bn: tl.constexpr, - stride_cm: tl.constexpr, - stride_cn: tl.constexpr, - heap_bases: tl.tensor, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + context_tensor: tl.tensor, cur_rank: tl.constexpr, world_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + EVEN_K: tl.constexpr, VARIANT: tl.constexpr, ): """ @@ -57,8 +60,8 @@ def _fused_matmul_all_reduce_kernel( - 'two_shot': Work distribution with reduce-scatter then all-gather pattern The kernel for each output tile: - 1. Computes GEMM: local_tile = A_tile @ B_tile - 2. Uses spinlock-protected read-modify-write to accumulate to all ranks + 1. Computes GEMM using tritonblas GemmContext + 2. Uses the specified variant for all-reduce across ranks Args: A: Pointer to input matrix A of shape (M, K) - local rank's data @@ -71,53 +74,55 @@ def _fused_matmul_all_reduce_kernel( stride_am, stride_ak: Strides for A tensor stride_bk, stride_bn: Strides for B tensor stride_cm, stride_cn: Strides for C tensor - heap_bases: Heap base pointers for all ranks + context_tensor: Device context tensor for RMA operations cur_rank: Current rank world_size: Total number of ranks BLOCK_SIZE_M: Block size for M dimension BLOCK_SIZE_N: Block size for N dimension BLOCK_SIZE_K: Block size for K dimension + EVEN_K: Whether K is evenly divisible by BLOCK_SIZE_K """ - # Get program ID and compute grid dimensions + # Get program ID and compute which tile this program handles pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - - # Compute which tile this program handles - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - - # Compute row and column indices for this tile - rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - # Initialize accumulator for GEMM - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # GEMM loop over K dimension - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - - # Load A tile - A_ptr = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - a = tl.load(A_ptr, mask=(rm[:, None] < M) & (rk[None, :] < K), other=0.0) - - # Load B tile - B_ptr = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - b = tl.load(B_ptr, mask=(rk[:, None] < K) & (rn[None, :] < N), other=0.0) + num_tiles_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_tiles_n + pid_n = pid % num_tiles_n + + # ═══════════════════════════════════════════════════════════════════════ + # GEMM using tritonblas stages + # ═══════════════════════════════════════════════════════════════════════ + tensorA = make_tensor_view(A, M, K, stride_am, stride_ak) + tensorB = make_tensor_view(B, K, N, stride_bk, stride_bn) + gemm_ctx = GemmContext( + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + num_sms=1, + even_k=EVEN_K, + ) + out_tile = Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N) + acc = gemm_ctx.reduce_axis(tensorA, tensorB, out_tile) - # Accumulate - acc += tl.dot(a, b) + # Get row and column indices from tile (needed for one_shot/two_shot variants) + rm, rn = out_tile.indices() # Convert to output dtype - c = acc.to(C.dtype.element_ty) + c = acc.to(C.type.element_ty) # Create views and context - ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases) - dst_view = iris.x.TensorView(C, M, N, stride_cm, stride_cn) + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) + dst_view = iris.x.make_tensor_view(C, M, N, stride_cm, stride_cn) - # For one_shot and two_shot: store tile to aux_buffer and signal ready with lock - if VARIANT == "one_shot" or VARIANT == "two_shot": + # Create tile object once for all variants + tile_obj = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) + + # Dispatch to appropriate all-reduce variant + if VARIANT == "atomic": + iris.x.all_reduce_atomic(tile_obj, dst_view, ctx) + elif VARIANT == "spinlock": + iris.x.all_reduce_spinlock(tile_obj, dst_view, locks, ctx) + elif VARIANT == "one_shot" or VARIANT == "two_shot": + # For one_shot and two_shot: store tile to aux_buffer and signal ready with lock # Store GEMM result to aux_buffer (avoid race condition with final output) temp_ptr = aux_buffer + rm[:, None] * stride_cm + rn[None, :] * stride_cn tl.store(temp_ptr, c, mask=(rm[:, None] < M) & (rn[None, :] < N), cache_modifier=".wt") @@ -125,35 +130,17 @@ def _fused_matmul_all_reduce_kernel( # Signal tile is ready by unlocking (set lock to 1) # Use atomic_xchg with release semantics to ensure memory ordering - num_tiles_n = tl.cdiv(N, BLOCK_SIZE_N) tile_id = pid_m * num_tiles_n + pid_n lock_ptr = locks + tile_id tl.atomic_xchg(lock_ptr, 1, sem="release", scope="gpu") # Release ensures prior stores visible - # Create src_view pointing to aux_buffer - src_view = iris.x.TensorView(aux_buffer, M, N, stride_cm, stride_cn) + # Create source view only when needed (aux_buffer is not None) + src_view = iris.x.make_tensor_view(aux_buffer, M, N, stride_cm, stride_cn) - # Dispatch to appropriate all-reduce variant - if VARIANT == "atomic": - # Atomic uses tile.data directly (no intermediate store needed) - tile_obj = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) - iris.x.all_reduce_atomic(tile_obj, dst_view, ctx) - elif VARIANT == "spinlock": - # Spinlock uses tile.data directly and lock for mutual exclusion - tile_obj = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) - iris.x.all_reduce_spinlock(tile_obj, dst_view, locks, ctx) - elif VARIANT == "one_shot": - # one_shot loads from all ranks (data already in memory, locks signal readiness) - tile_obj = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) - iris.x.all_reduce_one_shot(tile_obj, src_view, dst_view, locks, ctx) - elif VARIANT == "two_shot": - # two_shot with work distribution (data in memory, locks signal readiness) - tile_obj = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) - iris.x.all_reduce_two_shot(tile_obj, src_view, dst_view, locks, cur_rank, world_size, ctx) - # elif VARIANT == "ring": - # # Store locally first and signal ready - # tile_obj = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) - # iris.x.all_reduce_ring(tile_obj, src_view, dst_view, locks, ctx) + if VARIANT == "one_shot": + iris.x.all_reduce_one_shot(tile_obj, src_view, dst_view, locks, ctx) + elif VARIANT == "two_shot": + iris.x.all_reduce_two_shot(tile_obj, src_view, dst_view, locks, ctx) def matmul_all_reduce_preamble( @@ -307,14 +294,16 @@ def matmul_all_reduce( if needs_prepare: workspace = matmul_all_reduce_preamble(shmem, C, A, B, config=config, workspace=workspace) - # Get heap bases for RMA - heap_bases = shmem.get_heap_bases() + # Get device context for RMA + device_context = shmem.get_device_context() # Launch kernel num_pid_m = (M + config.block_size_m - 1) // config.block_size_m num_pid_n = (N + config.block_size_n - 1) // config.block_size_n grid = (num_pid_m * num_pid_n,) + even_k = K % config.block_size_k == 0 + _fused_matmul_all_reduce_kernel[grid]( A, B, @@ -330,12 +319,13 @@ def matmul_all_reduce( stride_bn, stride_cm, stride_cn, - heap_bases, + device_context, rank, world_size, config.block_size_m, config.block_size_n, config.block_size_k, + even_k, config.all_reduce_variant, ) diff --git a/iris/ops/matmul_reduce_scatter.py b/iris/ops/matmul_reduce_scatter.py index e2b2fd7b1..7f74ce749 100644 --- a/iris/ops/matmul_reduce_scatter.py +++ b/iris/ops/matmul_reduce_scatter.py @@ -13,6 +13,8 @@ import triton import triton.language as tl +from tritonblas.kernels.stages import GemmContext, make_tensor_view, Tile + from .config import FusedConfig from .workspace import FusedWorkspace import iris @@ -26,21 +28,22 @@ def _fused_matmul_reduce_scatter_kernel( C, aux_buffer, locks, - M: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - stride_am: tl.constexpr, - stride_ak: tl.constexpr, - stride_bk: tl.constexpr, - stride_bn: tl.constexpr, - stride_cm: tl.constexpr, - stride_cn: tl.constexpr, - heap_bases: tl.tensor, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + context_tensor: tl.tensor, cur_rank: tl.constexpr, world_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + EVEN_K: tl.constexpr, ): """ Fused GEMM + Reduce-Scatter kernel. @@ -60,50 +63,56 @@ def _fused_matmul_reduce_scatter_kernel( stride_am, stride_ak: Strides for A tensor stride_bk, stride_bn: Strides for B tensor stride_cm, stride_cn: Strides for C tensor - heap_bases: Heap base pointers for all ranks + context_tensor: Device context tensor for RMA operations cur_rank: Current rank world_size: Total number of ranks BLOCK_SIZE_M: Block size for M dimension BLOCK_SIZE_N: Block size for N dimension BLOCK_SIZE_K: Block size for K dimension + EVEN_K: Whether K is evenly divisible by BLOCK_SIZE_K """ pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - - rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - - A_ptr = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - a = tl.load(A_ptr, mask=(rm[:, None] < M) & (rk[None, :] < K), other=0.0) - - B_ptr = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - b = tl.load(B_ptr, mask=(rk[:, None] < K) & (rn[None, :] < N), other=0.0) + num_tiles_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_tiles_n + pid_n = pid % num_tiles_n + + # ═══════════════════════════════════════════════════════════════════════ + # GEMM using tritonblas stages + # ═══════════════════════════════════════════════════════════════════════ + tensorA = make_tensor_view(A, M, K, stride_am, stride_ak) + tensorB = make_tensor_view(B, K, N, stride_bk, stride_bn) + gemm_ctx = GemmContext( + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + num_sms=1, + even_k=EVEN_K, + ) + out_tile = Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N) + acc = gemm_ctx.reduce_axis(tensorA, tensorB, out_tile) - acc += tl.dot(a, b) + # Get row and column indices from tile + rm, rn = out_tile.indices() - c = acc.to(C.dtype.element_ty) + c = acc.to(C.type.element_ty) + # Store GEMM result to aux_buffer temp_ptr = aux_buffer + rm[:, None] * stride_cm + rn[None, :] * stride_cn tl.store(temp_ptr, c, mask=(rm[:, None] < M) & (rn[None, :] < N), cache_modifier=".wt") tl.debug_barrier() - tile_id = pid_m * num_pid_n + pid_n + # Signal tile is ready + tile_id = pid_m * num_tiles_n + pid_n lock_ptr = locks + tile_id tl.atomic_xchg(lock_ptr, 1, sem="release", scope="gpu") + # Create tile object and context tile_obj = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, c) - src_view = iris.x.TensorView(aux_buffer, M, N, stride_cm, stride_cn) - dst_view = iris.x.TensorView(C, M, N, stride_cm, stride_cn) - ctx = iris.x.DeviceContext(cur_rank, world_size, heap_bases) + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size) + + # Create tensor views for source and destination + src_view = iris.x.make_tensor_view(aux_buffer, M, N, stride_cm, stride_cn) + dst_view = iris.x.make_tensor_view(C, M, N, stride_cm, stride_cn) iris.x.reduce_scatter(tile_obj, src_view, dst_view, locks, ctx) @@ -217,6 +226,8 @@ def matmul_reduce_scatter( num_pid_n = (N + config.block_size_n - 1) // config.block_size_n grid = (num_pid_m * num_pid_n,) + even_k = K % config.block_size_k == 0 + _fused_matmul_reduce_scatter_kernel[grid]( A, B, @@ -232,12 +243,13 @@ def matmul_reduce_scatter( B.stride(1), C.stride(0), C.stride(1), - shmem.get_heap_bases(), + shmem.get_device_context(), rank, world_size, config.block_size_m, config.block_size_n, config.block_size_k, + even_k, ) if not async_op: diff --git a/iris/symmetric_heap.py b/iris/symmetric_heap.py index c545890f4..eef391976 100644 --- a/iris/symmetric_heap.py +++ b/iris/symmetric_heap.py @@ -10,8 +10,9 @@ import numpy as np import torch +import os -from iris.allocators import TorchAllocator +from iris.allocators import TorchAllocator, VMemAllocator from iris.fd_passing import setup_fd_infrastructure from iris._distributed_helpers import distributed_allgather @@ -22,9 +23,18 @@ class SymmetricHeap: Manages distributed memory with symmetric addressing across ranks, handling all allocator coordination and memory sharing internally. + + Supports multiple allocator backends: 'torch' (default) and 'vmem'. """ - def __init__(self, heap_size: int, device_id: int, cur_rank: int, num_ranks: int): + def __init__( + self, + heap_size: int, + device_id: int, + cur_rank: int, + num_ranks: int, + allocator_type: str = "torch", + ): """ Initialize symmetric heap. @@ -33,42 +43,69 @@ def __init__(self, heap_size: int, device_id: int, cur_rank: int, num_ranks: int device_id: GPU device ID cur_rank: Current process rank num_ranks: Total number of ranks + allocator_type: Type of allocator ("torch" or "vmem"); default "torch" + + Raises: + ValueError: If allocator_type is not supported """ self.heap_size = heap_size self.device_id = device_id self.cur_rank = cur_rank self.num_ranks = num_ranks + allocator_type = os.environ.get("IRIS_ALLOCATOR", allocator_type).lower() - # Create allocator - self.allocator = TorchAllocator(heap_size, device_id, cur_rank, num_ranks) + if allocator_type == "torch": + self.allocator = TorchAllocator(heap_size, device_id, cur_rank, num_ranks) + elif allocator_type == "vmem": + self.allocator = VMemAllocator(heap_size, device_id, cur_rank, num_ranks) + else: + raise ValueError(f"Unknown allocator type: {allocator_type}. Supported: 'torch', 'vmem'") - # All-gather heap bases for pointer translation - heap_base = self.allocator.get_base_address() - local_base_arr = np.array([heap_base], dtype=np.uint64) - all_bases_arr = distributed_allgather(local_base_arr).reshape(num_ranks).astype(np.uint64) - all_bases = {rank: int(all_bases_arr[rank]) for rank in range(num_ranks)} + self.fd_conns = setup_fd_infrastructure(cur_rank, num_ranks) + device = self.allocator.get_device() - # Setup FD passing infrastructure - fd_conns = setup_fd_infrastructure(cur_rank, num_ranks) + # Use int64 instead of uint64 for gloo backend compatibility + # Create from numpy array to avoid kernel issue (torch.zeros on small tensors triggers problematic kernel) + heap_bases_array = np.zeros(self.num_ranks, dtype=np.int64) + # Create on CPU first, then move to device to avoid FFM ioctl issue + from iris.util import is_simulation_env - # Establish access to peer memory - self.allocator.establish_peer_access(all_bases, fd_conns) + if is_simulation_env(): + self.heap_bases = torch.tensor(heap_bases_array, device="cpu", dtype=torch.int64) + self.heap_bases = self.heap_bases.to(device) + else: + self.heap_bases = torch.tensor(heap_bases_array, device=device, dtype=torch.int64) - # Get final heap bases - self.heap_bases = self.allocator.get_heap_bases() + self.refresh_peer_access() - def allocate(self, num_elements: int, dtype: torch.dtype) -> torch.Tensor: + def allocate(self, num_elements: int, dtype: torch.dtype, alignment: int = 1024) -> torch.Tensor: """ Allocate a tensor on the symmetric heap. + Always allocates at least the allocator's minimum allocation size so that + even zero-element requests get a buffer on the heap; for num_elements==0 + we return a zero-length slice of that buffer so the tensor is still on heap. + Args: num_elements: Number of elements to allocate dtype: PyTorch data type + alignment: Alignment requirement in bytes (default: 1024) Returns: - Allocated tensor on the symmetric heap + Allocated tensor on the symmetric heap (shape (num_elements,) or (0,) for empty) + + Note: + This should be called collectively across all ranks to maintain + symmetric heap consistency. After allocation, peer access is refreshed. """ - return self.allocator.allocate(num_elements, dtype) + min_bytes = self.allocator.get_minimum_allocation_size() + element_size = torch.tensor([], dtype=dtype).element_size() + min_elements = max(1, (min_bytes + element_size - 1) // element_size) + actual_elements = max(num_elements, min_elements) + tensor = self.allocator.allocate(actual_elements, dtype, alignment) + tensor = tensor[:num_elements] + self.refresh_peer_access() + return tensor def get_device(self) -> torch.device: """Get the torch device for this heap.""" @@ -84,9 +121,178 @@ def on_symmetric_heap(self, tensor: torch.Tensor) -> bool: Returns: True if tensor is on the symmetric heap, False otherwise """ - # Delegate to allocator to check if tensor is in heap return self.allocator.owns_tensor(tensor) + def is_symmetric(self, tensor: torch.Tensor) -> bool: + """ + Check if a tensor is allocated on the symmetric heap. + + This method provides a public API to check whether a tensor resides in the + symmetric heap, making it accessible for RMA operations across ranks. + + Args: + tensor: PyTorch tensor to check + + Returns: + True if tensor is on the symmetric heap, False otherwise + + Example: + >>> ctx = iris.iris(heap_size=2**30) + >>> symmetric_tensor = ctx.zeros(1000, dtype=torch.float32) + >>> external_tensor = torch.zeros(1000, dtype=torch.float32, device='cuda') + >>> ctx.heap.is_symmetric(symmetric_tensor) # True + >>> ctx.heap.is_symmetric(external_tensor) # False + """ + return self.on_symmetric_heap(tensor) + def get_heap_bases(self) -> torch.Tensor: """Get heap base addresses for all ranks as a tensor.""" return self.heap_bases + + def refresh_peer_access(self): + """ + Refresh peer DMA-BUF imports using segmented export/import. + Collective: all ranks must call together. Do not cache heap_bases. + """ + import torch.distributed as dist + from iris.fd_passing import send_fd, recv_fd + from iris.hip import ( + export_dmabuf_handle, + mem_import_from_shareable_handle, + mem_map, + mem_set_access, + mem_address_reserve, + hipMemAccessDesc, + hipMemLocationTypeDevice, + hipMemAccessFlagsProtReadWrite, + ) + + if dist.is_initialized(): + dist.barrier() + + my_base = self.allocator.get_base_address() + # Use int64 instead of uint64 to avoid gloo issues with all_gather_object + local_base_arr = np.array([my_base], dtype=np.int64) + all_bases_arr = distributed_allgather(local_base_arr).reshape(self.num_ranks).astype(np.int64) + self.heap_bases[self.cur_rank] = int(all_bases_arr[self.cur_rank]) + + if self.num_ranks == 1 or self.fd_conns is None: + return + + if not hasattr(self.allocator, "get_allocation_segments"): + if hasattr(self.allocator, "establish_peer_access"): + # In simulation, all ranks share the same device, so skip peer access setup + from iris.util import is_simulation_env + + if is_simulation_env(): + # Just set heap_bases directly from all_bases_arr + for r in range(self.num_ranks): + self.heap_bases[r] = int(all_bases_arr[r]) + else: + all_bases = {r: int(all_bases_arr[r]) for r in range(self.num_ranks)} + self.allocator.establish_peer_access(all_bases, self.fd_conns) + for r in range(self.num_ranks): + self.heap_bases[r] = int(self.allocator.heap_bases_array[r]) + return + + my_segments = self.allocator.get_allocation_segments() + my_exported_fds = [] + for offset, size, va in my_segments: + dmabuf_fd, export_base, export_size = export_dmabuf_handle(va, size) + my_exported_fds.append((dmabuf_fd, export_size, offset)) + + access_desc = hipMemAccessDesc() + access_desc.location.type = hipMemLocationTypeDevice + access_desc.location.id = self.device_id + access_desc.flags = hipMemAccessFlagsProtReadWrite + + for peer, sock in self.fd_conns.items(): + if peer == self.cur_rank: + continue + + if not hasattr(self, "_peer_va_ranges"): + self._peer_va_ranges = {} + + if peer not in self._peer_va_ranges: + peer_va_base = mem_address_reserve(self.heap_size, self.allocator.granularity, 0) + self._peer_va_ranges[peer] = peer_va_base + else: + peer_va_base = self._peer_va_ranges[peer] + + peer_fds = [] + for seg_idx, (my_fd, my_size, my_offset) in enumerate(my_exported_fds): + # Exchange FDs (higher rank sends first to avoid deadlock) + if self.cur_rank > peer: + send_fd(sock, my_fd) + peer_fd, _ = recv_fd(sock) + else: + peer_fd, _ = recv_fd(sock) + send_fd(sock, my_fd) + + peer_fds.append((peer_fd, my_size, my_offset)) + + if not hasattr(self, "_peer_cumulative_sizes"): + self._peer_cumulative_sizes = {} + cumulative_size = self._peer_cumulative_sizes.get(peer, 0) + + if not hasattr(self, "_peer_imported_segments"): + self._peer_imported_segments = {} + if peer not in self._peer_imported_segments: + self._peer_imported_segments[peer] = set() + + for peer_fd, segment_size, offset in peer_fds: + segment_key = (offset, segment_size) + if segment_key in self._peer_imported_segments[peer]: + import os + + os.close(peer_fd) + continue + + imported_handle = mem_import_from_shareable_handle(peer_fd) + import os + + os.close(peer_fd) + + peer_va = peer_va_base + offset + mem_map(peer_va, segment_size, 0, imported_handle) + self._peer_imported_segments[peer].add(segment_key) + + new_cumulative = offset + segment_size + if new_cumulative > cumulative_size: + cumulative_size = new_cumulative + mem_set_access(peer_va_base, cumulative_size, access_desc) + + self._peer_cumulative_sizes[peer] = cumulative_size + self.heap_bases[peer] = peer_va_base + + for fd, _, _ in my_exported_fds: + import os + + os.close(fd) + + if dist.is_initialized(): + dist.barrier() + + def as_symmetric(self, external_tensor: torch.Tensor) -> torch.Tensor: + """ + Place an external PyTorch tensor on the symmetric heap. + + With the torch allocator: allocates on the heap and copies the data; + the returned tensor is independent of the input. With the vmem + allocator: imports the memory so both tensors share the same storage. + + Args: + external_tensor: External PyTorch tensor (must be CUDA, contiguous) + + Returns: + Tensor on the symmetric heap (same shape/dtype; copy or shared per allocator) + + Raises: + RuntimeError: If allocator doesn't support imports or import fails + """ + if not hasattr(self.allocator, "import_external_tensor"): + raise RuntimeError(f"{type(self.allocator).__name__} does not support as_symmetric().") + + imported = self.allocator.import_external_tensor(external_tensor) + self.refresh_peer_access() + return imported diff --git a/iris/tensor_creation.py b/iris/tensor_creation.py new file mode 100644 index 000000000..3cf3f0f82 --- /dev/null +++ b/iris/tensor_creation.py @@ -0,0 +1,881 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Tensor creation abstraction for symmetric-heap tensors. + +Provides shared helpers (parse_size, device validation, allocation wiring, +output-tensor validation, layout and memory-format handling) and the core +creation logic for ``zeros``, ``ones``, ``full``, and ``zeros_like``. + +Both the Triton :class:`~iris.iris.Iris` backend and the Gluon +:class:`~iris.experimental.iris_gluon.IrisGluon` backend delegate to these +functions so that the logic lives in exactly one place. +""" + +import math + +import torch + +from .logging import logger + + +# --------------------------------------------------------------------------- +# Low-level helpers +# --------------------------------------------------------------------------- + + +def allocate(heap, num_elements: int, dtype: torch.dtype) -> torch.Tensor: + """Allocate a flat tensor on *heap*. + + Args: + heap: Symmetric heap exposing ``allocate(num_elements, dtype)``. + num_elements (int): Number of elements to allocate. + dtype (:class:`torch.dtype`): Element type. + + Returns: + :class:`torch.Tensor`: Flat tensor on the symmetric heap. + """ + logger.debug(f"allocate: num_elements = {num_elements}, dtype = {dtype}") + return heap.allocate(num_elements, dtype) + + +def parse_size(size): + """Parse a *size* argument and return ``(size_tuple, num_elements)``. + + Handles the common calling conventions:: + + zeros(2, 3) # *size = (2, 3) + zeros((2, 3)) # *size = ((2, 3),) + zeros([2, 3]) # *size = ([2, 3],) + zeros(((2, 3),)) # nested – flattened once + """ + # Flatten one level of wrapping tuple/list + while len(size) == 1 and isinstance(size[0], (tuple, list)): + size = size[0] + num_elements = math.prod(size) + return size, num_elements + + +def is_valid_device(device, iris_device) -> bool: + """Return *True* when *device* is compatible with *iris_device*. + + Args: + device: Requested device (``str``, :class:`torch.device`, or ``None``). + ``None`` is treated as "use the Iris default" and is always valid. + iris_device (:class:`torch.device`): Device of the Iris symmetric heap. + """ + if device is None: + return True # None means use default device + + requested_device = torch.device(device) if isinstance(device, str) else device + + # Both must be CUDA devices; index must match (or requested has no index) + if requested_device.type == "cuda" and iris_device.type == "cuda": + if requested_device.index is None: + return True + return requested_device.index == iris_device.index + + # Non-CUDA devices are not supported + return False + + +def throw_if_invalid_device(device, iris_device): + """Raise :exc:`RuntimeError` when *device* is incompatible with *iris_device*. + + Args: + device: Requested device (``str``, :class:`torch.device`, or ``None``). + iris_device (:class:`torch.device`): Device of the Iris symmetric heap. + + Raises: + RuntimeError: If the device does not match the Iris instance device. + """ + if not is_valid_device(device, iris_device): + raise RuntimeError( + f"Device mismatch: requested device {device} but Iris instance is on device {iris_device}. " + f"Iris only supports tensors on its own device." + ) + + +def throw_if_invalid_output_tensor(heap, tensor: torch.Tensor, num_elements: int, dtype: torch.dtype): + """Validate that *tensor* is suitable as an output buffer. + + Checks element count, dtype, and symmetric-heap membership in that order. + + Args: + heap: Symmetric heap instance exposing ``is_symmetric(tensor)``. + tensor (:class:`torch.Tensor`): Candidate output tensor. + num_elements (int): Required number of elements. + dtype (:class:`torch.dtype`): Required dtype. + + Raises: + RuntimeError: On any mismatch. + """ + if tensor.numel() != num_elements: + raise RuntimeError(f"The output tensor has {tensor.numel()} elements, but {num_elements} are required") + if tensor.dtype != dtype: + raise RuntimeError(f"The output tensor has dtype {tensor.dtype}, but {dtype} is required") + if not heap.is_symmetric(tensor): + raise RuntimeError("The output tensor is not on the symmetric heap") + + +def apply_layout(tensor: torch.Tensor, layout: torch.layout) -> torch.Tensor: + """Return *tensor* after applying *layout*. + + Only :data:`torch.strided` is currently supported. + + Raises: + ValueError: For unsupported layouts. + """ + if layout == torch.strided: + return tensor + raise ValueError(f"Layout {layout} not supported. Only torch.strided is currently supported.") + + +def _normalize_steps(steps) -> int: + """Normalise *steps* to a plain ``int``. + + Accepts an integer, a single-element tuple/list (possibly nested once), + or a multi-element sequence (where the total number of elements is used). + """ + if isinstance(steps, (tuple, list)): + if len(steps) == 1: + inner = steps[0] + if isinstance(inner, (tuple, list)): + inner = inner[0] + return int(inner) + else: + _, num_elements = parse_size(steps) + return num_elements + return int(steps) + + +# --------------------------------------------------------------------------- +# Memory-format helper (used by zeros_like) +# --------------------------------------------------------------------------- + + +def _create_tensor_with_strides(heap, original_tensor: torch.Tensor, size: tuple, strides: tuple): + """Allocate a symmetric-heap tensor with the given *size* and *strides*. + + Creates a temporary tensor to establish the desired layout, copies data + from *original_tensor* (with any necessary permutation), then returns a + view of a freshly heap-allocated buffer with the requested strides. + + Args: + heap: Symmetric heap exposing ``allocate(num_elements, dtype)``. + original_tensor (:class:`torch.Tensor`): Source tensor (contiguous). + size (tuple): Target shape. + strides (tuple): Target strides. + """ + temp_tensor = torch.empty_strided(size, strides, dtype=original_tensor.dtype, device=original_tensor.device) + + if size != original_tensor.shape: + if len(size) == 4: + N, H, W, C = size[0], size[1], size[2], size[3] + expected_strides = (H * W * C, 1, W * C, C) + if strides == expected_strides: + permuted = original_tensor.permute(0, 2, 3, 1) + else: + try: + permuted = original_tensor.reshape(size) + except Exception: + raise ValueError( + "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." + ) + elif len(size) == 5: + N, D, H, W, C = size[0], size[1], size[2], size[3], size[4] + expected_strides = (D * H * W * C, 1, H * W * C, W * C, C) + if strides == expected_strides: + permuted = original_tensor.permute(0, 2, 3, 4, 1) + else: + try: + permuted = original_tensor.reshape(size) + except Exception: + raise ValueError( + "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." + ) + else: + try: + permuted = original_tensor.reshape(size) + except Exception: + raise ValueError( + "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." + ) + else: + permuted = original_tensor + + temp_tensor.copy_(permuted) + + num_elements = math.prod(size) + heap_tensor = allocate(heap, num_elements, original_tensor.dtype) + heap_tensor = heap_tensor.reshape(size) + heap_tensor.copy_(temp_tensor) + del temp_tensor + + return torch.as_strided(heap_tensor, size, strides) + + +def apply_memory_format( + heap, + tensor: torch.Tensor, + size: tuple, + memory_format: torch.memory_format, + input_tensor: torch.Tensor = None, +) -> torch.Tensor: + """Apply *memory_format* to *tensor*, keeping it on the symmetric heap. + + Args: + heap: Symmetric heap exposing ``allocate(num_elements, dtype)`` (used + when a new stride layout requires a copy). + tensor (:class:`torch.Tensor`): Tensor to reformat. + size (tuple): Shape of *tensor*. + memory_format (:class:`torch.memory_format`): Desired memory format. + input_tensor (:class:`torch.Tensor`, optional): Reference tensor used + to detect the format to preserve when + *memory_format* is :data:`torch.preserve_format`. + + Returns: + :class:`torch.Tensor`: Tensor in the requested memory format. + """ + if memory_format == torch.contiguous_format: + return tensor + + if memory_format == torch.channels_last and len(size) == 4: + N, C, H, W = size[0], size[1], size[2], size[3] + return _create_tensor_with_strides(heap, tensor, size, (C * H * W, 1, C * W, C)) + + if memory_format == torch.channels_last_3d and len(size) == 5: + N, C, D, H, W = size[0], size[1], size[2], size[3], size[4] + return _create_tensor_with_strides(heap, tensor, size, (C * D * H * W, 1, C * D * W, C * W, C)) + + if memory_format == torch.preserve_format: + if input_tensor is not None: + input_strides = input_tensor.stride() + if len(size) == 4 and len(input_strides) == 4 and input_strides[1] == 1: + input_shape = input_tensor.shape + if len(input_shape) == 4: + return _create_tensor_with_strides(heap, tensor, input_shape, input_strides) + elif len(size) == 5 and len(input_strides) == 5 and input_strides[1] == 1: + input_shape = input_tensor.shape + if len(input_shape) == 5: + return _create_tensor_with_strides(heap, tensor, input_shape, input_strides) + return tensor + + # Unsupported format or dimension combination – fall back to contiguous + return tensor + + +# --------------------------------------------------------------------------- +# Tensor creation functions +# --------------------------------------------------------------------------- + + +def zeros(heap, iris_device, size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """Allocate a zero-filled tensor on *heap*. + + Args: + heap: Symmetric heap (``allocate`` / ``is_symmetric``). + iris_device (:class:`torch.device`): Device of the heap. + size (tuple): Shape of the tensor. + + Keyword Args: + out (:class:`torch.Tensor`, optional): Pre-allocated output tensor. + dtype (:class:`torch.dtype`, optional): Defaults to + :func:`torch.get_default_dtype`. + layout (:class:`torch.layout`): Default :data:`torch.strided`. + device: Must be compatible with *iris_device* or ``None``. + requires_grad (bool): Default ``False``. + + Returns: + :class:`torch.Tensor`: Zero tensor on the symmetric heap. + """ + logger.debug(f"zeros: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = iris_device + throw_if_invalid_device(device, iris_device) + size, num_elements = parse_size(size) + + # In simulation, avoid GPU kernel operations which trigger HIP errors + from iris.util import is_simulation_env + + if is_simulation_env(): + # Allocate and leave as-is (memory is already zero-initialized) + if out is not None: + throw_if_invalid_output_tensor(heap, out, num_elements, dtype) + # Don't call zero_() - memory is already zeroed, avoid GPU kernel + tensor = out.view(size) + else: + tensor = allocate(heap, num_elements, dtype) + # Don't call zero_() - memory is already zeroed, avoid GPU kernel + tensor = tensor.reshape(size) + else: + if out is not None: + throw_if_invalid_output_tensor(heap, out, num_elements, dtype) + out.zero_() + tensor = out.view(size) + else: + tensor = allocate(heap, num_elements, dtype) + tensor.zero_() + tensor = tensor.reshape(size) + + tensor = apply_layout(tensor, layout) + if requires_grad: + tensor.requires_grad_() + return tensor + + +def ones(heap, iris_device, size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """Allocate a ones-filled tensor on *heap*. + + Args: + heap: Symmetric heap (``allocate`` / ``is_symmetric``). + iris_device (:class:`torch.device`): Device of the heap. + size (tuple): Shape of the tensor. + + Keyword Args: + out (:class:`torch.Tensor`, optional): Pre-allocated output tensor. + dtype (:class:`torch.dtype`, optional): Defaults to + :func:`torch.get_default_dtype`. + layout (:class:`torch.layout`): Default :data:`torch.strided`. + device: Must be compatible with *iris_device* or ``None``. + requires_grad (bool): Default ``False``. + + Returns: + :class:`torch.Tensor`: Ones tensor on the symmetric heap. + """ + logger.debug(f"ones: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = iris_device + throw_if_invalid_device(device, iris_device) + size, num_elements = parse_size(size) + + if out is not None: + throw_if_invalid_output_tensor(heap, out, num_elements, dtype) + out.fill_(1) + tensor = out.view(size) + else: + tensor = allocate(heap, num_elements, dtype) + tensor.fill_(1) + tensor = tensor.reshape(size) + + tensor = apply_layout(tensor, layout) + if requires_grad: + tensor.requires_grad_() + return tensor + + +def full( + heap, + iris_device, + size, + fill_value, + *, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, +): + """Allocate a tensor filled with *fill_value* on *heap*. + + Args: + heap: Symmetric heap (``allocate`` / ``is_symmetric``). + iris_device (:class:`torch.device`): Device of the heap. + size (tuple): Shape of the tensor. + fill_value (scalar): Value to fill with. + + Keyword Args: + out (:class:`torch.Tensor`, optional): Pre-allocated output tensor. + dtype (:class:`torch.dtype`, optional): Inferred from *fill_value* + when ``None``. + layout (:class:`torch.layout`): Default :data:`torch.strided`. + device: Must be compatible with *iris_device* or ``None``. + requires_grad (bool): Default ``False``. + + Returns: + :class:`torch.Tensor`: Filled tensor on the symmetric heap. + """ + logger.debug( + f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + ) + if dtype is None: + if isinstance(fill_value, float): + dtype = torch.get_default_dtype() + elif isinstance(fill_value, int): + dtype = torch.int64 + else: + dtype = torch.get_default_dtype() + if device is None: + device = iris_device + throw_if_invalid_device(device, iris_device) + size, num_elements = parse_size(size) + + if out is not None: + throw_if_invalid_output_tensor(heap, out, num_elements, dtype) + out.fill_(fill_value) + tensor = out.view(size) + else: + tensor = allocate(heap, num_elements, dtype) + tensor.fill_(fill_value) + tensor = tensor.reshape(size) + + tensor = apply_layout(tensor, layout) + if requires_grad: + tensor.requires_grad_() + return tensor + + +def zeros_like( + heap, + iris_device, + input: torch.Tensor, + *, + dtype=None, + layout=None, + device=None, + requires_grad=False, + memory_format=torch.preserve_format, +): + """Allocate a zero-filled tensor with the same shape as *input* on *heap*. + + Args: + heap: Symmetric heap (``allocate`` / ``is_symmetric``). + iris_device (:class:`torch.device`): Device of the heap. + input (:class:`torch.Tensor`): Reference tensor. + + Keyword Args: + dtype (:class:`torch.dtype`, optional): Defaults to ``input.dtype``. + layout (:class:`torch.layout`, optional): Defaults to ``input.layout``. + device: Defaults to ``input.device``; must be compatible with + *iris_device*. + requires_grad (bool): Default ``False``. + memory_format (:class:`torch.memory_format`): Default + :data:`torch.preserve_format`. + + Returns: + :class:`torch.Tensor`: Zero tensor on the symmetric heap. + """ + logger.debug( + f"zeros_like: input_shape = {input.shape}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + ) + if dtype is None: + dtype = input.dtype + if layout is None: + layout = input.layout + if device is None: + device = input.device + throw_if_invalid_device(device, iris_device) + + size = input.size() + num_elements = input.numel() + + new_tensor = allocate(heap, num_elements, dtype) + new_tensor.zero_() + new_tensor = new_tensor.reshape(size) + + new_tensor = apply_memory_format(heap, new_tensor, size, memory_format, input) + new_tensor = apply_layout(new_tensor, layout) + + if requires_grad: + new_tensor.requires_grad_() + return new_tensor + + +def empty( + heap, + iris_device, + size, + *, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + memory_format=torch.contiguous_format, +): + """Allocate an uninitialised tensor on *heap*. + + Args: + heap: Symmetric heap (``allocate`` / ``is_symmetric``). + iris_device (:class:`torch.device`): Device of the heap. + size (tuple): Shape of the tensor. + + Keyword Args: + out (:class:`torch.Tensor`, optional): Pre-allocated output tensor. + dtype (:class:`torch.dtype`, optional): Defaults to + :func:`torch.get_default_dtype`. + layout (:class:`torch.layout`): Default :data:`torch.strided`. + device: Must be compatible with *iris_device* or ``None``. + requires_grad (bool): Default ``False``. + memory_format (:class:`torch.memory_format`): Default + :data:`torch.contiguous_format`. + + Returns: + :class:`torch.Tensor`: Uninitialised tensor on the symmetric heap. + """ + logger.debug(f"empty: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = iris_device + throw_if_invalid_device(device, iris_device) + size, num_elements = parse_size(size) + + if out is not None: + throw_if_invalid_output_tensor(heap, out, num_elements, dtype) + tensor = out.view(size) + else: + tensor = allocate(heap, num_elements, dtype) + tensor = tensor.reshape(size) + + tensor = apply_memory_format(heap, tensor, size, memory_format) + tensor = apply_layout(tensor, layout) + if requires_grad: + tensor.requires_grad_() + return tensor + + +def uniform(heap, iris_device, size, low=0.0, high=1.0, dtype=torch.float): + """Allocate a tensor filled with uniform random values on *heap*. + + Args: + heap: Symmetric heap (``allocate`` / ``is_symmetric``). + iris_device (:class:`torch.device`): Device of the heap. + size: Shape of the tensor. + low (float): Lower bound of the distribution. Default ``0.0``. + high (float): Upper bound of the distribution. Default ``1.0``. + dtype (:class:`torch.dtype`): Default :data:`torch.float`. + + Returns: + :class:`torch.Tensor`: Tensor on the symmetric heap. + """ + logger.debug(f"uniform: size = {size}, low = {low}, high = {high}, dtype = {dtype}") + size, num_elements = parse_size(size) + tensor = allocate(heap, num_elements, dtype) + tensor.uniform_(low, high) + return tensor.reshape(size) + + +def randn( + heap, + iris_device, + size, + *, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, +): + """Allocate a tensor filled with standard-normal random values on *heap*. + + Args: + heap: Symmetric heap (``allocate`` / ``is_symmetric``). + iris_device (:class:`torch.device`): Device of the heap. + size (tuple): Shape of the tensor. + + Keyword Args: + generator (:class:`torch.Generator`, optional): RNG. + out (:class:`torch.Tensor`, optional): Pre-allocated output tensor. + dtype (:class:`torch.dtype`, optional): Defaults to + :func:`torch.get_default_dtype`. + layout (:class:`torch.layout`): Default :data:`torch.strided`. + device: Must be compatible with *iris_device* or ``None``. + requires_grad (bool): Default ``False``. + + Returns: + :class:`torch.Tensor`: Tensor on the symmetric heap. + """ + logger.debug(f"randn: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = iris_device + throw_if_invalid_device(device, iris_device) + size, num_elements = parse_size(size) + + # In simulation, avoid GPU kernel operations which trigger HIP errors + # Create data on CPU and copy to GPU to avoid kernel execution + from iris.util import is_simulation_env + + if is_simulation_env(): + if out is not None: + throw_if_invalid_output_tensor(heap, out, num_elements, dtype) + # Create on CPU and copy to avoid GPU kernels + cpu_data = torch.ones(num_elements, dtype=dtype, device="cpu") + out.copy_(cpu_data) + tensor = out.view(size) + else: + tensor = allocate(heap, num_elements, dtype) + # Create on CPU and copy to avoid GPU kernels + cpu_data = torch.ones(num_elements, dtype=dtype, device="cpu") + tensor.copy_(cpu_data) + tensor = tensor.reshape(size) + else: + if out is not None: + throw_if_invalid_output_tensor(heap, out, num_elements, dtype) + random_data = torch.randn(num_elements, generator=generator, dtype=dtype, device=device, layout=layout) + out.copy_(random_data) + tensor = out.view(size) + else: + tensor = allocate(heap, num_elements, dtype) + random_data = torch.randn(num_elements, generator=generator, dtype=dtype, device=device, layout=layout) + tensor.copy_(random_data) + tensor = tensor.reshape(size) + + tensor = apply_layout(tensor, layout) + if requires_grad: + tensor.requires_grad_() + return tensor + + +def rand( + heap, + iris_device, + size, + *, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, +): + """Allocate a tensor filled with uniform random values in [0, 1) on *heap*. + + Args: + heap: Symmetric heap (``allocate`` / ``is_symmetric``). + iris_device (:class:`torch.device`): Device of the heap. + size (tuple): Shape of the tensor. + + Keyword Args: + generator (:class:`torch.Generator`, optional): RNG. + out (:class:`torch.Tensor`, optional): Pre-allocated output tensor. + dtype (:class:`torch.dtype`, optional): Defaults to + :func:`torch.get_default_dtype`. + layout (:class:`torch.layout`): Default :data:`torch.strided`. + device: Must be compatible with *iris_device* or ``None``. + requires_grad (bool): Default ``False``. + + Returns: + :class:`torch.Tensor`: Tensor on the symmetric heap. + """ + logger.debug(f"rand: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = iris_device + throw_if_invalid_device(device, iris_device) + size, num_elements = parse_size(size) + + if out is not None: + throw_if_invalid_output_tensor(heap, out, num_elements, dtype) + tensor = out.view(size) + else: + tensor = allocate(heap, num_elements, dtype) + tensor = tensor.reshape(size) + + if generator is not None: + torch.rand(size, generator=generator, out=tensor, dtype=dtype, device=device) + else: + torch.rand(size, out=tensor, dtype=dtype, device=device) + + tensor = apply_layout(tensor, layout) + if requires_grad: + tensor.requires_grad_() + return tensor + + +def randint( + heap, + iris_device, + low, + high, + size, + *, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, +): + """Allocate a tensor filled with random integers in [*low*, *high*) on *heap*. + + Args: + heap: Symmetric heap (``allocate`` / ``is_symmetric``). + iris_device (:class:`torch.device`): Device of the heap. + low (int): Lower bound (inclusive). + high (int): Upper bound (exclusive). + size (tuple): Shape of the tensor. + + Keyword Args: + generator (:class:`torch.Generator`, optional): RNG. + out (:class:`torch.Tensor`, optional): Pre-allocated output tensor. + dtype (:class:`torch.dtype`, optional): Defaults to :data:`torch.int64`. + layout (:class:`torch.layout`): Default :data:`torch.strided`. + device: Must be compatible with *iris_device* or ``None``. + requires_grad (bool): Default ``False``. + + Returns: + :class:`torch.Tensor`: Tensor on the symmetric heap. + """ + logger.debug(f"randint: low = {low}, high = {high}, size = {size}, dtype = {dtype}, device = {device}") + if dtype is None: + dtype = torch.int64 + if device is None: + device = iris_device + throw_if_invalid_device(device, iris_device) + size, num_elements = parse_size(size) + + if out is not None: + throw_if_invalid_output_tensor(heap, out, num_elements, dtype) + tensor = out.view(size) + else: + tensor = allocate(heap, num_elements, dtype) + tensor = tensor.reshape(size) + + if generator is not None: + torch.randint(low, high, size, generator=generator, out=tensor, dtype=dtype, device=device) + else: + torch.randint(low, high, size, out=tensor, dtype=dtype, device=device) + + tensor = apply_layout(tensor, layout) + if requires_grad: + tensor.requires_grad_() + return tensor + + +def arange( + heap, + iris_device, + start, + end, + step, + *, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, +): + """Allocate a 1-D tensor with evenly spaced values on *heap*. + + Args: + heap: Symmetric heap (``allocate`` / ``is_symmetric``). + iris_device (:class:`torch.device`): Device of the heap. + start: Starting value. + end: Ending value (exclusive). + step: Step between elements. + + Keyword Args: + out (:class:`torch.Tensor`, optional): Pre-allocated output tensor. + dtype (:class:`torch.dtype`, optional): Inferred from inputs when ``None``. + layout (:class:`torch.layout`): Default :data:`torch.strided`. + device: Must be compatible with *iris_device* or ``None``. + requires_grad (bool): Default ``False``. + + Returns: + :class:`torch.Tensor`: Tensor on the symmetric heap. + """ + logger.debug(f"arange: start = {start}, end = {end}, step = {step}, dtype = {dtype}, device = {device}") + if step == 0: + raise ValueError("step must be non-zero") + if step > 0 and start >= end: + raise ValueError(f"Invalid range: start >= end with positive step (start={start}, end={end}, step={step})") + elif step < 0 and start <= end: + raise ValueError(f"Invalid range: start <= end with negative step (start={start}, end={end}, step={step})") + + num_elements = math.ceil((end - start) / step) + + if dtype is None: + if any(isinstance(x, float) for x in [start, end, step]): + dtype = torch.get_default_dtype() + else: + dtype = torch.int64 + if device is None: + device = iris_device + throw_if_invalid_device(device, iris_device) + + if out is not None: + throw_if_invalid_output_tensor(heap, out, num_elements, dtype) + tensor = out + else: + tensor = allocate(heap, num_elements, dtype) + + values = torch.arange(start, end, step, dtype=dtype, device=tensor.device) + tensor[:] = values + tensor = apply_layout(tensor, layout) + if requires_grad: + tensor.requires_grad_() + return tensor + + +def linspace( + heap, + iris_device, + start, + end, + steps, + *, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, +): + """Allocate a 1-D tensor of *steps* linearly-spaced values on *heap*. + + Args: + heap: Symmetric heap (``allocate`` / ``is_symmetric``). + iris_device (:class:`torch.device`): Device of the heap. + start: Start of the interval. + end: End of the interval (inclusive). + steps (int): Number of points. + + Keyword Args: + out (:class:`torch.Tensor`, optional): Pre-allocated output tensor. + dtype (:class:`torch.dtype`, optional): Defaults to + :func:`torch.get_default_dtype` (or the corresponding complex dtype). + layout (:class:`torch.layout`): Default :data:`torch.strided`. + device: Must be compatible with *iris_device* or ``None``. + requires_grad (bool): Default ``False``. + + Returns: + :class:`torch.Tensor`: Tensor on the symmetric heap. + """ + logger.debug(f"linspace: start = {start}, end = {end}, steps = {steps}, dtype = {dtype}, device = {device}") + if dtype is None: + start_is_complex = isinstance(start, complex) or (hasattr(start, "dtype") and torch.is_complex(start)) + end_is_complex = isinstance(end, complex) or (hasattr(end, "dtype") and torch.is_complex(end)) + if start_is_complex or end_is_complex: + dtype = torch.complex64 if torch.get_default_dtype() == torch.float32 else torch.complex128 + else: + dtype = torch.get_default_dtype() + if device is None: + device = iris_device + throw_if_invalid_device(device, iris_device) + + # Normalise steps to a plain int + steps_int = _normalize_steps(steps) + size = (steps_int,) + num_elements = steps_int + + if out is not None: + throw_if_invalid_output_tensor(heap, out, num_elements, dtype) + tensor = out.view(size) + else: + tensor = allocate(heap, num_elements, dtype) + tensor = tensor.reshape(size) + + torch.linspace(start, end, steps_int, out=tensor, dtype=dtype, device=device) + tensor = apply_layout(tensor, layout) + if requires_grad: + tensor.requires_grad_() + return tensor diff --git a/iris/tensor_utils.py b/iris/tensor_utils.py new file mode 100644 index 000000000..49b308d20 --- /dev/null +++ b/iris/tensor_utils.py @@ -0,0 +1,132 @@ +""" +Utility functions for tensor creation and manipulation. + +Copyright (c) 2026 Advanced Micro Devices, Inc. +""" + +import torch + + +class CUDAArrayInterface: + """ + Wrapper for creating PyTorch tensors from raw GPU pointers using __cuda_array_interface__. + + This provides a clean interface for creating tensors from device pointers, + which is useful for VMem allocations, imported DMA-BUF handles, and other + scenarios where we need to wrap existing GPU memory. + + Args: + ptr: GPU device pointer (integer) + size_bytes: Size of the memory region in bytes + dtype: PyTorch data type (default: torch.uint8) + device: PyTorch device string (default: 'cuda') + shape: Optional explicit shape tuple (default: inferred from size_bytes and dtype) + + Example: + >>> ptr = 0x7f0000000000 # Some GPU pointer + >>> size_bytes = 1024 + >>> wrapper = CUDAArrayInterface(ptr, size_bytes, dtype=torch.float32) + >>> tensor = torch.as_tensor(wrapper, device='cuda') + """ + + def __init__( + self, + ptr: int, + size_bytes: int, + dtype: torch.dtype = torch.uint8, + device: str = "cuda", + shape: tuple = None, + ): + self.ptr = ptr + self.size_bytes = size_bytes + self.dtype = dtype + self.device = device + + if shape is not None: + self.shape = shape + else: + element_size = torch.tensor([], dtype=dtype).element_size() + num_elements = size_bytes // element_size + self.shape = (num_elements,) + + self.typestr = self._get_typestr(dtype) + + @staticmethod + def _get_typestr(dtype: torch.dtype) -> str: + """ + Convert PyTorch dtype to numpy-style typestr for __cuda_array_interface__. + + Format: + - endianness: '<' (little), '>' (big), '|' (not applicable) + - kind: 'f' (float), 'i' (signed int), 'u' (unsigned int), 'b' (bool) + - size: bytes per element + + Reference: https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html + """ + typestr_map = { + torch.float32: " dict: + """ + Provide __cuda_array_interface__ protocol for PyTorch interop. + + This allows PyTorch to create tensors directly from GPU pointers + without copying data. + + Returns: + dict: CUDA array interface dictionary with shape, typestr, data, and version + """ + return { + "shape": self.shape, + "typestr": self.typestr, + "data": (self.ptr, False), # (pointer, read_only=False) + "version": 3, + } + + +def tensor_from_ptr( + ptr: int, + size_bytes: int, + dtype: torch.dtype = torch.float32, + device: str = "cuda", + shape: tuple = None, +) -> torch.Tensor: + """ + Create a PyTorch tensor from a raw GPU pointer. + + This is a convenience function that wraps CUDAArrayInterface and creates + the tensor in one call. + + Args: + ptr: GPU device pointer (integer) + size_bytes: Size of the memory region in bytes + dtype: PyTorch data type (default: torch.float32) + device: PyTorch device string (default: 'cuda') + shape: Optional explicit shape tuple (default: inferred from size_bytes and dtype) + + Returns: + torch.Tensor: Tensor wrapping the GPU memory + + Example: + >>> ptr = 0x7f0000000000 + >>> tensor = tensor_from_ptr(ptr, 4096, dtype=torch.float32) + >>> print(tensor.shape) # torch.Size([1024]) + """ + wrapper = CUDAArrayInterface(ptr, size_bytes, dtype, device, shape) + return torch.as_tensor(wrapper, device=device) diff --git a/iris/tracing/__init__.py b/iris/tracing/__init__.py new file mode 100644 index 000000000..9c3473cce --- /dev/null +++ b/iris/tracing/__init__.py @@ -0,0 +1,12 @@ +""" +Device-side tracing support for Iris. + +Provides tracing functionality to capture and export device-side events +for debugging and performance analysis. +""" + +from .events import EVENT_NAMES, TraceEvent +from .core import Tracing +from .device import DeviceTracing + +__all__ = ["EVENT_NAMES", "TraceEvent", "Tracing", "DeviceTracing"] diff --git a/iris/tracing/core.py b/iris/tracing/core.py new file mode 100644 index 000000000..9b3ca9a62 --- /dev/null +++ b/iris/tracing/core.py @@ -0,0 +1,327 @@ +""" +Device-side tracing core: buffer allocation, capture, and export. +""" + +import torch +import json +import pickle +import sys +import os +import socket + +from .. import hip +from .events import EVENT_NAMES + + +class Tracing: + """ + Manages device-side event tracing for an Iris instance. + + Handles trace buffer allocation, event capture, and export to Perfetto format. + """ + + def __init__(self, iris_instance): + """ + Initialize tracing manager. + + Args: + iris_instance: Parent Iris instance + """ + self.iris = iris_instance + self.enabled = False + self.max_events = 0 + self.trace_buffers = {} + self.trace_counter = None + + def enable(self, max_events=1_000_000): + """ + Enable device-side event tracing. + + Allocates trace buffers to store events recorded by DeviceContext. + + Args: + max_events (int): Maximum number of events to record. Default: 1,000,000 + """ + self.enabled = True + self.max_events = max_events + + device = self.iris.device + + # Allocate trace buffers (Structure of Arrays for better memory access) + self.trace_buffers = { + "event_id": torch.zeros(max_events, dtype=torch.int32, device=device), + "pid": torch.zeros(max_events, dtype=torch.int32, device=device), + "pid_m": torch.zeros(max_events, dtype=torch.int32, device=device), + "pid_n": torch.zeros(max_events, dtype=torch.int32, device=device), + "cur_rank": torch.zeros(max_events, dtype=torch.int32, device=device), + "target_rank": torch.zeros(max_events, dtype=torch.int32, device=device), + "xcc_id": torch.zeros(max_events, dtype=torch.int32, device=device), + "cu_id": torch.zeros(max_events, dtype=torch.int32, device=device), + "timestamp": torch.zeros(max_events, dtype=torch.int64, device=device), + "address": torch.zeros(max_events, dtype=torch.int64, device=device), + "duration_cycles": torch.zeros(max_events, dtype=torch.int64, device=device), + "op_index": torch.zeros(max_events, dtype=torch.int32, device=device), + "payload_size": torch.zeros(max_events, dtype=torch.int32, device=device), + } + + # Atomic counter for event indexing + self.trace_counter = torch.zeros(1, dtype=torch.int32, device=device) + # Atomic counter for operation indexing (tracks operation order) + self.op_index_counter = torch.zeros(1, dtype=torch.int32, device=device) + + self.iris.info(f"Device tracing enabled with max {max_events} events") + + def reset(self): + """ + Reset trace counter to start a new trace capture. + + Clears the event counter and operation index counter but keeps buffers allocated. + """ + if not self.enabled: + self.iris.warning("Tracing not enabled. Call tracing.enable() first.") + return + + self.trace_counter.zero_() + self.op_index_counter.zero_() + self.iris.debug("Trace buffers reset") + + def _collect_system_metadata(self): + """Collect system and GPU metadata.""" + try: + device_name = torch.cuda.get_device_name(self.iris.cur_rank) + except Exception: + device_name = "Unknown GPU" + + try: + total_memory = torch.cuda.get_device_properties(self.iris.cur_rank).total_memory + total_memory_gb = total_memory / (1024**3) + except Exception: + total_memory_gb = 0 + + return { + "process_name": os.path.basename(sys.argv[0]) if sys.argv else "unknown", + "command_line": " ".join(sys.argv), + "hostname": socket.gethostname(), + "gpu_device_name": device_name, + "gpu_total_memory_gb": f"{total_memory_gb:.2f}", + "gpu_arch": hip.get_arch_string(self.iris.cur_rank), + "gpu_cu_count": hip.get_cu_count(self.iris.cur_rank), + "gpu_num_xcc": hip.get_num_xcc(self.iris.cur_rank), + "rocm_version": hip.get_rocm_version(), + } + + def _build_trace_events(self, num_events): + """Build Perfetto trace events from captured data.""" + trace_events = [] + + for i in range(num_events): + event_id = int(self.trace_buffers["event_id"][i].item()) + event_name = EVENT_NAMES.get(event_id, f"unknown_{event_id}") + + pid = int(self.trace_buffers["pid"][i].item()) + cur_rank = int(self.trace_buffers["cur_rank"][i].item()) + target_rank = int(self.trace_buffers["target_rank"][i].item()) + xcc_id = int(self.trace_buffers["xcc_id"][i].item()) + cu_id = int(self.trace_buffers["cu_id"][i].item()) + begin_ts = int(self.trace_buffers["timestamp"][i].item()) + end_ts = int(self.trace_buffers["duration_cycles"][i].item()) + + # Compute duration (0 = instant event) + duration_cycles = (end_ts - begin_ts) if end_ts > 0 else 0 + + # Perfetto event structure + perfetto_event = { + "name": event_name, + "cat": "iris", + "ts": begin_ts, + "pid": cur_rank, + "tid": f"XCC{xcc_id}_CU{cu_id}", + "args": { + "program_id": pid, + "pid_m": int(self.trace_buffers["pid_m"][i].item()), + "pid_n": int(self.trace_buffers["pid_n"][i].item()), + "target_rank": target_rank, + "address": hex(int(self.trace_buffers["address"][i].item())), + "xcc_id": xcc_id, + "cu_id": cu_id, + "op_index": int(self.trace_buffers["op_index"][i].item()), + "payload_size": int(self.trace_buffers["payload_size"][i].item()), + }, + } + + # Duration event or instant event? + if duration_cycles > 0: + perfetto_event["ph"] = "X" # Complete event + perfetto_event["dur"] = duration_cycles + else: + perfetto_event["ph"] = "i" # Instant event + perfetto_event["s"] = "t" + + trace_events.append(perfetto_event) + + # Add metadata event for this rank + metadata = { + "name": "process_name", + "ph": "M", + "pid": self.iris.cur_rank, + "args": {"name": f"Rank {self.iris.cur_rank}"}, + } + trace_events.append(metadata) + + return trace_events + + def export(self, filename="trace.json", merge=False): + """ + Export collected trace events to Perfetto/Chrome Trace Event Format. + + All timestamps are in raw cycles from s_memrealtime (100MHz constant clock). + View the output at: https://ui.perfetto.dev + + Args: + filename (str): Output JSON filename. Default: "trace.json" + merge (bool): If True, rank 0 collects and merges traces from all ranks + with timestamp alignment. If False, each rank exports its own file. + + Returns: + dict: Trace data (merged on rank 0 if merge=True, per-rank otherwise) + """ + import torch.distributed as dist + + if not self.enabled: + self.iris.warning("Tracing not enabled. Call tracing.enable() first.") + return {} + + # Get actual event count + num_events = min(self.trace_counter.item(), self.max_events) + + # Collect metadata + system_metadata = self._collect_system_metadata() + + # Build trace events + trace_events = self._build_trace_events(num_events) + + # Write per-rank file + per_rank_data = { + "traceEvents": trace_events, + "displayTimeUnit": "ns", + "metadata": { + "schema_version": "1.1", + "num_events": num_events, + "rank": self.iris.cur_rank, + "world_size": self.iris.num_ranks, + "time_unit": "raw cycles (s_memrealtime @ 100MHz)", + "fields": { + "name": "Event type name (e.g., 'put', 'get', 'load', 'store')", + "cat": "Event category (always 'iris')", + "ts": "Start timestamp in raw cycles", + "pid": "Process ID (current rank)", + "tid": "Thread ID (XCC{id}_CU{id})", + "ph": "Phase: 'X' for complete events, 'i' for instant events", + "dur": "Duration in cycles (only for complete events)", + "args": { + "program_id": "Triton program ID (block ID)", + "pid_m": "Program ID in M dimension", + "pid_n": "Program ID in N dimension", + "target_rank": "Target rank for the operation", + "address": "Memory address (hex) - min of address block", + "xcc_id": "XCC (chiplet) ID where event occurred", + "cu_id": "Compute Unit ID where event occurred", + "op_index": "Operation index (0, 1, 2, ...) - automatically tracked", + "payload_size": "Payload size in bytes - automatically calculated from mask and datatype", + }, + }, + **system_metadata, + }, + } + per_rank_filename = filename.replace(".json", f"_rank{self.iris.cur_rank}.json") + with open(per_rank_filename, "w") as f: + json.dump(per_rank_data, f, indent=2) + self.iris.info(f"Exported rank {self.iris.cur_rank} trace to {per_rank_filename}") + + # If not merging, return per-rank data + if not merge: + return per_rank_data + + # Merging logic: serialize and gather events from all ranks + events_bytes = pickle.dumps(trace_events) + events_tensor = torch.ByteTensor(list(events_bytes)).cuda() + + # Gather event counts to rank 0 + event_counts = torch.tensor([len(events_bytes)], dtype=torch.int64, device="cuda") + all_event_counts = [torch.zeros(1, dtype=torch.int64, device="cuda") for _ in range(self.iris.num_ranks)] + dist.all_gather(all_event_counts, event_counts) + + # Synchronize before point-to-point communication to ensure proper ordering + dist.barrier() + + # Rank 0: gather and merge all events + if self.iris.cur_rank == 0: + all_events = [] + + for rank_id in range(self.iris.num_ranks): + if rank_id == 0: + all_events.extend(trace_events) + else: + recv_size = all_event_counts[rank_id].item() + recv_tensor = torch.zeros(recv_size, dtype=torch.uint8, device="cuda") + dist.recv(recv_tensor, src=rank_id) + recv_bytes = bytes(recv_tensor.cpu().numpy()) + rank_events = pickle.loads(recv_bytes) + all_events.extend(rank_events) + + # Align timestamps: find minimum timestamp across all events + all_timestamps = [e["ts"] for e in all_events if e.get("ph") != "M"] + if all_timestamps: + min_ts = min(all_timestamps) + # Shift all timestamps to start from 0 + for event in all_events: + if event.get("ph") != "M": + event["ts"] = event["ts"] - min_ts + + merged_data = { + "traceEvents": all_events, + "displayTimeUnit": "ns", + "metadata": { + "schema_version": "1.1", + "total_events": len(all_events), + "max_events": self.max_events, + "time_unit": "cycles (s_memrealtime @ 100MHz)", + "world_size": self.iris.num_ranks, + "timestamp_offset": min_ts if all_timestamps else 0, + "aligned": "minimum timestamp across all ranks", + "fields": { + "name": "Event type name (e.g., 'put', 'get', 'load', 'store')", + "cat": "Event category (always 'iris')", + "ts": "Start timestamp in raw cycles", + "pid": "Process ID (current rank)", + "tid": "Thread ID (XCC{id}_CU{id})", + "ph": "Phase: 'X' for complete events, 'i' for instant events", + "dur": "Duration in cycles (only for complete events)", + "args": { + "program_id": "Triton program ID (block ID)", + "pid_m": "Program ID in M dimension", + "pid_n": "Program ID in N dimension", + "target_rank": "Target rank for the operation", + "address": "Memory address (hex) - min of address block", + "xcc_id": "XCC (chiplet) ID where event occurred", + "cu_id": "Compute Unit ID where event occurred", + "op_index": "Operation index (0, 1, 2, ...) - automatically tracked", + "payload_size": "Payload size in bytes - automatically calculated from mask and datatype", + }, + }, + **system_metadata, + }, + } + + # Write merged file + with open(filename, "w") as f: + json.dump(merged_data, f, indent=2) + + self.iris.info(f"Exported {len(all_events)} merged trace events to {filename} (aligned)") + self.iris.info("View at: https://ui.perfetto.dev") + + return merged_data + else: + # Other ranks: send events to rank 0 + dist.send(events_tensor, dst=0) + return {} diff --git a/iris/tracing/device.py b/iris/tracing/device.py new file mode 100644 index 000000000..25bd12d72 --- /dev/null +++ b/iris/tracing/device.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Device-side tracing aggregate for Iris. + +DeviceTracing is used inside Triton kernels to record events into trace buffers. +Bounds check uses Python `if event_idx.item() < max_events:` so we only store +when the event index is in range (avoids buffer overrun when event count exceeds capacity). +""" + +import triton +import triton.language as tl +from triton.language.core import _aggregate as aggregate + +from .. import device_utils + + +class _DeviceTracingCls: + """ + Device-side tracing: records events into SoA buffers from inside Triton kernels. + + Created by DeviceContext.initialize() when tracing=True. Use record_event_start + / record_event_end to bracket operations; events are exported via Tracing.export(). + + Bounds check: we only store when event_idx.item() < max_events to avoid overrun. + """ + + enabled: tl.constexpr + rank: tl.constexpr # current rank (from ctx) + max_events: tl.tensor # scalar (0-dim) + counter: tl.tensor # pointer to int32 (event counter) + op_index_counter: tl.tensor # pointer to int32 (operation index counter) + buf_event_id: tl.tensor + buf_pid: tl.tensor + buf_pid_m: tl.tensor + buf_pid_n: tl.tensor + buf_cur_rank: tl.tensor + buf_target_rank: tl.tensor + buf_xcc_id: tl.tensor + buf_cu_id: tl.tensor + buf_timestamp: tl.tensor + buf_address: tl.tensor + buf_duration_cycles: tl.tensor + buf_op_index: tl.tensor + buf_payload_size: tl.tensor + + def __init__( + self, + enabled, + rank, + max_events, + counter, + op_index_counter, + buf_event_id, + buf_pid, + buf_pid_m, + buf_pid_n, + buf_cur_rank, + buf_target_rank, + buf_xcc_id, + buf_cu_id, + buf_timestamp, + buf_address, + buf_duration_cycles, + buf_op_index, + buf_payload_size, + ): + """Construct DeviceTracing (called from DeviceContext.initialize).""" + self.enabled = enabled + self.rank = rank + self.max_events = max_events + self.counter = counter + self.op_index_counter = op_index_counter + self.buf_event_id = buf_event_id + self.buf_pid = buf_pid + self.buf_pid_m = buf_pid_m + self.buf_pid_n = buf_pid_n + self.buf_cur_rank = buf_cur_rank + self.buf_target_rank = buf_target_rank + self.buf_xcc_id = buf_xcc_id + self.buf_cu_id = buf_cu_id + self.buf_timestamp = buf_timestamp + self.buf_address = buf_address + self.buf_duration_cycles = buf_duration_cycles + self.buf_op_index = buf_op_index + self.buf_payload_size = buf_payload_size + + @triton.jit + def record_event_start( + self, + event_id: tl.constexpr, + target_rank, + address, + pid_m, + pid_n, + mask=None, + ): + """ + Record start of a traced operation. Returns a handle for record_event_end. + + Only stores when event_idx.item() < max_events (bounds check). + cur_rank is taken from the tracing context (ctx.rank). + op_index is automatically tracked internally (0, 1, 2, ...). + payload_size is automatically calculated from mask and datatype: + - Counts True values in mask to get number of elements + - Infers datatype size from address pointer type + - Multiplies elements * bytes_per_element to get total bytes + If mask is None, payload_size is set to 0 (unknown size). + + Args: + event_id: Event type ID (constexpr) + target_rank: Target rank for the operation + address: Memory address(es) - can be 1D or 2D block of pointers. + The element type is inferred from address.type.element_ty + pid_m: Program ID in M dimension + pid_n: Program ID in N dimension + mask: Optional mask tensor (1D or 2D) indicating valid elements. + If provided, payload_size is calculated as: + (count of True values) * (bytes per element from address dtype). + If None, payload_size is set to 0. + """ + if not self.enabled: + # Return dummy handle; record_event_end will no-op (0 < max_events is false when disabled) + return tl.full((), 0, dtype=tl.int32) + + event_idx = tl.atomic_add(self.counter, 1) + op_index = tl.atomic_add(self.op_index_counter, 1) + + # Calculate payload_size from mask and datatype + if mask is not None: + # Count True values in mask (True=1, False=0, so sum gives count of elements) + mask_i32 = tl.cast(mask, tl.int32) + num_elements = tl.sum(mask_i32) + + # Get element type from address pointer and calculate size in bytes + # address can be 1D or 2D block of pointers, all with same element type + # For blocks, use .dtype instead of .type (like in test_atomic_xchg_triton.py) + # address.dtype is the pointer type, address.dtype.element_ty is the element dtype + elem_type = address.dtype.element_ty + # Get size in bytes using primitive_bitwidth (bits / 8 = bytes) + bitwidth = elem_type.primitive_bitwidth + elem_size_bytes = bitwidth // 8 + # Calculate total payload size in bytes + payload_size = num_elements * elem_size_bytes + else: + # No mask provided, set to 0 to indicate unknown size + payload_size = tl.full((), 0, dtype=tl.int32) + + if event_idx.item() < self.max_events.item(): + tl.store(self.buf_event_id + event_idx, event_id) + tl.store(self.buf_pid + event_idx, tl.program_id(0)) + tl.store(self.buf_pid_m + event_idx, pid_m) + tl.store(self.buf_pid_n + event_idx, pid_n) + tl.store(self.buf_cur_rank + event_idx, self.rank) + tl.store(self.buf_target_rank + event_idx, target_rank) + tl.store(self.buf_xcc_id + event_idx, device_utils.get_xcc_id()) + tl.store(self.buf_cu_id + event_idx, device_utils.get_cu_id()) + tl.store(self.buf_timestamp + event_idx, device_utils.read_realtime()) + # Store one address per event: accept block of pointers (2D/1D) and take min as representative + addr_i64 = tl.cast(address, tl.int64) + tl.store(self.buf_address + event_idx, tl.min(addr_i64)) + tl.store(self.buf_duration_cycles + event_idx, tl.full((), 0, dtype=tl.int64)) + tl.store(self.buf_op_index + event_idx, op_index) + tl.store(self.buf_payload_size + event_idx, payload_size) + return event_idx + + @triton.jit + def record_event_end(self, handle): + """ + Record end timestamp for the event started with record_event_start(handle). + + Only stores when handle.item() < max_events (bounds check). + """ + if not self.enabled: + return + + end_ts = device_utils.read_realtime() + if handle.item() < self.max_events.item(): + tl.store(self.buf_duration_cycles + handle, end_ts) + + +# Mark __init__ as Triton builtin so dependency finder accepts it when hashing kernels. +_DeviceTracingCls.__init__.__triton_builtin__ = True +DeviceTracing = aggregate(_DeviceTracingCls) diff --git a/iris/tracing/events.py b/iris/tracing/events.py new file mode 100644 index 000000000..4838c09d6 --- /dev/null +++ b/iris/tracing/events.py @@ -0,0 +1,96 @@ +""" +Trace event type IDs and Triton-side enumeration. + +EVENT_NAMES and TraceEvent must stay in sync: same IDs for the same operations. +""" + +import triton +import triton.language as tl +from triton.language.core import _aggregate as aggregate + + +# Event type IDs to names mapping (used for export / display). +# Keep in sync with TraceEvent below. +EVENT_NAMES = { + 0: "load", + 1: "store", + 2: "get", + 3: "put", + 4: "copy", + 5: "atomic_add", + 6: "atomic_sub", + 7: "atomic_cas", + 8: "atomic_xchg", + 9: "atomic_xor", + 10: "atomic_and", + 11: "atomic_or", + 12: "atomic_min", + 13: "atomic_max", +} + + +@aggregate +class TraceEvent: + """ + Trace event type enumeration for iris remote memory operations. + + Usage: + >>> ctx.record_event(event_id=TraceEvent().put, target_rank=1, address=ptr) + + Available event types: + Data Movement: + - load (0): Remote load operation + - store (1): Remote store operation + - get (2): Remote read (pull from remote to local) + - put (3): Remote write (push from local to remote) + - copy (4): Peer-to-peer copy between ranks + + Atomic Operations: + - atomic_add (5): Atomic addition + - atomic_sub (6): Atomic subtraction + - atomic_cas (7): Atomic compare-and-swap + - atomic_xchg (8): Atomic exchange + - atomic_xor (9): Atomic XOR + - atomic_and (10): Atomic AND + - atomic_or (11): Atomic OR + - atomic_min (12): Atomic minimum + - atomic_max (13): Atomic maximum + """ + + # Data movement operations + load: tl.constexpr + store: tl.constexpr + get: tl.constexpr + put: tl.constexpr + copy: tl.constexpr + + # Atomic operations + atomic_add: tl.constexpr + atomic_sub: tl.constexpr + atomic_cas: tl.constexpr + atomic_xchg: tl.constexpr + atomic_xor: tl.constexpr + atomic_and: tl.constexpr + atomic_or: tl.constexpr + atomic_min: tl.constexpr + atomic_max: tl.constexpr + + @triton.constexpr_function + def __init__(self): + # Data movement + self.load = tl.constexpr(0) + self.store = tl.constexpr(1) + self.get = tl.constexpr(2) + self.put = tl.constexpr(3) + self.copy = tl.constexpr(4) + + # Atomics + self.atomic_add = tl.constexpr(5) + self.atomic_sub = tl.constexpr(6) + self.atomic_cas = tl.constexpr(7) + self.atomic_xchg = tl.constexpr(8) + self.atomic_xor = tl.constexpr(9) + self.atomic_and = tl.constexpr(10) + self.atomic_or = tl.constexpr(11) + self.atomic_min = tl.constexpr(12) + self.atomic_max = tl.constexpr(13) diff --git a/iris/util.py b/iris/util.py index 1e07a0d14..ae259273f 100644 --- a/iris/util.py +++ b/iris/util.py @@ -23,11 +23,70 @@ # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +import os import statistics import math import torch +def is_simulation_env() -> bool: + """ + Return True if running in a simulation environment (e.g. pre-silicon). + + When True, Iris will force the torch allocator regardless of allocator_type. + Set IRIS_SIMULATION=1 (or "true"/"yes") to enable. + """ + val = os.environ.get("IRIS_SIMULATION", "").strip().lower() + return val in ("1", "true", "yes") + + +def get_simulation_device_id(local_rank: int) -> int: + """ + Get the device ID to use in simulation mode. + + In simulation, multiple ranks may need to share the same physical device. + This function wraps the local_rank to ensure it's within available device bounds. + + Args: + local_rank: The local rank from the environment + + Returns: + Device ID that's guaranteed to be valid (wrapped if needed) + """ + import torch + + num_devices = torch.cuda.device_count() + if num_devices == 0: + return 0 # Fallback if no devices detected + # Wrap to available devices - in simulation, multiple ranks can share device 0 + return local_rank % num_devices + + +def get_device_id_for_rank(local_rank: int) -> int: + """ + Get the device ID to use for a given local rank. + + In simulation mode, this automatically wraps the rank to handle multiple ranks + sharing a single GPU. In normal mode, it returns the local_rank as-is. + + Args: + local_rank: The local rank from the environment (typically from LOCAL_RANK env var) + + Returns: + Device ID that's guaranteed to be valid + + Example: + >>> import iris + >>> local_rank = int(os.environ.get("LOCAL_RANK", 0)) + >>> device_id = iris.get_device_id_for_rank(local_rank) + >>> torch.cuda.set_device(device_id) + """ + if is_simulation_env(): + return get_simulation_device_id(local_rank) + else: + return local_rank + + def get_empty_cache_for_benchmark(): cache_size = 256 * 1024 * 1024 return torch.empty(int(cache_size // 4), dtype=torch.int, device="cuda") diff --git a/iris/x/__init__.py b/iris/x/__init__.py index 71e5d453f..7377fbe3b 100644 --- a/iris/x/__init__.py +++ b/iris/x/__init__.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. """ iris-x: Device-side tile-level primitives for fine-grained compute and collective operations. @@ -17,9 +17,9 @@ >>> @triton.jit >>> def my_kernel(input_ptr, output_ptr, ...): >>> tile = iris.x.TileView(pid_m, pid_n, BLOCK_M, BLOCK_N) - >>> src_view = iris.x.TensorView(input_ptr, M, N, stride_m, stride_n) - >>> dst_view = iris.x.TensorView(output_ptr, M, N, stride_m, stride_n) - >>> ctx = iris.x.DeviceContext(rank, world_size, heap_bases) + >>> src_view = iris.x.make_tensor_view(input_ptr, M, N, stride_m, stride_n) + >>> dst_view = iris.x.make_tensor_view(output_ptr, M, N, stride_m, stride_n) + >>> ctx = iris.DeviceContext.initialize(context_tensor, rank, world_size) >>> >>> # Call collectives on ctx directly (default algorithms) >>> ctx.all_reduce(tile, src_view, dst_view) @@ -31,9 +31,9 @@ >>> @triton.jit >>> def my_kernel(input_ptr, output_ptr, locks_ptr, ...): >>> tile = iris.x.TileView(pid_m, pid_n, BLOCK_M, BLOCK_N) - >>> src_view = iris.x.TensorView(input_ptr, M, N, stride_m, stride_n) - >>> dst_view = iris.x.TensorView(output_ptr, M, N, stride_m, stride_n) - >>> ctx = iris.x.DeviceContext(rank, world_size, heap_bases) + >>> src_view = iris.x.make_tensor_view(input_ptr, M, N, stride_m, stride_n) + >>> dst_view = iris.x.make_tensor_view(output_ptr, M, N, stride_m, stride_n) + >>> ctx = iris.DeviceContext.initialize(context_tensor, rank, world_size) >>> >>> # Use ring algorithm >>> config = iris.x.AllReduceConfig("ring") @@ -51,7 +51,16 @@ >>> iris.x.all_gather(tile, src_view, dst_view, dim, ctx) """ -from .core import Tile, TileView, TensorView, DeviceContext, AllReduceConfig, tile_layout, tile_ptr, offset_ptr +from .core import ( + Tile, + TileView, + TensorView, + AllReduceConfig, + tile_layout, + tile_ptr, + offset_ptr, + make_tensor_view, +) from .all_reduce import ( all_reduce_atomic, all_reduce_ring, @@ -69,11 +78,11 @@ "Tile", "TileView", "TensorView", - "DeviceContext", "AllReduceConfig", "tile_layout", "tile_ptr", "offset_ptr", + "make_tensor_view", # Device-side collectives "all_reduce_atomic", "all_reduce_ring", diff --git a/iris/x/all_gather.py b/iris/x/all_gather.py index b2be3ee31..a8c84bde2 100644 --- a/iris/x/all_gather.py +++ b/iris/x/all_gather.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. """ Tile-level all-gather primitive for Iris. @@ -10,7 +10,8 @@ import triton import triton.language as tl import iris -from .core import Tile, TensorView, DeviceContext +from iris.iris import DeviceContext +from .core import Tile, TensorView @triton.jit() @@ -63,7 +64,12 @@ def all_gather( # Scatter along N dimension: write to [:, ctx.rank * N_local : (ctx.rank+1) * N_local] dst_ptr, combined_mask = dst_view.offset_tile_ptr(tile, offset_n=ctx.rank * N_local, src_mask=None) - # Use iris.store to write to dest_rank's memory + # Use iris.store to write to dest_rank's memory. + # hint=(1, tile.block_n) asserts per-row contiguity only (BLOCK_N consecutive + # elements within each row). Using (tile.block_m, tile.block_n) would + # assert cross-row contiguity which is false when BLOCK_N < N (stride_m > BLOCK_N), + # causing getOrderFromContiguity to choose dim-0 for vectorization and emitting + # scalar buffer_store_short writes to wrong addresses. iris.store( dst_ptr, tile.data, @@ -71,4 +77,5 @@ def all_gather( dest_rank, # to_rank (destination rank) ctx.heap_bases, mask=combined_mask, + hint=(1, tile.block_n), ) diff --git a/iris/x/all_reduce.py b/iris/x/all_reduce.py index e6c287ccc..901f5adb6 100644 --- a/iris/x/all_reduce.py +++ b/iris/x/all_reduce.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. """ Tile-level all-reduce primitives for Iris. @@ -11,7 +11,8 @@ import triton import triton.language as tl import iris -from .core import Tile, TensorView, DeviceContext +from iris.iris import DeviceContext +from .core import Tile, TensorView @triton.jit() @@ -34,7 +35,7 @@ def all_reduce_atomic( Example: # After computing a local tile result tile = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, local_result) - dst_view = iris.x.TensorView(output_ptr, M, N, stride_m, stride_n) + dst_view = iris.x.make_tensor_view(output_ptr, M, N, stride_m, stride_n) iris.x.all_reduce_atomic(tile, dst_view, ctx) """ # Get destination tile pointer and mask for this tile position @@ -75,7 +76,7 @@ def all_reduce_spinlock( Example: # After computing a local tile result tile = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, local_result) - dst_view = iris.x.TensorView(output_ptr, M, N, stride_m, stride_n) + dst_view = iris.x.make_tensor_view(output_ptr, M, N, stride_m, stride_n) iris.x.all_reduce_spinlock(tile, dst_view, locks_ptr, ctx) """ # Compute tile ID for lock indexing @@ -139,8 +140,8 @@ def all_reduce_one_shot( Example: # After computing and storing a local tile result and signaling ready tile = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, local_result) - src_view = iris.x.TensorView(input_ptr, M, N, stride_m, stride_n) - dst_view = iris.x.TensorView(output_ptr, M, N, stride_m, stride_n) + src_view = iris.x.make_tensor_view(input_ptr, M, N, stride_m, stride_n) + dst_view = iris.x.make_tensor_view(output_ptr, M, N, stride_m, stride_n) iris.x.all_reduce_one_shot(tile, src_view, dst_view, locks, ctx) """ # Get tile pointers and mask @@ -236,8 +237,6 @@ def all_reduce_two_shot( src_view: TensorView, dst_view: TensorView, locks, - start_tile: tl.constexpr, - stride: tl.constexpr, ctx: DeviceContext, ): """ @@ -247,6 +246,7 @@ def all_reduce_two_shot( the result to all other ranks. Uses locks as ready flags: before loading, wait for remote tiles to be ready (lock == 1). + Uses interleaved distribution: rank handles tiles where tile_id % world_size == rank. Phase 1: If this tile is rank's responsibility, load from all ranks and reduce locally Phase 2: Scatter reduced tile to all ranks using iris.store @@ -256,25 +256,21 @@ def all_reduce_two_shot( src_view: TensorView for source tensor (to load remote data). dst_view: TensorView for output tensor where reduced result will be written. locks: Pointer to lock array (one per tile) used as ready flags. - start_tile: Starting tile ID for this rank's responsibility. - stride: Stride between tiles this rank is responsible for. ctx: DeviceContext with rank, world_size, and heap_bases. - Example (interleaved distribution): - # Rank 0 handles tiles 0, 2, 4, ... (start_tile=0, stride=2) - # Rank 1 handles tiles 1, 3, 5, ... (start_tile=1, stride=2) + Example: tile = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, local_result) - src_view = iris.x.TensorView(input_ptr, M, N, stride_m, stride_n) - dst_view = iris.x.TensorView(output_ptr, M, N, stride_m, stride_n) - iris.x.all_reduce_two_shot(tile, src_view, dst_view, locks, rank, world_size, ctx) + src_view = iris.x.make_tensor_view(input_ptr, M, N, stride_m, stride_n) + dst_view = iris.x.make_tensor_view(output_ptr, M, N, stride_m, stride_n) + iris.x.all_reduce_two_shot(tile, src_view, dst_view, locks, ctx) """ # Compute tile ID num_tiles_n = tl.cdiv(dst_view.N, tile.block_n) tile_id = tile.pid_m * num_tiles_n + tile.pid_n # Check if this tile is this rank's responsibility - # Tile is responsible if: (tile_id - start_tile) % stride == 0 and tile_id >= start_tile - is_responsible = (tile_id >= start_tile) and ((tile_id - start_tile) % stride == 0) + # Using interleaved distribution: rank handles tiles where tile_id % world_size == rank + is_responsible = (tile_id % ctx.world_size) == ctx.rank if is_responsible: # Phase 1: Reduce - load from all ranks and accumulate locally @@ -317,6 +313,7 @@ def all_reduce_two_shot( dest_rank, # to_rank (destination rank) ctx.heap_bases, mask=mask, + hint=(1, tile.block_n), ) diff --git a/iris/x/all_to_all.py b/iris/x/all_to_all.py index 56b92f917..55530a8cf 100644 --- a/iris/x/all_to_all.py +++ b/iris/x/all_to_all.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. """ Tile-level all-to-all primitive for Iris. @@ -10,7 +10,8 @@ import triton import triton.language as tl import iris -from .core import Tile, TensorView, DeviceContext +from iris.iris import DeviceContext +from .core import Tile, TensorView @triton.jit() diff --git a/iris/x/core.py b/iris/x/core.py index 492468004..fee50918e 100644 --- a/iris/x/core.py +++ b/iris/x/core.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. """ Core abstractions for iris.x tile-level primitives. @@ -246,48 +246,73 @@ def indices(self): return rm, rn +@triton.jit +def make_tensor_view(ptr, M, N, stride_m, stride_n): + """ + Factory function to create a TensorView inside a JIT context. + + This wrapper is needed because @triton.constexpr_function constructors + require a JIT context for proper semantic handling. It also converts + int/constexpr values to tensors using the +0 trick. + + Args: + ptr: Pointer to tensor data + M: Number of rows + N: Number of columns + stride_m: Stride in M dimension + stride_n: Stride in N dimension + + Returns: + TensorView instance + """ + # Convert to tensor if needed (handles constexpr ints like stride=1) + # The +0 trick promotes int/constexpr to tensor in JIT context + M_t = M + 0 + N_t = N + 0 + stride_m_t = stride_m + 0 + stride_n_t = stride_n + 0 + return TensorView(ptr, M_t, N_t, stride_m_t, stride_n_t) + + @aggregate class TensorView: """ TensorView storing pointer and tensor metadata. - This works when dimensions and strides are marked as tl.constexpr in the kernel signature! + Dimensions and strides are stored directly - when passed from kernel + parameters (non-constexpr), they are already tensors. - Example usage (with constexpr dimensions): + Example usage: @triton.jit - def kernel(ptr, M: tl.constexpr, N: tl.constexpr, - stride_m: tl.constexpr, stride_n: tl.constexpr, ...): - view = TensorView(ptr, M, N, stride_m, stride_n) + def kernel(ptr, M, N, stride_m, stride_n, ...): + view = make_tensor_view(ptr, M, N, stride_m, stride_n) tile = Tile(pid_m, pid_n, BLOCK_M, BLOCK_N) ptr, mask = view.tile_ptr(tile) - - Note: If M, N, strides are NOT constexpr (runtime kernel args), you cannot store them. - In that case, use the device functions directly or pass them as method arguments. """ ptr: tl.tensor - M: tl.constexpr - N: tl.constexpr - stride_m: tl.constexpr - stride_n: tl.constexpr + M: tl.tensor + N: tl.tensor + stride_m: tl.tensor + stride_n: tl.tensor @triton.constexpr_function def __init__(self, ptr, M, N, stride_m, stride_n): """ - Create a tensor view with pointer and constexpr dimensions/strides. + Create a tensor view with pointer and dimensions/strides. Args: - ptr: Pointer to tensor data (runtime tensor) - M: Number of rows (must be constexpr in kernel signature) - N: Number of columns (must be constexpr in kernel signature) - stride_m: Stride in M dimension (must be constexpr in kernel signature) - stride_n: Stride in N dimension (must be constexpr in kernel signature) + ptr: Pointer to tensor data + M: Number of rows (tensor from kernel parameter) + N: Number of columns (tensor from kernel parameter) + stride_m: Stride in M dimension (tensor from kernel parameter) + stride_n: Stride in N dimension (tensor from kernel parameter) """ self.ptr = ptr - self.M = tl.constexpr(M) - self.N = tl.constexpr(N) - self.stride_m = tl.constexpr(stride_m) - self.stride_n = tl.constexpr(stride_n) + self.M = M + self.N = N + self.stride_m = stride_m + self.stride_n = stride_n @triton.jit def tile_ptr(self, tile: Tile): @@ -462,52 +487,13 @@ def __init__(self, variant_code, locks_ptr): self.locks_ptr = locks_ptr -@aggregate -class DeviceContext: - """ - Device context encapsulating distributed system information. - - This class stores the rank, world size, and heap base pointers needed - for multi-GPU operations using iris primitives. - - IMPORTANT: Triton does not allow imports inside @triton.jit functions, - so collective methods cannot be added to this class. Instead, call the - collective primitives directly: - - Usage: - from iris.x.all_gather import all_gather - from iris.x.all_reduce import all_reduce_one_shot - from iris.x.reduce_scatter import reduce_scatter - - @triton.jit - def my_kernel(..., heap_bases, rank, world_size, ...): - ctx = DeviceContext(rank, world_size, heap_bases) - - # Call primitives directly with ctx as the last argument - all_gather(tile, src_view, dst_view, dim, ctx) - all_reduce_one_shot(tile, src_view, dst_view, ctx) - reduce_scatter(tile, src_view, dst_view, ctx) - - Attributes: - rank: Current rank (constexpr) - world_size: Total number of ranks (constexpr) - heap_bases: Heap base pointers for all ranks (tensor) - """ - - rank: tl.constexpr - world_size: tl.constexpr - heap_bases: tl.tensor - - @triton.constexpr_function - def __init__(self, rank, world_size, heap_bases): - """ - Create a device context for distributed operations. - - Args: - rank: Current rank (must be constexpr in kernel signature) - world_size: Total number of ranks (must be constexpr in kernel signature) - heap_bases: Heap base pointers for all ranks (runtime tensor) - """ - self.rank = tl.constexpr(rank) - self.world_size = tl.constexpr(world_size) - self.heap_bases = heap_bases +__all__ = [ + "TileView", + "Tile", + "TensorView", + "AllReduceConfig", + "tile_layout", + "tile_ptr", + "offset_ptr", + "make_tensor_view", +] diff --git a/iris/x/gather.py b/iris/x/gather.py index 456ab7e8d..ca8bd4f9c 100644 --- a/iris/x/gather.py +++ b/iris/x/gather.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. """ Tile-level gather primitive for Iris. @@ -14,7 +14,8 @@ import triton import triton.language as tl import iris -from .core import Tile, TensorView, DeviceContext +from iris.iris import DeviceContext +from .core import Tile, TensorView @triton.jit() diff --git a/iris/x/reduce_scatter.py b/iris/x/reduce_scatter.py index 96ad35e34..74197c9cd 100644 --- a/iris/x/reduce_scatter.py +++ b/iris/x/reduce_scatter.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. """ Tile-level reduce-scatter primitive for Iris. @@ -10,7 +10,8 @@ import triton import triton.language as tl import iris -from .core import Tile, TensorView, DeviceContext +from iris.iris import DeviceContext +from .core import Tile, TensorView @triton.jit() @@ -41,8 +42,8 @@ def reduce_scatter( # Rank 0 handles tiles 0, 1, 2 # Rank 1 handles tiles 3, 4, 5 tile = iris.x.Tile(pid_m, pid_n, BLOCK_SIZE_M, BLOCK_SIZE_N, local_result) - src_view = iris.x.TensorView(temp_buffer, M, N, stride_m, stride_n) - dst_view = iris.x.TensorView(output_ptr, M, N, stride_m, stride_n) + src_view = iris.x.make_tensor_view(temp_buffer, M, N, stride_m, stride_n) + dst_view = iris.x.make_tensor_view(output_ptr, M, N, stride_m, stride_n) iris.x.reduce_scatter(tile, src_view, dst_view, locks, ctx) """ num_tiles_n = tl.cdiv(dst_view.N, tile.block_n) diff --git a/pyproject.toml b/pyproject.toml index 377e6f61a..18e71badb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,8 @@ dependencies = [ "numpy", "requests", "ruff", + "tritonblas @ git+https://github.com/ROCm/tritonBLAS.git@df58476a4520b72495a3f03f911368a184126568", + ] [project.urls] diff --git a/scripts/roccap_wrapper.py b/scripts/roccap_wrapper.py new file mode 100644 index 000000000..a3f76fb5c --- /dev/null +++ b/scripts/roccap_wrapper.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +import os +import sys +import argparse +import shutil + +# Example: +# cd examples/29_ops_all_gather_matmul +# torchrun --nproc_per_node=4 --standalone ../../scripts/roccap_wrapper.py -k _fused_all_gather_matmul_kernel example.py -m 1024 -n 128 + +parser = argparse.ArgumentParser() +parser.add_argument("-k", "--kernel", type=str, default="_fused_all_gather_matmul_kernel") +parser.add_argument("--skip-roccap", action="store_true", help="Skip roccap and run script directly") + +# Everything else: first is the script to run, rest are passed through to it +parsed, unknown = parser.parse_known_args() +if not unknown: + sys.exit("Usage: roccap_wrapper.py [-k KERNEL] [--skip-roccap]