Skip to content

Commit c0363cd

Browse files
committed
[cherry-pick][hybrid performance] Grad fuse for gradient merge under pipeline mode (PaddlePaddle#35004)
1 parent 6fb58ae commit c0363cd

File tree

10 files changed

+534
-11
lines changed

10 files changed

+534
-11
lines changed

paddle/fluid/framework/distributed_strategy.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ message DistributedStrategy {
200200
optional int32 fuse_grad_size_in_num = 31 [ default = 8 ];
201201
optional bool calc_comm_same_stream = 32 [ default = false ];
202202
optional bool asp = 33 [ default = false ];
203+
optional bool fuse_grad_merge = 34 [ default = false ];
203204

204205
optional RecomputeConfig recompute_configs = 101;
205206
optional AMPConfig amp_configs = 102;

paddle/fluid/operators/coalesce_tensor_op.cc

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,49 @@
2020
#include "paddle/fluid/framework/var_type.h"
2121
#include "paddle/fluid/operators/math/math_function.h"
2222
#include "paddle/fluid/platform/device_memory_aligment.h"
23+
#ifdef PADDLE_WITH_ASCEND_CL
24+
#include "paddle/fluid/operators/npu_op_runner.h"
25+
#endif
2326

2427
namespace paddle {
2528
namespace operators {
2629

30+
template <typename DeviceContext>
31+
struct FillConstantVisitor {
32+
FillConstantVisitor(const DeviceContext &dev_ctx,
33+
framework::LoDTensor *tensor, const float value)
34+
: dev_ctx_(dev_ctx), tensor_(tensor), value_(value) {}
35+
36+
template <typename T>
37+
void apply(typename std::enable_if<std::is_same<T, int8_t>::value ||
38+
std::is_same<T, int16_t>::value>::type * =
39+
nullptr) const {
40+
PADDLE_THROW(platform::errors::InvalidArgument(
41+
"Not support data type for set_constant attr"));
42+
}
43+
44+
template <typename T>
45+
void apply(typename std::enable_if<!(std::is_same<T, int8_t>::value ||
46+
std::is_same<T, int16_t>::value)>::type
47+
* = nullptr) const {
48+
#ifdef PADDLE_WITH_ASCEND_CL
49+
if (platform::is_npu_place(dev_ctx_.GetPlace())) {
50+
FillNpuTensorWithConstant<T>(tensor_, static_cast<T>(value_));
51+
} else {
52+
math::SetConstant<DeviceContext, T> set_constant;
53+
set_constant(dev_ctx_, tensor_, static_cast<T>(value_));
54+
}
55+
#else
56+
math::SetConstant<DeviceContext, T> set_constant;
57+
set_constant(dev_ctx_, tensor_, static_cast<T>(value_));
58+
#endif
59+
}
60+
61+
const DeviceContext &dev_ctx_;
62+
framework::LoDTensor *tensor_;
63+
float value_;
64+
};
65+
2766
template <typename DeviceContext, typename T>
2867
class CoalesceTensorOpKernel : public framework::OpKernel<T> {
2968
public:
@@ -70,6 +109,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
70109
auto in_tensors = context.MultiInput<framework::LoDTensor>("Input");
71110
bool use_align = context.Attr<bool>("use_align");
72111
auto align_size = context.Attr<int>("align_size");
112+
auto size_of_dtype = context.Attr<int>("user_defined_size_of_dtype");
73113

74114
if (context.Attr<bool>("check_name")) {
75115
for (size_t i = 0; i < in_var_names.size(); ++i) {
@@ -94,7 +134,9 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
94134
size_t numel = 0;
95135
auto dtype = static_cast<framework::proto::VarType::Type>(
96136
context.Attr<int>("dtype"));
97-
size_t size_of_dtype = framework::SizeOfType(dtype);
137+
if (size_of_dtype == -1) {
138+
size_of_dtype = framework::SizeOfType(dtype);
139+
}
98140
GetMemSizeAndDtype(in_tensors, in_var_names, &numel, size_of_dtype,
99141
context.GetPlace(), use_align, align_size);
100142

@@ -121,10 +163,9 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
121163
: len;
122164
}
123165
} else if (context.Attr<bool>("set_constant")) {
124-
// TODO(Liu yuang) ADD NPU SET_CONSTANT FUNCTION.
125-
math::SetConstant<DeviceContext, T> set_constant;
126-
set_constant(dev_ctx, fused_tensor,
127-
static_cast<T>(context.Attr<float>("constant")));
166+
framework::VisitDataType(
167+
dtype, FillConstantVisitor<DeviceContext>(
168+
dev_ctx, fused_tensor, context.Attr<float>("constant")));
128169
} else if (context.Attr<bool>("persist_output")) {
129170
for (size_t i = 0; i < out_var_names.size(); ++i) {
130171
size_t len = static_cast<size_t>(out_tensors[i]->numel());
@@ -227,10 +268,13 @@ class CoalesceTensorOp : public framework::OperatorWithKernel {
227268
}
228269
auto use_align = ctx->Attrs().Get<bool>("use_align");
229270
auto align_size = ctx->Attrs().Get<int>("align_size");
271+
auto size_of_dtype = ctx->Attrs().Get<int>("user_defined_size_of_dtype");
230272

231273
auto dtype = static_cast<framework::proto::VarType::Type>(
232274
ctx->Attrs().Get<int>("dtype"));
233-
size_t size_of_dtype = framework::SizeOfType(dtype);
275+
if (size_of_dtype == -1) {
276+
size_of_dtype = framework::SizeOfType(dtype);
277+
}
234278

235279
auto alignment = [](size_t size, size_t align_size) {
236280
size_t remaining = size % align_size;
@@ -308,6 +352,15 @@ class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker {
308352
.SetDefault(true);
309353
AddAttr<int>("align_size", "The alignment size when use_align is True")
310354
.SetDefault(-1);
355+
AddAttr<int>("user_defined_size_of_dtype",
356+
"The user defined size of dtype. This is used to coalesce "
357+
"grad vars and merged_grad vars at the same time. For some "
358+
"strategy, the dtype of fused_grad_vars and the dtype of "
359+
"fused_grad_merged_vars are not identical, which will cause "
360+
"the shape of these two coalesced vars are different. To "
361+
"make sure the shape of these two vars are identical with "
362+
"each other, this attr is added.")
363+
.SetDefault(-1);
311364
AddComment(R"DOC(
312365
CoalesceTensor Operator.
313366

python/paddle/distributed/fleet/base/distributed_strategy.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,28 @@ def _calc_comm_same_stream(self, same):
967967
"WARNING: calc_comm_same_stream should have value of boolean type"
968968
)
969969

970+
@property
971+
def fuse_grad_merge(self):
972+
"""
973+
Set whether fuse the grad for gradient merge.
974+
Note: this flag will only effect the gradient merge under pipeline mode
975+
The default value for the fuse_grad_merge is False
976+
Examples:
977+
.. code-block:: python
978+
import paddle.distributed.fleet as fleet
979+
strategy = fleet.DistributedStrategy()
980+
strategy.fuse_param_grad = True
981+
"""
982+
return self.strategy.fuse_grad_merge
983+
984+
@fuse_grad_merge.setter
985+
@is_strict_auto
986+
def fuse_grad_merge(self, fuse_grad_merge):
987+
if isinstance(fuse_grad_merge, bool):
988+
self.strategy.fuse_grad_merge = fuse_grad_merge
989+
else:
990+
print("WARNING: fuse_grad_merge should have value of boolean type")
991+
970992
@property
971993
def fuse_grad_size_in_num(self):
972994
"""

python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ def remove_param(input_name):
122122
for idx, op in enumerate(block.ops):
123123
if is_optimizer_op(op):
124124
break
125+
# TODO (Yuang Liu): tmp solution for fuse_grad_merge + optimize_cast
126+
if not offload and op.type == 'coalesce_tensor':
127+
continue
125128
for input_name in op.desc.input_arg_names():
126129
if input_name not in param_to_idx:
127130
continue

python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,11 @@ def insert_allreduce_ops(block,
341341
if len(allreduce_vars) == 0:
342342
return
343343

344-
if user_defined_strategy and user_defined_strategy.fuse_all_reduce_ops:
344+
if user_defined_strategy and \
345+
user_defined_strategy.fuse_all_reduce_ops and \
346+
not user_defined_strategy.fuse_grad_merge:
347+
# If fuse_grad_merge is enable, the grad vars have already been fused during
348+
# gradient merge pass, therefore, those vars are not need to be fused here
345349
insert_fused_allreduce_ops(block, insert_idx, ring_id, allreduce_vars,
346350
op_role, use_calc_stream,
347351
user_defined_strategy.fuse_grad_size_in_MB)

python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,9 @@ def _insert_allreduce_for_pp(self):
319319
main_block._remove_op(idx)
320320

321321
accumulated_grad_names = self._pp_optimizer._accumulate_gradients(
322-
main_block, fp16_allreduce=fp16_allreduce)
322+
main_block,
323+
fp16_allreduce=fp16_allreduce,
324+
user_defined_strategy=strategy)
323325

324326
len_of_ops = len(main_block.ops)
325327
first_optimize_op_index = get_first_optimize_op_idx(main_block)

0 commit comments

Comments
 (0)