Skip to content

Commit 70e3970

Browse files
committed
copy FillConstantVisitor from PaddlePaddle#35004 and PaddlePaddle#35105
1 parent 94beecc commit 70e3970

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

paddle/fluid/operators/coalesce_tensor_op.cc

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,67 @@
2020
#include "paddle/fluid/framework/var_type.h"
2121
#include "paddle/fluid/operators/math/math_function.h"
2222
#include "paddle/fluid/platform/device_memory_aligment.h"
23+
#ifdef PADDLE_WITH_ASCEND_CL
24+
#include "paddle/fluid/operators/npu_op_runner.h"
25+
#endif
2326

2427
namespace paddle {
2528
namespace operators {
2629

30+
template <typename DeviceContext>
31+
struct FillConstantVisitor {
32+
FillConstantVisitor(const DeviceContext &dev_ctx,
33+
framework::LoDTensor *tensor, const float value,
34+
framework::proto::VarType::Type dtype,
35+
const framework::ExecutionContext &context)
36+
: dev_ctx_(dev_ctx),
37+
tensor_(tensor),
38+
value_(value),
39+
dtype_(dtype),
40+
context_(context) {}
41+
42+
template <typename T>
43+
void apply(typename std::enable_if<std::is_same<T, int8_t>::value ||
44+
std::is_same<T, int16_t>::value>::type * =
45+
nullptr) const {
46+
PADDLE_THROW(platform::errors::InvalidArgument(
47+
"Not support data type for set_constant attr"));
48+
}
49+
50+
template <typename T>
51+
void apply(typename std::enable_if<!(std::is_same<T, int8_t>::value ||
52+
std::is_same<T, int16_t>::value)>::type
53+
* = nullptr) const {
54+
#ifdef PADDLE_WITH_ASCEND_CL
55+
if (platform::is_npu_place(dev_ctx_.GetPlace())) {
56+
Tensor tensor_tmp(dtype_);
57+
tensor_tmp.mutable_data<T>({1}, context_.GetPlace());
58+
FillNpuTensorWithConstant<T>(&tensor_tmp, static_cast<T>(value_));
59+
60+
const auto &runner =
61+
NpuOpRunner("FillD", {tensor_tmp}, {*tensor_},
62+
{{"dims", framework::vectorize(tensor_->dims())}});
63+
auto stream =
64+
context_.template device_context<paddle::platform::NPUDeviceContext>()
65+
.stream();
66+
runner.Run(stream);
67+
} else {
68+
math::SetConstant<DeviceContext, T> set_constant;
69+
set_constant(dev_ctx_, tensor_, static_cast<T>(value_));
70+
}
71+
#else
72+
math::SetConstant<DeviceContext, T> set_constant;
73+
set_constant(dev_ctx_, tensor_, static_cast<T>(value_));
74+
#endif
75+
}
76+
77+
const DeviceContext &dev_ctx_;
78+
framework::LoDTensor *tensor_;
79+
float value_;
80+
framework::proto::VarType::Type dtype_;
81+
const framework::ExecutionContext &context_;
82+
};
83+
2784
template <typename DeviceContext, typename T>
2885
class CoalesceTensorOpKernel : public framework::OpKernel<T> {
2986
public:

0 commit comments

Comments
 (0)