Skip to content

Commit 7f77787

Browse files
committed
fix: support default annotation of OP function
fix: replace argo function in debug mode Signed-off-by: zjgemi <[email protected]>
1 parent d17244c commit 7f77787

File tree

3 files changed

+46
-28
lines changed

3 files changed

+46
-28
lines changed

src/dflow/io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,9 @@ def __repr__(self):
214214
def __getitem__(self, i):
215215
if config["mode"] == "debug":
216216
if isinstance(i, str):
217-
return Expression("%s['%s']" % (self.expr, i))
217+
return ArgoVar("%s['%s']" % (self.expr, i))
218218
elif isinstance(i, int):
219-
return Expression("%s[%s]" % (self.expr, i))
219+
return ArgoVar("%s[%s]" % (self.expr, i))
220220
if isinstance(i, str):
221221
item = "jsonpath(%s, '$')['%s']" % (self.expr, i)
222222
else:

src/dflow/python/op.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from functools import partial
1212
from importlib import import_module
1313
from pathlib import Path
14-
from typing import Dict, List, Set, Union
14+
from typing import Any, Dict, List, Set, Union
1515

1616
from ..argo_objects import ArgoObjectDict
1717
from ..config import config
@@ -209,11 +209,7 @@ def function(cls, func=None, **kwargs):
209209
if func is None:
210210
return partial(cls.function, **kwargs)
211211

212-
signature = func.__annotations__
213-
return_type = signature.get('return', None)
214-
input_sign = OPIOSign(
215-
{k: v for k, v in signature.items() if k != 'return'})
216-
output_sign, ret2opio, opio2ret = type2opiosign(return_type)
212+
input_sign, output_sign, ret2opio, opio2ret = get_sign_from_func(func)
217213

218214
class subclass(cls):
219215
task_kwargs = {}
@@ -272,11 +268,7 @@ def superfunction(cls, func=None, **kwargs):
272268
if func is None:
273269
return partial(cls.superfunction, **kwargs)
274270

275-
signature = func.__annotations__
276-
return_type = signature.get('return', None)
277-
input_sign = OPIOSign(
278-
{k: v for k, v in signature.items() if k != 'return'})
279-
output_sign, ret2opio, opio2ret = type2opiosign(return_type)
271+
input_sign, output_sign, ret2opio, opio2ret = get_sign_from_func(func)
280272

281273
from ..dag import DAG
282274
from ..task import Task
@@ -432,6 +424,29 @@ def handle_outputs(self, outputs, symlink=False):
432424
name, outputs[name], sign, slices, self.tmp_root)
433425

434426

427+
def get_sign_from_func(func):
428+
signature = inspect.signature(func)
429+
input_sign = OPIOSign()
430+
for parameter in signature.parameters.values():
431+
_type = parameter.annotation
432+
if _type is inspect._empty:
433+
_type = Any
434+
if parameter.default is not inspect._empty:
435+
if isinstance(_type, Artifact):
436+
if parameter.default is None:
437+
_type.optional = True
438+
elif isinstance(_type, (Parameter, BigParameter)):
439+
_type.default = parameter.default
440+
else:
441+
_type = Parameter(_type, default=parameter.default)
442+
input_sign[parameter.name] = _type
443+
return_type = signature.return_annotation
444+
if return_type is inspect._empty:
445+
return_type = None
446+
output_sign, ret2opio, opio2ret = type2opiosign(return_type)
447+
return input_sign, output_sign, ret2opio, opio2ret
448+
449+
435450
def type2opiosign(t):
436451
from typing import Tuple
437452
try:

src/dflow/step.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def argo_range(
6666
Each argument can be Argo parameter
6767
"""
6868
if config["mode"] == "debug":
69-
return Expression("list(range(%s))" % ", ".join(
69+
return expression("list(range(%s))" % ", ".join(
7070
map(lambda x: "int(%s)" % to_expr(x), args)))
7171
start = 0
7272
step = 1
@@ -198,7 +198,7 @@ def argo_len(
198198
param: the Argo parameter which is a list
199199
"""
200200
if config["mode"] == "debug":
201-
return Expression("len(%s)" % to_expr(param))
201+
return expression("len(%s)" % to_expr(param))
202202
return ArgoLen(param)
203203

204204

@@ -237,7 +237,7 @@ def argo_enumerate(
237237
if config["mode"] == "debug":
238238
values = "".join([", '%s': %s[i]" % (k, to_expr(v))
239239
for k, v in kwargs.items()])
240-
expr = Expression("[{'order': i%s} for i in range(len(%s))]" % (
240+
expr = expression("[{'order': i%s} for i in range(len(%s))]" % (
241241
values, to_expr(list(kwargs.values())[0])))
242242
expr.kwargs = kwargs
243243
return expr
@@ -1476,7 +1476,7 @@ def run(self, scope, context=None, order=None):
14761476
elif isinstance(self.when, (InputParameter, OutputParameter)):
14771477
value = get_var(self.when, scope).value
14781478
elif isinstance(self.when, ArgoVar):
1479-
value = Expression(self.when.expr).eval(scope)
1479+
value = expression(self.when.expr).eval(scope)
14801480
elif isinstance(self.when, str):
14811481
value = eval_expr(render_expr(self.when, scope))
14821482
if not value:
@@ -1517,7 +1517,7 @@ def handle_expr(val, scope):
15171517
elif isinstance(value, (InputParameter, OutputParameter)):
15181518
par.value = get_var(value, scope).value
15191519
elif isinstance(value, ArgoVar):
1520-
par.value = Expression(value.expr).eval(scope)
1520+
par.value = expression(value.expr).eval(scope)
15211521
elif isinstance(value, str):
15221522
par.value = render_expr(value, scope)
15231523
else:
@@ -1542,8 +1542,7 @@ def handle_expr(val, scope):
15421542
OutputParameter)):
15431543
item_list = self.with_param.value
15441544
elif isinstance(self.with_param, ArgoVar):
1545-
item_list = Expression(replace_argo_func(
1546-
self.with_param.expr)).eval(scope)
1545+
item_list = expression(self.with_param.expr).eval(scope)
15471546
if isinstance(item_list, str):
15481547
item_list = eval(item_list)
15491548
elif isinstance(self.with_param, str):
@@ -1560,15 +1559,15 @@ def handle_expr(val, scope):
15601559
elif isinstance(start, (InputParameter, OutputParameter)):
15611560
start = start.value
15621561
elif isinstance(start, ArgoVar):
1563-
start = int(Expression(start.expr).eval(scope))
1562+
start = int(expression(start.expr).eval(scope))
15641563
if self.with_sequence.count is not None:
15651564
count = self.with_sequence.count
15661565
if isinstance(count, Expression):
15671566
count = int(count.eval(scope))
15681567
elif isinstance(count, (InputParameter, OutputParameter)):
15691568
count = count.value
15701569
elif isinstance(count, ArgoVar):
1571-
count = int(Expression(count.expr).eval(scope))
1570+
count = int(expression(count.expr).eval(scope))
15721571
sequence = list(range(start, start + count))
15731572
if self.with_sequence.end is not None:
15741573
end = self.with_sequence.end
@@ -1577,7 +1576,7 @@ def handle_expr(val, scope):
15771576
elif isinstance(end, (InputParameter, OutputParameter)):
15781577
end = end.value
15791578
elif isinstance(end, ArgoVar):
1580-
end = int(Expression(end.expr).eval(scope))
1579+
end = int(expression(end.expr).eval(scope))
15811580
if end >= start:
15821581
sequence = list(range(start, end + 1))
15831582
else:
@@ -1973,8 +1972,8 @@ def exec_steps(self, scope, parameters, item=None, context=None,
19731972
raise e
19741973
elif par1.value_from_expression is not None:
19751974
if isinstance(par1.value_from_expression, str):
1976-
expr = replace_argo_func(par1.value_from_expression)
1977-
par1.value_from_expression = Expression(expr)
1975+
par1.value_from_expression = expression(
1976+
par1.value_from_expression)
19781977
par.value = par1.value_from_expression.eval(steps)
19791978

19801979
for name, art in self.outputs.artifacts.items():
@@ -1988,8 +1987,7 @@ def exec_steps(self, scope, parameters, item=None, context=None,
19881987
art.local_path = get_var(art1._from, steps).local_path
19891988
elif art1.from_expression is not None:
19901989
if isinstance(art1.from_expression, str):
1991-
expr = replace_argo_func(art1.from_expression)
1992-
art1.from_expression = Expression(expr)
1990+
art1.from_expression = expression(art1.from_expression)
19931991
art.local_path = art1.from_expression.eval(steps)
19941992

19951993
self.record_output_parameters(stepdir, self.outputs.parameters)
@@ -2262,7 +2260,7 @@ def render_expr(expr, scope):
22622260
while i >= 0:
22632261
j = expr.find("}}", i+2)
22642262
if expr[i:i+3] == "{{=":
2265-
value = Expression(replace_argo_func(expr[i+3:j])).eval(scope)
2263+
value = expression(expr[i+3:j]).eval(scope)
22662264
value = value if isinstance(value, str) else \
22672265
jsonpickle.dumps(value)
22682266
expr = expr[:i] + value.strip() + expr[j+2:]
@@ -2390,6 +2388,10 @@ def backup(path):
23902388
shutil.move(path, bk)
23912389

23922390

2391+
def expression(expr):
2392+
return Expression(replace_argo_func(expr))
2393+
2394+
23932395
def replace_argo_func(expr):
23942396
i = expr.find("toJson(map(sprig.untilStep(0, ")
23952397
j = expr.find(", 1), { {'order': #")
@@ -2412,6 +2414,7 @@ def replace_argo_func(expr):
24122414
expr = expr.replace("jsonpath",
24132415
"(lambda x, y: eval(x) if isinstance(x, str) else x)")
24142416
expr = expr.replace("string", "str")
2417+
expr = expr.replace("sprig.int", "int")
24152418
return expr
24162419

24172420

0 commit comments

Comments
 (0)