Skip to content

Arm backend: Make ConstantFoldingPass preserve Learnable Parameters#20372

Open
gggekov wants to merge 1 commit into
pytorch:mainfrom
gggekov:fix_learnable_param_bug
Open

Arm backend: Make ConstantFoldingPass preserve Learnable Parameters#20372
gggekov wants to merge 1 commit into
pytorch:mainfrom
gggekov:fix_learnable_param_bug

Conversation

@gggekov

@gggekov gggekov commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

Avoid folding expressions that depend on trainable parameters during QAT preparation. Folding these values can turn learnable tensors into fixed constants, leaving the prepared model without the parameters needed for QAT training.

Raised by #20148

cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani

Avoid folding expressions that depend on trainable parameters during
QAT preparation. Folding these values can turn learnable tensors into
fixed constants, leaving the prepared model without the parameters
needed for QAT training.

Raised by pytorch#20148

Signed-off-by: George Gekov <george.gekov@arm.com>
Change-Id: Idc9c089c87e7291567e73a0b48298c8bcea04081
Copilot AI review requested due to automatic review settings June 18, 2026 08:02
@gggekov gggekov requested a review from digantdesai as a code owner June 18, 2026 08:02
@pytorch-bot

pytorch-bot Bot commented Jun 18, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20372

Note: Links to docs will display an error until the docs builds have been completed.

❌ 7 New Failures, 2 Unrelated Failures

As of commit ce52f9f with merge base a6305cb (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 18, 2026
@gggekov gggekov added the partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm label Jun 18, 2026
@github-actions github-actions Bot added the module: arm Issues related to arm backend label Jun 18, 2026
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the Arm backend’s constant-folding pass to avoid folding subgraphs that depend on trainable parameters, preventing QAT preparation from accidentally turning learnable tensors into fixed constants (addressing the missing-parameters failure reported in #20148).

Changes:

  • Teach ConstantFoldingPass to preserve get_attr parameters with requires_grad=True (and their dependent nodes) during constant_fold.
  • Add/extend Arm backend QAT training-loop tests, including a regression case for preserving a learnable parameter through QAT prepare.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.

File Description
backends/arm/_passes/constant_folding_pass.py Adds preservation logic to prevent constant folding of trainable-parameter-dependent subgraphs during QAT prep.
backends/arm/test/misc/test_qat_training_loop.py Adds QAT training loop regression tests (learnable-param preservation + additional QAT smoke tests).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +64 to +68
if node.op == "get_attr":
try:
attr = getattr(graph_module, node.target, None)
except AttributeError:
continue
Comment on lines +56 to +60
# Returns a set of the graph nodes with learnable
# parameters and any nodes depending on the learnable
# parameter. Needed for the case when doing
# QAT and we should not do constatnt folding
# for the learnable parameters and their dependends.
Comment on lines +134 to +136
def train_step(
model: torch.nn.Module, optimizer: torch.optim.Optimizer, input_output, loss_fn
) -> None:
Comment on lines +167 to +169
optimizer = torch.optim.SGD(prepared_model.parameters(), lr=1.0, momentum=0.3)
for _ in range(3):
train_step(model, optimizer, _make_batch, torch.nn.functional.mse_loss)
Comment on lines +228 to +230
loss = train_step(
model, optimizer, _make_image_batch, torch.nn.functional.cross_entropy
)
Comment on lines +271 to +273
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
prepared_model = prepare_qat_pt2e(exported_model.module(), quantizer)
prepared_model = move_exported_model_to_train(prepared_model)
prepared_model,
optimizer,
_make_mha_batch,
torch.nn.functional.cross_entropy,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: arm Issues related to arm backend partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants