Skip to content

all_reduce_one_shot / all_reduce_two_shot use hardcoded lock value 1 #465

@aamarnat

Description

@aamarnat

Bug

In iris/x/all_reduce.py, all_reduce_one_shot and all_reduce_two_shot use a hardcoded value of 1 to signal "tile ready":

  • Producers write: tl.atomic_xchg(lock_ptr, 1, sem="release")
  • Consumers spin: while iris.atomic_add(lock_ptr, 0, ...) != 1: pass

Between calls, the lock array must be zeroed back to 0 via a collective shmem.zeros + barrier, adding overhead to every kernel invocation. If the lock array is not properly zeroed (e.g., due to workspace reuse or error), consumers see lock == 1 from a previous call and read stale data.

Impact

  • Per-call overhead from mandatory lock zeroing + barrier between invocations
  • Fragile: skipping the zeroing step silently produces wrong results
  • Prevents efficient workspace reuse across calls

Fix

Add a call_number parameter to both functions:

  • Producers signal with: tl.atomic_xchg(lock_ptr, call_number, sem="release", scope="sys")
  • Consumers spin on: while iris.atomic_add(lock_ptr, 0, ...) != call_number: pass

Add a monotonically increasing call_counter field to FusedWorkspace, incremented on every matmul_all_reduce call. Each call uses a new version number, so stale locks from previous calls are automatically ignored without zeroing.

Component

iris/x/all_reduce.py, iris/ops/workspace.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingirisIris project issue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions