Skip to content

Commit 45a1a41

Browse files
committed
pyp: use new type annotations
1 parent e154541 commit 45a1a41

File tree

3 files changed

+39
-35
lines changed

3 files changed

+39
-35
lines changed

pyp.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#!/usr/bin/env python3
2+
from __future__ import annotations
3+
24
import argparse
35
import ast
46
import importlib
@@ -9,7 +11,7 @@
911
import textwrap
1012
import traceback
1113
from collections import defaultdict
12-
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, cast
14+
from typing import Any, Iterator, cast
1315

1416
__all__ = ["pypprint"]
1517
__version__ = "1.3.0"
@@ -51,17 +53,17 @@ class NameFinder(ast.NodeVisitor):
5153
"""
5254

5355
def __init__(self, *trees: ast.AST) -> None:
54-
self._scopes: List[Set[str]] = [set()]
55-
self._comprehension_scopes: List[int] = []
56+
self._scopes: list[set[str]] = [set()]
57+
self._comprehension_scopes: list[int] = []
5658

57-
self.undefined: Set[str] = set()
58-
self.wildcard_imports: List[str] = []
59+
self.undefined: set[str] = set()
60+
self.wildcard_imports: list[str] = []
5961
for tree in trees:
6062
self.visit(tree)
6163
assert len(self._scopes) == 1
6264

6365
@property
64-
def top_level_defined(self) -> Set[str]:
66+
def top_level_defined(self) -> set[str]:
6567
return self._scopes[0]
6668

6769
def flexible_visit(self, value: Any) -> None:
@@ -73,7 +75,7 @@ def flexible_visit(self, value: Any) -> None:
7375
self.visit(value)
7476

7577
def generic_visit(self, node: ast.AST) -> None:
76-
def order(f_v: Tuple[str, Any]) -> int:
78+
def order(f_v: tuple[str, Any]) -> int:
7779
# This ordering fixes comprehensions, dict comps, loops, assignments
7880
return {"generators": -3, "iter": -3, "key": -2, "value": -1}.get(f_v[0], 0)
7981

@@ -138,7 +140,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
138140
# Classes are not okay with self-reference, so define ``name`` afterwards
139141
self._scopes[-1].add(node.name)
140142

141-
def visit_function_helper(self, node: Any, name: Optional[str] = None) -> None:
143+
def visit_function_helper(self, node: Any, name: str | None = None) -> None:
142144
# Functions are okay with recursion, but not self-reference while defining default values
143145
self.flexible_visit(node.args)
144146
if name is not None:
@@ -247,22 +249,22 @@ def __init__(self) -> None:
247249
raise PypError(f"Config has invalid syntax{error}") from e
248250

249251
# List of config parts
250-
self.parts: List[ast.stmt] = config_ast.body
252+
self.parts: list[ast.stmt] = config_ast.body
251253
# Maps from a name to index of config part that defines it
252-
self.name_to_def: Dict[str, int] = {}
253-
self.def_to_names: Dict[int, List[str]] = defaultdict(list)
254+
self.name_to_def: dict[str, int] = {}
255+
self.def_to_names: dict[int, list[str]] = defaultdict(list)
254256
# Maps from index of config part to undefined names it needs
255-
self.requires: Dict[int, Set[str]] = defaultdict(set)
257+
self.requires: dict[int, set[str]] = defaultdict(set)
256258
# Modules from which automatic imports work without qualification, ordered by AST encounter
257-
self.wildcard_imports: List[str] = []
259+
self.wildcard_imports: list[str] = []
258260

259261
self.shebang: str = "#!/usr/bin/env python3"
260262
if config_contents.startswith("#!"):
261263
self.shebang = "\n".join(
262264
itertools.takewhile(lambda line: line.startswith("#"), config_contents.splitlines())
263265
)
264266

265-
top_level: Tuple[Any, ...] = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)
267+
top_level: tuple[Any, ...] = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)
266268
top_level += (ast.Import, ast.ImportFrom, ast.Assign, ast.AnnAssign, ast.If, ast.Try)
267269
for index, part in enumerate(self.parts):
268270
if not isinstance(part, top_level):
@@ -298,13 +300,13 @@ class PypTransform:
298300

299301
def __init__(
300302
self,
301-
before: List[str],
302-
code: List[str],
303-
after: List[str],
303+
before: list[str],
304+
code: list[str],
305+
after: list[str],
304306
define_pypprint: bool,
305307
config: PypConfig,
306308
) -> None:
307-
def parse_input(code: List[str]) -> ast.Module:
309+
def parse_input(code: list[str]) -> ast.Module:
308310
try:
309311
return ast.parse(textwrap.dedent("\n".join(code).strip()))
310312
except SyntaxError as e:
@@ -326,9 +328,9 @@ def parse_input(code: List[str]) -> ast.Module:
326328
raise PypError("Config __pyp_after__ not supported")
327329

328330
f = NameFinder(self.before_tree, self.tree, self.after_tree)
329-
self.defined: Set[str] = f.top_level_defined
330-
self.undefined: Set[str] = f.undefined
331-
self.wildcard_imports: List[str] = f.wildcard_imports
331+
self.defined: set[str] = f.top_level_defined
332+
self.undefined: set[str] = f.undefined
333+
self.wildcard_imports: list[str] = f.wildcard_imports
332334
# We'll always use sys in ``build_input``, so add it to undefined.
333335
# This lets config define it or lets us automatically import it later
334336
# (If before defines it, we'll just let it override the import...)
@@ -338,11 +340,11 @@ def parse_input(code: List[str]) -> ast.Module:
338340
self.config = config
339341

340342
# The print statement ``build_output`` will add, if it determines it needs to.
341-
self.implicit_print: Optional[ast.Call] = None
343+
self.implicit_print: ast.Call | None = None
342344

343345
def build_missing_config(self) -> None:
344346
"""Modifies the AST to define undefined names defined in config."""
345-
config_definitions: Set[str] = set()
347+
config_definitions: set[str] = set()
346348
attempt_to_define = set(self.undefined)
347349
while attempt_to_define:
348350
can_define = attempt_to_define & set(self.config.name_to_def)
@@ -406,7 +408,7 @@ def build_output(self) -> None:
406408
if self.undefined & {"print", "pprint", "pp", "pypprint"}: # has an explicit print
407409
return
408410

409-
def inner(body: List[ast.stmt], use_pypprint: bool = False) -> bool:
411+
def inner(body: list[ast.stmt], use_pypprint: bool = False) -> bool:
410412
if not body:
411413
return False
412414
if isinstance(body[-1], ast.Pass):
@@ -642,7 +644,7 @@ def run_pyp(args: argparse.Namespace) -> None:
642644
# On error, reconstruct a traceback into the generated code
643645
# Also add some diagnostics for ModuleNotFoundError and NameError
644646
try:
645-
line_to_node: Dict[int, ast.AST] = {}
647+
line_to_node: dict[int, ast.AST] = {}
646648
for node in dfs_walk(tree):
647649
line_to_node.setdefault(getattr(node, "lineno", -1), node)
648650

@@ -699,7 +701,7 @@ def code_for_line(lineno: int) -> str:
699701
) from e
700702

701703

702-
def parse_options(args: List[str]) -> argparse.Namespace:
704+
def parse_options(args: list[str]) -> argparse.Namespace:
703705
parser = argparse.ArgumentParser(
704706
prog="pyp",
705707
formatter_class=argparse.RawDescriptionHelpFormatter,

tests/test_find_names.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
from __future__ import annotations
2+
13
import ast
24
import re
35
import sys
4-
from typing import List, Optional, Set
56

67
import pytest
78

@@ -10,9 +11,9 @@
1011

1112
def check_find_names(
1213
code: str,
13-
defined: Set[str],
14-
undefined: Set[str],
15-
wildcard_imports: Optional[List[str]] = None,
14+
defined: set[str],
15+
undefined: set[str],
16+
wildcard_imports: list[str] | None = None,
1617
confirm: bool = True,
1718
) -> None:
1819
names = NameFinder(ast.parse(code))

tests/test_pyp.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import ast
24
import contextlib
35
import io
@@ -7,7 +9,6 @@
79
import sys
810
import tempfile
911
import traceback
10-
from typing import List, Optional, Union
1112
from unittest.mock import patch
1213

1314
import pytest
@@ -24,7 +25,7 @@ def delete_config_env_var(monkeypatch):
2425
monkeypatch.delenv("PYP_CONFIG_PATH", raising=False)
2526

2627

27-
def run_cmd(cmd: str, input: Union[str, bytes, None] = None, check: bool = True) -> str:
28+
def run_cmd(cmd: str, input: str | bytes | None = None, check: bool = True) -> str:
2829
if isinstance(input, str):
2930
input = input.encode("utf-8")
3031
proc = subprocess.run(
@@ -33,7 +34,7 @@ def run_cmd(cmd: str, input: Union[str, bytes, None] = None, check: bool = True)
3334
return proc.stdout.decode("utf-8")
3435

3536

36-
def run_pyp(cmd: Union[str, List[str]], input: Optional[str] = None) -> str:
37+
def run_pyp(cmd: str | list[str], input: str | None = None) -> str:
3738
"""Run pyp in process. It's quicker and allows us to mock and so on."""
3839
if isinstance(cmd, str):
3940
cmd = shlex.split(cmd)
@@ -52,7 +53,7 @@ def run_pyp(cmd: Union[str, List[str]], input: Optional[str] = None) -> str:
5253

5354

5455
def compare_command(
55-
example_cmd: str, pyp_cmd: str, input: Optional[str] = None, use_subprocess: bool = False
56+
example_cmd: str, pyp_cmd: str, input: str | None = None, use_subprocess: bool = False
5657
) -> None:
5758
"""Compares running command example_cmd with the output of pyp_cmd.
5859
@@ -184,7 +185,7 @@ def test_tracebacks():
184185
def effect(*args, **kwargs):
185186
nonlocal count
186187
if count == 0:
187-
assert args[0] == ZeroDivisionError
188+
assert args[0] is ZeroDivisionError
188189
count += 1
189190
raise Exception
190191
return TBE(*args, **kwargs)

0 commit comments

Comments
 (0)