Skip to content
This repository was archived by the owner on Jan 6, 2026. It is now read-only.

Commit 739cc40

Browse files
committed
Add check for legacy default=None optional input
1 parent 81fcc92 commit 739cc40

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

python/coglet/check.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import ast
2+
import sys
3+
from typing import Optional, Type, TypeVar
4+
5+
T = TypeVar('T', covariant=True)
6+
7+
8+
def find(nodes: list, tpe: Type[T], attr: str, name: str) -> Optional[T]:
9+
for n in nodes:
10+
if type(n) is tpe and getattr(n, attr) == name:
11+
return n
12+
return None
13+
14+
15+
def check(file: str, predictor: str) -> None:
16+
with open(file, 'r') as f:
17+
content = f.read()
18+
lines = content.splitlines()
19+
root = ast.parse(content)
20+
21+
p = find(root.body, ast.ClassDef, 'name', predictor)
22+
if p is None:
23+
return
24+
fn = find(p.body, ast.FunctionDef, 'name', 'predict')
25+
if fn is None:
26+
fn = find(p.body, ast.AsyncFunctionDef, 'name', 'predict') # type: ignore
27+
args_and_defaults = zip(fn.args.args[-len(fn.args.defaults) :], fn.args.defaults) # type: ignore
28+
none_defaults = []
29+
for a, d in args_and_defaults:
30+
if type(a.annotation) is not ast.Name:
31+
continue
32+
if type(d) is not ast.Call or d.func.id != 'Input': # type: ignore
33+
continue
34+
v = find(d.keywords, ast.keyword, 'arg', 'default')
35+
if v is None or type(v.value) is not ast.Constant:
36+
continue
37+
if v.value.value is None:
38+
pos = f'{file}:{a.lineno}:{a.col_offset}'
39+
40+
# Add `Optional[]` to type annotation
41+
# No need to remove `default=None` since `x: Optional[T] = Input(default=None)` is valid
42+
ta = a.annotation
43+
l = lines[ta.lineno - 1]
44+
parts = l[:ta.col_offset], l[ta.col_offset:ta.end_col_offset], l[ta.end_col_offset:]
45+
l = f'{parts[0]}Optional[{parts[1]}]{parts[2]}'
46+
lines[ta.lineno - 1] = l
47+
48+
none_defaults.append(f'{pos}: {a.arg}: {ta.id}={ast.unparse(d)}')
49+
50+
if len(none_defaults) > 0:
51+
print('Default value of None without explicit Optional[T] type hint is ambiguous and deprecated, for example:', file=sys.stderr)
52+
print('- x: str=Input(default=None)', file=sys.stderr)
53+
print('+ x: Optional[str]=Input(default=None)', file=sys.stderr)
54+
print(file=sys.stderr)
55+
for l in none_defaults:
56+
print(l, file=sys.stderr)
57+
58+
# Check for `from typing import Optional`
59+
imports = find(root.body, ast.ImportFrom, 'module', 'typing')
60+
if imports is None or 'Optional' not in [n.name for n in imports.names]:
61+
# Missing import, add it at beginning of file or before first import
62+
# Skip `#!/usr/bin/env python3` or comments
63+
lno = 1
64+
while lines[lno - 1].startswith('#'):
65+
lno += 1
66+
for n in root.body:
67+
if type(n) in {ast.Import, ast.ImportFrom}:
68+
lno = n.lineno
69+
break
70+
lines = lines[:lno - 1] + ['from typing import Optional'] + lines[lno - 1:]
71+
print('\n'.join(lines))
72+
73+
74+
check(sys.argv[1], sys.argv[2])
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Optional
2+
3+
from cog import BasePredictor, Input
4+
5+
6+
class Predictor(BasePredictor):
7+
def predict(
8+
self,
9+
s1: str,
10+
s2: Optional[str],
11+
s3: str = Input(),
12+
s4: Optional[str] = Input(),
13+
s5: str = Input(default=None),
14+
s6: str = Input(default=None, description='s6'),
15+
s7: str = Input(description='s7', default=None),
16+
s8: Optional[str] = Input(default=None),
17+
) -> str:
18+
return f'{s1}:{s2}:{s3}:{s4}:{s5}:{s6}'

0 commit comments

Comments
 (0)