Skip to content

Commit 25173d2

Browse files
committed
add buffer and best_cfg_log for autotuner
1 parent 59c1c3d commit 25173d2

2 files changed

Lines changed: 33 additions & 23 deletions

File tree

python/paddle/distributed/auto_tuner/recorder.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,33 +51,26 @@ def sort_metric(self, direction, metric_name) -> None:
5151
reverse=False,
5252
)
5353

54-
def get_best(self, metric, direction, mode=None) -> Tuple[dict, bool]:
54+
def get_best(self, metric, direction, buffer, memory) -> Tuple[dict, bool]:
5555
self.sort_metric(direction=direction, metric_name=metric)
5656
if len(self.history) == 0:
5757
return (self.history[0], True)
58-
if mode == "SFT" or mode == "LoRA" or mode == "Pretrain":
59-
best_cfg = self.history[0]
60-
if (
61-
isinstance(best_cfg["max_mem_usage"], str)
62-
or best_cfg["time"] == -1
63-
):
64-
return (best_cfg, True)
65-
first_few = 0
58+
59+
best_cfg = self.history[0]
60+
if isinstance(best_cfg["max_mem_usage"], str) or best_cfg["time"] == -1:
61+
return (best_cfg, True)
62+
63+
if buffer > 0:
6664
for cfg in self.history:
65+
best_cfg = cfg
6766
if (
6867
not isinstance(cfg["max_mem_usage"], str)
69-
and cfg["max_mem_usage"] < best_cfg["max_mem_usage"]
68+
and cfg["max_mem_usage"] < memory - buffer
7069
and cfg["time"] != -1
7170
):
72-
best_cfg = cfg
73-
first_few += 1
74-
if first_few >= 3:
7571
break
7672
return (best_cfg, False)
77-
if isinstance(self.history[0]["max_mem_usage"], str) or (
78-
"time" in self.history[0] and self.history[0]["time"] == -1
79-
):
80-
return (self.history[0], True)
73+
8174
return (self.history[0], False)
8275

8376
def _store_history_impl(self, data, path="./history.csv"):

python/paddle/distributed/launch/main.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,12 @@ def launch():
436436
# max_search_time
437437
max_search_time = tuner_cfg.get("max_search_time", None)
438438

439+
# buffer and memory
440+
buffer = tuner_cfg.get("buffer", 0)
441+
tuner_cfg['buffer'] = buffer
442+
memory = tuner_cfg.get("memory", None)
443+
tuner_cfg['buffer'] = buffer
444+
439445
is_first_task = True
440446
# build history recorder
441447
recorder = HistoryRecorder(tuner_cfg)
@@ -681,6 +687,8 @@ def launch():
681687
cur_best_cfgs, err = recorder.get_best(
682688
metric=tuner_cfg['metric_cfg']['name'],
683689
direction=tuner_cfg['metric_cfg']['OptimizationDirection'],
690+
buffer=tuner_cfg['buffer'],
691+
memory=tuner_cfg['memory'],
684692
)
685693
if not err:
686694
ctx.logger.info(f"Current best config: {cur_best_cfgs}")
@@ -781,6 +789,8 @@ def launch():
781789
direction=tuner_cfg['metric_cfg'][
782790
'OptimizationDirection'
783791
],
792+
buffer=tuner_cfg['buffer'],
793+
memory=tuner_cfg['memory'],
784794
)
785795
if not err:
786796
ctx.logger.info(f"Current best config: {cur_best_cfgs}")
@@ -1158,7 +1168,8 @@ def launch():
11581168
cur_best_cfgs, err = recorder.get_best(
11591169
metric=tuner_cfg['metric_cfg']['name'],
11601170
direction=tuner_cfg['metric_cfg']['OptimizationDirection'],
1161-
mode=mode,
1171+
buffer=tuner_cfg['buffer'],
1172+
memory=tuner_cfg['memory'],
11621173
)
11631174
if not err:
11641175
ctx.logger.info(f"Current best config: {cur_best_cfgs}")
@@ -1206,7 +1217,8 @@ def launch():
12061217
best_cfg, err = recorder.get_best(
12071218
metric=tuner_cfg['metric_cfg']['name'],
12081219
direction=tuner_cfg['metric_cfg']['OptimizationDirection'],
1209-
mode=mode,
1220+
buffer=tuner_cfg['buffer'],
1221+
memory=tuner_cfg['memory'],
12101222
)
12111223
if err:
12121224
raise ValueError(
@@ -1232,7 +1244,8 @@ def launch():
12321244
best_cfg, err = recorder.get_best(
12331245
metric=tuner_cfg['metric_cfg']['name'],
12341246
direction=tuner_cfg['metric_cfg']['OptimizationDirection'],
1235-
mode=mode,
1247+
buffer=tuner_cfg['buffer'],
1248+
memory=tuner_cfg['memory'],
12361249
)
12371250
if err:
12381251
raise ValueError(
@@ -1255,9 +1268,13 @@ def launch():
12551268
ctx.args.job_id = "best_cfg"
12561269
ctx.logger.info(f"Launch best cfg: {best_cfg}")
12571270
logger.info(f"Launch best cfg: {best_cfg}")
1258-
ctx.args.log_dir = ctx.args.log_dir = os.path.join(
1259-
os.path.dirname(ctx.args.auto_tuner_json), "best_cfg"
1260-
)
1271+
1272+
if tuner_cfg.get("best_cfg_dir", None):
1273+
ctx.args.log_dir = tuner_cfg["best_cfg_dir"]
1274+
else:
1275+
ctx.args.log_dir = os.path.join(
1276+
os.path.dirname(ctx.args.auto_tuner_json), "best_cfg"
1277+
)
12611278
# run best cfg
12621279
c = controllers.init(ctx)
12631280
c.run()

0 commit comments

Comments
 (0)