Arm backend: Make ConstantFoldingPass preserve Learnable Parameters#20372
Arm backend: Make ConstantFoldingPass preserve Learnable Parameters#20372gggekov wants to merge 1 commit into
Conversation
Avoid folding expressions that depend on trainable parameters during QAT preparation. Folding these values can turn learnable tensors into fixed constants, leaving the prepared model without the parameters needed for QAT training. Raised by pytorch#20148 Signed-off-by: George Gekov <george.gekov@arm.com> Change-Id: Idc9c089c87e7291567e73a0b48298c8bcea04081
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20372
Note: Links to docs will display an error until the docs builds have been completed. ❌ 7 New Failures, 2 Unrelated FailuresAs of commit ce52f9f with merge base a6305cb ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
This PR updates the Arm backend’s constant-folding pass to avoid folding subgraphs that depend on trainable parameters, preventing QAT preparation from accidentally turning learnable tensors into fixed constants (addressing the missing-parameters failure reported in #20148).
Changes:
- Teach
ConstantFoldingPassto preserveget_attrparameters withrequires_grad=True(and their dependent nodes) duringconstant_fold. - Add/extend Arm backend QAT training-loop tests, including a regression case for preserving a learnable parameter through QAT prepare.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
backends/arm/_passes/constant_folding_pass.py |
Adds preservation logic to prevent constant folding of trainable-parameter-dependent subgraphs during QAT prep. |
backends/arm/test/misc/test_qat_training_loop.py |
Adds QAT training loop regression tests (learnable-param preservation + additional QAT smoke tests). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if node.op == "get_attr": | ||
| try: | ||
| attr = getattr(graph_module, node.target, None) | ||
| except AttributeError: | ||
| continue |
| # Returns a set of the graph nodes with learnable | ||
| # parameters and any nodes depending on the learnable | ||
| # parameter. Needed for the case when doing | ||
| # QAT and we should not do constatnt folding | ||
| # for the learnable parameters and their dependends. |
| def train_step( | ||
| model: torch.nn.Module, optimizer: torch.optim.Optimizer, input_output, loss_fn | ||
| ) -> None: |
| optimizer = torch.optim.SGD(prepared_model.parameters(), lr=1.0, momentum=0.3) | ||
| for _ in range(3): | ||
| train_step(model, optimizer, _make_batch, torch.nn.functional.mse_loss) |
| loss = train_step( | ||
| model, optimizer, _make_image_batch, torch.nn.functional.cross_entropy | ||
| ) |
| quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) | ||
| prepared_model = prepare_qat_pt2e(exported_model.module(), quantizer) | ||
| prepared_model = move_exported_model_to_train(prepared_model) |
| prepared_model, | ||
| optimizer, | ||
| _make_mha_batch, | ||
| torch.nn.functional.cross_entropy, |
Avoid folding expressions that depend on trainable parameters during QAT preparation. Folding these values can turn learnable tensors into fixed constants, leaving the prepared model without the parameters needed for QAT training.
Raised by #20148
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani