scan + pulse: keep Scan body source facts consistent with outer wires#2337
Closed
JulienBalianSonos wants to merge 5 commits into
Closed
scan + pulse: keep Scan body source facts consistent with outer wires#2337JulienBalianSonos wants to merge 5 commits into
JulienBalianSonos wants to merge 5 commits into
Conversation
5c8636b to
f7fe6dd
Compare
f7fe6dd to
e5f8e1a
Compare
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Five commits that together let pulse-mode handle Scan ops whose body source facts depend on the streaming symbol (canonical example: DPRNN-style chunked GRUs):
pulse/model: extend blockify gate to Scan body with stream-derived dim.PulsedModel::newalready runs the blockify symbolic-chunk substitution when the outer graph has quadratic sections; this commit extends the gate to fire whenever any Scan body has a source fact whose shape mentions the streaming symbol. Without it the warmup turn binds 0-length tensors against literal dims and bails withClashing resolution for expression. 2=2 != 0.pulse/model: substitute chunk symbol before declutter for Scan-body path. Declutter aggressively folds stream-derived expressions to literals once the streaming symbol becomes concrete-ish (e.g.(STREAM − n_fft)/hop + 1 → Val(2)). For the Scan-body case the substitutionSTREAM → S · pulsehas to run first, so the same expression folds to2·S − 1and stays symbolic.core/ops/scan: resync body source facts when outer inputs drift. This is the headline change. The Scan body is a separateTypedModelwhose Source facts are set at construction time. When the outer graph runs through declutter or axis-change passes that mutate the upstream wires (canonical: an EinSumNIHW,OI->OHWprojecting a literal-1 batch axis on the chain feeding the Scan), the body's source facts drift out of sync.Scan::output_factsreads frombody.output_fact(...)which traces back to the body source, so the drift is silent until a runtime warmup or downstream invariant check trips on it.New
declutter_resync_body_source_factsrule rebuilds the body viawire_nodesooutput_factsre-propagates the new shapes. Two documented bail-outs returnOk(None):TypedModelPatch::replace_single_opcannot propagate: keep the original body, log a warning.One hard error (
bail!):Scan::output_facts' state-equality check on the next pass. Failing here surfaces any future upstream bug at its origin instead of letting it propagate as a confusing downstream symptom.Covered by two new unit tests:
tests::test_declutter_resync_body_source_factsconstructs a Scan with a multi-input body whose Scan slot drifted ((1, 2, 4)while the outer chain feeding it collapsed to(T, 1, 4)) and asserts that the rule's patch resyncs the slot to(1, 1, 4)while leaving the matching State slot untouched.tests::test_declutter_resync_no_drift_no_patchlocks in the early-exit when source facts already match.core/ops/source: TypedSource::change_axes returns Ok(None) for inapplicable Rm.change_axes' contract is to returnOk(None)when the change cannot be applied;TypedSourcewas bubbling the underlyingchange_shapeerror forRmon a non-trivial axis, which aborts the surrounding declutter pass instead of lettingChangeAxesskip the proposal. Now maps the targeted failure mode toOk(None).Covered by two new unit tests:
tests::change_axes_rm_non_trivial_returns_nonelocks in the contract for the previously-failing case.tests::change_axes_rm_trivial_still_appliessanity-checks thatRmon a1-dim still applies.core/ops/array: implement set_symbols on DynSlice and Topk. Both ops carry aTDimfield that needs to ride the chunk-symbol substitution. The impl follows the existingSlice/Tilepattern: build a fresh op withsubstitute_all(subs)?on the TDim fields, thenwire_nodeit into the target.Why bundle them
Each commit on its own is a no-op for model classes already covered by the existing scan-warmup gate. The bundle unlocks a new class (Scan body whose source shape is a
chunk_sym-derived expression) end to end. The chunk substitution is meaningless without the gate, the gate is meaningless without the substitution running pre-declutter, theset_symbolsimpls onDynSlice/Topkmake the substitution carry through ops the pulse path actually hits, and the resync is the runtime-correctness half of the same story.Test
tract-core(249) andtract-pulse(37) tests pass on top of currentmain.tests::test_scan_body_with_stream_derived_dim_uses_chunk_symbol(pulse) covers the gate + substitution.tests::test_declutter_resync_body_source_factsandtest_declutter_resync_no_drift_no_patch(core) cover the resync rule.tests::change_axes_rm_non_trivial_returns_noneandtests::change_axes_rm_trivial_still_applies(core) cover theTypedSourcecontract fix.Performance
A/B'd
PulsedModel::newdeclutter time on the models we have at hand, with the resync rule enabled vs disabled:Ok(None)no-drift): 0.628s vs 0.638s averaged over 3 runs each. The delta sits inside run-to-run noise (±15 ms / ~2%), so the no-drift fast path costs nothing measurable on models that don't need the rule.PulsedModel::newtime is ~4.1s; the rule's rebuilds are a small fraction of that.Resync overhead is below the noise floor on the models we tested.
Supersedes the open
feat/scan-warmup-symbolic-chunkbranch (commit 1 here is the same gate, plus the rest).