Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 93 additions & 11 deletions python/paddle/distributed/auto_tuner/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import os
import subprocess
Expand All @@ -21,8 +22,8 @@
_PRUNE_HISTORY_FUNC = []


def log_pruned_info(cur_cfg, pruned_reason):
pruned_strategy = "DP{}_MP{}_PP{}_VPP_{}_Sharding{}_Stage{}_MBS{}_Recompute_{}_Granularity_{}".format(
def log_pruned_info(cur_cfg, pruned_reason, tuner_cfg):
pruned_strategy = "DP{}_MP{}_PP{}_VPP{}_Sharding{}_Stage{}_MBS{}_Recompute_{}_Granularity_{}".format(
cur_cfg["dp_degree"],
cur_cfg["mp_degree"],
cur_cfg["pp_degree"],
Expand All @@ -33,6 +34,11 @@ def log_pruned_info(cur_cfg, pruned_reason):
cur_cfg["use_recompute"],
cur_cfg["recompute_granularity"],
)
if "refined_recompute" in tuner_cfg:
for key in tuner_cfg["refined_recompute"]:
strategy = "".join(i.capitalize() for i in key.split("_"))
strategy += str(cur_cfg[key])
pruned_strategy = pruned_strategy + "_" + strategy

try:
from paddle.distributed.launch.main import ctx
Expand Down Expand Up @@ -215,7 +221,7 @@ def prune_by_mp_pp_history(tuner_cfg, cur_cfg, history_cfgs, pruned_cfgs):
and cfg.get("max_mem_usage") == "OOM"
):
pruned_reason = f"mp_degree {mp_degree}, pp_degree {pp_degree} may cause oom because {cfg['mp_degree']}, {cfg['pp_degree']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
log_pruned_info(cur_cfg, pruned_reason, tuner_cfg)
cur_cfg["max_mem_usage"] = "OOM"
return True

Expand Down Expand Up @@ -292,7 +298,7 @@ def prune_by_vpp_history(tuner_cfg, cur_cfg, history_cfgs=[], pruned_cfgs=[]):
and cfg.get("max_mem_usage") == "OOM"
):
pruned_reason = f"vpp_degree {vpp_degree} may cause oom because { cfg['vpp_degree']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
log_pruned_info(cur_cfg, pruned_reason, tuner_cfg)
cur_cfg["max_mem_usage"] = "OOM"
return True

Expand Down Expand Up @@ -336,9 +342,12 @@ def prune_by_mbs(tuner_cfg, cur_cfg, history_cfgs=[]):
if local_batch_size % micro_batch_size != 0:
return True
acc_steps = local_batch_size // micro_batch_size
pp_degree = cur_cfg.get("pp_degree", None)
if pp_degree is not None:
if acc_steps < pp_degree:
return True
vpp_degree = cur_cfg.get("vpp_degree", None)
if vpp_degree is not None and vpp_degree > 1:
pp_degree = cur_cfg.get("pp_degree", None)
if pp_degree is not None:
if acc_steps % pp_degree != 0:
return True
Expand Down Expand Up @@ -375,7 +384,7 @@ def prune_by_mbs_history(tuner_cfg, cur_cfg, history_cfgs=[], pruned_cfgs=[]):
and cfg.get("time", -1) > 0
):
pruned_reason = f"micro_batch_size {micro_batch_size} may be slower because {cfg['micro_batch_size']} has been already runnable."
log_pruned_info(cur_cfg, pruned_reason)
log_pruned_info(cur_cfg, pruned_reason, tuner_cfg)
cur_cfg["time"] = cfg["time"]
return True
# memory prune
Expand All @@ -384,7 +393,7 @@ def prune_by_mbs_history(tuner_cfg, cur_cfg, history_cfgs=[], pruned_cfgs=[]):
and cfg.get("max_mem_usage") == "OOM"
):
pruned_reason = f"micro_batch_size {micro_batch_size} may cause oom because {cfg['micro_batch_size']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
log_pruned_info(cur_cfg, pruned_reason, tuner_cfg)
cur_cfg["max_mem_usage"] = "OOM"
return True
return False
Expand Down Expand Up @@ -459,7 +468,7 @@ def prune_by_sharding_history(
and cfg.get("time", -1) > 0
):
pruned_reason = f"sharding_stage {sharding_stage} may be slower because {cfg['sharding_stage'] } has been already runnable."
log_pruned_info(cur_cfg, pruned_reason)
log_pruned_info(cur_cfg, pruned_reason, tuner_cfg)
cur_cfg["time"] = cfg["time"]
return True

Expand All @@ -469,7 +478,7 @@ def prune_by_sharding_history(
and cfg.get("max_mem_usage") == "OOM"
):
pruned_reason = f"sharding_stage {sharding_stage} may cause oom because {cfg['sharding_stage']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
log_pruned_info(cur_cfg, pruned_reason, tuner_cfg)
cur_cfg["max_mem_usage"] = "OOM"
return True

