Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions personal_python_ast_optimizer/futures.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,28 @@
class Futures:
class Future:
__slots__ = ("mandatory_version", "name")

def __init__(self, name: str, mandatory_version: tuple[int, int]) -> None:
self.name: str = name
self.mandatory_version: tuple[int, int] = mandatory_version


futures_to_mandatory_version: list[Futures] = [
Futures("nested_scopes", (2, 2)),
Futures("generators", (2, 3)),
Futures("with_statement", (2, 6)),
Futures("division", (3, 0)),
Futures("absolute_import", (3, 0)),
Futures("print_function", (3, 0)),
Futures("unicode_literals", (3, 0)),
Futures("generator_stop", (3, 7)),
future_to_mandatory_versions: list[Future] = [
Future("nested_scopes", (2, 2)),
Future("generators", (2, 3)),
Future("with_statement", (2, 6)),
Future("division", (3, 0)),
Future("absolute_import", (3, 0)),
Future("print_function", (3, 0)),
Future("unicode_literals", (3, 0)),
Future("generator_stop", (3, 7)),
]


def get_unneeded_futures(python_version: tuple[int, int]) -> list[str]:
"""Returns list of __future__ imports that are unneeded in provided
"""Returns __future__ imports that are unneeded in provided
python version"""
unneeded_futures: list[str] = [
return [
future.name
for future in futures_to_mandatory_version
for future in future_to_mandatory_versions
if python_version >= future.mandatory_version
]

return unneeded_futures
23 changes: 23 additions & 0 deletions personal_python_ast_optimizer/parser/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import ast


class AstNodeTransformerBase(ast.NodeTransformer):
"""Base class for ast node transformers. Intended for internal use."""

__slots__ = ()

# Nodes that do not need to be fully visited
def visit_alias(self, node: ast.alias) -> ast.alias:
return node

def visit_Break(self, node: ast.Break) -> ast.Break:
return node

def visit_Constant(self, node: ast.Constant) -> ast.Constant:
return node

def visit_Continue(self, node: ast.Continue) -> ast.Continue:
return node

def visit_Pass(self, node: ast.Pass) -> ast.Pass:
return node
37 changes: 5 additions & 32 deletions personal_python_ast_optimizer/parser/skipper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import Enum

from personal_python_ast_optimizer.futures import get_unneeded_futures
from personal_python_ast_optimizer.parser._base import AstNodeTransformerBase
from personal_python_ast_optimizer.parser.config import (
OptimizationsConfig,
SkipConfig,
Expand Down Expand Up @@ -37,7 +38,7 @@ class _NodeContext(Enum):
FUNCTION = 2


class AstNodeSkipper(ast.NodeTransformer):
class AstNodeSkipper(AstNodeTransformerBase):
__slots__ = (
"_has_imports",
"_node_context_skippable_futures",
Expand Down Expand Up @@ -468,9 +469,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom | None:
self._has_imports = True
return node

def visit_alias(self, node: ast.alias) -> ast.alias:
return node

def visit_Name(self, node: ast.Name) -> ast.AST:
"""Extends super's implementation by adding constant folding"""
if node.id in self.optimizations_config.vars_to_fold:
Expand Down Expand Up @@ -576,17 +574,11 @@ def visit_Assert(self, node: ast.Assert) -> ast.AST | None:
None if self.token_types_config.skip_asserts else self.generic_visit(node)
)

def visit_Pass(self, node: ast.Pass) -> None:
def visit_Pass(self, node: ast.Pass) -> None: # type: ignore[override]
"""Always returns None. Caller responsible for ensuring empty bodies
are populated with a Pass node."""
return # This could be toggleable

def visit_Break(self, node: ast.Break) -> ast.Break:
return node

def visit_Continue(self, node: ast.Continue) -> ast.Continue:
return node

def visit_Call(self, node: ast.Call) -> ast.AST | None:
if (
self.optimizations_config.assume_this_machine
Expand All @@ -608,9 +600,6 @@ def visit_Call(self, node: ast.Call) -> ast.AST | None:

return self.generic_visit(node)

def visit_Constant(self, node: ast.Constant) -> ast.Constant:
return node

def visit_Expr(self, node: ast.Expr) -> ast.AST | None:
if (
isinstance(node.value, ast.Call)
Expand Down Expand Up @@ -833,7 +822,7 @@ def _ast_constants_operation( # noqa: C901, PLR0912
return ast.Constant(result)


class UnusedImportSkipper(ast.NodeTransformer):
class UnusedImportSkipper(AstNodeTransformerBase):
__slots__ = ("names_and_attrs",)

def __init__(self, imports_to_preserve: Iterable[str]) -> None:
Expand Down Expand Up @@ -886,24 +875,8 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
self.names_and_attrs.add(node.attr)
return self.generic_visit(node)

# Nodes that do not need to be fully visited
def visit_alias(self, node: ast.alias) -> ast.alias:
return node

def visit_Pass(self, node: ast.Pass) -> ast.Pass:
return node

def visit_Break(self, node: ast.Break) -> ast.Break:
return node

def visit_Continue(self, node: ast.Continue) -> ast.Continue:
return node

def visit_Constant(self, node: ast.Constant) -> ast.Constant:
return node


class _DanglingExprCallFinder(ast.NodeTransformer):
class _DanglingExprCallFinder(AstNodeTransformerBase):
"""Finds all calls in a given dangling expression
except for a subset of builtin functions that have
no side effects."""
Expand Down
7 changes: 4 additions & 3 deletions personal_python_ast_optimizer/regex/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def re_replace(
regex_replacement: RegexReplacement | Iterable[RegexReplacement],
raise_if_not_applied: bool = False,
) -> str:
"""Runs a series of regex on given source.
Passing warning_id enabled warnings when patterns are not found"""
"""Runs a series of regex on given source."""

if isinstance(regex_replacement, RegexReplacement):
regex_replacement = (regex_replacement,)

Expand All @@ -59,7 +59,8 @@ def re_replace_file(
encoding: str = "utf-8",
raise_if_not_applied: bool = False,
):
"""Wraps apply_regex with opening and writing to a file"""
"""Wraps apply_regex with opening and writing to a file."""

with open(path, encoding=encoding) as fp:
source: str = fp.read()

Expand Down
Loading