diff --git a/src/orion/core/cli/db/rm.py b/src/orion/core/cli/db/rm.py index 9b00e181a..dbbec03f8 100644 --- a/src/orion/core/cli/db/rm.py +++ b/src/orion/core/cli/db/rm.py @@ -32,7 +32,7 @@ TRIALS_RM_MESSAGE = """ -Matching trials of all experiments above will be deleted. +Matching trials ({}) of all experiments above will be deleted. To select a specific version use --version . Note that trials of all children of a given version will be deleted. @@ -110,6 +110,27 @@ def add_subparser(parser): return rm_parser +def get_trial_count(storage, root, status): + """Select the matching trials of the given experiment.""" + trials_total = 0 + for node in root: + if status == "*": + query = {} + else: + query = {"status": status} + + count = len(storage.fetch_trials(uid=node.item.id, where=query)) + logger.debug( + "%d trials selected in experiment %s-v%d", + count, + node.item.name, + node.item.version, + ) + trials_total += count + + return trials_total + + def process_trial_rm(storage, root, status): """Delete the matching trials of the given experiment.""" trials_total = 0 @@ -167,7 +188,8 @@ def delete_experiments(storage, root, name, force): def delete_trials(storage, root, name, status, force): """Delete all matching trials after user confirmation.""" - confirmed = confirm_name(TRIALS_RM_MESSAGE, name, force) + count = get_trial_count(storage, root, status) + confirmed = confirm_name(TRIALS_RM_MESSAGE.format(count), name, force) if not confirmed: print("Confirmation failed, aborting operation.")