Skip to content

Commit 7446003

Browse files
committed
fix: support get sub steps of a specific step
Signed-off-by: zjgemi <[email protected]>
1 parent a3e7ee1 commit 7446003

File tree

2 files changed

+39
-8
lines changed

2 files changed

+39
-8
lines changed

src/dflow/argo_objects.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import time
77
from collections import UserDict, UserList
88
from copy import deepcopy
9-
from typing import Any, List, Union
9+
from typing import Any, List, Optional, Union
1010

1111
from .common import jsonpickle
1212
from .config import config, s3_config
@@ -359,6 +359,7 @@ def get_step(
359359
phase: Union[str, List[str]] = None,
360360
id: Union[str, List[str]] = None,
361361
type: Union[str, List[str]] = None,
362+
parent_id: Optional[str] = None,
362363
sort_by_generation: bool = False,
363364
) -> List[ArgoStep]:
364365
if name is not None and not isinstance(name, list):
@@ -373,7 +374,11 @@ def get_step(
373374
type = [type]
374375
step_list = []
375376
if hasattr(self.status, "nodes"):
376-
for step in self.status.nodes.values():
377+
if parent_id is not None:
378+
nodes = self.get_sub_nodes(parent_id)
379+
else:
380+
nodes = self.status.nodes.values()
381+
for step in nodes:
377382
if step["startedAt"] is None:
378383
continue
379384
if name is not None and not match(step["displayName"], name):
@@ -396,6 +401,8 @@ def get_step(
396401
continue
397402
step = ArgoStep(step, self.metadata.name)
398403
step_list.append(step)
404+
else:
405+
return []
399406
if sort_by_generation:
400407
self.generation = {}
401408
self.record_generation(self.id, 0)
@@ -405,13 +412,35 @@ def get_step(
405412
step_list.sort(key=lambda x: x["startedAt"])
406413
return step_list
407414

415+
def get_sub_nodes(self, node_id):
416+
assert node_id in self.status.nodes
417+
node = self.status.nodes[node_id]
418+
if node["type"] not in ["Steps", "DAG"]:
419+
return [node]
420+
if node.get("memoizationStatus", {}).get("hit", False):
421+
return [node]
422+
sub_nodes = []
423+
outbound_nodes = node.get("outboundNodes", [])
424+
children = node.get("children", [])
425+
# order by generation (BFS)
426+
current_generation = children
427+
while len(current_generation) > 0:
428+
for id in current_generation:
429+
sub_nodes.append(self.status.nodes[id])
430+
next_generation = []
431+
for id in current_generation:
432+
if id not in outbound_nodes:
433+
next_generation += self.status.nodes[id].get(
434+
"children", [])
435+
current_generation = next_generation
436+
return sub_nodes
437+
408438
def record_generation(self, node_id, generation):
409439
self.generation[node_id] = generation
410-
if "children" in self.status.nodes[node_id]:
411-
for child in self.status.nodes[node_id]["children"]:
412-
if child in self.generation:
413-
continue
414-
self.record_generation(child, generation+1)
440+
for child in self.status.nodes[node_id].get("children", []):
441+
if child in self.generation:
442+
continue
443+
self.record_generation(child, generation+1)
415444

416445
def get_duration(self) -> datetime.timedelta:
417446
return get_duration(self.status)

src/dflow/workflow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,7 @@ def query_step(
999999
phase: Union[str, List[str]] = None,
10001000
id: Union[str, List[str]] = None,
10011001
type: Union[str, List[str]] = None,
1002+
parent_id: Optional[str] = None,
10021003
sort_by_generation: bool = False,
10031004
) -> List[ArgoStep]:
10041005
"""
@@ -1017,6 +1018,7 @@ def query_step(
10171018
phase: filter by phase of step
10181019
id: filter by id of step
10191020
type: filter by type of step
1021+
parent_id: get sub steps of a specific step
10201022
sort_by_generation: sort results by the number of generation from
10211023
the root node
10221024
Returns:
@@ -1105,7 +1107,7 @@ def query_step(
11051107

11061108
return self.query().get_step(
11071109
name=name, key=key, phase=phase, id=id, type=type,
1108-
sort_by_generation=sort_by_generation)
1110+
parent_id=parent_id, sort_by_generation=sort_by_generation)
11091111

11101112
def query_keys_of_steps(
11111113
self,

0 commit comments

Comments
 (0)