From f035b4377b14c389122ed268f96da786b19156b1 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Tue, 29 Jul 2025 14:01:30 -0700 Subject: [PATCH] equip etrecord class with save method Differential Revision: [D79205242](https://our.internmc.facebook.com/intern/diff/D79205242/) [ghstack-poisoned] --- devtools/etrecord/_etrecord.py | 316 ++++++++++++----------- devtools/etrecord/tests/etrecord_test.py | 115 +++++++++ 2 files changed, 287 insertions(+), 144 deletions(-) diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index 014148f2a13..1bb1cbf4d03 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -55,96 +55,137 @@ class ETRecordReservedFileNames(StrEnum): REPRESENTATIVE_INPUTS = "representative_inputs" -@dataclass class ETRecord: - exported_program: Optional[ExportedProgram] = None - export_graph_id: Optional[int] = None - edge_dialect_program: Optional[ExportedProgram] = None - graph_map: Optional[Dict[str, ExportedProgram]] = None - _debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None - _delegate_map: Optional[ - Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]] - ] = None - _reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None - _representative_inputs: Optional[List[ProgramOutput]] = None - - -def _handle_exported_program( - etrecord_zip: ZipFile, module_name: str, method_name: str, ep: ExportedProgram -) -> None: - assert isinstance(ep, ExportedProgram) - serialized_artifact = serialize(ep) - assert isinstance(serialized_artifact.exported_program, bytes) + def __init__( + self, + exported_program: Optional[ExportedProgram] = None, + export_graph_id: Optional[int] = None, + edge_dialect_program: Optional[ExportedProgram] = None, + graph_map: Optional[Dict[str, ExportedProgram]] = None, + _debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None, + _delegate_map: Optional[ + Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]] + ] = None, + _reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None, + _representative_inputs: Optional[List[ProgramOutput]] = None, + ): + self.exported_program = exported_program + self.export_graph_id = export_graph_id + self.edge_dialect_program = edge_dialect_program + self.graph_map = graph_map + self._debug_handle_map = _debug_handle_map + self._delegate_map = _delegate_map + self._reference_outputs = _reference_outputs + self._representative_inputs = _representative_inputs + + def save(self, path: Union[str, os.PathLike, BinaryIO, IO[bytes]]) -> None: + """ + Serialize and save the ETRecord to the specified path. + + Args: + path: Path where the ETRecord file will be saved to. + """ + if isinstance(path, (str, os.PathLike)): + path = os.fspath(path) + + etrecord_zip = ZipFile(path, "w") + + try: + # Write the magic file identifier + etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "") + + # Save exported program if present + if self.exported_program is not None: + self._save_exported_program( + etrecord_zip, + ETRecordReservedFileNames.EXPORTED_PROGRAM, + "", + self.exported_program, + ) - method_name = f"/{method_name}" if method_name != "" else "" + # Save edge dialect program if present + if self.edge_dialect_program is not None: + self._save_edge_dialect_program(etrecord_zip, self.edge_dialect_program) + + # Save graph map if present + if self.graph_map is not None: + for module_name, export_module in self.graph_map.items(): + # Extract method name from module_name if it contains "/" + if "/" in module_name: + base_name, method_name = module_name.rsplit("/", 1) + self._save_exported_program( + etrecord_zip, base_name, method_name, export_module + ) + else: + self._save_exported_program( + etrecord_zip, module_name, "forward", export_module + ) + + # Save debug handle map + if self._debug_handle_map is not None: + etrecord_zip.writestr( + ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME, + json.dumps(self._debug_handle_map), + ) - etrecord_zip.writestr( - f"{module_name}{method_name}", serialized_artifact.exported_program - ) - etrecord_zip.writestr( - f"{module_name}{method_name}_state_dict", serialized_artifact.state_dict - ) - etrecord_zip.writestr( - f"{module_name}{method_name}_constants", serialized_artifact.constants - ) - etrecord_zip.writestr( - f"{module_name}{method_name}_example_inputs", - serialized_artifact.example_inputs, - ) + # Save delegate map + if self._delegate_map is not None: + etrecord_zip.writestr( + ETRecordReservedFileNames.DELEGATE_MAP_NAME, + json.dumps(self._delegate_map), + ) + # Save reference outputs + if self._reference_outputs is not None: + etrecord_zip.writestr( + ETRecordReservedFileNames.REFERENCE_OUTPUTS, + pickle.dumps(self._reference_outputs), + ) -def _handle_export_module( - etrecord_zip: ZipFile, - export_module: Union[ - ExirExportedProgram, - EdgeProgramManager, - ExportedProgram, - ], - module_name: str, -) -> None: - if isinstance(export_module, ExirExportedProgram): - _handle_exported_program( - etrecord_zip, module_name, "forward", export_module.exported_program - ) - elif isinstance(export_module, ExportedProgram): - _handle_exported_program(etrecord_zip, module_name, "forward", export_module) - elif isinstance( - export_module, - (EdgeProgramManager, exir.program._program.EdgeProgramManager), - ): - for method in export_module.methods: - _handle_exported_program( - etrecord_zip, - module_name, - method, - export_module.exported_program(method), - ) - else: - raise RuntimeError(f"Unsupported graph module type. {type(export_module)}") + # Save representative inputs + if self._representative_inputs is not None: + etrecord_zip.writestr( + ETRecordReservedFileNames.REPRESENTATIVE_INPUTS, + pickle.dumps(self._representative_inputs), + ) + # Save export graph id + if self.export_graph_id is not None: + etrecord_zip.writestr( + ETRecordReservedFileNames.EXPORT_GRAPH_ID, + json.dumps(self.export_graph_id), + ) -def _handle_edge_dialect_exported_program( - etrecord_zip: ZipFile, edge_dialect_exported_program: ExportedProgram -) -> None: - serialized_artifact = serialize(edge_dialect_exported_program) - assert isinstance(serialized_artifact.exported_program, bytes) + finally: + etrecord_zip.close() - etrecord_zip.writestr( - ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM, - serialized_artifact.exported_program, - ) - etrecord_zip.writestr( - f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_state_dict", - serialized_artifact.state_dict, - ) - etrecord_zip.writestr( - f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_constants", - serialized_artifact.constants, - ) - etrecord_zip.writestr( - f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}_example_inputs", - serialized_artifact.example_inputs, - ) + def _save_exported_program( + self, etrecord_zip: ZipFile, module_name: str, method_name: str, ep: ExportedProgram + ) -> None: + """Save an exported program to the ETRecord zip file.""" + serialized_artifact = serialize(ep) + assert isinstance(serialized_artifact.exported_program, bytes) + + method_name = f"/{method_name}" if method_name != "" else "" + base_name = f"{module_name}{method_name}" + + etrecord_zip.writestr(base_name, serialized_artifact.exported_program) + etrecord_zip.writestr(f"{base_name}_state_dict", serialized_artifact.state_dict) + etrecord_zip.writestr(f"{base_name}_constants", serialized_artifact.constants) + etrecord_zip.writestr(f"{base_name}_example_inputs", serialized_artifact.example_inputs) + + def _save_edge_dialect_program( + self, etrecord_zip: ZipFile, edge_dialect_program: ExportedProgram + ) -> None: + """Save the edge dialect program to the ETRecord zip file.""" + serialized_artifact = serialize(edge_dialect_program) + assert isinstance(serialized_artifact.exported_program, bytes) + + base_name = ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM + etrecord_zip.writestr(base_name, serialized_artifact.exported_program) + etrecord_zip.writestr(f"{base_name}_state_dict", serialized_artifact.state_dict) + etrecord_zip.writestr(f"{base_name}_constants", serialized_artifact.constants) + etrecord_zip.writestr(f"{base_name}_example_inputs", serialized_artifact.example_inputs) def _get_reference_outputs( @@ -231,32 +272,27 @@ def generate_etrecord( Returns: None """ - - if isinstance(et_record, (str, os.PathLike)): - et_record = os.fspath(et_record) # pyre-ignore - - etrecord_zip = ZipFile(et_record, "w") - # Write the magic file identifier that will be used to verify that this file - # is an etrecord when it's used later in the Developer Tools. - etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "") - - # Calculate export_graph_id before modifying exported_program + # Prepare data for ETRecord construction + processed_exported_program = None export_graph_id = 0 + processed_edge_dialect_program = None + graph_map = {} + debug_handle_map = None + delegate_map = None + reference_outputs = None + representative_inputs = None + # Process exported program if exported_program is not None: - # If multiple exported programs are provided, only save forward method if isinstance(exported_program, dict) and "forward" in exported_program: - exported_program = exported_program["forward"] - - if isinstance(exported_program, ExportedProgram): - export_graph_id = id(exported_program.graph) - _handle_exported_program( - etrecord_zip, - ETRecordReservedFileNames.EXPORTED_PROGRAM, - "", - exported_program, - ) + processed_exported_program = exported_program["forward"] + elif isinstance(exported_program, ExportedProgram): + processed_exported_program = exported_program + + if processed_exported_program is not None: + export_graph_id = id(processed_exported_program.graph) + # Process extra recorded export modules if extra_recorded_export_modules is not None: for module_name, export_module in extra_recorded_export_modules.items(): contains_reserved_name = any( @@ -267,57 +303,49 @@ def generate_etrecord( raise RuntimeError( f"The name {module_name} provided in the extra_recorded_export_modules dict is a reserved name in the ETRecord namespace." ) - _handle_export_module(etrecord_zip, export_module, module_name) - if isinstance( - edge_dialect_program, - (EdgeProgramManager, exir.program._program.EdgeProgramManager), - ): - _handle_edge_dialect_exported_program( - etrecord_zip, - edge_dialect_program.exported_program(), - ) + # Process different types of export modules + if isinstance(export_module, ExirExportedProgram): + graph_map[f"{module_name}/forward"] = export_module.exported_program + elif isinstance(export_module, ExportedProgram): + graph_map[f"{module_name}/forward"] = export_module + elif isinstance(export_module, (EdgeProgramManager, exir.program._program.EdgeProgramManager)): + for method in export_module.methods: + graph_map[f"{module_name}/{method}"] = export_module.exported_program(method) + else: + raise RuntimeError(f"Unsupported graph module type. {type(export_module)}") + + # Process edge dialect program + if isinstance(edge_dialect_program, (EdgeProgramManager, exir.program._program.EdgeProgramManager)): + processed_edge_dialect_program = edge_dialect_program.exported_program() elif isinstance(edge_dialect_program, ExirExportedProgram): - _handle_edge_dialect_exported_program( - etrecord_zip, - edge_dialect_program.exported_program, - ) + processed_edge_dialect_program = edge_dialect_program.exported_program else: - raise RuntimeError( - f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}." - ) + raise RuntimeError(f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}.") - # When a BundledProgram is passed in, extract the reference outputs and save in a file + # Process executorch program if isinstance(executorch_program, BundledProgram): reference_outputs = _get_reference_outputs(executorch_program) - etrecord_zip.writestr( - ETRecordReservedFileNames.REFERENCE_OUTPUTS, - # @lint-ignore PYTHONPICKLEISBAD - pickle.dumps(reference_outputs), - ) - representative_inputs = _get_representative_inputs(executorch_program) - etrecord_zip.writestr( - ETRecordReservedFileNames.REPRESENTATIVE_INPUTS, - # @lint-ignore PYTHONPICKLEISBAD - pickle.dumps(representative_inputs), - ) - executorch_program = executorch_program.executorch_program - - etrecord_zip.writestr( - ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME, - json.dumps(executorch_program.debug_handle_map), - ) + debug_handle_map = executorch_program.executorch_program.debug_handle_map + delegate_map = executorch_program.executorch_program.delegate_map + else: + debug_handle_map = executorch_program.debug_handle_map + delegate_map = executorch_program.delegate_map - etrecord_zip.writestr( - ETRecordReservedFileNames.DELEGATE_MAP_NAME, - json.dumps(executorch_program.delegate_map), + # Create ETRecord instance and save + etrecord = ETRecord( + exported_program=processed_exported_program, + export_graph_id=export_graph_id, + edge_dialect_program=processed_edge_dialect_program, + graph_map=graph_map if graph_map else None, + _debug_handle_map=debug_handle_map, + _delegate_map=delegate_map, + _reference_outputs=reference_outputs, + _representative_inputs=representative_inputs, ) - etrecord_zip.writestr( - ETRecordReservedFileNames.EXPORT_GRAPH_ID, - json.dumps(export_graph_id), - ) + etrecord.save(et_record) def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 432397347a5..ec55bb0cb8e 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -20,6 +20,7 @@ from executorch.devtools.etrecord._etrecord import ( _get_reference_outputs, _get_representative_inputs, + ETRecord, ETRecordReservedFileNames, ) from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge @@ -251,6 +252,120 @@ def test_etrecord_generation_with_exported_program(self): # Validate that export_graph_id matches the expected value self.assertEqual(etrecord.export_graph_id, expected_graph_id) + def test_etrecord_class_constructor_and_save(self): + """Test that ETRecord class constructor and save method work correctly.""" + captured_output, edge_output, et_output = self.get_test_model() + original_exported_program = captured_output.exported_program + expected_graph_id = id(original_exported_program.graph) + + # Create ETRecord instance directly using constructor + etrecord = ETRecord( + exported_program=original_exported_program, + export_graph_id=expected_graph_id, + edge_dialect_program=edge_output.exported_program, + graph_map={"test_module/forward": original_exported_program}, + _debug_handle_map=et_output.debug_handle_map, + _delegate_map=et_output.delegate_map, + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_direct.bin" + + # Use the save method + etrecord.save(etrecord_path) + + # Parse ETRecord back and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate that all components are preserved + self.assertIsNotNone(parsed_etrecord.exported_program) + self.check_graph_closeness( + parsed_etrecord.exported_program, + original_exported_program.graph_module, + ) + + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.check_graph_closeness( + parsed_etrecord.edge_dialect_program, + edge_output.exported_program.graph_module, + ) + + # Validate graph map + self.assertIsNotNone(parsed_etrecord.graph_map) + self.assertIn("test_module/forward", parsed_etrecord.graph_map) + self.check_graph_closeness( + parsed_etrecord.graph_map["test_module/forward"], + original_exported_program.graph_module, + ) + + # Validate debug and delegate maps + self.assertEqual( + parsed_etrecord._debug_handle_map, + json.loads(json.dumps(et_output.debug_handle_map)), + ) + self.assertEqual( + parsed_etrecord._delegate_map, + json.loads(json.dumps(et_output.delegate_map)), + ) + + # Validate export graph id + self.assertEqual(parsed_etrecord.export_graph_id, expected_graph_id) + + def test_etrecord_class_with_bundled_program_data(self): + """Test ETRecord class with bundled program data.""" + ( + captured_output, + edge_output, + bundled_program, + ) = self.get_test_model_with_bundled_program() + + # Extract bundled program data + reference_outputs = _get_reference_outputs(bundled_program) + representative_inputs = _get_representative_inputs(bundled_program) + + # Create ETRecord instance with bundled program data + etrecord = ETRecord( + exported_program=captured_output.exported_program, + export_graph_id=id(captured_output.exported_program.graph), + edge_dialect_program=edge_output.exported_program, + _debug_handle_map=bundled_program.executorch_program.debug_handle_map, + _delegate_map=bundled_program.executorch_program.delegate_map, + _reference_outputs=reference_outputs, + _representative_inputs=representative_inputs, + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_bundled.bin" + + # Save using the save method + etrecord.save(etrecord_path) + + # Parse and verify + parsed_etrecord = parse_etrecord(etrecord_path) + + # Validate bundled program specific data + self.assertIsNotNone(parsed_etrecord._reference_outputs) + self.assertIsNotNone(parsed_etrecord._representative_inputs) + + # Compare reference outputs + expected_outputs = parsed_etrecord._reference_outputs + self.assertTrue( + torch.equal( + expected_outputs["forward"][0][0], reference_outputs["forward"][0][0] + ) + ) + self.assertTrue( + torch.equal( + expected_outputs["forward"][1][0], reference_outputs["forward"][1][0] + ) + ) + + # Compare representative inputs + expected_inputs = parsed_etrecord._representative_inputs + for expected, actual in zip(expected_inputs, representative_inputs): + self.assertTrue(torch.equal(expected[0], actual[0])) + self.assertTrue(torch.equal(expected[1], actual[1])) + def test_etrecord_generation_with_exported_program_dict(self): """Test that exported program dictionary can be recorded and parsed back correctly.""" captured_output, edge_output, et_output = self.get_test_model()