|
23 | 23 |
|
24 | 24 |
|
25 | 25 | def log_pruned_info(cur_cfg, pruned_reason, tuner_cfg): |
26 | | - pruned_strategy = "DP{}_MP{}_PP{}_VPP_{}_Sharding{}_Stage{}_MBS{}_Recompute_{}_Granularity_{}".format( |
| 26 | + pruned_strategy = "DP{}_MP{}_PP{}_VPP{}_Sharding{}_Stage{}_MBS{}_Recompute_{}_Granularity_{}".format( |
27 | 27 | cur_cfg["dp_degree"], |
28 | 28 | cur_cfg["mp_degree"], |
29 | 29 | cur_cfg["pp_degree"], |
@@ -834,10 +834,12 @@ def prune_by_refined_recompute(tuner_cfg, cur_cfg, history_cfgs=[]): |
834 | 834 | pp_degree = cur_cfg["pp_degree"] |
835 | 835 | recompute = cur_cfg["use_recompute"] |
836 | 836 | recompute_granularity = cur_cfg["recompute_granularity"] |
| 837 | + compare = [cur_cfg[item] for item in rr] |
837 | 838 | if recompute: |
838 | 839 | if recompute_granularity and recompute_granularity != "full": |
839 | | - return True |
840 | | - if pp_degree == 1: |
| 840 | + if compare.count(0) != len(compare): |
| 841 | + return True |
| 842 | + if pp_degree == 1 and compare.count(0) != len(compare): |
841 | 843 | return True |
842 | 844 | if tuner_cfg["model_cfg"]["num_layers"] % pp_degree != 0: |
843 | 845 | return True |
@@ -873,7 +875,12 @@ def prune_by_refined_recompute_history( |
873 | 875 | log_pruned_info(cur_cfg, pruned_reason, tuner_cfg) |
874 | 876 | cur_cfg["time"] = cfg["time"] |
875 | 877 | return True |
876 | | - if cfg[item] > cur_cfg[item] and cfg.get("time", -1) > 0: |
| 878 | + if ( |
| 879 | + cfg[item] > cur_cfg[item] |
| 880 | + and cfg.get("time", -1) > 0 |
| 881 | + and cfg["use_recompute"] |
| 882 | + and cur_cfg["use_recompute"] |
| 883 | + ): |
877 | 884 | pruned_reason = f"{item} {cur_cfg[item]} may be slower because {cfg[item]} has been already runnable." |
878 | 885 | log_pruned_info(cur_cfg, pruned_reason, tuner_cfg) |
879 | 886 | cur_cfg["time"] = cfg["time"] |
|
0 commit comments