Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 268 additions & 6 deletions iris/experimental/iris_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
count_devices,
)
from iris.symmetric_heap import SymmetricHeap
from iris import device_utils
from iris.tracing.core import Tracing
import numpy as np
import torch
import logging
Expand All @@ -62,6 +64,152 @@
from .. import tensor_creation


class _GluonDeviceTracingCls:
"""
Gluon-native device-side tracing: records events into SoA buffers from inside Gluon kernels.

Created by IrisDeviceCtx.initialize() when tracing=True. Use record_event_start
/ record_event_end to bracket operations; events are exported via Tracing.export().
"""

enabled: tl.constexpr
rank: gl.tensor
max_events: gl.tensor
counter: gl.tensor
op_index_counter: gl.tensor
buf_event_id: gl.tensor
buf_pid: gl.tensor
buf_pid_m: gl.tensor
buf_pid_n: gl.tensor
buf_cur_rank: gl.tensor
buf_target_rank: gl.tensor
buf_xcc_id: gl.tensor
buf_cu_id: gl.tensor
buf_timestamp: gl.tensor
buf_address: gl.tensor
buf_duration_cycles: gl.tensor
buf_op_index: gl.tensor
buf_payload_size: gl.tensor

@gluon.constexpr_function
def __init__(
self,
enabled,
rank,
max_events,
counter,
op_index_counter,
buf_event_id,
buf_pid,
buf_pid_m,
buf_pid_n,
buf_cur_rank,
buf_target_rank,
buf_xcc_id,
buf_cu_id,
buf_timestamp,
buf_address,
buf_duration_cycles,
buf_op_index,
buf_payload_size,
):
"""Construct GluonDeviceTracing (called from IrisDeviceCtx.initialize)."""
self.enabled = enabled
self.rank = rank
self.max_events = max_events
self.counter = counter
self.op_index_counter = op_index_counter
self.buf_event_id = buf_event_id
self.buf_pid = buf_pid
self.buf_pid_m = buf_pid_m
self.buf_pid_n = buf_pid_n
self.buf_cur_rank = buf_cur_rank
self.buf_target_rank = buf_target_rank
self.buf_xcc_id = buf_xcc_id
self.buf_cu_id = buf_cu_id
self.buf_timestamp = buf_timestamp
self.buf_address = buf_address
self.buf_duration_cycles = buf_duration_cycles
self.buf_op_index = buf_op_index
self.buf_payload_size = buf_payload_size

@gluon.jit
def record_event_start(
self,
event_id: tl.constexpr,
target_rank,
address,
pid_m,
pid_n,
mask=None,
):
"""
Record start of a traced operation. Returns a handle for record_event_end.

Only stores when event_idx < max_events (bounds check).
cur_rank is taken from the tracing context (self.rank).

Args:
event_id: Event type ID (constexpr)
target_rank: Target rank for the operation
address: Memory address(es) - can be 1D or 2D block of pointers.
pid_m: Program ID in M dimension
pid_n: Program ID in N dimension
mask: Optional mask tensor indicating valid elements.
"""
if not self.enabled:
return tl.cast(0, tl.int32)

event_idx = tl.atomic_add(self.counter, 1)
op_index = tl.atomic_add(self.op_index_counter, 1)

# Calculate payload_size from mask and datatype
if mask is not None:
mask_i32 = tl.cast(mask, tl.int32)
num_elements = tl.sum(mask_i32)
elem_type = address.dtype.element_ty
bitwidth = elem_type.primitive_bitwidth
elem_size_bytes = bitwidth // 8
payload_size = num_elements * elem_size_bytes
else:
payload_size = tl.cast(0, tl.int32)

if event_idx < self.max_events:
tl.store(self.buf_event_id + event_idx, tl.cast(event_id, tl.int32))
tl.store(self.buf_pid + event_idx, tl.cast(gl.program_id(0), tl.int32))
tl.store(self.buf_pid_m + event_idx, tl.cast(pid_m, tl.int32))
tl.store(self.buf_pid_n + event_idx, tl.cast(pid_n, tl.int32))
tl.store(self.buf_cur_rank + event_idx, tl.cast(self.rank, tl.int32))
tl.store(self.buf_target_rank + event_idx, tl.cast(target_rank, tl.int32))
tl.store(self.buf_xcc_id + event_idx, device_utils.get_xcc_id())
tl.store(self.buf_cu_id + event_idx, device_utils.get_cu_id())
tl.store(self.buf_timestamp + event_idx, device_utils.read_realtime())
addr_i64 = tl.cast(address, tl.int64)
tl.store(self.buf_address + event_idx, tl.min(addr_i64))
tl.store(self.buf_duration_cycles + event_idx, tl.cast(0, tl.int64))
tl.store(self.buf_op_index + event_idx, op_index)
tl.store(self.buf_payload_size + event_idx, tl.cast(payload_size, tl.int32))
return event_idx

