Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)
from executorch.exir.passes import ToOutVarPass
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
from executorch.exir.program._program import to_edge_with_preserved_ops
from executorch.exir.program._program import to_edge
from torch._inductor.decomposition import remove_decompositions

from torch.export.exported_program import ExportedProgram
Expand Down Expand Up @@ -219,9 +219,9 @@ def quantize_pt2(
torch.ops.aten.angle.default,
torch.ops.aten.rms_norm.default,
]
TO_EDGE_PRESERVE_OPS: tuple[torch._ops.OpOverload, ...] = (
TO_EDGE_PRESERVE_OPS: list[torch._ops.OpOverload, ...] = [
torch.ops.aten.rms_norm.default,
)
]


def _lower_ep_to_edge(
Expand All @@ -233,18 +233,18 @@ def _lower_ep_to_edge(
"""
Lower an ExportedProgram to an EdgeProgramManager (in edge IR).
"""
# Call to_edge_with_preserved_ops to convert the graph to edge IR.
# Call to_edge to convert the graph to edge IR.
# Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
edge_prog_manager = to_edge_with_preserved_ops(
edge_prog_manager = to_edge(
expo_program,
compile_config=EdgeCompileConfig(
_skip_dim_order=True,
# Allow specific non-core aten ops in the IR.
_core_aten_ops_exception_list=TO_EDGE_OP_EXCEPTION_LIST
+ (core_aten_exceptions or []),
_preserve_ops=TO_EDGE_PRESERVE_OPS,
),
constant_methods=constant_methods,
preserve_ops=TO_EDGE_PRESERVE_OPS,
)

if dump_graphs:
Expand Down
2 changes: 1 addition & 1 deletion backends/nxp/nxp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def preprocess(
# Otherwise, we get violation that this op is not part of ATen Core ops.
edge_program._verifiers = [
EXIREdgeDialectVerifier(
class_only=True, exception_list=[torch.ops.aten.max_pool2d.default]
class_only=True, core_aten_ops_exception_list=[torch.ops.aten.max_pool2d.default]
)
]

Expand Down
16 changes: 8 additions & 8 deletions examples/apple/coreml/llama/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.exir.program._program import to_edge_with_preserved_ops
from executorch.exir.program._program import to_edge
from executorch.extension.export_util.utils import save_pte_program


Expand Down Expand Up @@ -196,17 +196,17 @@ def main() -> None:
print("Exported program")
print(ep)

edge_manager = to_edge_with_preserved_ops(
edge_manager = to_edge(
ep,
preserve_ops=[
torch.ops.aten.scaled_dot_product_attention.default,
# preserve norm op for numerical stability
torch.ops.aten.linalg_vector_norm.default,
torch.ops.aten.reciprocal.default,
],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
_preserve_ops=[
torch.ops.aten.scaled_dot_product_attention.default,
# preserve norm op for numerical stability
torch.ops.aten.linalg_vector_norm.default,
torch.ops.aten.reciprocal.default,
],
),
)
print("Edge program")
Expand Down
5 changes: 5 additions & 0 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ class EdgeCompileConfig:
# TODO(larryliu): remove this
_use_edge_ops: bool = True
# Allow core ATen ops check to be skipped for certain ops, but continue with the rest of the checks.
# Note: only use this for core ATen ops that are missing decompositions. This is temporary,
# enabling verification on the rest of the program until decomposition coverage is improved.
_core_aten_ops_exception_list: List[torch._ops.OpOverload] = field(
default_factory=list
)
# Allow ops to be preserved in the graph, i.e., prevent them from being decomposed.
# These may be core or non-core ATen ops; custom ops should not be here.
_preserve_ops: List[torch.torch._ops.OpOverload] = field(default_factory=list)
# TODO(gasoonjia): remove this
_skip_dim_order: bool = False

Expand Down
106 changes: 39 additions & 67 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import io
import logging
import os
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Type, Union

import torch
import torch._export
Expand All @@ -22,7 +22,6 @@
)
from executorch.exir._serialize._serialize import serialize_for_executorch
from executorch.exir._serialize.data_serializer import DataSerializer
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_api import (
MethodProgramsPartitionerSpec,
to_backend,
Expand Down Expand Up @@ -795,9 +794,19 @@ def _generate_edge_program(
name: str,
config: EdgeCompileConfig,
program: ExportedProgram,
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
) -> ExportedProgram:

"""
Args:
name: The name of the program.
config: The configuration for the edge program.
program: The exported program to be converted to an edge program.
core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten.
preserve_ops: A list of aten ops that should not be decomposed.
Returns:
An ExportedProgram in edge dialect.
"""
# Remove invalid assert ops, such as _assert_tensor_metadata
gm = program.graph_module
gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm)
Expand All @@ -812,7 +821,8 @@ def _generate_edge_program(
EXIRATenDialectVerifier(
edge_compile_config=config,
class_only=False,
exception_list=ops_set_to_not_decompose,
core_aten_ops_exception_list=core_aten_ops_exception_list,
preserve_ops=preserve_ops,
)(gm)
except ExportError as e:
logging.info(f"Input program {name} is not in ATen dialect.")
Expand Down Expand Up @@ -848,7 +858,8 @@ def _generate_edge_program(
EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
exception_list=ops_set_to_not_decompose,
core_aten_ops_exception_list=core_aten_ops_exception_list,
preserve_ops=preserve_ops,
)
],
)
Expand All @@ -864,7 +875,7 @@ def _replace_aten_ops_with_transformed_ops(
program: ExportedProgram,
partitioner,
):
ops_to_not_decompose = set()
preserve_ops = set()
partitioners = partitioner.get(name)
if partitioners is None:
return
Expand All @@ -889,7 +900,7 @@ def _replace_aten_ops_with_transformed_ops(
and node.target in ops_set_to_not_decompose
and is_op_supported
):
ops_to_not_decompose.add(node.target)
preserve_ops.add(node.target)
node.target = aten_op_to_transform_op[node.target]

for _, submod, _ in get_control_flow_submodules(program.graph_module):
Expand All @@ -900,10 +911,10 @@ def _replace_aten_ops_with_transformed_ops(
and node.target in ops_set_to_not_decompose
and is_op_supported
):
ops_to_not_decompose.add(node.target)
preserve_ops.add(node.target)
node.target = aten_op_to_transform_op[node.target]

return ops_to_not_decompose
return preserve_ops


def _restore_transformed_ops_to_aten_ops(program: ExportedProgram):
Expand Down Expand Up @@ -1014,7 +1025,7 @@ def _sanity_check_graph_for_non_decomp_ops(


def _remove_invalid_ops_for_not_decompose(
ops_to_not_decompose: List[torch._ops.OpOverload],
preserve_ops: List[torch._ops.OpOverload],
) -> List[torch._ops.OpOverload]:
_logged_warnings = set()

Expand Down Expand Up @@ -1079,7 +1090,7 @@ def keep(op):
return False
return True

return list(filter(keep, ops_to_not_decompose))
return list(filter(keep, preserve_ops))


def _gen_edge_manager_for_partitioners(
Expand Down Expand Up @@ -1136,7 +1147,7 @@ def _gen_edge_manager_for_partitioners(
name,
config,
program,
list(ops_set_to_not_decompose_by_program.get(name, [])),
preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])),
)

edge_manager = EdgeProgramManager(
Expand Down Expand Up @@ -1281,61 +1292,12 @@ def to_edge_transform_and_lower(
EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
exception_list=list(ops_set_to_not_decompose),
preserve_ops=list(ops_set_to_not_decompose),
)()(program.graph_module)

return edge_manager


@experimental(
"""
This is an experimental API which overloads to_edge by preserving specified ops to not be decomposed.
This function will be combined with to_edge in the future.
"""
)
def to_edge_with_preserved_ops(
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
constant_methods: Optional[Dict[str, Any]] = None,
compile_config: Optional[EdgeCompileConfig] = None,
preserve_ops: Tuple[torch._ops.OpOverload, ...] = (),
) -> "EdgeProgramManager":
"""
:func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in
ATen dialect. Upon construction those programs are transformed into edge dialect.

Args:
programs: Can be a single ExportedProgram or a dictionary mapping function names to their corresponding ExportedPrograms. If only a single ExportedProgram is provided it will be assigned the name "forward".
constant_methods: An optional dictionary of method name to the constant value returned by that method in eager mode. Often used to store config information on Edge models.
compile_config: An optional argument used to provide greater control over the transformation to edge dialect process.
preserve_ops: An argument used to specify ops that should not be decomposed.

Returns:
EdgeProgramManager
"""
assert not isinstance(constant_methods, EdgeCompileConfig)
config = compile_config or EdgeCompileConfig()
if not isinstance(programs, dict):
aten_programs = {"forward": programs}
else:
aten_programs = programs

edge_programs: Dict[str, ExportedProgram] = {}

for name, program in aten_programs.items():
# Decompose to Core ATen
table = _default_decomposition_table()
for op in preserve_ops:
table.pop(op, None)
program = program.run_decompositions(table)
edge_programs[name] = _generate_edge_program(
name, config, program, list(preserve_ops)
)

return EdgeProgramManager(
edge_programs, constant_methods, config, list(preserve_ops)
)


@et_logger("to_edge")
def to_edge(
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
Expand Down Expand Up @@ -1367,8 +1329,16 @@ def to_edge(

for name, program in aten_programs.items():
# Decompose to Core ATen
program = program.run_decompositions(_default_decomposition_table())
edge_programs[name] = _generate_edge_program(name, config, program)
table = _default_decomposition_table()
preserve_ops = []
if compile_config:
preserve_ops = compile_config._preserve_ops
for op in compile_config._preserve_ops:
table.pop(op, None)
program = program.run_decompositions(table)
edge_programs[name] = _generate_edge_program(
name, config, program, preserve_ops=preserve_ops
)

return EdgeProgramManager(edge_programs, constant_methods, config)

Expand All @@ -1389,7 +1359,8 @@ def __init__(
edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
constant_methods: Optional[Dict[str, Any]] = None,
compile_config: Optional[EdgeCompileConfig] = None,
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
):
"""
Should not be called directly by users. User should use :func:'to_edge' instead.
Expand All @@ -1404,7 +1375,8 @@ def __init__(
try:
EXIREdgeDialectVerifier(
edge_compile_config=self.compile_config,
exception_list=ops_set_to_not_decompose,
core_aten_ops_exception_list=core_aten_ops_exception_list,
preserve_ops=preserve_ops,
)(program.graph_module)
except ExportError as e:
logging.info(f"Input program {name} is not in aten dialect.")
Expand Down
5 changes: 3 additions & 2 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
ExecutorchProgramManager,
to_edge,
to_edge_transform_and_lower,
to_edge_with_preserved_ops,
)
from executorch.exir.tracer import _default_decomposition_table
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
Expand Down Expand Up @@ -784,7 +783,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def _test_to_edge_with_preserved_ops(
self, program, preserved_ops, expected_preserved_ops
):
edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops)
edge = to_edge(
program, compile_config=EdgeCompileConfig(_preserve_ops=preserved_ops)
)

def count_nodes(graph_module, target):
count = 0
Expand Down
14 changes: 14 additions & 0 deletions exir/verification/test/test_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,17 @@ def forward(self, input, label):
edge_verifier = EXIREdgeDialectVerifier()

edge_verifier(edge.exported_program())

def test_verifier_preserve_ops_view(self) -> None:
class TestExpand(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.expand(2, 2, 2, 2)

model = TestExpand()
config = EdgeCompileConfig(_preserve_ops=[torch.ops.aten.expand.default])
export_model = export(model, (torch.randn(2, 2, 2, 2),), strict=True)
with self.assertRaises(RuntimeError):
to_edge(export_model, compile_config=config)
Loading
Loading