Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ class TrainingArguments:
enable_dp_comm_overlap, fuse data parallel gradient communication.
enable_sharding_comm_overlap, fuse sharding stage 1 parallel gradient communication.
enable_release_grads, reduce peak memory usage by releasing gradients after each iteration. The creation of gradients will be postponed until backward propagation of the next iteration.
enable_overlap_p2p_comm, overlap p2p communication with computation.
enable_clear_every_step_cache, clear every step cache for pipeline parallel.
sharding_parallel_config (`str`, *optional*)(
Some additional config it highly affect the useage of sharding parallel, we provide some option to config it.
following config is support:
Expand Down Expand Up @@ -595,6 +597,8 @@ class TrainingArguments:
"enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.\n"
"enable_dp_comm_overlap, fuse data parallel gradient communication. \n"
"enable_sharding_comm_overlap, fuse sharding stage 1 parallel gradient communication. \n"
"enable_overlap_p2p_comm, overlap p2p communication with computation. \n"
"enable_clear_every_step_cache, clear every step cache for pipeline parallel. \n"
)
},
)
Expand Down Expand Up @@ -963,6 +967,8 @@ def __post_init__(self):
"enable_sharding_comm_overlap",
"enable_timer",
"enable_release_grads",
"enable_dp_comm_overlap",
"enable_clear_every_step_cache",
]:
raise ValueError(
f"Found unknown pipeline mode config {x}, accpet config is disable_p2p_cache_shape, disable_partial_send_recv."
Expand All @@ -976,14 +982,20 @@ def __post_init__(self):
# "delay_scale_loss": True, Fix ME
}
logger.info(f"PP configs:{strategy.pipeline_configs}, use master_grad: {self.amp_master_grad}")

using_comm_overlap = (
"enable_sharding_comm_overlap" in pipeline_parallel_config
and self.sharding_parallel_degree > 1
) or ("enable_dp_comm_overlap" in pipeline_parallel_config and self.data_parallel_degree > 1)

dygraph_pp_configs = {
"delay_scale_loss": True if "enable_delay_scale_loss" in pipeline_parallel_config else False,
"dp_comm_overlap": "enable_dp_comm_overlap" in pipeline_parallel_config
and self.data_parallel_degree > 1,
"sharding_comm_overlap": "enable_sharding_comm_overlap" in pipeline_parallel_config
and self.sharding_parallel_degree > 1,
"dp_comm_overlap": using_comm_overlap,
"sharding_comm_overlap": using_comm_overlap,
"enable_timer": "enable_timer" in pipeline_parallel_config,
"release_gradients": "enable_release_grads" in pipeline_parallel_config,
"overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config,
"clear_every_step_cache": "enable_clear_every_step_cache" in pipeline_parallel_config,
}
if dygraph_pp_configs["dp_comm_overlap"]:
raise ValueError("overlap has accuracy issue") # TODO: fix `overalap` + `delay_scale` issue
Expand Down