|
20 | 20 | #include "paddle/fluid/framework/var_type.h" |
21 | 21 | #include "paddle/fluid/operators/math/math_function.h" |
22 | 22 | #include "paddle/fluid/platform/device_memory_aligment.h" |
| 23 | +#ifdef PADDLE_WITH_ASCEND_CL |
| 24 | +#include "paddle/fluid/operators/npu_op_runner.h" |
| 25 | +#endif |
23 | 26 |
|
24 | 27 | namespace paddle { |
25 | 28 | namespace operators { |
26 | 29 |
|
| 30 | +template <typename DeviceContext> |
| 31 | +struct FillConstantVisitor { |
| 32 | + FillConstantVisitor(const DeviceContext &dev_ctx, |
| 33 | + framework::LoDTensor *tensor, const float value, |
| 34 | + framework::proto::VarType::Type dtype, |
| 35 | + const framework::ExecutionContext &context) |
| 36 | + : dev_ctx_(dev_ctx), |
| 37 | + tensor_(tensor), |
| 38 | + value_(value), |
| 39 | + dtype_(dtype), |
| 40 | + context_(context) {} |
| 41 | + |
| 42 | + template <typename T> |
| 43 | + void apply(typename std::enable_if<std::is_same<T, int8_t>::value || |
| 44 | + std::is_same<T, int16_t>::value>::type * = |
| 45 | + nullptr) const { |
| 46 | + PADDLE_THROW(platform::errors::InvalidArgument( |
| 47 | + "Not support data type for set_constant attr")); |
| 48 | + } |
| 49 | + |
| 50 | + template <typename T> |
| 51 | + void apply(typename std::enable_if<!(std::is_same<T, int8_t>::value || |
| 52 | + std::is_same<T, int16_t>::value)>::type |
| 53 | + * = nullptr) const { |
| 54 | +#ifdef PADDLE_WITH_ASCEND_CL |
| 55 | + if (platform::is_npu_place(dev_ctx_.GetPlace())) { |
| 56 | + Tensor tensor_tmp(dtype_); |
| 57 | + tensor_tmp.mutable_data<T>({1}, context_.GetPlace()); |
| 58 | + FillNpuTensorWithConstant<T>(&tensor_tmp, static_cast<T>(value_)); |
| 59 | + |
| 60 | + const auto &runner = |
| 61 | + NpuOpRunner("FillD", {tensor_tmp}, {*tensor_}, |
| 62 | + {{"dims", framework::vectorize(tensor_->dims())}}); |
| 63 | + auto stream = |
| 64 | + context_.template device_context<paddle::platform::NPUDeviceContext>() |
| 65 | + .stream(); |
| 66 | + runner.Run(stream); |
| 67 | + } else { |
| 68 | + math::SetConstant<DeviceContext, T> set_constant; |
| 69 | + set_constant(dev_ctx_, tensor_, static_cast<T>(value_)); |
| 70 | + } |
| 71 | +#else |
| 72 | + math::SetConstant<DeviceContext, T> set_constant; |
| 73 | + set_constant(dev_ctx_, tensor_, static_cast<T>(value_)); |
| 74 | +#endif |
| 75 | + } |
| 76 | + |
| 77 | + const DeviceContext &dev_ctx_; |
| 78 | + framework::LoDTensor *tensor_; |
| 79 | + float value_; |
| 80 | + framework::proto::VarType::Type dtype_; |
| 81 | + const framework::ExecutionContext &context_; |
| 82 | +}; |
| 83 | + |
27 | 84 | template <typename DeviceContext, typename T> |
28 | 85 | class CoalesceTensorOpKernel : public framework::OpKernel<T> { |
29 | 86 | public: |
|
0 commit comments