Skip to content
This repository was archived by the owner on Apr 15, 2026. It is now read-only.

Commit 911594f

Browse files
committed
Add check for legacy default=None optional input
1 parent 81fcc92 commit 911594f

2 files changed

Lines changed: 77 additions & 0 deletions

File tree

python/coglet/check.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import ast
2+
import sys
3+
from typing import Optional, Type, TypeVar
4+
5+
T = TypeVar('T', covariant=True)
6+
7+
8+
def find(nodes: list, tpe: Type[T], attr: str, name: str) -> Optional[T]:
9+
for n in nodes:
10+
if type(n) is tpe and getattr(n, attr) == name:
11+
return n
12+
return None
13+
14+
15+
def check(file: str, predictor: str) -> None:
16+
with open(file, 'r') as f:
17+
content = f.read()
18+
lines = content.splitlines()
19+
root = ast.parse(content)
20+
21+
p = find(root.body, ast.ClassDef, 'name', predictor)
22+
if p is None:
23+
return
24+
fn = find(p.body, ast.FunctionDef, 'name', 'predict')
25+
if fn is None:
26+
fn = find(p.body, ast.AsyncFunctionDef, 'name', 'predict') # type: ignore
27+
args_and_defaults = zip(fn.args.args[-len(fn.args.defaults) :], fn.args.defaults) # type: ignore
28+
none_defaults = []
29+
for a, d in args_and_defaults:
30+
if type(a.annotation) is not ast.Name:
31+
continue
32+
if type(d) is not ast.Call or d.func.id != 'Input': # type: ignore
33+
continue
34+
v = find(d.keywords, ast.keyword, 'arg', 'default')
35+
if v is None or type(v.value) is not ast.Constant:
36+
continue
37+
if v.value.value is None:
38+
pos = f'{file}:{a.lineno}:{a.col_offset}'
39+
40+
# Add Optional[] to type annotation
41+
# No need to remove `default=None` since `x: Optional[T] = Input(default=None)` is valid
42+
ta = a.annotation
43+
l = lines[ta.lineno - 1]
44+
parts = l[:ta.col_offset], l[ta.col_offset:ta.end_col_offset], l[ta.end_col_offset:]
45+
l = f'{parts[0]}Optional[{parts[1]}]{parts[2]}'
46+
lines[ta.lineno - 1] = l
47+
48+
none_defaults.append(f'{pos}: {a.arg}: {ta.id}={ast.unparse(d)}')
49+
if len(none_defaults) > 0:
50+
print('Default value of None without explicit Optional[T] type hint is ambiguous and deprecated, for example:')
51+
print('- x: str=Input(default=None)')
52+
print('+ x: Optional[str]=Input(default=None)')
53+
print()
54+
for l in none_defaults:
55+
print(l)
56+
print('\n'.join(lines))
57+
58+
59+
check(sys.argv[1], sys.argv[2])
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Optional
2+
3+
from cog import BasePredictor, Input
4+
5+
6+
class Predictor(BasePredictor):
7+
def predict(
8+
self,
9+
s1: str,
10+
s2: Optional[str],
11+
s3: str = Input(),
12+
s4: Optional[str] = Input(),
13+
s5: str = Input(default=None),
14+
s6: str = Input(default=None, description='s6'),
15+
s7: str = Input(description='s7', default=None),
16+
s8: Optional[str] = Input(default=None),
17+
) -> str:
18+
return f'{s1}:{s2}:{s3}:{s4}:{s5}:{s6}'

0 commit comments

Comments
 (0)