Skip to content

Commit 5fd6dd2

Browse files
authored
add enable_sp_async_reduce_scatter (#8803)
1 parent 77f6e98 commit 5fd6dd2

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

paddlenlp/trainer/training_args.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ class TrainingArguments:
245245
enable_mp_async_allreduce, it supports all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear backward when it set True, which can accelerate model parallel performance.
246246
enable_mp_skip_c_identity, it supports skip c_identity in ColumnParallelLinear and RowParallelLinear. It only works when set mp_async_allreduce is True. It can accelerate model parallel further.
247247
enable_mp_fused_linear_param_grad_add, it supports fused_linear_param_grad_add in ColumnParallelLinear (cuda >= 11.6). It only works when mp_async_allreduce is true. It can accelerate model parallel further.
248+
enable_sp_async_reduce_scatter, it supports async reduce_scatter in ColumnSequenceParallelLinear. It only works when set sp_async_reduce_scatter is True. It can accelerate sequence parallel further.
248249
enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.
249250
sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.
250251
sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.
@@ -629,6 +630,7 @@ class TrainingArguments:
629630
"enable_mp_async_allreduce, it supports all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear backward when it set True, which can accelerate model parallel performance. \n"
630631
"enable_mp_skip_c_identity, it supports skip c_identity in ColumnParallelLinear and RowParallelLinear. It only works when set mp_async_allreduce is True. It can accelerate model parallel further.\n"
631632
"enable_mp_fused_linear_param_grad_add, it supports fused_linear_param_grad_add in ColumnParallelLinear (cuda >= 11.6). It only works when mp_async_allreduce is true. It can accelerate model parallel further.\n"
633+
"enable_sp_async_reduce_scatter, it supports async reduce_scatter in ColumnSequenceParallelLinear. It only works when set sp_async_reduce_scatter is True. It can accelerate sequence parallel further.\n"
632634
"enable_delay_scale_loss, accumulate gradients until optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.\n"
633635
"sync_param, in optimizer step, use broadcast to sync parameters those attr 'is_distributed' is False.\n"
634636
"sync_grad, in optimizer step, use broadcast to sync gradients those attr 'is_distributed' is False.\n"
@@ -1128,14 +1130,15 @@ def split_parallel_config(parallel_config):
11281130
"enable_mp_async_allreduce",
11291131
"enable_mp_skip_c_identity",
11301132
"enable_mp_fused_linear_param_grad_add",
1133+
"enable_sp_async_reduce_scatter",
11311134
"enable_delay_scale_loss",
11321135
"sync_param",
11331136
"sync_grad",
11341137
"sync_moment",
11351138
]:
11361139
raise ValueError(
11371140
f"Found unknown tensor parallell config {x}, "
1138-
f"accept config is enable_mp_async_allreduce, enable_mp_skip_c_identity, enable_mp_fused_linear_param_grad_add, sync_param, sync_grad and sync_moment."
1141+
f"accept config is enable_mp_async_allreduce, enable_mp_skip_c_identity, enable_mp_fused_linear_param_grad_add, enable_sp_async_reduce_scatter, enable_delay_scale_loss, sync_param, sync_grad and sync_moment."
11391142
)
11401143
try:
11411144
if "enable_mp_async_allreduce" in mp_config:
@@ -1153,6 +1156,8 @@ def split_parallel_config(parallel_config):
11531156
warnings.warn(
11541157
"enable_mp_fused_linear_param_grad_add only works with enable_mp_async_allreduce. It will not work."
11551158
)
1159+
if "enable_sp_async_reduce_scatter" in mp_config:
1160+
strategy.hybrid_configs["mp_configs"].sp_async_reduce_scatter = True
11561161

11571162
sync_param = "sync_param" in mp_config
11581163
sync_grad = "sync_grad" in mp_config

0 commit comments

Comments
 (0)