Skip to content

feat(ptodsl): implement simt_allreduce_sum for SIMT cross-workitem all-reduce#3

Open
kuri780 wants to merge 2 commits into
and0d0:rmsNormfrom
kuri780:feature-PTODSL-allreduce-clean
Open

feat(ptodsl): implement simt_allreduce_sum for SIMT cross-workitem all-reduce#3
kuri780 wants to merge 2 commits into
and0d0:rmsNormfrom
kuri780:feature-PTODSL-allreduce-clean

Conversation

@kuri780

@kuri780 kuri780 commented Jun 23, 2026

Copy link
Copy Markdown

Summary

Implements pto.simt_allreduce_sum as designed in mission/483/483_docs.md.

API

pto.simt_allreduce_sum(value, *, threads, scale=1, thread_offset=0, scratch=None, scratch_offset=0) -> ScalarType

Implementation

  • Pure Python MLIR IR emission — no C++ extension needed
  • 3 dispatch strategies: warp_reduce (≤32 threads, pow2), cross_warp_reduce (>32, pow2), ub_reduce (fallback)
  • Lazy helper function deduplication
  • Supports f32 and f16

Files

File Lines
ptodsl/ptodsl/_allreduce.py +674 (new)
ptodsl/ptodsl/pto.py +3
ptodsl/tests/test_allreduce.py +533 (new)

Test

python3 ptodsl/tests/test_allreduce.py

Expected: ptodsl_allreduce: PASS

wenxuekun and others added 2 commits June 23, 2026 17:05
…l-reduce

Implement the pto.simt_allreduce_sum frontend interface as designed in
mission/483/483_docs.md.  Pure Python MLIR IR emission with three
dispatch strategies: warp_reduce (<=32 threads, pow2), cross_warp_reduce
(>32, pow2), ub_reduce (fallback).  Supports f32 and f16.

- ptodsl/ptodsl/_allreduce.py: new — 674 lines
- ptodsl/ptodsl/pto.py: export simt_allreduce_sum (+3 lines)
- ptodsl/tests/test_allreduce.py: new — 533 lines, all passing

Co-Authored-By: Claude <noreply@anthropic.com>
Refactor _allreduce.py from helper-function outline to inline emission,
add max/min reducer support, and create VPTO simulator test cases.

Core changes (_allreduce.py):
- Replace func.call outline with inline emission (_emit_inline)
- Add simt_allreduce_max and simt_allreduce_min APIs
- Add reducer dispatch tables (IDENTITY, COMBINE, REDUX)
- Convert control flow to PTODSL if_() context manager
- Convert ops to PTODSL wrappers (pto.const, scalar.*, redux_*, ...)
- Keep raw arith only for unsigned ops (DivUIOp, RemUIOp, ShRUIOp, ult)
- Auto-attach pto.simt_entry attribute for syncthreads verifier

Export (pto.py):
- Export simt_allreduce_sum, simt_allreduce_max, simt_allreduce_min

Tests (test_allreduce.py):
- Add IR structure tests for all 4 paths × 3 reducers
- Add ptoas lowering verification for warp paths
- Document bisheng stack-smashing bug on cross-warp scratch paths

VPTO simulator tests (test/vpto/cases/micro-op/simt/allreduce_*):
- 6 cases: warp_sum/max/min (32 lanes) + cross_sum/max/min (128 lanes)
- kernel.pto + launch.cpp + main.cpp + golden.py + compare.py
- All 6 cases verified on Ascend950PR_9599 simulator (DEVICE=SIM)

Co-Authored-By: Claude <noreply@anthropic.com>
@kuri780 kuri780 force-pushed the feature-PTODSL-allreduce-clean branch from 761e9a1 to a5ab39f Compare June 30, 2026 08:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant