Arm backend: Add MXFP4/MXFP6 linear support#20370
Open
martinlsm wants to merge 3 commits into
Open
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20370
Note: Links to docs will display an error until the docs builds have been completed. ❌ 7 New Failures, 3 Cancelled Jobs, 7 Unrelated Failures, 29 Unclassified FailuresAs of commit 65128df with merge base 0da9ca3 ( NEW FAILURES - The following jobs have failed:
UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Collaborator
Author
|
@pytorchbot label ciflow/trunk |
Collaborator
Author
|
@pytorchbot label "partner: arm" |
Collaborator
Author
|
@pytorchbot label "release notes: arm" |
zingo
approved these changes
Jun 18, 2026
zingo
left a comment
Collaborator
There was a problem hiding this comment.
OK to merge if tests pass and Mypy error is fixed
nss test fails in unrelated
Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: I29ca5ed16db5fe15331402e5139aead1040ce1b6
Add support for running torch.nn.Linear modules in MXFP6E3M2 and MXFP6E2M3. Update the `MXFPOpConfig` to support the data types. Since `torch` lacks FP6 datatypes, the string-based definitions in `torchao` are used as a workaround. The custom TOSA op receives a new argument called `str weight_payload_dtype` which tells which dtype is used for the weights. The weight tensor itself does not contain this info since the FP4 and FP6 formats are storted into uint8 tensors. The CAST_TO_BLOCK_SCALED required to transform activations from/to MXFP is also updated to support the new datatype. Its custom TOSA op gets a `str output_dtype` similar to the `weight_payload_dtype` for the MXFP linear op. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: Iea06b859429c458b1921ee3e736a65f597b37239
Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: Ic972b1b93224bb03bca03599268c3a2e09f5ec4f
9463dee to
65128df
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add support for running
torch.nn.Linearmodules in MXFP4E2M1,MXFP6E3M2 and MXFP6E2M3. Update the
MXFPOpConfigto supportthe data types.
Since
torchlacks FP6 datatypes, the string-based definitions intorchaoare used as a workaround.The custom TOSA op receives a new argument called
str weight_payload_dtypewhich tells which dtype is used for theweights. The weight tensor itself does not contain this info since the
FP4 and FP6 formats are storted into uint8 tensors.
The CAST_TO_BLOCK_SCALED required to transform activations from/to MXFP
is also updated to support the new datatype. Its custom TOSA op gets a
str output_dtypesimilar to theweight_payload_dtypefor the MXFPlinear op.
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani