-
Notifications
You must be signed in to change notification settings - Fork 37
Open
Labels
Description
Bug
In iris/x/all_reduce.py, all_reduce_one_shot and all_reduce_two_shot use a hardcoded value of 1 to signal "tile ready":
- Producers write:
tl.atomic_xchg(lock_ptr, 1, sem="release") - Consumers spin:
while iris.atomic_add(lock_ptr, 0, ...) != 1: pass
Between calls, the lock array must be zeroed back to 0 via a collective shmem.zeros + barrier, adding overhead to every kernel invocation. If the lock array is not properly zeroed (e.g., due to workspace reuse or error), consumers see lock == 1 from a previous call and read stale data.
Impact
- Per-call overhead from mandatory lock zeroing + barrier between invocations
- Fragile: skipping the zeroing step silently produces wrong results
- Prevents efficient workspace reuse across calls
Fix
Add a call_number parameter to both functions:
- Producers signal with:
tl.atomic_xchg(lock_ptr, call_number, sem="release", scope="sys") - Consumers spin on:
while iris.atomic_add(lock_ptr, 0, ...) != call_number: pass
Add a monotonically increasing call_counter field to FusedWorkspace, incremented on every matmul_all_reduce call. Each call uses a new version number, so stale locks from previous calls are automatically ignored without zeroing.
Component
iris/x/all_reduce.py, iris/ops/workspace.py
Reactions are currently unavailable