Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions python/paddle/distributed/auto_tuner/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,33 +51,53 @@ def sort_metric(self, direction, metric_name) -> None:
reverse=False,
)

def get_best(self, metric, direction, mode=None) -> Tuple[dict, bool]:
def get_best(
self, metric, direction, buffer=None, max_mem_usage=None
) -> Tuple[dict, bool]:
self.sort_metric(direction=direction, metric_name=metric)
if len(self.history) == 0:
return (self.history[0], True)
if mode == "SFT" or mode == "LoRA" or mode == "Pretrain":
best_cfg = self.history[0]
if (
isinstance(best_cfg["max_mem_usage"], str)
or best_cfg["time"] == -1
):
return (best_cfg, True)
first_few = 0
return (None, True)

best_cfg = self.history[0]
if isinstance(best_cfg["max_mem_usage"], str) or best_cfg["time"] == -1:
return (best_cfg, True)

if buffer is not None:
if buffer < 0:
raise ValueError("The buffer should be not less than 0.")
assert (
max_mem_usage is not None
), "max_mem_usage cannot be None when buffer is greater than 0."
if max_mem_usage <= 0:
raise ValueError("max_mem_usage should be greater than 0.")

for cfg in self.history:
if (
not best_cfg["max_mem_usage"]
and cfg["max_mem_usage"]
and not isinstance(cfg["max_mem_usage"], str)
and cfg["time"] != -1
):
best_cfg = cfg
continue

if (
not isinstance(cfg["max_mem_usage"], str)
and cfg["max_mem_usage"]
and cfg["max_mem_usage"] < best_cfg["max_mem_usage"]
and cfg["time"] != -1
):
best_cfg = cfg
first_few += 1
if first_few >= 3:

if (
not isinstance(cfg["max_mem_usage"], str)
and cfg["max_mem_usage"]
and cfg["max_mem_usage"] < max_mem_usage - buffer
and cfg["time"] != -1
):
break
return (best_cfg, False)
if isinstance(self.history[0]["max_mem_usage"], str) or (
"time" in self.history[0] and self.history[0]["time"] == -1
):
return (self.history[0], True)

return (self.history[0], False)

def _store_history_impl(self, data, path="./history.csv"):
Expand Down
27 changes: 21 additions & 6 deletions python/paddle/distributed/launch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,10 @@ def launch():
# max_search_time
max_search_time = tuner_cfg.get("max_search_time", None)

# buffer and memory
buffer = tuner_cfg.get("buffer", None)
max_mem_usage = tuner_cfg.get("max_mem_usage", None)

is_first_task = True
# build history recorder
recorder = HistoryRecorder(tuner_cfg)
Expand Down Expand Up @@ -681,6 +685,8 @@ def launch():
cur_best_cfgs, err = recorder.get_best(
metric=tuner_cfg['metric_cfg']['name'],
direction=tuner_cfg['metric_cfg']['OptimizationDirection'],
buffer=buffer,
max_mem_usage=max_mem_usage,
)
if not err:
ctx.logger.info(f"Current best config: {cur_best_cfgs}")
Expand Down Expand Up @@ -781,6 +787,8 @@ def launch():
direction=tuner_cfg['metric_cfg'][
'OptimizationDirection'
],
buffer=buffer,
max_mem_usage=max_mem_usage,
)
if not err:
ctx.logger.info(f"Current best config: {cur_best_cfgs}")
Expand Down Expand Up @@ -1158,7 +1166,8 @@ def launch():
cur_best_cfgs, err = recorder.get_best(
metric=tuner_cfg['metric_cfg']['name'],
direction=tuner_cfg['metric_cfg']['OptimizationDirection'],
mode=mode,
buffer=buffer,
max_mem_usage=max_mem_usage,
)
if not err:
ctx.logger.info(f"Current best config: {cur_best_cfgs}")
Expand Down Expand Up @@ -1206,7 +1215,8 @@ def launch():
best_cfg, err = recorder.get_best(
metric=tuner_cfg['metric_cfg']['name'],
direction=tuner_cfg['metric_cfg']['OptimizationDirection'],
mode=mode,
buffer=buffer,
max_mem_usage=max_mem_usage,
)
if err:
raise ValueError(
Expand All @@ -1232,7 +1242,8 @@ def launch():
best_cfg, err = recorder.get_best(
metric=tuner_cfg['metric_cfg']['name'],
direction=tuner_cfg['metric_cfg']['OptimizationDirection'],
mode=mode,
buffer=buffer,
max_mem_usage=max_mem_usage,
)
if err:
raise ValueError(
Expand All @@ -1255,9 +1266,13 @@ def launch():
ctx.args.job_id = "best_cfg"
ctx.logger.info(f"Launch best cfg: {best_cfg}")
logger.info(f"Launch best cfg: {best_cfg}")
ctx.args.log_dir = ctx.args.log_dir = os.path.join(
os.path.dirname(ctx.args.auto_tuner_json), "best_cfg"
)

if tuner_cfg.get("best_cfg_dir", None):
ctx.args.log_dir = tuner_cfg["best_cfg_dir"]
else:
ctx.args.log_dir = os.path.join(
os.path.dirname(ctx.args.auto_tuner_json), "best_cfg"
)
# run best cfg
c = controllers.init(ctx)
c.run()
Expand Down