Fix XNNPACK channels-last reshape for per-channel binary ops under dynamic quant#20376
Open
Hyungkeun-Park-Nota wants to merge 2 commits into
Open
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20376
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…namic quant ChannelsLastTaggedReshapePass.input_to_nhwc has a dynamic-quant branch that calls input_node.replace_all_uses_with(input_node_nhwc), globally redirecting a shared activation to its NHWC copy. The main traversal visits nodes in topological order, so this can retroactively switch a broadcasting binary op's activation operand to NHWC after that op was already processed, while its other operand (e.g. a per-channel constant) stays NCHW. The two operands then have incompatible logical shapes (e.g. [1, H, W, C] vs [1, C, 1, 1]), and XNNPACK fails at runtime in xnn_reshape_binary_elementwise_nd with xnn_status_invalid_parameter. Lowering succeeds; only execution fails (observed on DETR with w8a8 dynamic quantization). Add a re-convergence sweep after the main traversal that restores the pass's own invariant -- all operands of a node share one memory format -- for broadcasting binary ops (add/mul/sub/div), converging to NHWC when possible (else NCHW). It runs after the retrace so it observes the settled graph, and retraces once more only if anything changed. Graphs with no diverged binary operands are untouched. Add a regression test (DynamicQuantPerChannelBinaryChain) whose per-channel binary op is the first consumer of an input activation, with the convolution chain running through it; it fails to execute (xnn_status_invalid_parameter) without this fix and passes with it.
050fb6e to
289a570
Compare
Contributor
Author
|
@pytorchbot label "release notes: xnnpack" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
ChannelsLastTaggedReshapePass.input_to_nhwc's dynamic-quant branch callsinput_node.replace_all_uses_with(input_node_nhwc)on a shared activation. Becausethe main traversal runs in topological order, this can switch a broadcasting binary
op's activation operand to NHWC after that op was already processed, while its
other operand (e.g. a per-channel constant) stays NCHW. The operands then have
incompatible shapes (e.g.
[1, H, W, C]vs[1, C, 1, 1]), and XNNPACK fails atruntime in
xnn_reshape_binary_elementwise_ndwithxnn_status_invalid_parameter.Lowering succeeds; only
execute()fails. Found on a ResNet-50-backbone detectionmodel lowered with w8a8 dynamic quantization.
Fix
After the main traversal, re-converge any broadcasting binary op (
add/mul/sub/div) whose operands ended up in different memory formats, reusing the existinginput_to_nhwc/input_to_nchwhelpers. It runs after the retrace so it sees thesettled graph, and retraces again only if something changed; it is a no-op for
graphs without diverged operands.
Test
test_dynamic_quant_per_channel_binary_chain_lowers_and_runsbuilds a graph where aper-channel binary op is the first consumer of an input activation and the
convolution chain runs through it. It fails to execute without this fix and passes
with it.