diff --git a/python/paddle/distributed/auto_tuner/recorder.py b/python/paddle/distributed/auto_tuner/recorder.py index 69c53001832625..006093a348d4a4 100644 --- a/python/paddle/distributed/auto_tuner/recorder.py +++ b/python/paddle/distributed/auto_tuner/recorder.py @@ -51,33 +51,53 @@ def sort_metric(self, direction, metric_name) -> None: reverse=False, ) - def get_best(self, metric, direction, mode=None) -> Tuple[dict, bool]: + def get_best( + self, metric, direction, buffer=None, max_mem_usage=None + ) -> Tuple[dict, bool]: self.sort_metric(direction=direction, metric_name=metric) if len(self.history) == 0: - return (self.history[0], True) - if mode == "SFT" or mode == "LoRA" or mode == "Pretrain": - best_cfg = self.history[0] - if ( - isinstance(best_cfg["max_mem_usage"], str) - or best_cfg["time"] == -1 - ): - return (best_cfg, True) - first_few = 0 + return (None, True) + + best_cfg = self.history[0] + if isinstance(best_cfg["max_mem_usage"], str) or best_cfg["time"] == -1: + return (best_cfg, True) + + if buffer is not None: + if buffer < 0: + raise ValueError("The buffer should be not less than 0.") + assert ( + max_mem_usage is not None + ), "max_mem_usage cannot be None when buffer is greater than 0." + if max_mem_usage <= 0: + raise ValueError("max_mem_usage should be greater than 0.") + for cfg in self.history: + if ( + not best_cfg["max_mem_usage"] + and cfg["max_mem_usage"] + and not isinstance(cfg["max_mem_usage"], str) + and cfg["time"] != -1 + ): + best_cfg = cfg + continue + if ( not isinstance(cfg["max_mem_usage"], str) + and cfg["max_mem_usage"] and cfg["max_mem_usage"] < best_cfg["max_mem_usage"] and cfg["time"] != -1 ): best_cfg = cfg - first_few += 1 - if first_few >= 3: + + if ( + not isinstance(cfg["max_mem_usage"], str) + and cfg["max_mem_usage"] + and cfg["max_mem_usage"] < max_mem_usage - buffer + and cfg["time"] != -1 + ): break return (best_cfg, False) - if isinstance(self.history[0]["max_mem_usage"], str) or ( - "time" in self.history[0] and self.history[0]["time"] == -1 - ): - return (self.history[0], True) + return (self.history[0], False) def _store_history_impl(self, data, path="./history.csv"): diff --git a/python/paddle/distributed/launch/main.py b/python/paddle/distributed/launch/main.py index f5c0f8d7f1671a..80f082260d1109 100644 --- a/python/paddle/distributed/launch/main.py +++ b/python/paddle/distributed/launch/main.py @@ -436,6 +436,10 @@ def launch(): # max_search_time max_search_time = tuner_cfg.get("max_search_time", None) + # buffer and memory + buffer = tuner_cfg.get("buffer", None) + max_mem_usage = tuner_cfg.get("max_mem_usage", None) + is_first_task = True # build history recorder recorder = HistoryRecorder(tuner_cfg) @@ -681,6 +685,8 @@ def launch(): cur_best_cfgs, err = recorder.get_best( metric=tuner_cfg['metric_cfg']['name'], direction=tuner_cfg['metric_cfg']['OptimizationDirection'], + buffer=buffer, + max_mem_usage=max_mem_usage, ) if not err: ctx.logger.info(f"Current best config: {cur_best_cfgs}") @@ -781,6 +787,8 @@ def launch(): direction=tuner_cfg['metric_cfg'][ 'OptimizationDirection' ], + buffer=buffer, + max_mem_usage=max_mem_usage, ) if not err: ctx.logger.info(f"Current best config: {cur_best_cfgs}") @@ -1158,7 +1166,8 @@ def launch(): cur_best_cfgs, err = recorder.get_best( metric=tuner_cfg['metric_cfg']['name'], direction=tuner_cfg['metric_cfg']['OptimizationDirection'], - mode=mode, + buffer=buffer, + max_mem_usage=max_mem_usage, ) if not err: ctx.logger.info(f"Current best config: {cur_best_cfgs}") @@ -1206,7 +1215,8 @@ def launch(): best_cfg, err = recorder.get_best( metric=tuner_cfg['metric_cfg']['name'], direction=tuner_cfg['metric_cfg']['OptimizationDirection'], - mode=mode, + buffer=buffer, + max_mem_usage=max_mem_usage, ) if err: raise ValueError( @@ -1232,7 +1242,8 @@ def launch(): best_cfg, err = recorder.get_best( metric=tuner_cfg['metric_cfg']['name'], direction=tuner_cfg['metric_cfg']['OptimizationDirection'], - mode=mode, + buffer=buffer, + max_mem_usage=max_mem_usage, ) if err: raise ValueError( @@ -1255,9 +1266,13 @@ def launch(): ctx.args.job_id = "best_cfg" ctx.logger.info(f"Launch best cfg: {best_cfg}") logger.info(f"Launch best cfg: {best_cfg}") - ctx.args.log_dir = ctx.args.log_dir = os.path.join( - os.path.dirname(ctx.args.auto_tuner_json), "best_cfg" - ) + + if tuner_cfg.get("best_cfg_dir", None): + ctx.args.log_dir = tuner_cfg["best_cfg_dir"] + else: + ctx.args.log_dir = os.path.join( + os.path.dirname(ctx.args.auto_tuner_json), "best_cfg" + ) # run best cfg c = controllers.init(ctx) c.run()