Skip to content

Commit 485f6b1

Browse files
committed
fix rr pp prune bug
1 parent 5e6d82a commit 485f6b1

1 file changed

Lines changed: 11 additions & 4 deletions

File tree

  • python/paddle/distributed/auto_tuner

python/paddle/distributed/auto_tuner/prune.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
def log_pruned_info(cur_cfg, pruned_reason, tuner_cfg):
26-
pruned_strategy = "DP{}_MP{}_PP{}_VPP_{}_Sharding{}_Stage{}_MBS{}_Recompute_{}_Granularity_{}".format(
26+
pruned_strategy = "DP{}_MP{}_PP{}_VPP{}_Sharding{}_Stage{}_MBS{}_Recompute_{}_Granularity_{}".format(
2727
cur_cfg["dp_degree"],
2828
cur_cfg["mp_degree"],
2929
cur_cfg["pp_degree"],
@@ -834,10 +834,12 @@ def prune_by_refined_recompute(tuner_cfg, cur_cfg, history_cfgs=[]):
834834
pp_degree = cur_cfg["pp_degree"]
835835
recompute = cur_cfg["use_recompute"]
836836
recompute_granularity = cur_cfg["recompute_granularity"]
837+
compare = [cur_cfg[item] for item in rr]
837838
if recompute:
838839
if recompute_granularity and recompute_granularity != "full":
839-
return True
840-
if pp_degree == 1:
840+
if compare.count(0) != len(compare):
841+
return True
842+
if pp_degree == 1 and compare.count(0) != len(compare):
841843
return True
842844
if tuner_cfg["model_cfg"]["num_layers"] % pp_degree != 0:
843845
return True
@@ -873,7 +875,12 @@ def prune_by_refined_recompute_history(
873875
log_pruned_info(cur_cfg, pruned_reason, tuner_cfg)
874876
cur_cfg["time"] = cfg["time"]
875877
return True
876-
if cfg[item] > cur_cfg[item] and cfg.get("time", -1) > 0:
878+
if (
879+
cfg[item] > cur_cfg[item]
880+
and cfg.get("time", -1) > 0
881+
and cfg["use_recompute"]
882+
and cur_cfg["use_recompute"]
883+
):
877884
pruned_reason = f"{item} {cur_cfg[item]} may be slower because {cfg[item]} has been already runnable."
878885
log_pruned_info(cur_cfg, pruned_reason, tuner_cfg)
879886
cur_cfg["time"] = cfg["time"]

0 commit comments

Comments
 (0)