@gluon.jit
def record_event_end(self, handle):
"""
Record end timestamp for the event started with record_event_start(handle).

Only stores when handle < max_events (bounds check).
"""
if not self.enabled:
return

end_ts = device_utils.read_realtime()
if handle < self.max_events:
tl.store(self.buf_duration_cycles + handle, end_ts)


_GluonDeviceTracingCls.__init__.__triton_builtin__ = True
GluonDeviceTracing = aggregate(_GluonDeviceTracingCls)


@aggregate
class IrisDeviceCtx:
"""
Expand All @@ -74,28 +222,36 @@ class IrisDeviceCtx:
cur_rank: Current rank ID
num_ranks: Total number of ranks
heap_bases: Pointer to array of heap base addresses for all ranks
tracing: GluonDeviceTracing instance (active when tracing=True)
"""

cur_rank: gl.tensor
num_ranks: gl.tensor
heap_bases: gl.tensor
tracing: GluonDeviceTracing

@gluon.constexpr_function
def __init__(self, cur_rank, num_ranks, heap_bases):
def __init__(self, cur_rank, num_ranks, heap_bases, tracing):
self.cur_rank = cur_rank
self.num_ranks = num_ranks
self.heap_bases = heap_bases
self.tracing = tracing

@staticmethod
@gluon.jit
def initialize(context_tensor):
def initialize(context_tensor, tracing: gl.constexpr = False):
"""
Initialize `IrisDeviceCtx` from the encoded tensor.

The context tensor has the format: `[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]`
The context tensor has the format:
``[cur_rank, num_ranks, heap_base_0, heap_base_1, ..., trace_info...]``

If tracing is enabled on the host (via ``shmem.tracing.enable()``), the
context tensor also contains tracing buffer pointers after the heap bases.

Args:
context_tensor: Pointer to encoded context data
tracing: Enable event tracing (constexpr, default: False)

Returns:
`IrisDeviceCtx`: Initialized device context
Expand All @@ -107,7 +263,82 @@ def initialize(context_tensor):
# Extract heap bases (from index 2 onwards)
heap_bases = context_tensor + 2 # Offset pointer to start at heap bases

return IrisDeviceCtx(cur_rank, num_ranks, heap_bases)
if tracing:
# Extract tracing info: starts after heap_bases, then skip trace_enabled flag
# Layout: [cur_rank, num_ranks, heap_base_0..N-1, trace_enabled, max_events,
# trace_counter_ptr, op_index_counter_ptr, buf_event_id, ...(13 buffers)]
trace_info_base = 2 + num_ranks + 1 # skip cur_rank, num_ranks, heap_bases, trace_enabled
max_events = tl.cast(gl.load(context_tensor + trace_info_base + 0), tl.int32)
trace_counter_ptr = gl.load(context_tensor + trace_info_base + 1)
op_index_counter_ptr = gl.load(context_tensor + trace_info_base + 2)

# Cast counter pointers
trace_counter = tl.cast(trace_counter_ptr, tl.pointer_type(tl.int32))
op_index_counter = tl.cast(op_index_counter_ptr, tl.pointer_type(tl.int32))

# Extract trace buffer pointers (13 buffers, same order as Iris._build_device_context)
buf_base = trace_info_base + 3
buf_event_id = tl.cast(gl.load(context_tensor + buf_base + 0), tl.pointer_type(tl.int32))
buf_pid = tl.cast(gl.load(context_tensor + buf_base + 1), tl.pointer_type(tl.int32))
buf_pid_m = tl.cast(gl.load(context_tensor + buf_base + 2), tl.pointer_type(tl.int32))
buf_pid_n = tl.cast(gl.load(context_tensor + buf_base + 3), tl.pointer_type(tl.int32))
buf_cur_rank = tl.cast(gl.load(context_tensor + buf_base + 4), tl.pointer_type(tl.int32))
buf_target_rank = tl.cast(gl.load(context_tensor + buf_base + 5), tl.pointer_type(tl.int32))
buf_xcc_id = tl.cast(gl.load(context_tensor + buf_base + 6), tl.pointer_type(tl.int32))
buf_cu_id = tl.cast(gl.load(context_tensor + buf_base + 7), tl.pointer_type(tl.int32))
buf_timestamp = tl.cast(gl.load(context_tensor + buf_base + 8), tl.pointer_type(tl.int64))
buf_address = tl.cast(gl.load(context_tensor + buf_base + 9), tl.pointer_type(tl.int64))
buf_duration_cycles = tl.cast(gl.load(context_tensor + buf_base + 10), tl.pointer_type(tl.int64))
buf_op_index = tl.cast(gl.load(context_tensor + buf_base + 11), tl.pointer_type(tl.int32))
buf_payload_size = tl.cast(gl.load(context_tensor + buf_base + 12), tl.pointer_type(tl.int32))

device_tracing = GluonDeviceTracing(
enabled=tracing,
rank=cur_rank,
max_events=max_events,
counter=trace_counter,
op_index_counter=op_index_counter,
buf_event_id=buf_event_id,
buf_pid=buf_pid,
buf_pid_m=buf_pid_m,
buf_pid_n=buf_pid_n,
buf_cur_rank=buf_cur_rank,
buf_target_rank=buf_target_rank,
buf_xcc_id=buf_xcc_id,
buf_cu_id=buf_cu_id,
buf_timestamp=buf_timestamp,
buf_address=buf_address,
buf_duration_cycles=buf_duration_cycles,
buf_op_index=buf_op_index,
buf_payload_size=buf_payload_size,
)
else:
# When tracing disabled, use dummy pointers (never dereferenced)
dummy_ptr_i32 = tl.cast(context_tensor, tl.pointer_type(tl.int32))
dummy_ptr_i64 = tl.cast(context_tensor, tl.pointer_type(tl.int64))
max_events_zero = tl.cast(0, tl.int32)
device_tracing = GluonDeviceTracing(
enabled=False,
rank=cur_rank,
max_events=max_events_zero,
counter=dummy_ptr_i32,
op_index_counter=dummy_ptr_i32,
buf_event_id=dummy_ptr_i32,
buf_pid=dummy_ptr_i32,
buf_pid_m=dummy_ptr_i32,
buf_pid_n=dummy_ptr_i32,
buf_cur_rank=dummy_ptr_i32,
buf_target_rank=dummy_ptr_i32,
buf_xcc_id=dummy_ptr_i32,
buf_cu_id=dummy_ptr_i32,
buf_timestamp=dummy_ptr_i64,
buf_address=dummy_ptr_i64,
buf_duration_cycles=dummy_ptr_i64,
buf_op_index=dummy_ptr_i32,
buf_payload_size=dummy_ptr_i32,
)

return IrisDeviceCtx(cur_rank, num_ranks, heap_bases, device_tracing)

@gluon.jit
def _translate(self, ptr, from_rank, to_rank):
Expand Down Expand Up @@ -513,6 +744,9 @@ def __init__(self, heap_size=1 << 30):

distributed_barrier()

# Initialize tracing manager (disabled by default)
self.tracing = Tracing(self)

# Initialize CCL interface
self.ccl = self.CCL(self)

Expand Down Expand Up @@ -668,20 +902,48 @@ def _build_device_context(self):
"""
Build and cache the device context tensor.

Called during __init__ to pre-build the tensor once.
Called during __init__ and again after tracing.enable() to include tracing fields.
"""
# Convert heap_bases to a list for concatenation
heap_bases_list = self.heap_bases.tolist()

# Create context tensor: [cur_rank, num_ranks, heap_base_0, heap_base_1, ...]
context_data = [self.cur_rank, self.num_ranks] + heap_bases_list

# Add tracing info if enabled (same layout as Iris._build_device_context)
if self.tracing.enabled:
trace_buffer_ptrs = [
self.tracing.trace_buffers["event_id"].data_ptr(),
self.tracing.trace_buffers["pid"].data_ptr(),
self.tracing.trace_buffers["pid_m"].data_ptr(),
self.tracing.trace_buffers["pid_n"].data_ptr(),
self.tracing.trace_buffers["cur_rank"].data_ptr(),
self.tracing.trace_buffers["target_rank"].data_ptr(),
self.tracing.trace_buffers["xcc_id"].data_ptr(),
self.tracing.trace_buffers["cu_id"].data_ptr(),
self.tracing.trace_buffers["timestamp"].data_ptr(),
self.tracing.trace_buffers["address"].data_ptr(),
self.tracing.trace_buffers["duration_cycles"].data_ptr(),
self.tracing.trace_buffers["op_index"].data_ptr(),
self.tracing.trace_buffers["payload_size"].data_ptr(),
]
context_data += [
1, # trace_enabled = 1 (true)
self.tracing.max_events,
self.tracing.trace_counter.data_ptr(),
self.tracing.op_index_counter.data_ptr(),
] + trace_buffer_ptrs
else:
context_data += [0] # trace_enabled = 0 (false)

self._device_context = torch.tensor(context_data, dtype=torch.int64, device=self.device)

def get_device_context(self):
"""
Get the device context tensor for Gluon kernels.

Returns a tensor encoding: `[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]`
Returns a tensor encoding: ``[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]``
If tracing is enabled, also includes: ``[trace_enabled, max_events, trace_counter_ptr, trace_buffer_ptrs...]``

Returns:
torch.Tensor: Encoded context data as int64 tensor on device
Expand Down
Loading