66import time
77from collections import UserDict , UserList
88from copy import deepcopy
9- from typing import Any , List , Union
9+ from typing import Any , List , Optional , Union
1010
1111from .common import jsonpickle
1212from .config import config , s3_config
@@ -359,6 +359,7 @@ def get_step(
359359 phase : Union [str , List [str ]] = None ,
360360 id : Union [str , List [str ]] = None ,
361361 type : Union [str , List [str ]] = None ,
362+ parent_id : Optional [str ] = None ,
362363 sort_by_generation : bool = False ,
363364 ) -> List [ArgoStep ]:
364365 if name is not None and not isinstance (name , list ):
@@ -373,7 +374,11 @@ def get_step(
373374 type = [type ]
374375 step_list = []
375376 if hasattr (self .status , "nodes" ):
376- for step in self .status .nodes .values ():
377+ if parent_id is not None :
378+ nodes = self .get_sub_nodes (parent_id )
379+ else :
380+ nodes = self .status .nodes .values ()
381+ for step in nodes :
377382 if step ["startedAt" ] is None :
378383 continue
379384 if name is not None and not match (step ["displayName" ], name ):
@@ -396,6 +401,8 @@ def get_step(
396401 continue
397402 step = ArgoStep (step , self .metadata .name )
398403 step_list .append (step )
404+ else :
405+ return []
399406 if sort_by_generation :
400407 self .generation = {}
401408 self .record_generation (self .id , 0 )
@@ -405,13 +412,35 @@ def get_step(
405412 step_list .sort (key = lambda x : x ["startedAt" ])
406413 return step_list
407414
415+ def get_sub_nodes (self , node_id ):
416+ assert node_id in self .status .nodes
417+ node = self .status .nodes [node_id ]
418+ if node ["type" ] not in ["Steps" , "DAG" ]:
419+ return [node ]
420+ if node .get ("memoizationStatus" , {}).get ("hit" , False ):
421+ return [node ]
422+ sub_nodes = []
423+ outbound_nodes = node .get ("outboundNodes" , [])
424+ children = node .get ("children" , [])
425+ # order by generation (BFS)
426+ current_generation = children
427+ while len (current_generation ) > 0 :
428+ for id in current_generation :
429+ sub_nodes .append (self .status .nodes [id ])
430+ next_generation = []
431+ for id in current_generation :
432+ if id not in outbound_nodes :
433+ next_generation += self .status .nodes [id ].get (
434+ "children" , [])
435+ current_generation = next_generation
436+ return sub_nodes
437+
408438 def record_generation (self , node_id , generation ):
409439 self .generation [node_id ] = generation
410- if "children" in self .status .nodes [node_id ]:
411- for child in self .status .nodes [node_id ]["children" ]:
412- if child in self .generation :
413- continue
414- self .record_generation (child , generation + 1 )
440+ for child in self .status .nodes [node_id ].get ("children" , []):
441+ if child in self .generation :
442+ continue
443+ self .record_generation (child , generation + 1 )
415444
416445 def get_duration (self ) -> datetime .timedelta :
417446 return get_duration (self .status )
0 commit comments