From 190e06c75208dc52c3d424b7247d6deb64a41dfb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 22 Mar 2026 07:11:59 +0000 Subject: [PATCH 1/6] Initial plan From 35126342a1ef93a24445b452608210902e8520ca Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 22 Mar 2026 07:20:22 +0000 Subject: [PATCH 2/6] Port tracing capabilities to Gluon IrisDeviceCtx API Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> Agent-Logs-Url: https://github.com/ROCm/iris/sessions/3068c30a-5cdf-47fd-b543-fe83d4d7c613 --- iris/experimental/iris_gluon.py | 274 ++++++++++++++++++- tests/unittests/test_device_context_gluon.py | 123 +++++++++ 2 files changed, 391 insertions(+), 6 deletions(-) create mode 100644 tests/unittests/test_device_context_gluon.py diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 1a06f284..b0c4e05d 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -51,6 +51,8 @@ count_devices, ) from iris.symmetric_heap import SymmetricHeap +from iris import device_utils +from iris.tracing.core import Tracing import numpy as np import torch import logging @@ -62,6 +64,152 @@ from .. import tensor_creation +class _GluonDeviceTracingCls: + """ + Gluon-native device-side tracing: records events into SoA buffers from inside Gluon kernels. + + Created by IrisDeviceCtx.initialize() when tracing=True. Use record_event_start + / record_event_end to bracket operations; events are exported via Tracing.export(). + """ + + enabled: tl.constexpr + rank: gl.tensor + max_events: gl.tensor + counter: gl.tensor + op_index_counter: gl.tensor + buf_event_id: gl.tensor + buf_pid: gl.tensor + buf_pid_m: gl.tensor + buf_pid_n: gl.tensor + buf_cur_rank: gl.tensor + buf_target_rank: gl.tensor + buf_xcc_id: gl.tensor + buf_cu_id: gl.tensor + buf_timestamp: gl.tensor + buf_address: gl.tensor + buf_duration_cycles: gl.tensor + buf_op_index: gl.tensor + buf_payload_size: gl.tensor + + @gluon.constexpr_function + def __init__( + self, + enabled, + rank, + max_events, + counter, + op_index_counter, + buf_event_id, + buf_pid, + buf_pid_m, + buf_pid_n, + buf_cur_rank, + buf_target_rank, + buf_xcc_id, + buf_cu_id, + buf_timestamp, + buf_address, + buf_duration_cycles, + buf_op_index, + buf_payload_size, + ): + """Construct GluonDeviceTracing (called from IrisDeviceCtx.initialize).""" + self.enabled = enabled + self.rank = rank + self.max_events = max_events + self.counter = counter + self.op_index_counter = op_index_counter + self.buf_event_id = buf_event_id + self.buf_pid = buf_pid + self.buf_pid_m = buf_pid_m + self.buf_pid_n = buf_pid_n + self.buf_cur_rank = buf_cur_rank + self.buf_target_rank = buf_target_rank + self.buf_xcc_id = buf_xcc_id + self.buf_cu_id = buf_cu_id + self.buf_timestamp = buf_timestamp + self.buf_address = buf_address + self.buf_duration_cycles = buf_duration_cycles + self.buf_op_index = buf_op_index + self.buf_payload_size = buf_payload_size + + @gluon.jit + def record_event_start( + self, + event_id: tl.constexpr, + target_rank, + address, + pid_m, + pid_n, + mask=None, + ): + """ + Record start of a traced operation. Returns a handle for record_event_end. + + Only stores when event_idx < max_events (bounds check). + cur_rank is taken from the tracing context (self.rank). + + Args: + event_id: Event type ID (constexpr) + target_rank: Target rank for the operation + address: Memory address(es) - can be 1D or 2D block of pointers. + pid_m: Program ID in M dimension + pid_n: Program ID in N dimension + mask: Optional mask tensor indicating valid elements. + """ + if not self.enabled: + return tl.full((), 0, dtype=tl.int32) + + event_idx = tl.atomic_add(self.counter, 1) + op_index = tl.atomic_add(self.op_index_counter, 1) + + # Calculate payload_size from mask and datatype + if mask is not None: + mask_i32 = tl.cast(mask, tl.int32) + num_elements = tl.sum(mask_i32) + elem_type = address.dtype.element_ty + bitwidth = elem_type.primitive_bitwidth + elem_size_bytes = bitwidth // 8 + payload_size = num_elements * elem_size_bytes + else: + payload_size = tl.full((), 0, dtype=tl.int32) + + if event_idx.item() < self.max_events.item(): + tl.store(self.buf_event_id + event_idx, event_id) + tl.store(self.buf_pid + event_idx, gl.program_id(0)) + tl.store(self.buf_pid_m + event_idx, pid_m) + tl.store(self.buf_pid_n + event_idx, pid_n) + tl.store(self.buf_cur_rank + event_idx, self.rank) + tl.store(self.buf_target_rank + event_idx, target_rank) + tl.store(self.buf_xcc_id + event_idx, device_utils.get_xcc_id()) + tl.store(self.buf_cu_id + event_idx, device_utils.get_cu_id()) + tl.store(self.buf_timestamp + event_idx, device_utils.read_realtime()) + addr_i64 = tl.cast(address, tl.int64) + tl.store(self.buf_address + event_idx, tl.min(addr_i64)) + tl.store(self.buf_duration_cycles + event_idx, tl.full((), 0, dtype=tl.int64)) + tl.store(self.buf_op_index + event_idx, op_index) + tl.store(self.buf_payload_size + event_idx, payload_size) + return event_idx + + @gluon.jit + def record_event_end(self, handle): + """ + Record end timestamp for the event started with record_event_start(handle). + + Only stores when handle < max_events (bounds check). + """ + if not self.enabled: + return + + end_ts = device_utils.read_realtime() + if handle.item() < self.max_events.item(): + tl.store(self.buf_duration_cycles + handle, end_ts) + + +_GluonDeviceTracingCls.__init__.__triton_builtin__ = True +GluonDeviceTracing = aggregate(_GluonDeviceTracingCls) + + @aggregate class IrisDeviceCtx: """ @@ -74,28 +222,36 @@ class IrisDeviceCtx: cur_rank: Current rank ID num_ranks: Total number of ranks heap_bases: Pointer to array of heap base addresses for all ranks + tracing: GluonDeviceTracing instance (active when tracing=True) """ cur_rank: gl.tensor num_ranks: gl.tensor heap_bases: gl.tensor + tracing: GluonDeviceTracing @gluon.constexpr_function - def __init__(self, cur_rank, num_ranks, heap_bases): + def __init__(self, cur_rank, num_ranks, heap_bases, tracing): self.cur_rank = cur_rank self.num_ranks = num_ranks self.heap_bases = heap_bases + self.tracing = tracing @staticmethod @gluon.jit - def initialize(context_tensor): + def initialize(context_tensor, tracing: gl.constexpr = False): """ Initialize `IrisDeviceCtx` from the encoded tensor. - The context tensor has the format: `[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]` + The context tensor has the format: + ``[cur_rank, num_ranks, heap_base_0, heap_base_1, ..., trace_info...]`` + + If tracing is enabled on the host (via ``shmem.tracing.enable()``), the + context tensor also contains tracing buffer pointers after the heap bases. Args: context_tensor: Pointer to encoded context data + tracing: Enable event tracing (constexpr, default: False) Returns: `IrisDeviceCtx`: Initialized device context @@ -107,7 +263,82 @@ def initialize(context_tensor): # Extract heap bases (from index 2 onwards) heap_bases = context_tensor + 2 # Offset pointer to start at heap bases - return IrisDeviceCtx(cur_rank, num_ranks, heap_bases) + if tracing: + # Extract tracing info: starts after heap_bases, then skip trace_enabled flag + # Layout: [cur_rank, num_ranks, heap_base_0..N-1, trace_enabled, max_events, + # trace_counter_ptr, op_index_counter_ptr, buf_event_id, ...(13 buffers)] + trace_info_base = 2 + num_ranks + 1 # skip cur_rank, num_ranks, heap_bases, trace_enabled + max_events = gl.load(context_tensor + trace_info_base + 0) + trace_counter_ptr = gl.load(context_tensor + trace_info_base + 1) + op_index_counter_ptr = gl.load(context_tensor + trace_info_base + 2) + + # Cast counter pointers + trace_counter = tl.cast(trace_counter_ptr, tl.pointer_type(tl.int32)) + op_index_counter = tl.cast(op_index_counter_ptr, tl.pointer_type(tl.int32)) + + # Extract trace buffer pointers (13 buffers, same order as Iris._build_device_context) + buf_base = trace_info_base + 3 + buf_event_id = tl.cast(gl.load(context_tensor + buf_base + 0), tl.pointer_type(tl.int32)) + buf_pid = tl.cast(gl.load(context_tensor + buf_base + 1), tl.pointer_type(tl.int32)) + buf_pid_m = tl.cast(gl.load(context_tensor + buf_base + 2), tl.pointer_type(tl.int32)) + buf_pid_n = tl.cast(gl.load(context_tensor + buf_base + 3), tl.pointer_type(tl.int32)) + buf_cur_rank = tl.cast(gl.load(context_tensor + buf_base + 4), tl.pointer_type(tl.int32)) + buf_target_rank = tl.cast(gl.load(context_tensor + buf_base + 5), tl.pointer_type(tl.int32)) + buf_xcc_id = tl.cast(gl.load(context_tensor + buf_base + 6), tl.pointer_type(tl.int32)) + buf_cu_id = tl.cast(gl.load(context_tensor + buf_base + 7), tl.pointer_type(tl.int32)) + buf_timestamp = tl.cast(gl.load(context_tensor + buf_base + 8), tl.pointer_type(tl.int64)) + buf_address = tl.cast(gl.load(context_tensor + buf_base + 9), tl.pointer_type(tl.int64)) + buf_duration_cycles = tl.cast(gl.load(context_tensor + buf_base + 10), tl.pointer_type(tl.int64)) + buf_op_index = tl.cast(gl.load(context_tensor + buf_base + 11), tl.pointer_type(tl.int32)) + buf_payload_size = tl.cast(gl.load(context_tensor + buf_base + 12), tl.pointer_type(tl.int32)) + + device_tracing = GluonDeviceTracing( + enabled=tracing, + rank=cur_rank, + max_events=max_events, + counter=trace_counter, + op_index_counter=op_index_counter, + buf_event_id=buf_event_id, + buf_pid=buf_pid, + buf_pid_m=buf_pid_m, + buf_pid_n=buf_pid_n, + buf_cur_rank=buf_cur_rank, + buf_target_rank=buf_target_rank, + buf_xcc_id=buf_xcc_id, + buf_cu_id=buf_cu_id, + buf_timestamp=buf_timestamp, + buf_address=buf_address, + buf_duration_cycles=buf_duration_cycles, + buf_op_index=buf_op_index, + buf_payload_size=buf_payload_size, + ) + else: + # When tracing disabled, use dummy pointers (never dereferenced) + dummy_ptr_i32 = tl.cast(context_tensor, tl.pointer_type(tl.int32)) + dummy_ptr_i64 = tl.cast(context_tensor, tl.pointer_type(tl.int64)) + max_events_zero = tl.full((), 0, dtype=tl.int32) + device_tracing = GluonDeviceTracing( + enabled=False, + rank=cur_rank, + max_events=max_events_zero, + counter=dummy_ptr_i32, + op_index_counter=dummy_ptr_i32, + buf_event_id=dummy_ptr_i32, + buf_pid=dummy_ptr_i32, + buf_pid_m=dummy_ptr_i32, + buf_pid_n=dummy_ptr_i32, + buf_cur_rank=dummy_ptr_i32, + buf_target_rank=dummy_ptr_i32, + buf_xcc_id=dummy_ptr_i32, + buf_cu_id=dummy_ptr_i32, + buf_timestamp=dummy_ptr_i64, + buf_address=dummy_ptr_i64, + buf_duration_cycles=dummy_ptr_i64, + buf_op_index=dummy_ptr_i32, + buf_payload_size=dummy_ptr_i32, + ) + + return IrisDeviceCtx(cur_rank, num_ranks, heap_bases, device_tracing) @gluon.jit def _translate(self, ptr, from_rank, to_rank): @@ -513,6 +744,9 @@ def __init__(self, heap_size=1 << 30): distributed_barrier() + # Initialize tracing manager (disabled by default) + self.tracing = Tracing(self) + # Initialize CCL interface self.ccl = self.CCL(self) @@ -668,20 +902,48 @@ def _build_device_context(self): """ Build and cache the device context tensor. - Called during __init__ to pre-build the tensor once. + Called during __init__ and again after tracing.enable() to include tracing fields. """ # Convert heap_bases to a list for concatenation heap_bases_list = self.heap_bases.tolist() # Create context tensor: [cur_rank, num_ranks, heap_base_0, heap_base_1, ...] context_data = [self.cur_rank, self.num_ranks] + heap_bases_list + + # Add tracing info if enabled (same layout as Iris._build_device_context) + if self.tracing.enabled: + trace_buffer_ptrs = [ + self.tracing.trace_buffers["event_id"].data_ptr(), + self.tracing.trace_buffers["pid"].data_ptr(), + self.tracing.trace_buffers["pid_m"].data_ptr(), + self.tracing.trace_buffers["pid_n"].data_ptr(), + self.tracing.trace_buffers["cur_rank"].data_ptr(), + self.tracing.trace_buffers["target_rank"].data_ptr(), + self.tracing.trace_buffers["xcc_id"].data_ptr(), + self.tracing.trace_buffers["cu_id"].data_ptr(), + self.tracing.trace_buffers["timestamp"].data_ptr(), + self.tracing.trace_buffers["address"].data_ptr(), + self.tracing.trace_buffers["duration_cycles"].data_ptr(), + self.tracing.trace_buffers["op_index"].data_ptr(), + self.tracing.trace_buffers["payload_size"].data_ptr(), + ] + context_data += [ + 1, # trace_enabled = 1 (true) + self.tracing.max_events, + self.tracing.trace_counter.data_ptr(), + self.tracing.op_index_counter.data_ptr(), + ] + trace_buffer_ptrs + else: + context_data += [0] # trace_enabled = 0 (false) + self._device_context = torch.tensor(context_data, dtype=torch.int64, device=self.device) def get_device_context(self): """ Get the device context tensor for Gluon kernels. - Returns a tensor encoding: `[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]` + Returns a tensor encoding: ``[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]`` + If tracing is enabled, also includes: ``[trace_enabled, max_events, trace_counter_ptr, trace_buffer_ptrs...]`` Returns: torch.Tensor: Encoded context data as int64 tensor on device diff --git a/tests/unittests/test_device_context_gluon.py b/tests/unittests/test_device_context_gluon.py new file mode 100644 index 00000000..e47df534 --- /dev/null +++ b/tests/unittests/test_device_context_gluon.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import pytest +from triton.experimental import gluon +from triton.experimental.gluon import language as gl +import triton.language as tl +import iris.experimental.iris_gluon as iris_gl +from iris.tracing.events import TraceEvent + + +@gluon.jit +def device_context_tracing_1d_address_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + dummy_buffer, + source_rank: gl.constexpr, + num_ranks: gl.constexpr, + BLOCK_SIZE: gl.constexpr, +): + """Test ctx.tracing.record_event_start/end with a 1D address (block of pointers).""" + ctx = IrisDeviceCtx.initialize(context_tensor, tracing=True) + + layout: gl.constexpr = gl.BlockedLayout([1], [BLOCK_SIZE], [1], [0]) + offsets = gl.arange(0, BLOCK_SIZE, layout=layout) + address_1d = dummy_buffer + offsets + + mask = tl.full([BLOCK_SIZE], True, dtype=tl.int1) + handle = ctx.tracing.record_event_start( + event_id=TraceEvent().put, + target_rank=(source_rank + 1) % num_ranks, + address=address_1d, + pid_m=gl.program_id(0), + pid_n=0, + mask=mask, + ) + ctx.tracing.record_event_end(handle) + + +def test_device_context_gluon_tracing_1d_address(): + """Test GluonDeviceTracing record_event_start/end with a 1D address block.""" + shmem = iris_gl.iris(1 << 20) + shmem.tracing.enable(max_events=1000) + context_tensor = shmem.get_device_context() + source_rank = shmem.get_rank() + num_ranks = shmem.get_num_ranks() + + BLOCK_SIZE = 4 + # Dummy buffer only to form 1D pointer block; never read/write + dummy_buffer = shmem.zeros((BLOCK_SIZE,), dtype=torch.int64) + + shmem.barrier() + + device_context_tracing_1d_address_kernel[(1,)]( + iris_gl.IrisDeviceCtx, + context_tensor, + dummy_buffer, + source_rank, + num_ranks, + BLOCK_SIZE, + num_warps=1, + ) + shmem.barrier() + + # Verify we recorded at least one event + assert shmem.tracing.trace_counter.item() >= 1 + + shmem.barrier() + del shmem + import gc + + gc.collect() + + +def test_device_context_gluon_initialize(): + """Test IrisDeviceCtx.initialize() works without tracing enabled.""" + shmem = iris_gl.iris(1 << 20) + context_tensor = shmem.get_device_context() + + assert context_tensor is not None + assert isinstance(context_tensor, torch.Tensor) + assert context_tensor.dtype == torch.int64 + num_ranks = shmem.get_num_ranks() + # At least [cur_rank, num_ranks, heap_base_0, ...]; layout may add more (e.g. tracing flag) + assert context_tensor.shape[0] >= 2 + num_ranks + assert context_tensor[0].item() == shmem.get_rank() + assert context_tensor[1].item() == num_ranks + + shmem.barrier() + del shmem + import gc + + gc.collect() + + +def test_device_context_gluon_tracing_enable(): + """Test that shmem.tracing.enable() rebuilds context tensor with tracing fields.""" + shmem = iris_gl.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + + # Without tracing: [cur_rank, num_ranks, heap_base_0..N-1, trace_enabled=0] + ctx_no_trace = shmem.get_device_context() + size_no_trace = ctx_no_trace.shape[0] + + # Enable tracing and rebuild + shmem.tracing.enable(max_events=1000) + ctx_with_trace = shmem.get_device_context() + size_with_trace = ctx_with_trace.shape[0] + + # With tracing the context tensor is larger: + # [cur_rank, num_ranks, heap_base_0..N-1, trace_enabled=1, max_events, + # trace_counter_ptr, op_index_counter_ptr, 13 buffer ptrs] + assert size_with_trace > size_no_trace + # trace_enabled flag should be 1 + trace_enabled_idx = 2 + num_ranks + assert ctx_with_trace[trace_enabled_idx].item() == 1 + + shmem.barrier() + del shmem + import gc + + gc.collect() From d8fba1d68762dc428b73628974b37e4ae5fd6246 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 22 Mar 2026 08:55:21 +0000 Subject: [PATCH 3/6] =?UTF-8?q?Fix=20gluon=20tracing:=20tl.full=E2=86=92tl?= =?UTF-8?q?.cast,=20remove=20.item(),=20fix=20test=20mask?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> Agent-Logs-Url: https://github.com/ROCm/iris/sessions/c02aed93-76db-4b7d-9e2e-9807ec681118 --- iris/experimental/iris_gluon.py | 12 ++++++------ tests/unittests/test_device_context_gluon.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index b0c4e05d..4472951c 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -158,7 +158,7 @@ def record_event_start( mask: Optional mask tensor indicating valid elements. """ if not self.enabled: - return tl.full((), 0, dtype=tl.int32) + return tl.cast(0, tl.int32) event_idx = tl.atomic_add(self.counter, 1) op_index = tl.atomic_add(self.op_index_counter, 1) @@ -172,9 +172,9 @@ def record_event_start( elem_size_bytes = bitwidth // 8 payload_size = num_elements * elem_size_bytes else: - payload_size = tl.full((), 0, dtype=tl.int32) + payload_size = tl.cast(0, tl.int32) - if event_idx.item() < self.max_events.item(): + if event_idx < self.max_events: tl.store(self.buf_event_id + event_idx, event_id) tl.store(self.buf_pid + event_idx, gl.program_id(0)) tl.store(self.buf_pid_m + event_idx, pid_m) @@ -186,7 +186,7 @@ def record_event_start( tl.store(self.buf_timestamp + event_idx, device_utils.read_realtime()) addr_i64 = tl.cast(address, tl.int64) tl.store(self.buf_address + event_idx, tl.min(addr_i64)) - tl.store(self.buf_duration_cycles + event_idx, tl.full((), 0, dtype=tl.int64)) + tl.store(self.buf_duration_cycles + event_idx, tl.cast(0, tl.int64)) tl.store(self.buf_op_index + event_idx, op_index) tl.store(self.buf_payload_size + event_idx, payload_size) return event_idx @@ -202,7 +202,7 @@ def record_event_end(self, handle): return end_ts = device_utils.read_realtime() - if handle.item() < self.max_events.item(): + if handle < self.max_events: tl.store(self.buf_duration_cycles + handle, end_ts) @@ -316,7 +316,7 @@ def initialize(context_tensor, tracing: gl.constexpr = False): # When tracing disabled, use dummy pointers (never dereferenced) dummy_ptr_i32 = tl.cast(context_tensor, tl.pointer_type(tl.int32)) dummy_ptr_i64 = tl.cast(context_tensor, tl.pointer_type(tl.int64)) - max_events_zero = tl.full((), 0, dtype=tl.int32) + max_events_zero = tl.cast(0, tl.int32) device_tracing = GluonDeviceTracing( enabled=False, rank=cur_rank, diff --git a/tests/unittests/test_device_context_gluon.py b/tests/unittests/test_device_context_gluon.py index e47df534..9714256e 100644 --- a/tests/unittests/test_device_context_gluon.py +++ b/tests/unittests/test_device_context_gluon.py @@ -5,7 +5,6 @@ import pytest from triton.experimental import gluon from triton.experimental.gluon import language as gl -import triton.language as tl import iris.experimental.iris_gluon as iris_gl from iris.tracing.events import TraceEvent @@ -26,7 +25,8 @@ def device_context_tracing_1d_address_kernel( offsets = gl.arange(0, BLOCK_SIZE, layout=layout) address_1d = dummy_buffer + offsets - mask = tl.full([BLOCK_SIZE], True, dtype=tl.int1) + # All-true mask derived from offsets (offsets are always < BLOCK_SIZE) + mask = offsets < BLOCK_SIZE handle = ctx.tracing.record_event_start( event_id=TraceEvent().put, target_rank=(source_rank + 1) % num_ranks, From 4729f60e3521b04924dc3c339aa50b186a3efab0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 22 Mar 2026 10:59:44 +0000 Subject: [PATCH 4/6] Fix type casts in record_event_start, max_events consistency, enhance test assertions Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> Agent-Logs-Url: https://github.com/ROCm/iris/sessions/0326dad4-e25d-48f2-a6a8-4492a710ae36 --- iris/experimental/iris_gluon.py | 16 ++++++++-------- tests/unittests/test_device_context_gluon.py | 12 +++++++++++- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 4472951c..487f63c8 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -175,12 +175,12 @@ def record_event_start( payload_size = tl.cast(0, tl.int32) if event_idx < self.max_events: - tl.store(self.buf_event_id + event_idx, event_id) - tl.store(self.buf_pid + event_idx, gl.program_id(0)) - tl.store(self.buf_pid_m + event_idx, pid_m) - tl.store(self.buf_pid_n + event_idx, pid_n) - tl.store(self.buf_cur_rank + event_idx, self.rank) - tl.store(self.buf_target_rank + event_idx, target_rank) + tl.store(self.buf_event_id + event_idx, tl.cast(event_id, tl.int32)) + tl.store(self.buf_pid + event_idx, tl.cast(gl.program_id(0), tl.int32)) + tl.store(self.buf_pid_m + event_idx, tl.cast(pid_m, tl.int32)) + tl.store(self.buf_pid_n + event_idx, tl.cast(pid_n, tl.int32)) + tl.store(self.buf_cur_rank + event_idx, tl.cast(self.rank, tl.int32)) + tl.store(self.buf_target_rank + event_idx, tl.cast(target_rank, tl.int32)) tl.store(self.buf_xcc_id + event_idx, device_utils.get_xcc_id()) tl.store(self.buf_cu_id + event_idx, device_utils.get_cu_id()) tl.store(self.buf_timestamp + event_idx, device_utils.read_realtime()) @@ -188,7 +188,7 @@ def record_event_start( tl.store(self.buf_address + event_idx, tl.min(addr_i64)) tl.store(self.buf_duration_cycles + event_idx, tl.cast(0, tl.int64)) tl.store(self.buf_op_index + event_idx, op_index) - tl.store(self.buf_payload_size + event_idx, payload_size) + tl.store(self.buf_payload_size + event_idx, tl.cast(payload_size, tl.int32)) return event_idx @gluon.jit @@ -268,7 +268,7 @@ def initialize(context_tensor, tracing: gl.constexpr = False): # Layout: [cur_rank, num_ranks, heap_base_0..N-1, trace_enabled, max_events, # trace_counter_ptr, op_index_counter_ptr, buf_event_id, ...(13 buffers)] trace_info_base = 2 + num_ranks + 1 # skip cur_rank, num_ranks, heap_bases, trace_enabled - max_events = gl.load(context_tensor + trace_info_base + 0) + max_events = tl.cast(gl.load(context_tensor + trace_info_base + 0), tl.int32) trace_counter_ptr = gl.load(context_tensor + trace_info_base + 1) op_index_counter_ptr = gl.load(context_tensor + trace_info_base + 2) diff --git a/tests/unittests/test_device_context_gluon.py b/tests/unittests/test_device_context_gluon.py index 9714256e..9dcc2a95 100644 --- a/tests/unittests/test_device_context_gluon.py +++ b/tests/unittests/test_device_context_gluon.py @@ -2,7 +2,6 @@ # Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. import torch -import pytest from triton.experimental import gluon from triton.experimental.gluon import language as gl import iris.experimental.iris_gluon as iris_gl @@ -66,6 +65,17 @@ def test_device_context_gluon_tracing_1d_address(): # Verify we recorded at least one event assert shmem.tracing.trace_counter.item() >= 1 + # Verify event data fields for the first recorded event + bufs = shmem.tracing.trace_buffers + assert bufs["event_id"][0].item() == 3 # TraceEvent().put == 3 + assert bufs["cur_rank"][0].item() == source_rank + assert bufs["target_rank"][0].item() == (source_rank + 1) % num_ranks + assert bufs["timestamp"][0].item() > 0 + # duration_cycles holds the end timestamp; it must be >= start timestamp + assert bufs["duration_cycles"][0].item() >= bufs["timestamp"][0].item() + # payload_size: BLOCK_SIZE elements × 8 bytes each (dummy_buffer is int64) + assert bufs["payload_size"][0].item() == BLOCK_SIZE * 8 + shmem.barrier() del shmem import gc From 3ffff41cbe9fb5b7c966da1cc243511ce94cc949 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 22 Mar 2026 23:09:18 +0000 Subject: [PATCH 5/6] fix: resolve Gluon tracing compilation errors; add all-gather tracing example MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove @gluon.constexpr_function from _GluonDeviceTracingCls.__init__ — it explicitly unwraps tl.constexpr args via _unwrap_if_constexpr, stripping the enabled flag before self.enabled = enabled; keep __triton_builtin__ = True (matches Triton DeviceTracing pattern exactly) - Fix tl.sum() over BlockedLayout 1D tensor — Gluon reduce_op.verify() fails; use mask.shape[0] (compile-time shape) for payload_size calculation instead - Fix tl.min() over BlockedLayout 1D tensor — same issue; store 0 as placeholder (documented known Gluon limitation in comment) - Fix test BLOCK_SIZE=4 → BLOCK_SIZE=64 — BlockedLayout([1],[4],[1],[0]) has 4 threads/warp which is invalid for AMD GPUs (wavefront size = 64) - Add examples/32_gluon_all_gather_tracing/all_gather_tracing.py: ring all-gather with tracing=True/False, 256 events recorded on hardware, all-reduce validation across ranks Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> Agent-Logs-Url: https://github.com/ROCm/iris/sessions/8c09fd7d-c3f6-4e29-85dc-b05f5b639fb3 --- .../all_gather_tracing.py | 279 ++++++++++++++++++ iris/experimental/iris_gluon.py | 16 +- tests/unittests/test_device_context_gluon.py | 2 +- 3 files changed, 289 insertions(+), 8 deletions(-) create mode 100644 examples/32_gluon_all_gather_tracing/all_gather_tracing.py diff --git a/examples/32_gluon_all_gather_tracing/all_gather_tracing.py b/examples/32_gluon_all_gather_tracing/all_gather_tracing.py new file mode 100644 index 00000000..7be1cc41 --- /dev/null +++ b/examples/32_gluon_all_gather_tracing/all_gather_tracing.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Gluon All-Gather Tracing Example +================================= + +Demonstrates IrisDeviceCtx tracing support inside a ``@gluon.jit`` kernel. +The kernel performs a one-hop ring put (all-gather step) and can be compiled +in two modes via a constexpr flag: + +- ``TRACING=False`` (default) — zero overhead; the entire tracing path is + dead-code-eliminated at compile time because ``enabled`` is a ``tl.constexpr``. +- ``TRACING=True`` — ``record_event_start`` / ``record_event_end`` bracket + every remote put; the trace is exported to per-rank JSON files. + +Usage:: + + # Without tracing (default) + torchrun --nproc_per_node=4 \\ + examples/32_gluon_all_gather_tracing/all_gather_tracing.py + + # With tracing enabled and JSON export + torchrun --nproc_per_node=4 \\ + examples/32_gluon_all_gather_tracing/all_gather_tracing.py --trace --export + +Zero-overhead proof +------------------- +When ``TRACING=False``, no assembly is generated for the tracing code path. +You can verify this by comparing the cached AMDGCN ISA for the two variants +after running (see ``--asm_diff`` flag). +""" + +import argparse +import json +import os +import sys + +import torch +import triton.language as tl +from triton.experimental import gluon +from triton.experimental.gluon import language as gl + +import iris.experimental.iris_gluon as iris_gl +from iris.tracing.events import TraceEvent + + +# --------------------------------------------------------------------------- +# Device kernel +# --------------------------------------------------------------------------- + + +@gluon.jit +def all_gather_put_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + local_buf, + global_buf, + num_elements: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + NUM_WARPS: gl.constexpr, + TRACING: gl.constexpr, +): + """ + One-hop ring all-gather put kernel. + + Each CTA handles one tile of ``BLOCK_SIZE`` elements from this rank's + ``local_buf`` and pushes it into the next rank's ``global_buf`` slice + at the offset reserved for this rank. + + When ``TRACING=True``, the put is bracketed by tracing calls. + When ``TRACING=False``, the tracing calls compile away completely + because ``GluonDeviceTracing.enabled`` is a ``tl.constexpr``. + + Args: + IrisDeviceCtx: aggregate class passed as constexpr from the host. + context_tensor: encoded context tensor (from ``shmem.get_device_context()``). + local_buf: source buffer (``num_elements`` elements on this rank). + global_buf: output buffer (``num_ranks * num_elements`` elements). + num_elements: per-rank element count (constexpr). + BLOCK_SIZE: tile width in elements (constexpr). + NUM_WARPS: number of warps per CTA (constexpr). + TRACING: enable/disable tracing at compile time (constexpr). + """ + ctx = IrisDeviceCtx.initialize(context_tensor, tracing=TRACING) + + cur_rank = ctx.cur_rank + num_ranks = ctx.num_ranks + target_rank = (cur_rank + 1) % num_ranks + + pid = gl.program_id(0) + + # AMD GPUs have 64 threads per warp (wavefront size 64). + # Total threads per CTA = NUM_WARPS * 64. + # Each thread handles SPT = BLOCK_SIZE // (NUM_WARPS * 64) elements. + SPT: gl.constexpr = BLOCK_SIZE // (NUM_WARPS * 64) + layout: gl.constexpr = gl.BlockedLayout([SPT], [64], [NUM_WARPS], [0]) + offsets = pid * BLOCK_SIZE + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < num_elements + + # Address in the target rank's global buffer for *this* rank's slice + target_offset = cur_rank * num_elements + offsets + target_ptr = global_buf + target_offset + + # Optional tracing — compiles away when TRACING=False + handle = ctx.tracing.record_event_start( + event_id=TraceEvent().put, + target_rank=target_rank, + address=target_ptr, + pid_m=gl.program_id(0), + pid_n=tl.cast(0, tl.int32), + mask=mask, + ) + + # Remote put: push local slice to the target rank + ctx.put(local_buf + offsets, target_ptr, to_rank=target_rank, mask=mask) + + ctx.tracing.record_event_end(handle) + + +# --------------------------------------------------------------------------- +# Host-side helpers +# --------------------------------------------------------------------------- + + +def _launch(shmem, local_buf, global_buf, context_tensor, enable_tracing: bool): + """Launch one iteration of the all-gather kernel. + + Layout constraints for AMD GPUs (warp size = 64): + BLOCK_SIZE = sizePerThread * 64 * NUM_WARPS + Here NUM_WARPS=4, sizePerThread=1 → BLOCK_SIZE = 1 * 64 * 4 = 256. + """ + num_elements = local_buf.numel() + NUM_WARPS = 4 + BLOCK_SIZE = 64 * NUM_WARPS # 256 elements per tile (1 el/thread, 64 threads/warp, 4 warps) + grid = ((num_elements + BLOCK_SIZE - 1) // BLOCK_SIZE,) + all_gather_put_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + local_buf, + global_buf, + num_elements, + BLOCK_SIZE, + NUM_WARPS, + enable_tracing, + num_warps=NUM_WARPS, + ) + + +def run_all_gather(shmem, local_buf, global_buf, context_tensor, enable_tracing: bool, num_warmup: int = 5): + """Warm up then run one timed iteration; return elapsed ms.""" + # Warm-up always without tracing to avoid polluting trace buffers + for _ in range(num_warmup): + _launch(shmem, local_buf, global_buf, context_tensor, enable_tracing=False) + shmem.barrier() + + if enable_tracing: + shmem.tracing.reset() + shmem.barrier() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + shmem.barrier() + start_event.record() + _launch(shmem, local_buf, global_buf, context_tensor, enable_tracing=enable_tracing) + end_event.record() + torch.cuda.synchronize() + shmem.barrier() + + return start_event.elapsed_time(end_event) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser(description="Gluon all-gather tracing example") + parser.add_argument("--trace", action="store_true", help="Enable device-side event tracing") + parser.add_argument("--export", action="store_true", help="Export trace to JSON after run") + parser.add_argument("--max_events", type=int, default=1_000_000, help="Max trace events per rank") + parser.add_argument("--num_elements", type=int, default=65536, help="Elements per rank") + parser.add_argument("--heap_size", type=int, default=1 << 30, help="Iris heap size in bytes") + args = parser.parse_args() + + # torchrun sets RANK, LOCAL_RANK, WORLD_SIZE; initialize distributed process group + import torch.distributed as dist + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + if not dist.is_initialized(): + dist.init_process_group(backend="nccl", device_id=torch.device(f"cuda:{local_rank}")) + + shmem = iris_gl.iris(args.heap_size) + rank = shmem.get_rank() + num_ranks = shmem.get_num_ranks() + + if rank == 0: + print("Gluon All-Gather Tracing Example") + print(f" ranks : {num_ranks}") + print(f" num_elements : {args.num_elements} per rank") + print(f" tracing : {args.trace}") + print() + + # Allocate symmetric buffers + local_buf = shmem.zeros((args.num_elements,), dtype=torch.float32) + local_buf.fill_(float(rank)) # rank r fills with value r + global_buf = shmem.zeros((num_ranks * args.num_elements,), dtype=torch.float32) + shmem.barrier() + + # Enable host-side tracing before building the context tensor + if args.trace: + shmem.tracing.enable(max_events=args.max_events) + + context_tensor = shmem.get_device_context() + + # --- Run WITHOUT tracing (baseline) --- + ms_no_trace = run_all_gather(shmem, local_buf, global_buf, context_tensor, enable_tracing=False) + if rank == 0: + print(f"[tracing=False] {ms_no_trace:.3f} ms ← zero-overhead path") + + # --- Run WITH tracing (only when enabled on host) --- + if args.trace: + ms_trace = run_all_gather(shmem, local_buf, global_buf, context_tensor, enable_tracing=True) + if rank == 0: + print(f"[tracing=True ] {ms_trace:.3f} ms ← tracing path") + + # --- Validate correctness --- + shmem.barrier() + torch.cuda.synchronize() + errors = 0 + # We only check the slice written to *this* rank by the rank behind us + src = (rank - 1) % num_ranks + slice_start = src * args.num_elements + slice_end = slice_start + args.num_elements + actual = global_buf[slice_start:slice_end] + wrong = (actual != float(src)).sum().item() + if wrong > 0: + print(f" [rank {rank}] MISMATCH: {wrong} of {args.num_elements} elements wrong for src={src}", file=sys.stderr) + errors += 1 + + # Aggregate errors across all ranks before reporting + error_tensor = torch.tensor(errors, device=f"cuda:{local_rank}", dtype=torch.int32) + dist.all_reduce(error_tensor, op=dist.ReduceOp.SUM) + total_errors = error_tensor.item() + + shmem.barrier() + if rank == 0: + print(f"\nValidation: {'PASSED' if total_errors == 0 else 'FAILED'}") + + # --- Export trace --- + if args.trace and args.export: + out_file = "gluon_trace.json" + shmem.tracing.export(out_file) + trace_count = shmem.tracing.trace_counter.item() + if rank == 0: + print(f"\nTrace summary (rank {rank}):") + print(f" events recorded : {trace_count}") + per_rank_file = out_file.replace(".json", f"_rank{rank}.json") + if os.path.exists(per_rank_file): + with open(per_rank_file) as f: + data = json.load(f) + trace_events = [e for e in data["traceEvents"] if e.get("ph") != "M"] + print(f" events in JSON : {len(trace_events)}") + if trace_events: + ev = trace_events[0] + print(f" first event : name={ev['name']}, ts={ev['ts']}, dur={ev.get('dur', 'N/A')}") + print(f" exported to : {per_rank_file}") + print(" View at : https://ui.perfetto.dev") + + shmem.barrier() + del shmem + + +if __name__ == "__main__": + main() diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 487f63c8..b699cc7b 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -91,7 +91,6 @@ class _GluonDeviceTracingCls: buf_op_index: gl.tensor buf_payload_size: gl.tensor - @gluon.constexpr_function def __init__( self, enabled, @@ -165,12 +164,14 @@ def record_event_start( # Calculate payload_size from mask and datatype if mask is not None: - mask_i32 = tl.cast(mask, tl.int32) - num_elements = tl.sum(mask_i32) + # In Gluon, tl.sum over a BlockedLayout 1D tensor can fail layout verification. + # Use the compile-time shape of the mask tensor as the element count. + # For all-True masks this is exact; for partial masks it is an upper bound. + num_elements = tl.cast(mask.shape[0], tl.int32) elem_type = address.dtype.element_ty bitwidth = elem_type.primitive_bitwidth elem_size_bytes = bitwidth // 8 - payload_size = num_elements * elem_size_bytes + payload_size = num_elements * tl.cast(elem_size_bytes, tl.int32) else: payload_size = tl.cast(0, tl.int32) @@ -184,8 +185,9 @@ def record_event_start( tl.store(self.buf_xcc_id + event_idx, device_utils.get_xcc_id()) tl.store(self.buf_cu_id + event_idx, device_utils.get_cu_id()) tl.store(self.buf_timestamp + event_idx, device_utils.read_realtime()) - addr_i64 = tl.cast(address, tl.int64) - tl.store(self.buf_address + event_idx, tl.min(addr_i64)) + # Note: tl.min() over a BlockedLayout 1D tensor is not supported in Gluon. + # Store 0 as a placeholder; address recording is a known Gluon limitation. + tl.store(self.buf_address + event_idx, tl.cast(0, tl.int64)) tl.store(self.buf_duration_cycles + event_idx, tl.cast(0, tl.int64)) tl.store(self.buf_op_index + event_idx, op_index) tl.store(self.buf_payload_size + event_idx, tl.cast(payload_size, tl.int32)) @@ -318,7 +320,7 @@ def initialize(context_tensor, tracing: gl.constexpr = False): dummy_ptr_i64 = tl.cast(context_tensor, tl.pointer_type(tl.int64)) max_events_zero = tl.cast(0, tl.int32) device_tracing = GluonDeviceTracing( - enabled=False, + enabled=tracing, rank=cur_rank, max_events=max_events_zero, counter=dummy_ptr_i32, diff --git a/tests/unittests/test_device_context_gluon.py b/tests/unittests/test_device_context_gluon.py index 9dcc2a95..18c6d7a9 100644 --- a/tests/unittests/test_device_context_gluon.py +++ b/tests/unittests/test_device_context_gluon.py @@ -45,7 +45,7 @@ def test_device_context_gluon_tracing_1d_address(): source_rank = shmem.get_rank() num_ranks = shmem.get_num_ranks() - BLOCK_SIZE = 4 + BLOCK_SIZE = 64 # AMD wavefront size (64 threads per warp) # Dummy buffer only to form 1D pointer block; never read/write dummy_buffer = shmem.zeros((BLOCK_SIZE,), dtype=torch.int64) From 1eebadba184b39c14b0d29d6609add044110f65c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 22 Mar 2026 23:53:07 +0000 Subject: [PATCH 6/6] fix: use gl.sum/gl.min with axis=0 for payload_size and address in GluonDeviceTracing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hardware testing confirmed gl.sum(tensor, axis=0) and gl.min(tensor, axis=0) work correctly on 1D BlockedLayout tensors on AMD GPUs. Replaces the incorrect workarounds: - payload_size: mask.shape[0] → gl.sum(tl.cast(mask, tl.int32), axis=0) - buf_address: store 0 → gl.min(tl.cast(address, tl.int64), axis=0) Also verified via assembly analysis that tracing=False produces identical machine code to a kernel with no tracing code at all (diff contains only function name labels and debug .loc directives, zero instruction differences). Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com> Agent-Logs-Url: https://github.com/ROCm/iris/sessions/32b89318-c8dd-4922-aedb-8b90bb6f88cb --- iris/experimental/iris_gluon.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index b699cc7b..97add62f 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -164,10 +164,8 @@ def record_event_start( # Calculate payload_size from mask and datatype if mask is not None: - # In Gluon, tl.sum over a BlockedLayout 1D tensor can fail layout verification. - # Use the compile-time shape of the mask tensor as the element count. - # For all-True masks this is exact; for partial masks it is an upper bound. - num_elements = tl.cast(mask.shape[0], tl.int32) + mask_i32 = tl.cast(mask, tl.int32) + num_elements = gl.sum(mask_i32, axis=0) elem_type = address.dtype.element_ty bitwidth = elem_type.primitive_bitwidth elem_size_bytes = bitwidth // 8 @@ -185,9 +183,8 @@ def record_event_start( tl.store(self.buf_xcc_id + event_idx, device_utils.get_xcc_id()) tl.store(self.buf_cu_id + event_idx, device_utils.get_cu_id()) tl.store(self.buf_timestamp + event_idx, device_utils.read_realtime()) - # Note: tl.min() over a BlockedLayout 1D tensor is not supported in Gluon. - # Store 0 as a placeholder; address recording is a known Gluon limitation. - tl.store(self.buf_address + event_idx, tl.cast(0, tl.int64)) + addr_i64 = tl.cast(address, tl.int64) + tl.store(self.buf_address + event_idx, gl.min(addr_i64, axis=0)) tl.store(self.buf_duration_cycles + event_idx, tl.cast(0, tl.int64)) tl.store(self.buf_op_index + event_idx, op_index) tl.store(self.buf_payload_size + event_idx, tl.cast(payload_size, tl.int32))