2020#include " paddle/fluid/framework/var_type.h"
2121#include " paddle/fluid/operators/math/math_function.h"
2222#include " paddle/fluid/platform/device_memory_aligment.h"
23- #ifdef PADDLE_WITH_ASCEND_CL
24- #include " paddle/fluid/operators/npu_op_runner.h"
25- #endif
2623
2724namespace paddle {
2825namespace operators {
2926
30- template <typename DeviceContext>
31- struct FillConstantVisitor {
32- FillConstantVisitor (const DeviceContext &dev_ctx,
33- framework::LoDTensor *tensor, const float value)
34- : dev_ctx_(dev_ctx), tensor_(tensor), value_(value) {}
35-
36- template <typename T>
37- void apply (typename std::enable_if<std::is_same<T, int8_t >::value ||
38- std::is_same<T, int16_t >::value>::type * =
39- nullptr ) const {
40- PADDLE_THROW (platform::errors::InvalidArgument (
41- " Not support data type for set_constant attr" ));
42- }
43-
44- template <typename T>
45- void apply (typename std::enable_if<!(std::is_same<T, int8_t >::value ||
46- std::is_same<T, int16_t >::value)>::type
47- * = nullptr) const {
48- #ifdef PADDLE_WITH_ASCEND_CL
49- if (platform::is_npu_place (dev_ctx_.GetPlace ())) {
50- FillNpuTensorWithConstant<T>(tensor_, static_cast <T>(value_));
51- } else {
52- math::SetConstant<DeviceContext, T> set_constant;
53- set_constant (dev_ctx_, tensor_, static_cast <T>(value_));
54- }
55- #else
56- math::SetConstant<DeviceContext, T> set_constant;
57- set_constant (dev_ctx_, tensor_, static_cast <T>(value_));
58- #endif
59- }
60-
61- const DeviceContext &dev_ctx_;
62- framework::LoDTensor *tensor_;
63- float value_;
64- };
65-
6627template <typename DeviceContext, typename T>
6728class CoalesceTensorOpKernel : public framework ::OpKernel<T> {
6829 public:
@@ -109,7 +70,6 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
10970 auto in_tensors = context.MultiInput <framework::LoDTensor>(" Input" );
11071 bool use_align = context.Attr <bool >(" use_align" );
11172 auto align_size = context.Attr <int >(" align_size" );
112- auto size_of_dtype = context.Attr <int >(" user_defined_size_of_dtype" );
11373
11474 if (context.Attr <bool >(" check_name" )) {
11575 for (size_t i = 0 ; i < in_var_names.size (); ++i) {
@@ -134,9 +94,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
13494 size_t numel = 0 ;
13595 auto dtype = static_cast <framework::proto::VarType::Type>(
13696 context.Attr <int >(" dtype" ));
137- if (size_of_dtype == -1 ) {
138- size_of_dtype = framework::SizeOfType (dtype);
139- }
97+ size_t size_of_dtype = framework::SizeOfType (dtype);
14098 GetMemSizeAndDtype (in_tensors, in_var_names, &numel, size_of_dtype,
14199 context.GetPlace (), use_align, align_size);
142100
@@ -163,9 +121,10 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
163121 : len;
164122 }
165123 } else if (context.Attr <bool >(" set_constant" )) {
166- framework::VisitDataType (
167- dtype, FillConstantVisitor<DeviceContext>(
168- dev_ctx, fused_tensor, context.Attr <float >(" constant" )));
124+ // TODO(Liu yuang) ADD NPU SET_CONSTANT FUNCTION.
125+ math::SetConstant<DeviceContext, T> set_constant;
126+ set_constant (dev_ctx, fused_tensor,
127+ static_cast <T>(context.Attr <float >(" constant" )));
169128 } else if (context.Attr <bool >(" persist_output" )) {
170129 for (size_t i = 0 ; i < out_var_names.size (); ++i) {
171130 size_t len = static_cast <size_t >(out_tensors[i]->numel ());
@@ -268,13 +227,10 @@ class CoalesceTensorOp : public framework::OperatorWithKernel {
268227 }
269228 auto use_align = ctx->Attrs ().Get <bool >(" use_align" );
270229 auto align_size = ctx->Attrs ().Get <int >(" align_size" );
271- auto size_of_dtype = ctx->Attrs ().Get <int >(" user_defined_size_of_dtype" );
272230
273231 auto dtype = static_cast <framework::proto::VarType::Type>(
274232 ctx->Attrs ().Get <int >(" dtype" ));
275- if (size_of_dtype == -1 ) {
276- size_of_dtype = framework::SizeOfType (dtype);
277- }
233+ size_t size_of_dtype = framework::SizeOfType (dtype);
278234
279235 auto alignment = [](size_t size, size_t align_size) {
280236 size_t remaining = size % align_size;
@@ -352,15 +308,6 @@ class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker {
352308 .SetDefault (true );
353309 AddAttr<int >(" align_size" , " The alignment size when use_align is True" )
354310 .SetDefault (-1 );
355- AddAttr<int >(" user_defined_size_of_dtype" ,
356- " The user defined size of dtype. This is used to coalesce "
357- " grad vars and merged_grad vars at the same time. For some "
358- " strategy, the dtype of fused_grad_vars and the dtype of "
359- " fused_grad_merged_vars are not identical, which will cause "
360- " the shape of these two coalesced vars are different. To "
361- " make sure the shape of these two vars are identical with "
362- " each other, this attr is added." )
363- .SetDefault (-1 );
364311 AddComment (R"DOC(
365312CoalesceTensor Operator.
366313
0 commit comments