From af246961b84a60d91408d507c97980638f85002b Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Thu, 31 Jul 2025 20:47:48 -0700 Subject: [PATCH 1/2] [Executorch][Export][2/N] Add to_edge and to_backend stages Pull Request resolved: https://github.com/pytorch/executorch/pull/12937 Address (6) in the rfc: https://github.com/pytorch/executorch/issues/12660 1. Adds stage implementations for `to_edge` and `to_backend` 2. Adds unit tests for the two stages 3. Adds these two stages in the validation pipeline. Fixes #12932 ghstack-source-id: 300019403 @exported-using-ghexport Differential Revision: [D79120576](https://our.internmc.facebook.com/intern/diff/D79120576/) --- export/export.py | 36 +++++++-- export/stages.py | 116 +++++++++++++++++++++++++++- export/tests/test_export_session.py | 11 ++- export/tests/test_export_stages.py | 104 +++++++++++++++++++++++++ export/types.py | 2 + 5 files changed, 259 insertions(+), 10 deletions(-) diff --git a/export/export.py b/export/export.py index f5b0c6149d0..ac9d894fea1 100644 --- a/export/export.py +++ b/export/export.py @@ -24,6 +24,8 @@ QuantizeStage, SourceTransformStage, Stage, + ToBackendStage, + ToEdgeStage, TorchExportStage, ) from .types import StageType @@ -147,7 +149,9 @@ def __init__( ) # Stage registry: map of StageType to Stage instances - self._stage_registry: Dict[StageType, Stage] = self._build_default_stages() + self._stage_registry: Dict[StageType, Stage] = self._build_stages( + self._pipeline_stages + ) # Intialize run context self._run_context: Dict[str, Any] = { @@ -170,10 +174,12 @@ def _get_default_pipeline(self) -> List[StageType]: StageType.TO_EXECUTORCH, ] - def _build_default_stages(self) -> Dict[StageType, Stage]: + def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]: + """Build the stage registry from the given stages.""" stage_registry: Dict[StageType, Stage] = {} - for stage_type in self._get_default_pipeline(): + stage = None + for stage_type in stages or self._get_default_pipeline(): if stage_type == StageType.SOURCE_TRANSFORM: stage = SourceTransformStage(self._quant_recipe) elif stage_type == StageType.QUANTIZE: @@ -191,12 +197,24 @@ def _build_default_stages(self) -> Dict[StageType, Stage]: transform_passes=self._export_recipe.edge_transform_passes, compile_config=self._export_recipe.edge_compile_config, ) + elif stage_type == StageType.TO_EDGE: + stage = ToEdgeStage( + edge_compile_config=self._export_recipe.edge_compile_config + ) + elif stage_type == StageType.TO_BACKEND: + stage = ToBackendStage( + partitioners=self._export_recipe.partitioners, + transform_passes=self._export_recipe.edge_transform_passes, + ) elif stage_type == StageType.TO_EXECUTORCH: stage = ExecutorchStage(self._export_recipe.executorch_backend_config) else: - raise ValueError(f"Unknown stage type: {stage_type}") + logging.info( + f"{stage_type} is unknown, you have to register it before executing export()" + ) - stage_registry[stage_type] = stage + if stage: + stage_registry[stage_type] = stage return stage_registry def register_stage(self, stage_type: StageType, stage: Stage) -> None: @@ -241,7 +259,9 @@ def _validate_pipeline_sequence( first_stage = stages[0] first_stage_instance = self._stage_registry.get(first_stage) if first_stage_instance is None: - raise ValueError(f"Stage {first_stage} not found in registry") + raise ValueError( + f"Stage {first_stage} not found in registry, register it using session.register_stage()" + ) if not first_stage_instance.can_start_pipeline: raise ValueError(f"Stage {first_stage} cannot start a pipeline. ") @@ -254,7 +274,9 @@ def _validate_pipeline_sequence( # Get the stage instance to check its valid predecessors stage_instance = self._stage_registry.get(current_stage) if stage_instance is None: - raise ValueError(f"Stage {current_stage} not found in registry") + raise ValueError( + f"Stage {current_stage} not found in registry, , register it using session.register_stage()" + ) valid_predecessors = stage_instance.valid_predecessor_stages diff --git a/export/stages.py b/export/stages.py index 61672e55bb7..fd27c298028 100644 --- a/export/stages.py +++ b/export/stages.py @@ -10,8 +10,9 @@ import torch from executorch.devtools.backend_debug import get_delegation_info +from executorch.exir import EdgeCompileConfig from executorch.exir.backend.backend_api import validation_disabled -from executorch.exir.program import to_edge_transform_and_lower +from executorch.exir.program import to_edge, to_edge_transform_and_lower from executorch.exir.program._program import _transform from executorch.export.recipe import QuantizationRecipe from executorch.export.types import StageType @@ -223,7 +224,7 @@ def stage_type(self) -> str: @property def valid_predecessor_stages(self) -> List["StageType"]: - return [StageType.TO_EDGE_TRANSFORM_AND_LOWER] + return [StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_BACKEND] @property def can_start_pipeline(self) -> bool: @@ -354,3 +355,114 @@ def run(self, artifact: PipelineArtifact) -> None: quantized_models[method_name] = quantized_model self._artifact = artifact.copy_with_new_data(quantized_models) + + +class ToEdgeStage(Stage): + """ + Stage: Convert ExportedProgram to EdgeProgramManager. + """ + + def __init__( + self, + edge_compile_config: Optional[EdgeCompileConfig] = None, # pyre-ignore + ) -> None: + super().__init__() + self._edge_compile_config = edge_compile_config + + @property + def stage_type(self) -> str: + return StageType.TO_EDGE + + @property + def valid_predecessor_stages(self) -> List["StageType"]: + return [StageType.TORCH_EXPORT] + + @property + def can_start_pipeline(self) -> bool: + return False + + def run(self, artifact: PipelineArtifact) -> None: + """ + Convert ExportedProgram to EdgeProgramManager. + + Args: + artifact: Contains exported programs and context + """ + exported_programs = artifact.data + constant_methods = artifact.get_context("constant_methods") + + # Convert to edge program manager + edge_program_manager = to_edge( + exported_programs, + constant_methods=constant_methods, + compile_config=self._edge_compile_config, + ) + + self._artifact = artifact.copy_with_new_data(edge_program_manager) + + +class ToBackendStage(Stage): + """ + Stage: Apply transformations and partitioning to EdgeProgramManager. + """ + + def __init__( + self, + partitioners: Optional[List[Any]] = None, + transform_passes: Optional[Sequence[Callable[[Any], Optional[Any]]]] = None, + ) -> None: + super().__init__() + self._partitioners = partitioners + self._transform_passes = transform_passes + + @property + def stage_type(self) -> str: + return StageType.TO_BACKEND + + @property + def valid_predecessor_stages(self) -> List["StageType"]: + return [StageType.TO_EDGE] + + @property + def can_start_pipeline(self) -> bool: + return False + + def run(self, artifact: PipelineArtifact) -> None: + """ + Apply transformations and partitioning to EdgeProgramManager. + + Args: + artifact: Contains edge program manager and context + """ + edge_program_manager = artifact.data + + if edge_program_manager is None: + raise RuntimeError("Edge program manager is not set.") + + # Apply transform passes if available + if self._transform_passes: + edge_program_manager = edge_program_manager.transform( + self._transform_passes + ) + + # Apply partitioners if available + if self._partitioners is not None and len(self._partitioners) > 0: + with validation_disabled(): + # pyre-ignore + for partitioner in self._partitioners: + edge_program_manager = edge_program_manager.to_backend(partitioner) + + # Get delegation info + delegation_info = get_delegation_info( + edge_program_manager.exported_program().graph_module + ) + + self._artifact = artifact.copy_with_new_data(edge_program_manager) + self._artifact.add_context("delegation_info", delegation_info) + + @property + def delegation_info(self) -> Any: + """ + Returns the delegation info. + """ + return self._artifact.get_context("delegation_info") diff --git a/export/tests/test_export_session.py b/export/tests/test_export_session.py index cc9f2a74062..7bef0d01876 100644 --- a/export/tests/test_export_session.py +++ b/export/tests/test_export_session.py @@ -249,7 +249,7 @@ def _get_export_session(self, stages: List[StageType]): def test_valid_pipeline_sequences(self) -> None: """Test various valid pipeline sequences.""" valid_sequences = [ - # Full pipeline + # Full pipeline with to_edge_transform_lower [ StageType.SOURCE_TRANSFORM, StageType.QUANTIZE, @@ -257,6 +257,15 @@ def test_valid_pipeline_sequences(self) -> None: StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_EXECUTORCH, ], + # Full pipeline with to_edge, to_backend + [ + StageType.SOURCE_TRANSFORM, + StageType.QUANTIZE, + StageType.TORCH_EXPORT, + StageType.TO_EDGE, + StageType.TO_BACKEND, + StageType.TO_EXECUTORCH, + ], # Skip quantize [ StageType.SOURCE_TRANSFORM, diff --git a/export/tests/test_export_stages.py b/export/tests/test_export_stages.py index 5d83b4f9046..2b3e533723a 100644 --- a/export/tests/test_export_stages.py +++ b/export/tests/test_export_stages.py @@ -19,6 +19,8 @@ QuantizeStage, SourceTransformStage, StageType, + ToBackendStage, + ToEdgeStage, TorchExportStage, ) from torch.export import ExportedProgram @@ -282,3 +284,105 @@ def test_run_empty_example_inputs(self) -> None: self.assertIn( "Example inputs for method forward not found or empty", str(cm.exception) ) + + +class TestToEdgeStage(unittest.TestCase): + def setUp(self) -> None: + self.mock_exported_program = Mock(spec=ExportedProgram) + self.exported_programs = {"forward": self.mock_exported_program} + self.context = {"constant_methods": None} + + @patch("executorch.export.stages.to_edge") + def test_run_success(self, mock_to_edge: Mock) -> None: + mock_edge_manager = Mock(spec=EdgeProgramManager) + mock_to_edge.return_value = mock_edge_manager + mock_config = Mock() + + stage = ToEdgeStage(edge_compile_config=mock_config) + artifact = PipelineArtifact(data=self.exported_programs, context=self.context) + stage.run(artifact) + + # Verify to_edge was called with correct parameters + mock_to_edge.assert_called_once_with( + self.exported_programs, + constant_methods=None, + compile_config=mock_config, + ) + + # Verify artifacts are set correctly + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, mock_edge_manager) + + +class TestToBackendStage(unittest.TestCase): + def setUp(self) -> None: + self.mock_edge_manager = Mock(spec=EdgeProgramManager) + self.context = {} + + @patch("executorch.export.stages.get_delegation_info") + def test_run_success_no_transforms_or_partitioners( + self, mock_get_delegation_info: Mock + ) -> None: + # Test successful execution without transforms or partitioners + mock_delegation_info = {"delegation": "info"} + mock_get_delegation_info.return_value = mock_delegation_info + mock_exported_program = Mock() + mock_graph_module = Mock() + mock_exported_program.graph_module = mock_graph_module + self.mock_edge_manager.exported_program.return_value = mock_exported_program + + stage = ToBackendStage() + artifact = PipelineArtifact(data=self.mock_edge_manager, context=self.context) + stage.run(artifact) + + # Verify get_delegation_info was called + mock_get_delegation_info.assert_called_once_with(mock_graph_module) + + # Verify artifacts are set correctly + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, self.mock_edge_manager) + self.assertEqual( + result_artifact.get_context("delegation_info"), mock_delegation_info + ) + + @patch("executorch.export.stages.get_delegation_info") + def test_run_with_partitioners_and_passes( + self, mock_get_delegation_info: Mock + ) -> None: + mock_delegation_info = {"delegation": "info"} + mock_get_delegation_info.return_value = mock_delegation_info + mock_exported_program = Mock() + mock_graph_module = Mock() + mock_exported_program.graph_module = mock_graph_module + + mock_edge_program_manager = Mock(spec=EdgeProgramManager) + mock_edge_program_manager.transform.return_value = mock_edge_program_manager + mock_edge_program_manager.to_backend.return_value = mock_edge_program_manager + + mock_partitioner = Mock() + mock_transform_passes = [Mock(), Mock()] + stage = ToBackendStage( + partitioners=[mock_partitioner], transform_passes=mock_transform_passes + ) + artifact = PipelineArtifact( + data=mock_edge_program_manager, context=self.context + ) + stage.run(artifact) + + # Verify transform and to_backend called correctly + mock_edge_program_manager.transform.assert_called_once_with( + mock_transform_passes + ) + mock_edge_program_manager.to_backend.assert_called_once_with(mock_partitioner) + + # Verify artifacts contain the backend manager + result_artifact = stage.get_artifacts() + self.assertEqual(result_artifact.data, mock_edge_program_manager) + + def test_run_edge_manager_none(self) -> None: + stage = ToBackendStage() + artifact = PipelineArtifact(data=None, context=self.context) + + with self.assertRaises(RuntimeError) as cm: + stage.run(artifact) + self.assertIn("Edge program manager is not set", str(cm.exception)) diff --git a/export/types.py b/export/types.py index 8ffa287f91a..760f8461d41 100644 --- a/export/types.py +++ b/export/types.py @@ -16,4 +16,6 @@ class StageType(str, Enum): QUANTIZE = "quantize" TORCH_EXPORT = "torch_export" TO_EDGE_TRANSFORM_AND_LOWER = "to_edge_transform_and_lower" + TO_EDGE = "to_edge" + TO_BACKEND = "to_backend" TO_EXECUTORCH = "to_executorch" From 8471ded174e7674216069938367140515880bcb7 Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Thu, 31 Jul 2025 20:47:49 -0700 Subject: [PATCH 2/2] [Executorch][Export][3/N] Modularize export recipes Pull Request resolved: https://github.com/pytorch/executorch/pull/12938 Addresses (7) in the rfc: https://github.com/pytorch/executorch/issues/12660 Changes: 1. Add data class called `LoweringRecipe` 2. Modify current xnnpack recipes to use lowering recipes Fixes: #12933 ghstack-source-id: 300019402 Differential Revision: [D79120575](https://our.internmc.facebook.com/intern/diff/D79120575/) --- .../recipes/xnnpack_recipe_provider.py | 18 +++++--- export/__init__.py | 3 +- export/export.py | 21 ++++----- export/recipe.py | 32 +++++++++---- export/stages.py | 36 ++++++++++++++- export/tests/test_export_session.py | 46 +++++++++++++++++++ 6 files changed, 125 insertions(+), 31 deletions(-) diff --git a/backends/xnnpack/recipes/xnnpack_recipe_provider.py b/backends/xnnpack/recipes/xnnpack_recipe_provider.py index 9d00c3c9c98..8fba58c12c3 100644 --- a/backends/xnnpack/recipes/xnnpack_recipe_provider.py +++ b/backends/xnnpack/recipes/xnnpack_recipe_provider.py @@ -27,6 +27,7 @@ from executorch.export import ( BackendRecipeProvider, ExportRecipe, + LoweringRecipe, QuantizationRecipe, RecipeType, ) @@ -88,12 +89,19 @@ def create_recipe( ) return None + def _get_xnnpack_lowering_recipe( + self, precision_type: Optional[ConfigPrecisionType] = None + ) -> LoweringRecipe: + return LoweringRecipe( + partitioners=[XnnpackPartitioner(precision_type=precision_type)], + edge_compile_config=get_xnnpack_edge_compile_config(), + ) + def _build_fp32_recipe(self, recipe_type: RecipeType) -> ExportRecipe: return ExportRecipe( name=recipe_type.value, - edge_compile_config=get_xnnpack_edge_compile_config(), + lowering_recipe=self._get_xnnpack_lowering_recipe(), executorch_backend_config=get_xnnpack_executorch_backend_config(), - partitioners=[XnnpackPartitioner()], ) def _build_quantized_recipe( @@ -120,9 +128,8 @@ def _build_quantized_recipe( return ExportRecipe( name=recipe_type.value, quantization_recipe=quant_recipe, - edge_compile_config=get_xnnpack_edge_compile_config(), + lowering_recipe=self._get_xnnpack_lowering_recipe(precision_type), executorch_backend_config=get_xnnpack_executorch_backend_config(), - partitioners=[XnnpackPartitioner(config_precision=precision_type)], ) def _build_int8da_intx_weight_recipe( @@ -150,9 +157,8 @@ def _build_int8da_intx_weight_recipe( return ExportRecipe( name=recipe_type.value, quantization_recipe=quant_recipe, - edge_compile_config=get_xnnpack_edge_compile_config(), + lowering_recipe=self._get_xnnpack_lowering_recipe(), executorch_backend_config=get_xnnpack_executorch_backend_config(), - partitioners=[XnnpackPartitioner()], ) def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None: diff --git a/export/__init__.py b/export/__init__.py index 2ee5026d320..d5f3826ab90 100644 --- a/export/__init__.py +++ b/export/__init__.py @@ -15,7 +15,7 @@ """ from .export import export, ExportSession -from .recipe import ExportRecipe, QuantizationRecipe, RecipeType +from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType from .recipe_provider import BackendRecipeProvider from .recipe_registry import recipe_registry from .types import StageType @@ -23,6 +23,7 @@ __all__ = [ "StageType", "ExportRecipe", + "LoweringRecipe", "QuantizationRecipe", "ExportSession", "export", diff --git a/export/export.py b/export/export.py index ac9d894fea1..e5c3b793ccd 100644 --- a/export/export.py +++ b/export/export.py @@ -16,7 +16,7 @@ from tabulate import tabulate from torch import nn -from .recipe import ExportRecipe, QuantizationRecipe +from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe from .stages import ( EdgeTransformAndLowerStage, ExecutorchStage, @@ -143,6 +143,10 @@ def __init__( self._export_recipe.quantization_recipe ) + self._lowering_recipe: Optional[LoweringRecipe] = ( + self._export_recipe.lowering_recipe + ) + # Stages to run self._pipeline_stages = ( self._export_recipe.pipeline_stages or self._get_default_pipeline() @@ -192,20 +196,11 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]: ) stage = TorchExportStage(pre_edge_passes) elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER: - stage = EdgeTransformAndLowerStage( - partitioners=self._export_recipe.partitioners, - transform_passes=self._export_recipe.edge_transform_passes, - compile_config=self._export_recipe.edge_compile_config, - ) + stage = EdgeTransformAndLowerStage.from_recipe(self._lowering_recipe) elif stage_type == StageType.TO_EDGE: - stage = ToEdgeStage( - edge_compile_config=self._export_recipe.edge_compile_config - ) + stage = ToEdgeStage.from_recipe(self._lowering_recipe) elif stage_type == StageType.TO_BACKEND: - stage = ToBackendStage( - partitioners=self._export_recipe.partitioners, - transform_passes=self._export_recipe.edge_transform_passes, - ) + stage = ToBackendStage.from_recipe(self._lowering_recipe) elif stage_type == StageType.TO_EXECUTORCH: stage = ExecutorchStage(self._export_recipe.executorch_backend_config) else: diff --git a/export/recipe.py b/export/recipe.py index 315404c54af..8f7251cd419 100644 --- a/export/recipe.py +++ b/export/recipe.py @@ -89,6 +89,26 @@ def get_quantizers(self) -> Optional[List[Quantizer]]: return self.quantizers +@dataclass +class LoweringRecipe: + """ + Configuration recipe for lowering and partitioning. + + This class holds the configuration parameters for lowering a model + to backend-specific representations. + + Attributes: + partitioners: Optional list of partitioners for model partitioning + edge_transform_passes: Optional sequence of transformation passes to apply + edge_compile_config: Optional edge compilation configuration + """ + + partitioners: Optional[List[Partitioner]] = None + edge_transform_passes: Optional[Sequence[PassType]] = None + # pyre-ignore[11]: Type not defined + edge_compile_config: Optional[EdgeCompileConfig] = None + + @experimental( "This API and all of its related functionality such as ExportSession and ExportRecipe are experimental." ) @@ -103,13 +123,9 @@ class ExportRecipe: Attributes: name: Optional name for the recipe quantization_recipe: Optional quantization recipe for model quantization - edge_compile_config: Optional edge compilation configuration pre_edge_transform_passes: Optional function to apply transformation passes before edge lowering - edge_transform_passes: Optional sequence of transformation passes to apply - during edge lowering - transform_check_ir_validity: Whether to check IR validity during transformation - partitioners: Optional list of partitioners for model partitioning + lowering_recipe: Optional lowering recipe for model lowering and partitioning executorch_backend_config: Optional backend configuration for ExecuTorch pipeline_stages: Optional list of stages to execute, defaults to a standard pipeline. mode: Export mode (debug or release) @@ -117,12 +133,8 @@ class ExportRecipe: name: Optional[str] = None quantization_recipe: Optional[QuantizationRecipe] = None - # pyre-ignore[11]: Type not defined - edge_compile_config: Optional[EdgeCompileConfig] = None pre_edge_transform_passes: Optional[Sequence[PassType]] = None - edge_transform_passes: Optional[Sequence[PassType]] = None - transform_check_ir_validity: bool = True - partitioners: Optional[List[Partitioner]] = None + lowering_recipe: Optional[LoweringRecipe] = None # pyre-ignore[11]: Type not defined executorch_backend_config: Optional[ExecutorchBackendConfig] = None pipeline_stages: Optional[List[StageType]] = None diff --git a/export/stages.py b/export/stages.py index fd27c298028..dd22155e929 100644 --- a/export/stages.py +++ b/export/stages.py @@ -14,7 +14,7 @@ from executorch.exir.backend.backend_api import validation_disabled from executorch.exir.program import to_edge, to_edge_transform_and_lower from executorch.exir.program._program import _transform -from executorch.export.recipe import QuantizationRecipe +from executorch.export.recipe import LoweringRecipe, QuantizationRecipe from executorch.export.types import StageType from torch import nn from torch._export.pass_base import PassType @@ -168,6 +168,19 @@ def __init__( self._transform_passes = transform_passes self._compile_config = compile_config + @classmethod + def from_recipe( + cls, lowering_recipe: Optional["LoweringRecipe"] + ) -> "EdgeTransformAndLowerStage": + if lowering_recipe is None: + return cls() + + return cls( + partitioners=lowering_recipe.partitioners, + transform_passes=lowering_recipe.edge_transform_passes, + compile_config=lowering_recipe.edge_compile_config, + ) + @property def stage_type(self) -> str: return StageType.TO_EDGE_TRANSFORM_AND_LOWER @@ -369,6 +382,15 @@ def __init__( super().__init__() self._edge_compile_config = edge_compile_config + @classmethod + def from_recipe(cls, lowering_recipe: Optional["LoweringRecipe"]) -> "ToEdgeStage": + if lowering_recipe is None: + return cls() + + return cls( + edge_compile_config=lowering_recipe.edge_compile_config, + ) + @property def stage_type(self) -> str: return StageType.TO_EDGE @@ -415,6 +437,18 @@ def __init__( self._partitioners = partitioners self._transform_passes = transform_passes + @classmethod + def from_recipe( + cls, lowering_recipe: Optional["LoweringRecipe"] + ) -> "ToBackendStage": + if lowering_recipe is None: + return cls() + + return cls( + partitioners=lowering_recipe.partitioners, + transform_passes=lowering_recipe.edge_transform_passes, + ) + @property def stage_type(self) -> str: return StageType.TO_BACKEND diff --git a/export/tests/test_export_session.py b/export/tests/test_export_session.py index 7bef0d01876..92aeebb7304 100644 --- a/export/tests/test_export_session.py +++ b/export/tests/test_export_session.py @@ -12,6 +12,7 @@ import torch from executorch.export import ExportRecipe, ExportSession +from executorch.export.recipe import LoweringRecipe, QuantizationRecipe from executorch.export.stages import PipelineArtifact from executorch.export.types import StageType @@ -434,3 +435,48 @@ def test_save_to_pte_invalid_name(self) -> None: with self.assertRaises(AssertionError): session.save_to_pte(None) # pyre-ignore + + +class TestExportSessionPipelineBuilding(unittest.TestCase): + """Test pipeline building and stage configuration.""" + + def setUp(self) -> None: + self.model = SimpleTestModel() + self.example_inputs = [(torch.randn(2, 10),)] + + def test_pipeline_building_with_all_recipes(self) -> None: + """Test pipeline building with quantization and lowering recipes.""" + # Create comprehensive recipes + quant_recipe = QuantizationRecipe( + ao_base_config=[Mock()], + quantizers=[Mock()], + ) + lowering_recipe = LoweringRecipe( + partitioners=[Mock()], + edge_transform_passes=[Mock()], + edge_compile_config=Mock(), + ) + recipe = ExportRecipe( + name="comprehensive_test", + quantization_recipe=quant_recipe, + lowering_recipe=lowering_recipe, + executorch_backend_config=Mock(), + ) + + session = ExportSession( + model=self.model, + example_inputs=self.example_inputs, + export_recipe=recipe, + ) + + registered_stages = session.get_all_registered_stages() + + self.assertEqual(len(registered_stages), 5) + expected_types = [ + StageType.SOURCE_TRANSFORM, + StageType.QUANTIZE, + StageType.TORCH_EXPORT, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + StageType.TO_EXECUTORCH, + ] + self.assertListEqual(list(registered_stages.keys()), expected_types)