Arm backend: Generalize RemovePermutesAroundElementwiseTosaOps#20238
Arm backend: Generalize RemovePermutesAroundElementwiseTosaOps#20238AdrianLundell wants to merge 5 commits into
Conversation
- Use is_param_node for finding constant placeholders - Ensure ops are not modified by multiple subgraphs Signed-off-by: Adrian Lundell <adrian.lundell@arm.com> Change-Id: Ibb972d89c9c4125dc09bc918eae5f3fe81186cc0
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20238
Note: Links to docs will display an error until the docs builds have been completed. ❌ 8 New Failures, 6 Unrelated FailuresAs of commit 601922d with merge base 0da9ca3 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Pull request overview
This PR generalizes the Arm TOSA pass that removes redundant permutes around elementwise ops by improving constant-placeholder detection and preventing invalid rewrites when multiple candidate subgraphs overlap.
Changes:
- Add a “stale subgraph” guard to skip applying rewrites when a previously-applied rewrite has already changed a candidate’s boundary edges.
- Update
RemovePermutesAroundElementwiseTosaOpsto useis_param_node(exported_program, node)for identifying constant placeholders, and threadexported_programthrough the Arm pass pipeline. - Expand and update Arm backend tests to cover folded scalar constants and shared-boundary/stale-subgraph behavior; adjust transpose count expectations accordingly.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| backends/transforms/remove_permutes_around_elementwise_ops.py | Adds _subgraph_edges_are_current() and checks it before rewriting a candidate subgraph. |
| backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py | Requires ExportedProgram, improves constant detection via is_param_node, and keeps TABLE constant inputs from being permuted. |
| backends/arm/_passes/arm_pass_manager.py | Wires exported_program into the RemovePermutesAroundElementwiseTosaOps pass construction. |
| backends/arm/test/passes/test_remove_permutes_around_elementwise_tosa_ops.py | Adds new unit tests for folded scalar constants and stale shared-boundary subgraphs; updates pass construction to provide an exported program. |
| backends/arm/test/misc/test_transpose_counts.py | Updates expected transpose count for a model affected by improved permute-removal behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| PERMUTE_TARGET = exir_ops.edge.aten.permute_copy.default | ||
| RESCALE_TARGET = exir_ops.backend.tosa.RESCALE.default | ||
| TABLE_TARGET = exir_ops.backend.tosa.TABLE.default | ||
| MUL_TARGET = exir_ops.edge.aten.mul.Tensor | ||
| ADD_TARGET = exir_ops.edge.aten.add.Tensor | ||
| ERF_TARGET = exir_ops.edge.aten.erf.default |
| result = RemovePermutesAroundElementwiseTosaOps(_fake_exported_program()).call( | ||
| graph_module | ||
| ) |
There was a problem hiding this comment.
I'm not sure I get that part, can you expand on it?
There was a problem hiding this comment.
Since it builds the graph directly instead of exporting it it needs to fake an exported program as the pass now requires it.
…tream/change-1265674 Change-Id: I1e041d0ed886b0122ef2cb1917416ebc979499fe
Signed-off-by: Adrian Lundell <adrian.lundell@arm.com> Change-Id: Ic1d87ceb42e425628c50d95b27eeb0fe7ac0e6ff
| def _is_constant(self, node: torch.fx.Node) -> bool: | ||
| # Override fragile string match check with exported program check | ||
| return super()._is_constant(node) or is_param_node(self.exported_program, node) |
| def _is_constant(self, node: torch.fx.Node) -> bool: | ||
| # Override fragile string match check with exported program check | ||
| return super()._is_constant(node) or is_param_node(self.exported_program, node) |
| def _subgraph_edges_are_current(self, subgraph: Subgraph) -> bool: | ||
| """Return false if an earlier rewrite invalidated this candidate.""" | ||
| for inp, out in subgraph.edges_in: |
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani