|
| 1 | +import ast |
| 2 | +import sys |
| 3 | +from typing import Optional, Type, TypeVar |
| 4 | + |
| 5 | +T = TypeVar('T', covariant=True) |
| 6 | + |
| 7 | + |
| 8 | +def find(nodes: list, tpe: Type[T], attr: str, name: str) -> Optional[T]: |
| 9 | + for n in nodes: |
| 10 | + if type(n) is tpe and getattr(n, attr) == name: |
| 11 | + return n |
| 12 | + return None |
| 13 | + |
| 14 | + |
| 15 | +def check(file: str, predictor: str) -> None: |
| 16 | + with open(file, 'r') as f: |
| 17 | + content = f.read() |
| 18 | + lines = content.splitlines() |
| 19 | + root = ast.parse(content) |
| 20 | + |
| 21 | + p = find(root.body, ast.ClassDef, 'name', predictor) |
| 22 | + if p is None: |
| 23 | + return |
| 24 | + fn = find(p.body, ast.FunctionDef, 'name', 'predict') |
| 25 | + if fn is None: |
| 26 | + fn = find(p.body, ast.AsyncFunctionDef, 'name', 'predict') # type: ignore |
| 27 | + args_and_defaults = zip(fn.args.args[-len(fn.args.defaults) :], fn.args.defaults) # type: ignore |
| 28 | + none_defaults = [] |
| 29 | + for a, d in args_and_defaults: |
| 30 | + if type(a.annotation) is not ast.Name: |
| 31 | + continue |
| 32 | + if type(d) is not ast.Call or d.func.id != 'Input': # type: ignore |
| 33 | + continue |
| 34 | + v = find(d.keywords, ast.keyword, 'arg', 'default') |
| 35 | + if v is None or type(v.value) is not ast.Constant: |
| 36 | + continue |
| 37 | + if v.value.value is None: |
| 38 | + pos = f'{file}:{a.lineno}:{a.col_offset}' |
| 39 | + |
| 40 | + # Add `Optional[]` to type annotation |
| 41 | + # No need to remove `default=None` since `x: Optional[T] = Input(default=None)` is valid |
| 42 | + ta = a.annotation |
| 43 | + l = lines[ta.lineno - 1] |
| 44 | + parts = l[:ta.col_offset], l[ta.col_offset:ta.end_col_offset], l[ta.end_col_offset:] |
| 45 | + l = f'{parts[0]}Optional[{parts[1]}]{parts[2]}' |
| 46 | + lines[ta.lineno - 1] = l |
| 47 | + |
| 48 | + none_defaults.append(f'{pos}: {a.arg}: {ta.id}={ast.unparse(d)}') |
| 49 | + |
| 50 | + if len(none_defaults) > 0: |
| 51 | + print('Default value of None without explicit Optional[T] type hint is ambiguous and deprecated, for example:', file=sys.stderr) |
| 52 | + print('- x: str=Input(default=None)', file=sys.stderr) |
| 53 | + print('+ x: Optional[str]=Input(default=None)', file=sys.stderr) |
| 54 | + print(file=sys.stderr) |
| 55 | + for l in none_defaults: |
| 56 | + print(l, file=sys.stderr) |
| 57 | + |
| 58 | + # Check for `from typing import Optional` |
| 59 | + imports = find(root.body, ast.ImportFrom, 'module', 'typing') |
| 60 | + if imports is None or 'Optional' not in [n.name for n in imports.names]: |
| 61 | + # Missing import, add it at beginning of file or before first import |
| 62 | + # Skip `#!/usr/bin/env python3` or comments |
| 63 | + lno = 1 |
| 64 | + while lines[lno - 1].startswith('#'): |
| 65 | + lno += 1 |
| 66 | + for n in root.body: |
| 67 | + if type(n) in {ast.Import, ast.ImportFrom}: |
| 68 | + lno = n.lineno |
| 69 | + break |
| 70 | + lines = lines[:lno - 1] + ['from typing import Optional'] + lines[lno - 1:] |
| 71 | + print('\n'.join(lines)) |
| 72 | + |
| 73 | + |
| 74 | +check(sys.argv[1], sys.argv[2]) |
0 commit comments