[Feat] Support dataclass in magi_register_custom_op#32
Open
themistbeforedawn wants to merge 1 commit into
Open
[Feat] Support dataclass in magi_register_custom_op#32themistbeforedawn wants to merge 1 commit into
themistbeforedawn wants to merge 1 commit into
Conversation
jiahy0825
reviewed
May 17, 2026
Comment on lines
+58
to
+95
| Part B. Runtime paths -- the three pipelines | ||
| ============================================ | ||
|
|
||
| Three pipelines are possible; the decorator returns whichever object sits | ||
| at the end of the path: | ||
|
|
||
| 1. simple fn -> torch_registered_op | ||
| Returned: ``torch._ops.OpOverload`` (slot 2). | ||
| Runtime: zero magi-level overhead -- straight into torch.library's | ||
| dispatcher. | ||
|
|
||
| 2. sig-only-rewrite fn -> lowered_fn -> torch_registered_op | ||
| Returned: ``torch._ops.OpOverload`` (slot 2). | ||
| Runtime: same as simple -- ``lowered_fn`` is a transparent | ||
| forwarding shim (the rewrite is registration-time only). | ||
|
|
||
| 3. dataclass-flatten fn -> lowered_fn -> torch_registered_op | ||
| -> magi_exposed_op | ||
| Returned: a Python callable carrying the | ||
| ``_magi_torch_registered_op`` attribute (slot 3). | ||
| Runtime forward (per call): | ||
| user code calls magi_exposed_op(x, cfg=...) | ||
| -> _flatten_call_args (original kwargs -> flat tuple) | ||
| -> _flatten_value_into (DFS over param_mapping_tree) | ||
| -> torch_registered_op(*flat) (slot 2 -- enters dispatcher) | ||
| -> lowered_fn(*flat) (slot 1 -- still in lowered shape) | ||
| -> _reassemble_kwargs (flat tuple -> original kwargs) | ||
| -> _build_value_from_node (rebuilds dataclass instances) | ||
| -> fn(**original_kwargs) (slot 0 -- user code finally sees | ||
| its original dataclass-bearing | ||
| signature) | ||
| Runtime backward (when backward_fn is supplied): | ||
| autograd calls _bridged_backward(ctx, *grads) | ||
| -> user_backward(ctx, *grads) (returns one grad per ORIGINAL | ||
| input, possibly a dataclass-shaped | ||
| grad object) | ||
| -> _flatten_grads (original grads -> flat grads) | ||
| -> _flatten_grad_into (DFS over param_mapping_tree) |
Collaborator
There was a problem hiding this comment.
Make these comments self-explanatory by code instead of writing docs.
Or move these comments into a markdown document that clarifies the design of _register_custom_op.
Both acceptable for me haha~
Comment on lines
+546
to
579
| """Lower ``fn``'s signature into a form ``torch.library.infer_schema`` accepts. | ||
|
|
||
| "Lower" is used in the compiler sense (high-level -> low-level): we walk | ||
| ``fn``'s parameters once and do six things at the same time -- they all | ||
| need the same resolved annotations and the same iteration: | ||
|
|
||
| 1. VALIDATE -- reject variadics, missing annotations, mutable dataclasses, | ||
| unsupported containers, dataclass returns (sec 1). | ||
| 2. RESOLVE -- turn stringified annotations into real types via | ||
| ``_resolve_annotations``, so dataclass detection works. | ||
| 3. NORMALIZE -- collapse parameter kinds to POSITIONAL_OR_KEYWORD, | ||
| downgrade Literal/Enum to ``str``, scrub unsupported defaults. | ||
| 4. FLATTEN -- expand each frozen-dataclass parameter (recursively) into | ||
| its primitive leaves via ``_build_dataclass_sub_mapping_tree``. | ||
| 5. PYTREE -- side effect of step 4: register every dataclass as a pytree | ||
| node so Dynamo / AOTAutograd can trace through it. | ||
| 6. EMIT -- assemble ``(original_sig, lowered_sig, param_mapping_tree)``. | ||
|
|
||
| A single pass is intentional: splitting concerns would force re-resolving | ||
| annotations and threading accumulator state. When the input is already | ||
| schema-compatible the lowered signature is bit-identical to the original, | ||
| and the caller's ``_signatures_differ`` check restores the zero-overhead path. | ||
|
|
||
| Returns: | ||
| - 1 if the return type is a single Tensor | ||
| - N if the return type is tuple[Tensor, Tensor, ...] with N elements | ||
| - 1 if no annotation or unrecognized annotation (default to single output) | ||
| original_sig (inspect.Signature): the user's un-flattened signature. | ||
| lowered_sig (inspect.Signature): what ``infer_schema`` will see. | ||
| param_mapping_tree (list[tuple]): the bridge between the two; a list | ||
| of root nodes (one per original parameter), each of which is: | ||
| * ``("primitive", attr_name, lowered_name, None)``, or | ||
| * ``("dataclass", attr_name, cls, [child_nodes...])``. | ||
| ``attr_name`` is the parameter name at top level / field name | ||
| deeper down. The same tree drives both runtime translation | ||
| directions (sec 7). | ||
| """ |
Collaborator
There was a problem hiding this comment.
These comments confuse me. Use a human-readable sentence instead of AI-explained comments.
It takes me a long time to understand what these comments say. Provide simplified comments and code examples. I guess you are trying to say the return follows such rules, just paste the examples below:
original_sig: (q: Tensor, cfg: AttnCfg, mode: Literal['a','b']='a') -> Tensor
lowered_sig: (q: Tensor, cfg__scale: float, cfg__causal: bool, mode: str='a') -> Tensor
param_mapping_tree ≈
[
("primitive", "q", "q", None),
("dataclass", "cfg", AttnCfg, [
("primitive", "scale", "cfg__scale", None),
("primitive", "causal", "cfg__causal", None),
]),
("primitive", "mode", "mode", None),
]
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
🗂️ PR Category
📝 Description
What's new
@magi_register_custom_opnow accepts frozen-dataclass parameters (recursively nested), so users can group config / flags as a single@dataclass(frozen=True)whiletorch.library's schema continues to see only primitives:The same lower-signature pass also handles (transparent to users):
Literal[...]/ string-Enumannotations → auto-downgraded tostrmutates_argsaccepts either the dataclass-level name (expands to all Tensor leaves) or any lowered leaf namebackward_fnreturns one grad per original parameter (not per leaf); a whole non-differentiable dataclass arg collapses to a singleNoneArchitecture — 4-slot pipeline
Each registration owns up to 4 named objects:
fnlowered_fntorch_registered_optorch.librarymagi_exposed_opThe naming is deliberately dual:
torch_registered_opis registered intotorch.library's dispatcher;magi_exposed_opis exposed out of Magi to the user.Architecture — 3 runtime paths
The slot set produced at registration time selects one of three runtime paths:
fn → torch_registered_op— zero per-call overhead; returns theOpOverloaddirectlyfn → lowered_fn → torch_registered_op— e.g.Literaldowngrade onlyfn → lowered_fn → torch_registered_op → magi_exposed_op— the wrapper flattens / unflattens on every call; the underlyingOpOverloadis accessible viaop._magi_torch_registered_opmagi_compiler/_magi_register_custom_op.pyis laid out 1:1 against this model — 8 numbered sections grouped into registration-time helpers · runtime helpers · main pipeline.Tests
83 new tests in
tests/api_tests/test_register_custom_op.pycover all three runtime paths, autograd bridging through dataclass inputs, nested dataclasses,Optional/Literal/Enum/ dtype / device fields,torch.compileintegration, and full error-path coverage.