diff --git a/personal_python_ast_optimizer/parser/config.py b/personal_python_ast_optimizer/parser/config.py index eb960c7..ccc298e 100644 --- a/personal_python_ast_optimizer/parser/config.py +++ b/personal_python_ast_optimizer/parser/config.py @@ -140,6 +140,7 @@ class OptimizationsConfig(_Config): "remove_unused_imports", "remove_useless_else", "simplify_named_tuples", + "unused_imports_to_preserve", "vars_to_fold", ) @@ -153,6 +154,7 @@ def __init__( # noqa: PLR0913 enums_to_fold: Iterable[EnumType] | None = None, functions_safe_to_exclude_in_test_expr: set[str] | None = None, remove_unused_imports: bool = True, + unused_imports_to_preserve: Iterable[str] | None = None, remove_useless_else: bool = True, remove_typing_cast: bool = True, collection_concat_to_unpack: bool = False, @@ -160,6 +162,9 @@ def __init__( # noqa: PLR0913 assume_this_machine: bool = False, simplify_named_tuples: bool = False, ) -> None: + if unused_imports_to_preserve and not remove_unused_imports: + raise ValueError("Can't preserve imports if remove_unused_imports is False") + self.vars_to_fold: dict[ str, str | bytes | bool | int | float | complex | None | EllipsisType ] = {} if vars_to_fold is None else vars_to_fold @@ -168,10 +173,15 @@ def __init__( # noqa: PLR0913 if enums_to_fold is None else self._format_enums_to_fold_as_dict(enums_to_fold) ) + self.functions_safe_to_exclude_in_test_expr: set[str] = ( functions_safe_to_exclude_in_test_expr or default_functions_safe_to_exclude_in_test_expr ) + self.unused_imports_to_preserve: Iterable[str] = ( + unused_imports_to_preserve or [] + ) + self.remove_unused_imports: bool = remove_unused_imports self.remove_useless_else: bool = remove_useless_else self.remove_typing_cast: bool = remove_typing_cast diff --git a/personal_python_ast_optimizer/parser/skipper.py b/personal_python_ast_optimizer/parser/skipper.py index 1806920..753ab5b 100644 --- a/personal_python_ast_optimizer/parser/skipper.py +++ b/personal_python_ast_optimizer/parser/skipper.py @@ -1,5 +1,6 @@ import ast import warnings +from collections.abc import Iterable from enum import Enum from personal_python_ast_optimizer.futures import get_unneeded_futures @@ -181,7 +182,9 @@ def visit_Module(self, node: ast.Module) -> ast.AST: import_to_update.names.append(alias) if self.optimizations_config.remove_unused_imports and self._has_imports: - import_filter = UnusedImportSkipper() + import_filter = UnusedImportSkipper( + self.optimizations_config.unused_imports_to_preserve + ) import_filter.visit(node) self._warn_unused_skips() @@ -833,8 +836,8 @@ def _ast_constants_operation( # noqa: C901, PLR0912 class UnusedImportSkipper(ast.NodeTransformer): __slots__ = ("names_and_attrs",) - def __init__(self) -> None: - self.names_and_attrs: set[str] = set() + def __init__(self, imports_to_preserve: Iterable[str]) -> None: + self.names_and_attrs: set[str] = set(imports_to_preserve) def generic_visit(self, node: ast.AST) -> ast.AST: for field, old_value in ast.iter_fields(node): diff --git a/tests/parser/test_imports.py b/tests/parser/test_imports.py index 2c1f4d7..c8a1a6b 100644 --- a/tests/parser/test_imports.py +++ b/tests/parser/test_imports.py @@ -120,9 +120,16 @@ def asdf(a):return a """.strip(), "bar()", ), + BeforeAndAfter( + "import asdf", + "import asdf", + ), ] @pytest.mark.parametrize("before_and_after", _unused_import_test_cases) def test_remove_unused_import(before_and_after: BeforeAndAfter): - run_minifier_and_assert_correct(before_and_after) + run_minifier_and_assert_correct( + before_and_after, + optimizations_config=OptimizationsConfig(unused_imports_to_preserve=["asdf"]), + ) diff --git a/version.txt b/version.txt index 8b22a32..8104cab 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -8.0.2 +8.1.0