2020#include " paddle/fluid/framework/var_type.h"
2121#include " paddle/fluid/operators/math/math_function.h"
2222#include " paddle/fluid/platform/device_memory_aligment.h"
23+ #ifdef PADDLE_WITH_ASCEND_CL
24+ #include " paddle/fluid/operators/npu_op_runner.h"
25+ #endif
2326
2427namespace paddle {
2528namespace operators {
2629
30+ template <typename DeviceContext>
31+ struct FillConstantVisitor {
32+ FillConstantVisitor (const DeviceContext &dev_ctx,
33+ framework::LoDTensor *tensor, const float value)
34+ : dev_ctx_(dev_ctx), tensor_(tensor), value_(value) {}
35+
36+ template <typename T>
37+ void apply (typename std::enable_if<std::is_same<T, int8_t >::value ||
38+ std::is_same<T, int16_t >::value>::type * =
39+ nullptr ) const {
40+ PADDLE_THROW (platform::errors::InvalidArgument (
41+ " Not support data type for set_constant attr" ));
42+ }
43+
44+ template <typename T>
45+ void apply (typename std::enable_if<!(std::is_same<T, int8_t >::value ||
46+ std::is_same<T, int16_t >::value)>::type
47+ * = nullptr) const {
48+ #ifdef PADDLE_WITH_ASCEND_CL
49+ if (platform::is_npu_place (dev_ctx_.GetPlace ())) {
50+ FillNpuTensorWithConstant<T>(tensor_, static_cast <T>(value_));
51+ } else {
52+ math::SetConstant<DeviceContext, T> set_constant;
53+ set_constant (dev_ctx_, tensor_, static_cast <T>(value_));
54+ }
55+ #else
56+ math::SetConstant<DeviceContext, T> set_constant;
57+ set_constant (dev_ctx_, tensor_, static_cast <T>(value_));
58+ #endif
59+ }
60+
61+ const DeviceContext &dev_ctx_;
62+ framework::LoDTensor *tensor_;
63+ float value_;
64+ };
65+
2766template <typename DeviceContext, typename T>
2867class CoalesceTensorOpKernel : public framework ::OpKernel<T> {
2968 public:
@@ -70,6 +109,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
70109 auto in_tensors = context.MultiInput <framework::LoDTensor>(" Input" );
71110 bool use_align = context.Attr <bool >(" use_align" );
72111 auto align_size = context.Attr <int >(" align_size" );
112+ auto size_of_dtype = context.Attr <int >(" user_defined_size_of_dtype" );
73113
74114 if (context.Attr <bool >(" check_name" )) {
75115 for (size_t i = 0 ; i < in_var_names.size (); ++i) {
@@ -94,7 +134,9 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
94134 size_t numel = 0 ;
95135 auto dtype = static_cast <framework::proto::VarType::Type>(
96136 context.Attr <int >(" dtype" ));
97- size_t size_of_dtype = framework::SizeOfType (dtype);
137+ if (size_of_dtype == -1 ) {
138+ size_of_dtype = framework::SizeOfType (dtype);
139+ }
98140 GetMemSizeAndDtype (in_tensors, in_var_names, &numel, size_of_dtype,
99141 context.GetPlace (), use_align, align_size);
100142
@@ -121,10 +163,9 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
121163 : len;
122164 }
123165 } else if (context.Attr <bool >(" set_constant" )) {
124- // TODO(Liu yuang) ADD NPU SET_CONSTANT FUNCTION.
125- math::SetConstant<DeviceContext, T> set_constant;
126- set_constant (dev_ctx, fused_tensor,
127- static_cast <T>(context.Attr <float >(" constant" )));
166+ framework::VisitDataType (
167+ dtype, FillConstantVisitor<DeviceContext>(
168+ dev_ctx, fused_tensor, context.Attr <float >(" constant" )));
128169 } else if (context.Attr <bool >(" persist_output" )) {
129170 for (size_t i = 0 ; i < out_var_names.size (); ++i) {
130171 size_t len = static_cast <size_t >(out_tensors[i]->numel ());
@@ -227,10 +268,13 @@ class CoalesceTensorOp : public framework::OperatorWithKernel {
227268 }
228269 auto use_align = ctx->Attrs ().Get <bool >(" use_align" );
229270 auto align_size = ctx->Attrs ().Get <int >(" align_size" );
271+ auto size_of_dtype = ctx->Attrs ().Get <int >(" user_defined_size_of_dtype" );
230272
231273 auto dtype = static_cast <framework::proto::VarType::Type>(
232274 ctx->Attrs ().Get <int >(" dtype" ));
233- size_t size_of_dtype = framework::SizeOfType (dtype);
275+ if (size_of_dtype == -1 ) {
276+ size_of_dtype = framework::SizeOfType (dtype);
277+ }
234278
235279 auto alignment = [](size_t size, size_t align_size) {
236280 size_t remaining = size % align_size;
@@ -308,6 +352,15 @@ class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker {
308352 .SetDefault (true );
309353 AddAttr<int >(" align_size" , " The alignment size when use_align is True" )
310354 .SetDefault (-1 );
355+ AddAttr<int >(" user_defined_size_of_dtype" ,
356+ " The user defined size of dtype. This is used to coalesce "
357+ " grad vars and merged_grad vars at the same time. For some "
358+ " strategy, the dtype of fused_grad_vars and the dtype of "
359+ " fused_grad_merged_vars are not identical, which will cause "
360+ " the shape of these two coalesced vars are different. To "
361+ " make sure the shape of these two vars are identical with "
362+ " each other, this attr is added." )
363+ .SetDefault (-1 );
311364 AddComment (R"DOC(
312365CoalesceTensor Operator.
313366
0 commit comments