diff --git a/examples/32_gluon_all_gather_tracing/all_gather_tracing.py b/examples/32_gluon_all_gather_tracing/all_gather_tracing.py new file mode 100644 index 00000000..7be1cc41 --- /dev/null +++ b/examples/32_gluon_all_gather_tracing/all_gather_tracing.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Gluon All-Gather Tracing Example +================================= + +Demonstrates IrisDeviceCtx tracing support inside a ``@gluon.jit`` kernel. +The kernel performs a one-hop ring put (all-gather step) and can be compiled +in two modes via a constexpr flag: + +- ``TRACING=False`` (default) — zero overhead; the entire tracing path is + dead-code-eliminated at compile time because ``enabled`` is a ``tl.constexpr``. +- ``TRACING=True`` — ``record_event_start`` / ``record_event_end`` bracket + every remote put; the trace is exported to per-rank JSON files. + +Usage:: + + # Without tracing (default) + torchrun --nproc_per_node=4 \\ + examples/32_gluon_all_gather_tracing/all_gather_tracing.py + + # With tracing enabled and JSON export + torchrun --nproc_per_node=4 \\ + examples/32_gluon_all_gather_tracing/all_gather_tracing.py --trace --export + +Zero-overhead proof +------------------- +When ``TRACING=False``, no assembly is generated for the tracing code path. +You can verify this by comparing the cached AMDGCN ISA for the two variants +after running (see ``--asm_diff`` flag). +""" + +import argparse +import json +import os +import sys + +import torch +import triton.language as tl +from triton.experimental import gluon +from triton.experimental.gluon import language as gl + +import iris.experimental.iris_gluon as iris_gl +from iris.tracing.events import TraceEvent + + +# --------------------------------------------------------------------------- +# Device kernel +# --------------------------------------------------------------------------- + + +@gluon.jit +def all_gather_put_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + local_buf, + global_buf, + num_elements: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + NUM_WARPS: gl.constexpr, + TRACING: gl.constexpr, +): + """ + One-hop ring all-gather put kernel. + + Each CTA handles one tile of ``BLOCK_SIZE`` elements from this rank's + ``local_buf`` and pushes it into the next rank's ``global_buf`` slice + at the offset reserved for this rank. + + When ``TRACING=True``, the put is bracketed by tracing calls. + When ``TRACING=False``, the tracing calls compile away completely + because ``GluonDeviceTracing.enabled`` is a ``tl.constexpr``. + + Args: + IrisDeviceCtx: aggregate class passed as constexpr from the host. + context_tensor: encoded context tensor (from ``shmem.get_device_context()``). + local_buf: source buffer (``num_elements`` elements on this rank). + global_buf: output buffer (``num_ranks * num_elements`` elements). + num_elements: per-rank element count (constexpr). + BLOCK_SIZE: tile width in elements (constexpr). + NUM_WARPS: number of warps per CTA (constexpr). + TRACING: enable/disable tracing at compile time (constexpr). + """ + ctx = IrisDeviceCtx.initialize(context_tensor, tracing=TRACING) + + cur_rank = ctx.cur_rank + num_ranks = ctx.num_ranks + target_rank = (cur_rank + 1) % num_ranks + + pid = gl.program_id(0) + + # AMD GPUs have 64 threads per warp (wavefront size 64). + # Total threads per CTA = NUM_WARPS * 64. + # Each thread handles SPT = BLOCK_SIZE // (NUM_WARPS * 64) elements. + SPT: gl.constexpr = BLOCK_SIZE // (NUM_WARPS * 64) + layout: gl.constexpr = gl.BlockedLayout([SPT], [64], [NUM_WARPS], [0]) + offsets = pid * BLOCK_SIZE + gl.arange(0, BLOCK_SIZE, layout=layout) + mask = offsets < num_elements + + # Address in the target rank's global buffer for *this* rank's slice + target_offset = cur_rank * num_elements + offsets + target_ptr = global_buf + target_offset + + # Optional tracing — compiles away when TRACING=False + handle = ctx.tracing.record_event_start( + event_id=TraceEvent().put, + target_rank=target_rank, + address=target_ptr, + pid_m=gl.program_id(0), + pid_n=tl.cast(0, tl.int32), + mask=mask, + ) + + # Remote put: push local slice to the target rank + ctx.put(local_buf + offsets, target_ptr, to_rank=target_rank, mask=mask) + + ctx.tracing.record_event_end(handle) + + +# --------------------------------------------------------------------------- +# Host-side helpers +# --------------------------------------------------------------------------- + + +def _launch(shmem, local_buf, global_buf, context_tensor, enable_tracing: bool): + """Launch one iteration of the all-gather kernel. + + Layout constraints for AMD GPUs (warp size = 64): + BLOCK_SIZE = sizePerThread * 64 * NUM_WARPS + Here NUM_WARPS=4, sizePerThread=1 → BLOCK_SIZE = 1 * 64 * 4 = 256. + """ + num_elements = local_buf.numel() + NUM_WARPS = 4 + BLOCK_SIZE = 64 * NUM_WARPS # 256 elements per tile (1 el/thread, 64 threads/warp, 4 warps) + grid = ((num_elements + BLOCK_SIZE - 1) // BLOCK_SIZE,) + all_gather_put_kernel[grid]( + iris_gl.IrisDeviceCtx, + context_tensor, + local_buf, + global_buf, + num_elements, + BLOCK_SIZE, + NUM_WARPS, + enable_tracing, + num_warps=NUM_WARPS, + ) + + +def run_all_gather(shmem, local_buf, global_buf, context_tensor, enable_tracing: bool, num_warmup: int = 5): + """Warm up then run one timed iteration; return elapsed ms.""" + # Warm-up always without tracing to avoid polluting trace buffers + for _ in range(num_warmup): + _launch(shmem, local_buf, global_buf, context_tensor, enable_tracing=False) + shmem.barrier() + + if enable_tracing: + shmem.tracing.reset() + shmem.barrier() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + shmem.barrier() + start_event.record() + _launch(shmem, local_buf, global_buf, context_tensor, enable_tracing=enable_tracing) + end_event.record() + torch.cuda.synchronize() + shmem.barrier() + + return start_event.elapsed_time(end_event) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser(description="Gluon all-gather tracing example") + parser.add_argument("--trace", action="store_true", help="Enable device-side event tracing") + parser.add_argument("--export", action="store_true", help="Export trace to JSON after run") + parser.add_argument("--max_events", type=int, default=1_000_000, help="Max trace events per rank") + parser.add_argument("--num_elements", type=int, default=65536, help="Elements per rank") + parser.add_argument("--heap_size", type=int, default=1 << 30, help="Iris heap size in bytes") + args = parser.parse_args() + + # torchrun sets RANK, LOCAL_RANK, WORLD_SIZE; initialize distributed process group + import torch.distributed as dist + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + if not dist.is_initialized(): + dist.init_process_group(backend="nccl", device_id=torch.device(f"cuda:{local_rank}")) + + shmem = iris_gl.iris(args.heap_size) + rank = shmem.get_rank() + num_ranks = shmem.get_num_ranks() + + if rank == 0: + print("Gluon All-Gather Tracing Example") + print(f" ranks : {num_ranks}") + print(f" num_elements : {args.num_elements} per rank") + print(f" tracing : {args.trace}") + print() + + # Allocate symmetric buffers + local_buf = shmem.zeros((args.num_elements,), dtype=torch.float32) + local_buf.fill_(float(rank)) # rank r fills with value r + global_buf = shmem.zeros((num_ranks * args.num_elements,), dtype=torch.float32) + shmem.barrier() + + # Enable host-side tracing before building the context tensor + if args.trace: + shmem.tracing.enable(max_events=args.max_events) + + context_tensor = shmem.get_device_context() + + # --- Run WITHOUT tracing (baseline) --- + ms_no_trace = run_all_gather(shmem, local_buf, global_buf, context_tensor, enable_tracing=False) + if rank == 0: + print(f"[tracing=False] {ms_no_trace:.3f} ms ← zero-overhead path") + + # --- Run WITH tracing (only when enabled on host) --- + if args.trace: + ms_trace = run_all_gather(shmem, local_buf, global_buf, context_tensor, enable_tracing=True) + if rank == 0: + print(f"[tracing=True ] {ms_trace:.3f} ms ← tracing path") + + # --- Validate correctness --- + shmem.barrier() + torch.cuda.synchronize() + errors = 0 + # We only check the slice written to *this* rank by the rank behind us + src = (rank - 1) % num_ranks + slice_start = src * args.num_elements + slice_end = slice_start + args.num_elements + actual = global_buf[slice_start:slice_end] + wrong = (actual != float(src)).sum().item() + if wrong > 0: + print(f" [rank {rank}] MISMATCH: {wrong} of {args.num_elements} elements wrong for src={src}", file=sys.stderr) + errors += 1 + + # Aggregate errors across all ranks before reporting + error_tensor = torch.tensor(errors, device=f"cuda:{local_rank}", dtype=torch.int32) + dist.all_reduce(error_tensor, op=dist.ReduceOp.SUM) + total_errors = error_tensor.item() + + shmem.barrier() + if rank == 0: + print(f"\nValidation: {'PASSED' if total_errors == 0 else 'FAILED'}") + + # --- Export trace --- + if args.trace and args.export: + out_file = "gluon_trace.json" + shmem.tracing.export(out_file) + trace_count = shmem.tracing.trace_counter.item() + if rank == 0: + print(f"\nTrace summary (rank {rank}):") + print(f" events recorded : {trace_count}") + per_rank_file = out_file.replace(".json", f"_rank{rank}.json") + if os.path.exists(per_rank_file): + with open(per_rank_file) as f: + data = json.load(f) + trace_events = [e for e in data["traceEvents"] if e.get("ph") != "M"] + print(f" events in JSON : {len(trace_events)}") + if trace_events: + ev = trace_events[0] + print(f" first event : name={ev['name']}, ts={ev['ts']}, dur={ev.get('dur', 'N/A')}") + print(f" exported to : {per_rank_file}") + print(" View at : https://ui.perfetto.dev") + + shmem.barrier() + del shmem + + +if __name__ == "__main__": + main() diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 1a06f284..97add62f 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -51,6 +51,8 @@ count_devices, ) from iris.symmetric_heap import SymmetricHeap +from iris import device_utils +from iris.tracing.core import Tracing import numpy as np import torch import logging @@ -62,6 +64,151 @@ from .. import tensor_creation +class _GluonDeviceTracingCls: + """ + Gluon-native device-side tracing: records events into SoA buffers from inside Gluon kernels. + + Created by IrisDeviceCtx.initialize() when tracing=True. Use record_event_start + / record_event_end to bracket operations; events are exported via Tracing.export(). + """ + + enabled: tl.constexpr + rank: gl.tensor + max_events: gl.tensor + counter: gl.tensor + op_index_counter: gl.tensor + buf_event_id: gl.tensor + buf_pid: gl.tensor + buf_pid_m: gl.tensor + buf_pid_n: gl.tensor + buf_cur_rank: gl.tensor + buf_target_rank: gl.tensor + buf_xcc_id: gl.tensor + buf_cu_id: gl.tensor + buf_timestamp: gl.tensor + buf_address: gl.tensor + buf_duration_cycles: gl.tensor + buf_op_index: gl.tensor + buf_payload_size: gl.tensor + + def __init__( + self, + enabled, + rank, + max_events, + counter, + op_index_counter, + buf_event_id, + buf_pid, + buf_pid_m, + buf_pid_n, + buf_cur_rank, + buf_target_rank, + buf_xcc_id, + buf_cu_id, + buf_timestamp, + buf_address, + buf_duration_cycles, + buf_op_index, + buf_payload_size, + ): + """Construct GluonDeviceTracing (called from IrisDeviceCtx.initialize).""" + self.enabled = enabled + self.rank = rank + self.max_events = max_events + self.counter = counter + self.op_index_counter = op_index_counter + self.buf_event_id = buf_event_id + self.buf_pid = buf_pid + self.buf_pid_m = buf_pid_m + self.buf_pid_n = buf_pid_n + self.buf_cur_rank = buf_cur_rank + self.buf_target_rank = buf_target_rank + self.buf_xcc_id = buf_xcc_id + self.buf_cu_id = buf_cu_id + self.buf_timestamp = buf_timestamp + self.buf_address = buf_address + self.buf_duration_cycles = buf_duration_cycles + self.buf_op_index = buf_op_index + self.buf_payload_size = buf_payload_size + + @gluon.jit + def record_event_start( + self, + event_id: tl.constexpr, + target_rank, + address, + pid_m, + pid_n, + mask=None, + ): + """ + Record start of a traced operation. Returns a handle for record_event_end. + + Only stores when event_idx < max_events (bounds check). + cur_rank is taken from the tracing context (self.rank). + + Args: + event_id: Event type ID (constexpr) + target_rank: Target rank for the operation + address: Memory address(es) - can be 1D or 2D block of pointers. + pid_m: Program ID in M dimension + pid_n: Program ID in N dimension + mask: Optional mask tensor indicating valid elements. + """ + if not self.enabled: + return tl.cast(0, tl.int32) + + event_idx = tl.atomic_add(self.counter, 1) + op_index = tl.atomic_add(self.op_index_counter, 1) + + # Calculate payload_size from mask and datatype + if mask is not None: + mask_i32 = tl.cast(mask, tl.int32) + num_elements = gl.sum(mask_i32, axis=0) + elem_type = address.dtype.element_ty + bitwidth = elem_type.primitive_bitwidth + elem_size_bytes = bitwidth // 8 + payload_size = num_elements * tl.cast(elem_size_bytes, tl.int32) + else: + payload_size = tl.cast(0, tl.int32) + + if event_idx < self.max_events: + tl.store(self.buf_event_id + event_idx, tl.cast(event_id, tl.int32)) + tl.store(self.buf_pid + event_idx, tl.cast(gl.program_id(0), tl.int32)) + tl.store(self.buf_pid_m + event_idx, tl.cast(pid_m, tl.int32)) + tl.store(self.buf_pid_n + event_idx, tl.cast(pid_n, tl.int32)) + tl.store(self.buf_cur_rank + event_idx, tl.cast(self.rank, tl.int32)) + tl.store(self.buf_target_rank + event_idx, tl.cast(target_rank, tl.int32)) + tl.store(self.buf_xcc_id + event_idx, device_utils.get_xcc_id()) + tl.store(self.buf_cu_id + event_idx, device_utils.get_cu_id()) + tl.store(self.buf_timestamp + event_idx, device_utils.read_realtime()) + addr_i64 = tl.cast(address, tl.int64) + tl.store(self.buf_address + event_idx, gl.min(addr_i64, axis=0)) + tl.store(self.buf_duration_cycles + event_idx, tl.cast(0, tl.int64)) + tl.store(self.buf_op_index + event_idx, op_index) + tl.store(self.buf_payload_size + event_idx, tl.cast(payload_size, tl.int32)) + return event_idx + + @gluon.jit + def record_event_end(self, handle): + """ + Record end timestamp for the event started with record_event_start(handle). + + Only stores when handle < max_events (bounds check). + """ + if not self.enabled: + return + + end_ts = device_utils.read_realtime() + if handle < self.max_events: + tl.store(self.buf_duration_cycles + handle, end_ts) + + +_GluonDeviceTracingCls.__init__.__triton_builtin__ = True +GluonDeviceTracing = aggregate(_GluonDeviceTracingCls) + + @aggregate class IrisDeviceCtx: """ @@ -74,28 +221,36 @@ class IrisDeviceCtx: cur_rank: Current rank ID num_ranks: Total number of ranks heap_bases: Pointer to array of heap base addresses for all ranks + tracing: GluonDeviceTracing instance (active when tracing=True) """ cur_rank: gl.tensor num_ranks: gl.tensor heap_bases: gl.tensor + tracing: GluonDeviceTracing @gluon.constexpr_function - def __init__(self, cur_rank, num_ranks, heap_bases): + def __init__(self, cur_rank, num_ranks, heap_bases, tracing): self.cur_rank = cur_rank self.num_ranks = num_ranks self.heap_bases = heap_bases + self.tracing = tracing @staticmethod @gluon.jit - def initialize(context_tensor): + def initialize(context_tensor, tracing: gl.constexpr = False): """ Initialize `IrisDeviceCtx` from the encoded tensor. - The context tensor has the format: `[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]` + The context tensor has the format: + ``[cur_rank, num_ranks, heap_base_0, heap_base_1, ..., trace_info...]`` + + If tracing is enabled on the host (via ``shmem.tracing.enable()``), the + context tensor also contains tracing buffer pointers after the heap bases. Args: context_tensor: Pointer to encoded context data + tracing: Enable event tracing (constexpr, default: False) Returns: `IrisDeviceCtx`: Initialized device context @@ -107,7 +262,82 @@ def initialize(context_tensor): # Extract heap bases (from index 2 onwards) heap_bases = context_tensor + 2 # Offset pointer to start at heap bases - return IrisDeviceCtx(cur_rank, num_ranks, heap_bases) + if tracing: + # Extract tracing info: starts after heap_bases, then skip trace_enabled flag + # Layout: [cur_rank, num_ranks, heap_base_0..N-1, trace_enabled, max_events, + # trace_counter_ptr, op_index_counter_ptr, buf_event_id, ...(13 buffers)] + trace_info_base = 2 + num_ranks + 1 # skip cur_rank, num_ranks, heap_bases, trace_enabled + max_events = tl.cast(gl.load(context_tensor + trace_info_base + 0), tl.int32) + trace_counter_ptr = gl.load(context_tensor + trace_info_base + 1) + op_index_counter_ptr = gl.load(context_tensor + trace_info_base + 2) + + # Cast counter pointers + trace_counter = tl.cast(trace_counter_ptr, tl.pointer_type(tl.int32)) + op_index_counter = tl.cast(op_index_counter_ptr, tl.pointer_type(tl.int32)) + + # Extract trace buffer pointers (13 buffers, same order as Iris._build_device_context) + buf_base = trace_info_base + 3 + buf_event_id = tl.cast(gl.load(context_tensor + buf_base + 0), tl.pointer_type(tl.int32)) + buf_pid = tl.cast(gl.load(context_tensor + buf_base + 1), tl.pointer_type(tl.int32)) + buf_pid_m = tl.cast(gl.load(context_tensor + buf_base + 2), tl.pointer_type(tl.int32)) + buf_pid_n = tl.cast(gl.load(context_tensor + buf_base + 3), tl.pointer_type(tl.int32)) + buf_cur_rank = tl.cast(gl.load(context_tensor + buf_base + 4), tl.pointer_type(tl.int32)) + buf_target_rank = tl.cast(gl.load(context_tensor + buf_base + 5), tl.pointer_type(tl.int32)) + buf_xcc_id = tl.cast(gl.load(context_tensor + buf_base + 6), tl.pointer_type(tl.int32)) + buf_cu_id = tl.cast(gl.load(context_tensor + buf_base + 7), tl.pointer_type(tl.int32)) + buf_timestamp = tl.cast(gl.load(context_tensor + buf_base + 8), tl.pointer_type(tl.int64)) + buf_address = tl.cast(gl.load(context_tensor + buf_base + 9), tl.pointer_type(tl.int64)) + buf_duration_cycles = tl.cast(gl.load(context_tensor + buf_base + 10), tl.pointer_type(tl.int64)) + buf_op_index = tl.cast(gl.load(context_tensor + buf_base + 11), tl.pointer_type(tl.int32)) + buf_payload_size = tl.cast(gl.load(context_tensor + buf_base + 12), tl.pointer_type(tl.int32)) + + device_tracing = GluonDeviceTracing( + enabled=tracing, + rank=cur_rank, + max_events=max_events, + counter=trace_counter, + op_index_counter=op_index_counter, + buf_event_id=buf_event_id, + buf_pid=buf_pid, + buf_pid_m=buf_pid_m, + buf_pid_n=buf_pid_n, + buf_cur_rank=buf_cur_rank, + buf_target_rank=buf_target_rank, + buf_xcc_id=buf_xcc_id, + buf_cu_id=buf_cu_id, + buf_timestamp=buf_timestamp, + buf_address=buf_address, + buf_duration_cycles=buf_duration_cycles, + buf_op_index=buf_op_index, + buf_payload_size=buf_payload_size, + ) + else: + # When tracing disabled, use dummy pointers (never dereferenced) + dummy_ptr_i32 = tl.cast(context_tensor, tl.pointer_type(tl.int32)) + dummy_ptr_i64 = tl.cast(context_tensor, tl.pointer_type(tl.int64)) + max_events_zero = tl.cast(0, tl.int32) + device_tracing = GluonDeviceTracing( + enabled=tracing, + rank=cur_rank, + max_events=max_events_zero, + counter=dummy_ptr_i32, + op_index_counter=dummy_ptr_i32, + buf_event_id=dummy_ptr_i32, + buf_pid=dummy_ptr_i32, + buf_pid_m=dummy_ptr_i32, + buf_pid_n=dummy_ptr_i32, + buf_cur_rank=dummy_ptr_i32, + buf_target_rank=dummy_ptr_i32, + buf_xcc_id=dummy_ptr_i32, + buf_cu_id=dummy_ptr_i32, + buf_timestamp=dummy_ptr_i64, + buf_address=dummy_ptr_i64, + buf_duration_cycles=dummy_ptr_i64, + buf_op_index=dummy_ptr_i32, + buf_payload_size=dummy_ptr_i32, + ) + + return IrisDeviceCtx(cur_rank, num_ranks, heap_bases, device_tracing) @gluon.jit def _translate(self, ptr, from_rank, to_rank): @@ -513,6 +743,9 @@ def __init__(self, heap_size=1 << 30): distributed_barrier() + # Initialize tracing manager (disabled by default) + self.tracing = Tracing(self) + # Initialize CCL interface self.ccl = self.CCL(self) @@ -668,20 +901,48 @@ def _build_device_context(self): """ Build and cache the device context tensor. - Called during __init__ to pre-build the tensor once. + Called during __init__ and again after tracing.enable() to include tracing fields. """ # Convert heap_bases to a list for concatenation heap_bases_list = self.heap_bases.tolist() # Create context tensor: [cur_rank, num_ranks, heap_base_0, heap_base_1, ...] context_data = [self.cur_rank, self.num_ranks] + heap_bases_list + + # Add tracing info if enabled (same layout as Iris._build_device_context) + if self.tracing.enabled: + trace_buffer_ptrs = [ + self.tracing.trace_buffers["event_id"].data_ptr(), + self.tracing.trace_buffers["pid"].data_ptr(), + self.tracing.trace_buffers["pid_m"].data_ptr(), + self.tracing.trace_buffers["pid_n"].data_ptr(), + self.tracing.trace_buffers["cur_rank"].data_ptr(), + self.tracing.trace_buffers["target_rank"].data_ptr(), + self.tracing.trace_buffers["xcc_id"].data_ptr(), + self.tracing.trace_buffers["cu_id"].data_ptr(), + self.tracing.trace_buffers["timestamp"].data_ptr(), + self.tracing.trace_buffers["address"].data_ptr(), + self.tracing.trace_buffers["duration_cycles"].data_ptr(), + self.tracing.trace_buffers["op_index"].data_ptr(), + self.tracing.trace_buffers["payload_size"].data_ptr(), + ] + context_data += [ + 1, # trace_enabled = 1 (true) + self.tracing.max_events, + self.tracing.trace_counter.data_ptr(), + self.tracing.op_index_counter.data_ptr(), + ] + trace_buffer_ptrs + else: + context_data += [0] # trace_enabled = 0 (false) + self._device_context = torch.tensor(context_data, dtype=torch.int64, device=self.device) def get_device_context(self): """ Get the device context tensor for Gluon kernels. - Returns a tensor encoding: `[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]` + Returns a tensor encoding: ``[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]`` + If tracing is enabled, also includes: ``[trace_enabled, max_events, trace_counter_ptr, trace_buffer_ptrs...]`` Returns: torch.Tensor: Encoded context data as int64 tensor on device diff --git a/tests/unittests/test_device_context_gluon.py b/tests/unittests/test_device_context_gluon.py new file mode 100644 index 00000000..18c6d7a9 --- /dev/null +++ b/tests/unittests/test_device_context_gluon.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +import torch +from triton.experimental import gluon +from triton.experimental.gluon import language as gl +import iris.experimental.iris_gluon as iris_gl +from iris.tracing.events import TraceEvent + + +@gluon.jit +def device_context_tracing_1d_address_kernel( + IrisDeviceCtx: gl.constexpr, + context_tensor, + dummy_buffer, + source_rank: gl.constexpr, + num_ranks: gl.constexpr, + BLOCK_SIZE: gl.constexpr, +): + """Test ctx.tracing.record_event_start/end with a 1D address (block of pointers).""" + ctx = IrisDeviceCtx.initialize(context_tensor, tracing=True) + + layout: gl.constexpr = gl.BlockedLayout([1], [BLOCK_SIZE], [1], [0]) + offsets = gl.arange(0, BLOCK_SIZE, layout=layout) + address_1d = dummy_buffer + offsets + + # All-true mask derived from offsets (offsets are always < BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + handle = ctx.tracing.record_event_start( + event_id=TraceEvent().put, + target_rank=(source_rank + 1) % num_ranks, + address=address_1d, + pid_m=gl.program_id(0), + pid_n=0, + mask=mask, + ) + ctx.tracing.record_event_end(handle) + + +def test_device_context_gluon_tracing_1d_address(): + """Test GluonDeviceTracing record_event_start/end with a 1D address block.""" + shmem = iris_gl.iris(1 << 20) + shmem.tracing.enable(max_events=1000) + context_tensor = shmem.get_device_context() + source_rank = shmem.get_rank() + num_ranks = shmem.get_num_ranks() + + BLOCK_SIZE = 64 # AMD wavefront size (64 threads per warp) + # Dummy buffer only to form 1D pointer block; never read/write + dummy_buffer = shmem.zeros((BLOCK_SIZE,), dtype=torch.int64) + + shmem.barrier() + + device_context_tracing_1d_address_kernel[(1,)]( + iris_gl.IrisDeviceCtx, + context_tensor, + dummy_buffer, + source_rank, + num_ranks, + BLOCK_SIZE, + num_warps=1, + ) + shmem.barrier() + + # Verify we recorded at least one event + assert shmem.tracing.trace_counter.item() >= 1 + + # Verify event data fields for the first recorded event + bufs = shmem.tracing.trace_buffers + assert bufs["event_id"][0].item() == 3 # TraceEvent().put == 3 + assert bufs["cur_rank"][0].item() == source_rank + assert bufs["target_rank"][0].item() == (source_rank + 1) % num_ranks + assert bufs["timestamp"][0].item() > 0 + # duration_cycles holds the end timestamp; it must be >= start timestamp + assert bufs["duration_cycles"][0].item() >= bufs["timestamp"][0].item() + # payload_size: BLOCK_SIZE elements × 8 bytes each (dummy_buffer is int64) + assert bufs["payload_size"][0].item() == BLOCK_SIZE * 8 + + shmem.barrier() + del shmem + import gc + + gc.collect() + + +def test_device_context_gluon_initialize(): + """Test IrisDeviceCtx.initialize() works without tracing enabled.""" + shmem = iris_gl.iris(1 << 20) + context_tensor = shmem.get_device_context() + + assert context_tensor is not None + assert isinstance(context_tensor, torch.Tensor) + assert context_tensor.dtype == torch.int64 + num_ranks = shmem.get_num_ranks() + # At least [cur_rank, num_ranks, heap_base_0, ...]; layout may add more (e.g. tracing flag) + assert context_tensor.shape[0] >= 2 + num_ranks + assert context_tensor[0].item() == shmem.get_rank() + assert context_tensor[1].item() == num_ranks + + shmem.barrier() + del shmem + import gc + + gc.collect() + + +def test_device_context_gluon_tracing_enable(): + """Test that shmem.tracing.enable() rebuilds context tensor with tracing fields.""" + shmem = iris_gl.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + + # Without tracing: [cur_rank, num_ranks, heap_base_0..N-1, trace_enabled=0] + ctx_no_trace = shmem.get_device_context() + size_no_trace = ctx_no_trace.shape[0] + + # Enable tracing and rebuild + shmem.tracing.enable(max_events=1000) + ctx_with_trace = shmem.get_device_context() + size_with_trace = ctx_with_trace.shape[0] + + # With tracing the context tensor is larger: + # [cur_rank, num_ranks, heap_base_0..N-1, trace_enabled=1, max_events, + # trace_counter_ptr, op_index_counter_ptr, 13 buffer ptrs] + assert size_with_trace > size_no_trace + # trace_enabled flag should be 1 + trace_enabled_idx = 2 + num_ranks + assert ctx_with_trace[trace_enabled_idx].item() == 1 + + shmem.barrier() + del shmem + import gc + + gc.collect()