feat(ptodsl): implement simt_allreduce_sum for SIMT cross-workitem all-reduce#3
Open
kuri780 wants to merge 2 commits into
Open
feat(ptodsl): implement simt_allreduce_sum for SIMT cross-workitem all-reduce#3kuri780 wants to merge 2 commits into
kuri780 wants to merge 2 commits into
Conversation
…l-reduce Implement the pto.simt_allreduce_sum frontend interface as designed in mission/483/483_docs.md. Pure Python MLIR IR emission with three dispatch strategies: warp_reduce (<=32 threads, pow2), cross_warp_reduce (>32, pow2), ub_reduce (fallback). Supports f32 and f16. - ptodsl/ptodsl/_allreduce.py: new — 674 lines - ptodsl/ptodsl/pto.py: export simt_allreduce_sum (+3 lines) - ptodsl/tests/test_allreduce.py: new — 533 lines, all passing Co-Authored-By: Claude <noreply@anthropic.com>
Refactor _allreduce.py from helper-function outline to inline emission, add max/min reducer support, and create VPTO simulator test cases. Core changes (_allreduce.py): - Replace func.call outline with inline emission (_emit_inline) - Add simt_allreduce_max and simt_allreduce_min APIs - Add reducer dispatch tables (IDENTITY, COMBINE, REDUX) - Convert control flow to PTODSL if_() context manager - Convert ops to PTODSL wrappers (pto.const, scalar.*, redux_*, ...) - Keep raw arith only for unsigned ops (DivUIOp, RemUIOp, ShRUIOp, ult) - Auto-attach pto.simt_entry attribute for syncthreads verifier Export (pto.py): - Export simt_allreduce_sum, simt_allreduce_max, simt_allreduce_min Tests (test_allreduce.py): - Add IR structure tests for all 4 paths × 3 reducers - Add ptoas lowering verification for warp paths - Document bisheng stack-smashing bug on cross-warp scratch paths VPTO simulator tests (test/vpto/cases/micro-op/simt/allreduce_*): - 6 cases: warp_sum/max/min (32 lanes) + cross_sum/max/min (128 lanes) - kernel.pto + launch.cpp + main.cpp + golden.py + compare.py - All 6 cases verified on Ascend950PR_9599 simulator (DEVICE=SIM) Co-Authored-By: Claude <noreply@anthropic.com>
761e9a1 to
a5ab39f
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements
pto.simt_allreduce_sumas designed inmission/483/483_docs.md.API
pto.simt_allreduce_sum(value, *, threads, scale=1, thread_offset=0, scratch=None, scratch_offset=0) -> ScalarType
Implementation
Files
Test
python3 ptodsl/tests/test_allreduce.py
Expected: ptodsl_allreduce: PASS