Skip to content

Port tracing capabilities to Gluon IrisDeviceCtx API #472

@mawad-amd

Description

@mawad-amd

Summary

The Triton-based DeviceContext API has full tracing/instrumentation support via DeviceTracing, but the Gluon-based IrisDeviceCtx API (iris/experimental/iris_gluon.py) has none. This makes it impossible to profile and instrument gluon kernels the same way we do with Triton kernels.

Current State

Triton API (has tracing):

  • DeviceContext.initialize(context_tensor, rank, world_size, tracing=True) enables tracing
  • ctx.tracing.record_event_start(event_id, target_rank, address, pid_m, pid_n, mask) records start
  • ctx.tracing.record_event_end(handle) records end with duration
  • Tracing.export("trace.json") exports to Perfetto/Chrome Trace format
  • Full example: examples/23_gemm_all_scatter_tracing/

Gluon API (missing tracing):

  • IrisDeviceCtx has no tracing attribute
  • No record_event_start / record_event_end methods
  • initialize() only decodes [cur_rank, num_ranks, heap_bases...] — no tracing buffer pointers
  • IrisGluon.get_device_context() doesn't encode tracing buffer info

What Needs to Be Done

  1. Add tracing parameter to IrisDeviceCtx.initialize(): tracing: gl.constexpr = False
  2. Create gluon-native DeviceTracing class: Convert iris/tracing/device.py (DeviceTracing) to use @gluon.jit methods instead of @triton.jit
  3. Update context tensor encoding/decoding: When tracing is enabled, include the 13 trace buffer pointers (event_id, pid, pid_m, pid_n, cur_rank, target_rank, xcc_id, cu_id, timestamp, address, duration_cycles, op_index, payload_size) in the context tensor, same as DeviceContext does
  4. Add tracing attribute to IrisDeviceCtx: Store the gluon DeviceTracing instance
  5. Update IrisGluon._build_device_context(): Include tracing buffer pointers when tracing is enabled
  6. Add tracing property to IrisGluon: Expose Tracing host-side class for enable(), reset(), export()
  7. Add tests: Port test_device_context_tracing_1d_address to gluon

Reference Files

  • Tracing core: iris/tracing/core.py (host-side Tracing class)
  • Device tracing: iris/tracing/device.py (device-side DeviceTracing class, needs gluon port)
  • Event types: iris/tracing/events.py (TraceEvent enum — reuse as-is)
  • Triton integration: iris/iris.py (DeviceContext.initialize() with tracing)
  • Gluon API (target): iris/experimental/iris_gluon.py (IrisDeviceCtx)
  • Example: examples/23_gemm_all_scatter_tracing/gemm_all_scatter.py

Usage After Implementation

import iris.experimental.iris_gluon as iris_gluon

shmem = iris_gluon.iris(2**30)
shmem.tracing.enable(max_events=1_000_000)
context_tensor = shmem.get_device_context()

@gluon.jit
def my_kernel(IrisDeviceCtx: gl.constexpr, context_tensor, ...):
    ctx = IrisDeviceCtx.initialize(context_tensor, tracing=True)
    
    handle = ctx.tracing.record_event_start(
        event_id=TraceEvent().store,
        target_rank=target_rank,
        address=ptr,
        pid_m=pid_m, pid_n=pid_n,
        mask=mask,
    )
    ctx.store(ptr, data, target_rank, mask=mask)
    ctx.tracing.record_event_end(handle)

# Export trace
shmem.tracing.export("gluon_trace.json")

Metadata

Metadata

Labels

coreCore Iris library developmentexamplesExamples showcasing Iris APIs and usagegluonirisIris project issue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions