Skip to content

Commit 54c46f2

Browse files
committed
fix: support set output before OP returns
Signed-off-by: zjgemi <[email protected]>
2 parents d887fce + f68a848 commit 54c46f2

File tree

6 files changed

+131
-35
lines changed

6 files changed

+131
-35
lines changed

CHANGELOG.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,45 @@
11
# CHANGELOG
22

3+
## [1.8.98](https://github.com/deepmodeling/dflow/compare/v1.8.97...v1.8.98) (2024-10-31)
4+
5+
6+
### Bug Fixes
7+
8+
* support specify hooks after defining step ([3c016b7](https://github.com/deepmodeling/dflow/commit/3c016b7d42d1c547cce20888959339ae4d019735))
9+
10+
## [1.8.97](https://github.com/deepmodeling/dflow/compare/v1.8.96...v1.8.97) (2024-10-25)
11+
12+
13+
### Bug Fixes
14+
15+
* add raise_for_group to Slices ([a3e7ee1](https://github.com/deepmodeling/dflow/commit/a3e7ee11413eb9c0bebed7d59ce84efe52b7b196))
16+
* add sort_by_generation to query_step ([a3e7ee1](https://github.com/deepmodeling/dflow/commit/a3e7ee11413eb9c0bebed7d59ce84efe52b7b196))
17+
* hooks of task ([67b5876](https://github.com/deepmodeling/dflow/commit/67b58766845dd8d98443e89a5b0a56558c21291d))
18+
* support get sub steps of a specific step ([7446003](https://github.com/deepmodeling/dflow/commit/74460034816a7a2ba3c7af44ce9581cda8c923f2))
19+
20+
## [1.8.96](https://github.com/deepmodeling/dflow/compare/v1.8.95...v1.8.96) (2024-10-23)
21+
22+
23+
### Bug Fixes
24+
25+
* add lifecycle hooks to step/task ([9030946](https://github.com/deepmodeling/dflow/commit/9030946e0f0edcffe5e806197bba4d3afa74b6e3))
26+
* add onExit hook ([acbd12d](https://github.com/deepmodeling/dflow/commit/acbd12d23e083d739290a57d04b241b522b9e063))
27+
28+
## [1.8.95](https://github.com/deepmodeling/dflow/compare/v1.8.94...v1.8.95) (2024-10-18)
29+
30+
31+
### Bug Fixes
32+
33+
* pass None for HDF5Datasets ([1310554](https://github.com/deepmodeling/dflow/commit/1310554ead008acee1754112ef6032c9a062dfd9))
34+
35+
## [1.8.94](https://github.com/deepmodeling/dflow/compare/v1.8.93...v1.8.94) (2024-10-17)
36+
37+
38+
### Bug Fixes
39+
40+
* HDF5Datasets with grouped slices ([1bfae67](https://github.com/deepmodeling/dflow/commit/1bfae67ab6533c9dddf8066a40919c56bb7c53b5))
41+
* None in HDF5Datasets ([1bfae67](https://github.com/deepmodeling/dflow/commit/1bfae67ab6533c9dddf8066a40919c56bb7c53b5))
42+
343
## [1.8.93](https://github.com/deepmodeling/dflow/compare/v1.8.92...v1.8.93) (2024-09-20)
444

545

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.8.93
1+
1.8.98

src/dflow/plugins/dispatcher.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class DispatcherExecutor(Executor):
5858
remote_root: remote root path for working
5959
retry_on_submission_error: max retries on submission error
6060
merge_sliced_step: handle multi slices in a single dispatcher job
61+
terminate_grace_period: grace period in seconds after termination
62+
for collecting outputs
6163
"""
6264

6365
def __init__(self,
@@ -89,6 +91,7 @@ def __init__(self,
8991
remove_scheduling_strategies: bool = True,
9092
envs: Optional[Dict[str, str]] = None,
9193
merge_bohrium_job_group: bool = False,
94+
terminate_grace_period: Optional[int] = None,
9295
) -> None:
9396
self.host = host
9497
self.queue_name = queue_name
@@ -127,6 +130,7 @@ def __init__(self,
127130
self.remove_scheduling_strategies = remove_scheduling_strategies
128131
self.envs = envs
129132
self.merge_bohrium_job_group = merge_bohrium_job_group
133+
self.terminate_grace_period = terminate_grace_period
130134

131135
conf = {}
132136
if json_file is not None:
@@ -541,6 +545,10 @@ def render(self, template):
541545
new_template.script += " print('Got SIGTERM, kill unfinished tasks"\
542546
"!')\n"
543547
new_template.script += " submission.remove_unfinished_tasks()\n"
548+
if self.terminate_grace_period:
549+
new_template.script += " import time\n"
550+
new_template.script += " time.sleep(%s)\n" % \
551+
self.terminate_grace_period
544552
new_template.script += "import signal\n"
545553
new_template.script += "signal.signal(signal.SIGTERM, sigterm_handler"\
546554
")\n"

src/dflow/python/op.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
OutputParameter, type_to_str)
2121
from ..utils import dict2list, get_key, randstr, s3_config
2222
from .opio import OPIO, Artifact, BigParameter, OPIOSign, Parameter
23+
from .utils import handle_output_artifact, handle_output_parameter
2324
from .vendor.typeguard import check_type
2425

2526
iwd = os.getcwd()
@@ -48,6 +49,10 @@ class OP(ABC):
4849
progress_current = 0
4950
key = None
5051
workflow_name = None
52+
outputs = {}
53+
slices = {}
54+
tmp_root = "/tmp"
55+
create_slice_dir = False
5156

5257
def __init__(
5358
self,
@@ -400,6 +405,26 @@ def from_graph(cls, graph):
400405
op = getattr(mod, graph["name"])
401406
return op
402407

408+
def set_output(self, name, value):
409+
self.outputs[name] = value
410+
self.handle_outputs({name: value}, symlink=True)
411+
412+
def handle_outputs(self, outputs, symlink=False):
413+
os.makedirs("%s/outputs/parameters" % self.tmp_root, exist_ok=True)
414+
os.makedirs("%s/outputs/artifacts" % self.tmp_root, exist_ok=True)
415+
output_sign = self.get_output_sign()
416+
for name in outputs:
417+
sign = output_sign[name]
418+
if isinstance(sign, Artifact):
419+
slices = self.slices.get(name)
420+
handle_output_artifact(
421+
name, outputs[name], sign, slices, self.tmp_root,
422+
self.create_slice_dir and slices, symlink=symlink)
423+
else:
424+
slices = self.slices.get(name)
425+
handle_output_parameter(
426+
name, outputs[name], sign, slices, self.tmp_root)
427+
403428

404429
def type2opiosign(t):
405430
from typing import Tuple

src/dflow/python/python_op_template.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,21 @@ def render_script(self):
593593
"%s, r'%s')\n" % (name, name, name, name, slices,
594594
self.tmp_root)
595595

596+
script += " op_obj.tmp_root = '%s'\n" % self.tmp_root
597+
script += " op_obj.create_slice_dir = %s\n" % (
598+
self.create_slice_dir and getattr(self.slices, "pool_size", None)
599+
is None)
600+
for name, sign in output_sign.items():
601+
if isinstance(sign, Artifact):
602+
slices = self.get_slices(output_artifact_slices, name)
603+
script += " op_obj.slices['%s'] = %s\n" % (name, slices)
604+
605+
script += " import signal\n"
606+
script += " def sigterm_handler(signum, frame):\n"
607+
script += " print('Got SIGTERM')\n"
608+
script += " raise RuntimeError('Got SIGTERM')\n"
609+
script += " signal.signal(signal.SIGTERM, sigterm_handler)\n"
610+
596611
if self.slices is not None and self.slices.pool_size is not None:
597612
sliced_inputs = self.slices.input_artifact + \
598613
self.slices.input_parameter
@@ -651,7 +666,7 @@ def render_script(self):
651666
script += " if o is not None:\n"
652667
script += " output = o\n"
653668
for name in sliced_outputs:
654-
script += " output['%s'] = [o['%s'] if o is not None"\
669+
script += " output['%s'] = [o.get('%s') if o is not None"\
655670
" else None for o in output_list]\n" % (name, name)
656671
if isinstance(output_sign[name], Artifact):
657672
if output_sign[name].type == str:
@@ -662,31 +677,21 @@ def render_script(self):
662677
"]\n" % name
663678
else:
664679
script += " try:\n"
665-
script += " output = op_obj.execute(input)\n"
680+
script += " try:\n"
681+
script += " output = op_obj.execute(input)\n"
682+
script += " except Exception as e:\n"
683+
script += " if op_obj.outputs:\n"
684+
script += " op_obj.handle_outputs(op_obj.outputs)\n"
685+
script += " raise e\n"
666686
script += " except TransientError:\n"
667687
script += " traceback.print_exc()\n"
668688
script += " sys.exit(1)\n"
669689
script += " except FatalError:\n"
670690
script += " traceback.print_exc()\n"
671691
script += " sys.exit(2)\n"
672692

673-
script += " os.makedirs(r'%s/outputs/parameters', exist_ok=True)\n"\
674-
% self.tmp_root
675-
script += " os.makedirs(r'%s/outputs/artifacts', exist_ok=True)\n" \
676-
% self.tmp_root
677-
for name, sign in output_sign.items():
678-
if isinstance(sign, Artifact):
679-
slices = self.get_slices(output_artifact_slices, name)
680-
script += " handle_output_artifact('%s', output['%s'], "\
681-
"output_sign['%s'], %s, r'%s', %s)\n" % (
682-
name, name, name, slices, self.tmp_root,
683-
slices is not None and self.create_slice_dir and
684-
getattr(self.slices, "pool_size", None) is None)
685-
else:
686-
slices = self.get_slices(output_parameter_slices, name)
687-
script += " handle_output_parameter('%s', output['%s'], "\
688-
"output_sign['%s'], %s, r'%s')\n" % (name, name, name,
689-
slices, self.tmp_root)
693+
script += " op_obj.handle_outputs(output)\n"
694+
690695
if config["register_tasks"]:
691696
if self.slices is not None and self.slices.register_first_only:
692697
if "{{item}}" in self.dflow_vars:

src/dflow/python/utils.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ def slice_to_dir(slice):
210210

211211

212212
def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp",
213-
create_dir=False):
213+
create_dir=False, symlink=False):
214+
if os.path.isdir(data_root + '/outputs/artifacts/' + name):
215+
shutil.rmtree(data_root + '/outputs/artifacts/' + name)
214216
path_list = []
215217
if sign.type == HDF5Datasets:
216218
import h5py
@@ -268,15 +270,16 @@ def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp",
268270
slices = 0
269271
path_list.append(copy_results_and_return_path_item(
270272
value, name, slices, data_root,
271-
slice_to_dir(slices) if create_dir else None))
273+
slice_to_dir(slices) if create_dir else None, symlink=symlink))
272274
elif sign.type in [List[str], List[Path], Set[str], Set[Path]]:
273275
os.makedirs(data_root + '/outputs/artifacts/' + name, exist_ok=True)
274276
if slices is not None:
275277
if isinstance(slices, int):
276278
for path in value:
277279
path_list.append(copy_results_and_return_path_item(
278280
path, name, slices, data_root,
279-
slice_to_dir(slices) if create_dir else None))
281+
slice_to_dir(slices) if create_dir else None,
282+
symlink=symlink))
280283
else:
281284
assert len(slices) == len(value)
282285
for path, s in zip(value, slices):
@@ -285,25 +288,26 @@ def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp",
285288
path_list.append(
286289
copy_results_and_return_path_item(
287290
p, name, s, data_root, slice_to_dir(
288-
s) if create_dir else None))
291+
s) if create_dir else None,
292+
symlink=symlink))
289293
else:
290294
path_list.append(copy_results_and_return_path_item(
291295
path, name, s, data_root, slice_to_dir(
292-
s) if create_dir else None))
296+
s) if create_dir else None, symlink=symlink))
293297
else:
294298
for s, path in enumerate(value):
295299
path_list.append(copy_results_and_return_path_item(
296-
path, name, s, data_root))
300+
path, name, s, data_root, symlink=symlink))
297301
elif sign.type in [Dict[str, str], Dict[str, Path]]:
298302
os.makedirs(data_root + '/outputs/artifacts/' + name, exist_ok=True)
299303
for s, path in value.items():
300304
path_list.append(copy_results_and_return_path_item(
301-
path, name, s, data_root))
305+
path, name, s, data_root, symlink=symlink))
302306
elif sign.type in [NestedDict[str], NestedDict[Path]]:
303307
os.makedirs(data_root + '/outputs/artifacts/' + name, exist_ok=True)
304308
for s, path in flatten(value).items():
305309
path_list.append(copy_results_and_return_path_item(
306-
path, name, s, data_root))
310+
path, name, s, data_root, symlink=symlink))
307311

308312
os.makedirs(data_root + "/outputs/artifacts/%s/%s" % (name, config[
309313
"catalog_dir_name"]), exist_ok=True)
@@ -343,15 +347,16 @@ def handle_output_parameter(name, value, sign, slices=None, data_root="/tmp"):
343347

344348

345349
def copy_results_and_return_path_item(path, name, order, data_root="/tmp",
346-
slice_dir=None):
347-
if path and os.path.exists(str(path)):
350+
slice_dir=None, symlink=False):
351+
if (path and os.path.exists(str(path))) or symlink:
348352
return {"dflow_list_item": copy_results(
349-
path, name, data_root, slice_dir), "order": order}
353+
path, name, data_root, slice_dir, symlink), "order": order}
350354
else:
351355
return {"dflow_list_item": None, "order": order}
352356

353357

354-
def copy_results(source, name, data_root="/tmp", slice_dir=None):
358+
def copy_results(source, name, data_root="/tmp", slice_dir=None,
359+
symlink=False):
355360
source = str(source)
356361
# if refer to input artifact
357362
if source.find(data_root + "/inputs/artifacts/") == 0:
@@ -364,7 +369,12 @@ def copy_results(source, name, data_root="/tmp", slice_dir=None):
364369
if slice_dir is not None:
365370
rel_path = "%s/%s" % (slice_dir, rel_path)
366371
target = data_root + "/outputs/artifacts/%s/%s" % (name, rel_path)
367-
copy_file(source, target, shutil.copy)
372+
if symlink:
373+
os.makedirs(os.path.abspath(os.path.dirname(target)),
374+
exist_ok=True)
375+
os.symlink(source, target)
376+
else:
377+
copy_file(source, target, shutil.copy)
368378
if rel_path[:1] == "/":
369379
rel_path = rel_path[1:]
370380
return rel_path
@@ -378,7 +388,12 @@ def copy_results(source, name, data_root="/tmp", slice_dir=None):
378388
if slice_dir is not None:
379389
rel_path = "%s/%s" % (slice_dir, rel_path)
380390
target = data_root + "/outputs/artifacts/%s/%s" % (name, rel_path)
381-
copy_file(source, target)
391+
if symlink:
392+
os.makedirs(os.path.abspath(os.path.dirname(target)),
393+
exist_ok=True)
394+
os.symlink(source, target)
395+
else:
396+
copy_file(source, target)
382397
return rel_path
383398

384399

@@ -456,7 +471,10 @@ def try_to_execute(input, slice_dir, op_obj, output_sign, cwd, timeout=None):
456471
except Exception as e:
457472
traceback.print_exc()
458473
os.chdir(cwd)
459-
return None, e
474+
if op_obj.outputs:
475+
return op_obj.outputs, e
476+
else:
477+
return None, e
460478
finally:
461479
if timeout is not None:
462480
signal.alarm(0)

0 commit comments

Comments
 (0)