diff --git a/personal_python_ast_optimizer/futures.py b/personal_python_ast_optimizer/futures.py index f59a6e8..e271fcc 100644 --- a/personal_python_ast_optimizer/futures.py +++ b/personal_python_ast_optimizer/futures.py @@ -1,4 +1,4 @@ -class Futures: +class Future: __slots__ = ("mandatory_version", "name") def __init__(self, name: str, mandatory_version: tuple[int, int]) -> None: @@ -6,25 +6,23 @@ def __init__(self, name: str, mandatory_version: tuple[int, int]) -> None: self.mandatory_version: tuple[int, int] = mandatory_version -futures_to_mandatory_version: list[Futures] = [ - Futures("nested_scopes", (2, 2)), - Futures("generators", (2, 3)), - Futures("with_statement", (2, 6)), - Futures("division", (3, 0)), - Futures("absolute_import", (3, 0)), - Futures("print_function", (3, 0)), - Futures("unicode_literals", (3, 0)), - Futures("generator_stop", (3, 7)), +future_to_mandatory_versions: list[Future] = [ + Future("nested_scopes", (2, 2)), + Future("generators", (2, 3)), + Future("with_statement", (2, 6)), + Future("division", (3, 0)), + Future("absolute_import", (3, 0)), + Future("print_function", (3, 0)), + Future("unicode_literals", (3, 0)), + Future("generator_stop", (3, 7)), ] def get_unneeded_futures(python_version: tuple[int, int]) -> list[str]: - """Returns list of __future__ imports that are unneeded in provided + """Returns __future__ imports that are unneeded in provided python version""" - unneeded_futures: list[str] = [ + return [ future.name - for future in futures_to_mandatory_version + for future in future_to_mandatory_versions if python_version >= future.mandatory_version ] - - return unneeded_futures diff --git a/personal_python_ast_optimizer/parser/_base.py b/personal_python_ast_optimizer/parser/_base.py new file mode 100644 index 0000000..d76525e --- /dev/null +++ b/personal_python_ast_optimizer/parser/_base.py @@ -0,0 +1,23 @@ +import ast + + +class AstNodeTransformerBase(ast.NodeTransformer): + """Base class for ast node transformers. Intended for internal use.""" + + __slots__ = () + + # Nodes that do not need to be fully visited + def visit_alias(self, node: ast.alias) -> ast.alias: + return node + + def visit_Break(self, node: ast.Break) -> ast.Break: + return node + + def visit_Constant(self, node: ast.Constant) -> ast.Constant: + return node + + def visit_Continue(self, node: ast.Continue) -> ast.Continue: + return node + + def visit_Pass(self, node: ast.Pass) -> ast.Pass: + return node diff --git a/personal_python_ast_optimizer/parser/skipper.py b/personal_python_ast_optimizer/parser/skipper.py index 753ab5b..8e07222 100644 --- a/personal_python_ast_optimizer/parser/skipper.py +++ b/personal_python_ast_optimizer/parser/skipper.py @@ -4,6 +4,7 @@ from enum import Enum from personal_python_ast_optimizer.futures import get_unneeded_futures +from personal_python_ast_optimizer.parser._base import AstNodeTransformerBase from personal_python_ast_optimizer.parser.config import ( OptimizationsConfig, SkipConfig, @@ -37,7 +38,7 @@ class _NodeContext(Enum): FUNCTION = 2 -class AstNodeSkipper(ast.NodeTransformer): +class AstNodeSkipper(AstNodeTransformerBase): __slots__ = ( "_has_imports", "_node_context_skippable_futures", @@ -468,9 +469,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom | None: self._has_imports = True return node - def visit_alias(self, node: ast.alias) -> ast.alias: - return node - def visit_Name(self, node: ast.Name) -> ast.AST: """Extends super's implementation by adding constant folding""" if node.id in self.optimizations_config.vars_to_fold: @@ -576,17 +574,11 @@ def visit_Assert(self, node: ast.Assert) -> ast.AST | None: None if self.token_types_config.skip_asserts else self.generic_visit(node) ) - def visit_Pass(self, node: ast.Pass) -> None: + def visit_Pass(self, node: ast.Pass) -> None: # type: ignore[override] """Always returns None. Caller responsible for ensuring empty bodies are populated with a Pass node.""" return # This could be toggleable - def visit_Break(self, node: ast.Break) -> ast.Break: - return node - - def visit_Continue(self, node: ast.Continue) -> ast.Continue: - return node - def visit_Call(self, node: ast.Call) -> ast.AST | None: if ( self.optimizations_config.assume_this_machine @@ -608,9 +600,6 @@ def visit_Call(self, node: ast.Call) -> ast.AST | None: return self.generic_visit(node) - def visit_Constant(self, node: ast.Constant) -> ast.Constant: - return node - def visit_Expr(self, node: ast.Expr) -> ast.AST | None: if ( isinstance(node.value, ast.Call) @@ -833,7 +822,7 @@ def _ast_constants_operation( # noqa: C901, PLR0912 return ast.Constant(result) -class UnusedImportSkipper(ast.NodeTransformer): +class UnusedImportSkipper(AstNodeTransformerBase): __slots__ = ("names_and_attrs",) def __init__(self, imports_to_preserve: Iterable[str]) -> None: @@ -886,24 +875,8 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST: self.names_and_attrs.add(node.attr) return self.generic_visit(node) - # Nodes that do not need to be fully visited - def visit_alias(self, node: ast.alias) -> ast.alias: - return node - - def visit_Pass(self, node: ast.Pass) -> ast.Pass: - return node - - def visit_Break(self, node: ast.Break) -> ast.Break: - return node - - def visit_Continue(self, node: ast.Continue) -> ast.Continue: - return node - - def visit_Constant(self, node: ast.Constant) -> ast.Constant: - return node - -class _DanglingExprCallFinder(ast.NodeTransformer): +class _DanglingExprCallFinder(AstNodeTransformerBase): """Finds all calls in a given dangling expression except for a subset of builtin functions that have no side effects.""" diff --git a/personal_python_ast_optimizer/regex/replace.py b/personal_python_ast_optimizer/regex/replace.py index 132c10d..cef3ef4 100644 --- a/personal_python_ast_optimizer/regex/replace.py +++ b/personal_python_ast_optimizer/regex/replace.py @@ -31,8 +31,8 @@ def re_replace( regex_replacement: RegexReplacement | Iterable[RegexReplacement], raise_if_not_applied: bool = False, ) -> str: - """Runs a series of regex on given source. - Passing warning_id enabled warnings when patterns are not found""" + """Runs a series of regex on given source.""" + if isinstance(regex_replacement, RegexReplacement): regex_replacement = (regex_replacement,) @@ -59,7 +59,8 @@ def re_replace_file( encoding: str = "utf-8", raise_if_not_applied: bool = False, ): - """Wraps apply_regex with opening and writing to a file""" + """Wraps apply_regex with opening and writing to a file.""" + with open(path, encoding=encoding) as fp: source: str = fp.read()