Expand Down Expand Up @@ -567,7 +576,7 @@ def prune_by_recompute_history(
and cfg.get("time", -1) > 0
):
pruned_reason = f"use_recompute may be slower because {cfg['use_recompute']} has been already runnable."
log_pruned_info(cur_cfg, pruned_reason)
log_pruned_info(cur_cfg, pruned_reason, tuner_cfg)
cur_cfg["time"] = cfg["time"]
return True

Expand All @@ -576,7 +585,7 @@ def prune_by_recompute_history(
and cfg.get("max_mem_usage") == "OOM"
):
pruned_reason = f"use_recompute may cause oom because {cfg['use_recompute']} already oom."
log_pruned_info(cur_cfg, pruned_reason)
log_pruned_info(cur_cfg, pruned_reason, tuner_cfg)
cur_cfg["max_mem_usage"] = "OOM"
return True

Expand Down Expand Up @@ -816,3 +825,76 @@ def prune_by_invalid_strategy(tuner_cfg, cur_cfg, history_cfgs=[]):
return True

return False


@register_prune
def prune_by_refined_recompute(tuner_cfg, cur_cfg, history_cfgs=[]):
if tuner_cfg.get("refined_recompute", None):
rr = tuner_cfg.get("refined_recompute")
pp_degree = cur_cfg["pp_degree"]
recompute = cur_cfg["use_recompute"]
recompute_granularity = cur_cfg["recompute_granularity"]
compare = [cur_cfg[item] for item in rr]
if recompute:
if recompute_granularity and recompute_granularity != "full":
if compare.count(0) != len(compare):
return True
if pp_degree == 1 and compare.count(0) != len(compare):
return True
if tuner_cfg["model_cfg"]["num_layers"] % pp_degree != 0:
return True
max_value = tuner_cfg["model_cfg"]["num_layers"] / pp_degree
if cur_cfg[rr[0]] > max_value:
return True
i = 1
while i < len(rr):
if cur_cfg[rr[i]] > max_value or (
cur_cfg[rr[i - 1]] != max_value and cur_cfg[rr[i]] != 0
):
return True
i += 1

return False


@register_prune_history
def prune_by_refined_recompute_history(
tuner_cfg, cur_cfg, history_cfgs=[], pruned_cfgs=[]
):
if tuner_cfg.get("refined_recompute", None):
history_cfgs.extend(pruned_cfgs)
rr = tuner_cfg.get("refined_recompute")
compare = copy.deepcopy(rr)
compare.append("use_recompute")
cfgs = same_cfgs_beside(compare, cur_cfg, history_cfgs)
for item in rr:
if cfgs:
for cfg in cfgs:
if not cfg["use_recompute"] and cfg.get("time", -1) > 0:
pruned_reason = f"{item} {cur_cfg[item]} may be slower because not recompute has been already runnable."
log_pruned_info(cur_cfg, pruned_reason, tuner_cfg)
cur_cfg["time"] = cfg["time"]
return True
if (
cfg[item] > cur_cfg[item]
and cfg.get("time", -1) > 0
and cfg["use_recompute"]
and cur_cfg["use_recompute"]
):
pruned_reason = f"{item} {cur_cfg[item]} may be slower because {cfg[item]} has been already runnable."
log_pruned_info(cur_cfg, pruned_reason, tuner_cfg)
cur_cfg["time"] = cfg["time"]
return True
# memory prune
if (
cfg[item] < cur_cfg[item]
and cfg.get("max_mem_usage") == "OOM"
and cfg["use_recompute"]
and cur_cfg["use_recompute"]
):
pruned_reason = f"{item} {cur_cfg[item]} may cause oom because {cfg[item]} already oom."
log_pruned_info(cur_cfg, pruned_reason, tuner_cfg)
cur_cfg["max_mem_usage"] = "OOM"
return True

return False
4 changes: 3 additions & 1 deletion python/paddle/distributed/auto_tuner/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def search_once(self, history_cfgs):
stop = False
if history_cfgs:
if history_cfgs[-1].get("time", -1) > 0:
if self.baseline is None:
if self.baseline is None and self.tuner_cfg.get(
"need_baseline", False
):
from .utils import performance_sort

self.baseline = history_cfgs[-1]
Expand Down
5 changes: 5 additions & 0 deletions python/paddle/distributed/auto_tuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ def get_cfg_from_resume(self, cur_cfg):
'sharding_overlap',
'acc_steps',
]

if self.tuner_cfg.get("refined_recompute", None):
for rr in self.tuner_cfg["refined_recompute"]:
keys_to_compare.append(rr)

for cfg in self.resume_cfgs:
ret_is_same = True
for key in keys_to_compare:
Expand Down
Loading