From d986e4ffea05e6c722671851abe0bd5b6b2973e3 Mon Sep 17 00:00:00 2001 From: qili93 Date: Fri, 6 Aug 2021 10:52:43 +0800 Subject: [PATCH 1/2] fix npu compile error, test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 2 +- paddle/fluid/operators/expand_op_npu.cc | 21 ++++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 0107f5976499ce..384f80395c7784 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -59,7 +59,7 @@ cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS grap pass_library(graph_to_program_pass base) pass_library(graph_viz_pass base) -pass_library(lock_free_optimize_pass base) +pass_library(lock_free_optimize_pass base DEPS string_helper) pass_library(fc_fuse_pass inference) pass_library(map_matmul_to_mul_pass inference) pass_library(attention_lstm_fuse_pass inference) diff --git a/paddle/fluid/operators/expand_op_npu.cc b/paddle/fluid/operators/expand_op_npu.cc index 76d5a203f306b9..2f66316c483a9c 100644 --- a/paddle/fluid/operators/expand_op_npu.cc +++ b/paddle/fluid/operators/expand_op_npu.cc @@ -39,7 +39,26 @@ class ExpandNPUKernel : public framework::OpKernel { "The number of dimensions of the input 'x' for Op(expand) " "must be less than or equal to %d, but the value received is %d.", MAX_RANK_SUPPORTED, rank)); - switch (rank) { REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED) } + switch (rank) { + case 1: + Expand<1>(context); + break; + case 2: + Expand<2>(context); + break; + case 3: + Expand<3>(context); + break; + case 4: + Expand<4>(context); + break; + case 5: + Expand<5>(context); + break; + case 6: + Expand<6>(context); + break; + } } protected: From 37e0d313e9c33697257b0275a8c663021324670d Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Tue, 10 Aug 2021 00:50:41 +0800 Subject: [PATCH 2/2] add fill constant batch size lilke op npu,test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 2 +- .../fill_constant_batch_size_like_op_npu.cc | 108 +++++++++++++ ...st_fill_constant_batch_size_like_op_npu.py | 150 ++++++++++++++++++ 3 files changed, 259 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/fill_constant_batch_size_like_op_npu.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_fill_constant_batch_size_like_op_npu.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 384f80395c7784..0107f5976499ce 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -59,7 +59,7 @@ cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS grap pass_library(graph_to_program_pass base) pass_library(graph_viz_pass base) -pass_library(lock_free_optimize_pass base DEPS string_helper) +pass_library(lock_free_optimize_pass base) pass_library(fc_fuse_pass inference) pass_library(map_matmul_to_mul_pass inference) pass_library(attention_lstm_fuse_pass inference) diff --git a/paddle/fluid/operators/fill_constant_batch_size_like_op_npu.cc b/paddle/fluid/operators/fill_constant_batch_size_like_op_npu.cc new file mode 100644 index 00000000000000..a58815748b6758 --- /dev/null +++ b/paddle/fluid/operators/fill_constant_batch_size_like_op_npu.cc @@ -0,0 +1,108 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/operators/fill_constant_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" +#include "paddle/fluid/operators/utils.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class FillConstantBatchSizeLikeOpNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto data_type = + static_cast(ctx.Attr("dtype")); + auto float_value = ctx.Attr("value"); + auto str_value = ctx.Attr("str_value"); + auto force_cpu = ctx.Attr("force_cpu"); + + auto *out = ctx.Output("Out"); + auto *input = ctx.Input("Input"); + if (ctx.Attr("input_dim_idx") == 0) { + // set the correct batch size. + auto odims = out->dims(); + auto idims = input->dims(); + int output_dim_idx = ctx.Attr("output_dim_idx"); + odims[output_dim_idx] = static_cast(idims[0]); + out->mutable_data(odims, ctx.GetPlace()); + } + + T value; + if (str_value.empty()) { + value = static_cast(float_value); + } else { + // handle NaN/Inf first, which cannot be read from stream. + if (str_value == "inf") { + value = static_cast(std::numeric_limits::infinity()); + } else if (str_value == "-inf") { + value = static_cast(-std::numeric_limits::infinity()); + } else if (str_value == "nan") { + value = static_cast(std::numeric_limits::quiet_NaN()); + } else { + std::stringstream convert_stream(str_value); + if (std::is_same::value) { + int64_t tmp_value; + convert_stream >> tmp_value; + value = static_cast(tmp_value); + } else { + double tmp_value; + convert_stream >> tmp_value; + value = static_cast(tmp_value); + } + } + } + + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(ctx.GetPlace()); + if (force_cpu) { + math::SetConstant functor; + out->mutable_data(platform::CPUPlace(), data_type); + functor(reinterpret_cast(dev_ctx), + out, static_cast(value)); + } else { + out->mutable_data(ctx.GetPlace(), data_type); + Tensor tensor_tmp(data_type); + tensor_tmp.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&tensor_tmp, value); + + auto stream = + ctx.template device_context() + .stream(); + const auto &runner = + NpuOpRunner("FillD", {tensor_tmp}, {*out}, + {{"dims", framework::vectorize(out->dims())}}); + runner.Run(stream); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + fill_constant_batch_size_like, + ops::FillConstantBatchSizeLikeOpNPUKernel< + paddle::platform::NPUDeviceContext, float>, + ops::FillConstantBatchSizeLikeOpNPUKernel< + paddle::platform::NPUDeviceContext, int>, + ops::FillConstantBatchSizeLikeOpNPUKernel< + paddle::platform::NPUDeviceContext, paddle::platform::float16>); diff --git a/python/paddle/fluid/tests/unittests/npu/test_fill_constant_batch_size_like_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_fill_constant_batch_size_like_op_npu.py new file mode 100644 index 00000000000000..9fb650ddb514e1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_fill_constant_batch_size_like_op_npu.py @@ -0,0 +1,150 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import core + +paddle.enable_static() +SEED = 2021 + + +class TestFillConstantBatchSizeLike(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "fill_constant_batch_size_like" + self.init_shape() + self.init_value() + self.init_dtype() + self.init_force_cpu() + self.init_dim_idx() + + self.inputs = { + 'Input': np.random.random(self.input_shape).astype("float32") + } + self.attrs = { + 'shape': self.shape, + 'value': self.value, + 'str_value': self.str_value, + 'dtype': self.dtype, + 'force_cpu': self.force_cpu, + 'input_dim_idx': self.input_dim_idx, + 'output_dim_idx': self.output_dim_idx + } + self.outputs = { + 'Out': np.full(self.output_shape, self.output_value, + self.output_dtype) + } + + def set_npu(self): + self.__class__.use_npu = True + + def init_shape(self): + self.input_shape = [4, 5] + self.shape = [123, 92] + self.output_shape = (4, 92) + + def init_value(self): + self.value = 3.8 + self.str_value = '' + self.output_value = 3.8 + + def init_dtype(self): + self.dtype = core.VarDesc.VarType.FP32 + self.output_dtype = np.float32 + + def init_force_cpu(self): + self.force_cpu = False + + def init_dim_idx(self): + self.input_dim_idx = 0 + self.output_dim_idx = 0 + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestFillConstantBatchSizeLike2(TestFillConstantBatchSizeLike): + def init_shape(self): + # test shape + self.input_shape = [4, 5, 6, 7] + self.shape = [10, 123, 92] + self.output_shape = (4, 123, 92) + + +class TestFillConstantBatchSizeLike3(TestFillConstantBatchSizeLike): + def init_value(self): + # use 'str_value' rather than 'value' + self.value = 3.8 + self.str_value = '4.5' + self.output_value = 4.5 + + +class TestFillConstantBatchSizeLike4(TestFillConstantBatchSizeLike): + def init_value(self): + # str_value = 'inf' + self.value = 3.8 + self.str_value = 'inf' + self.output_value = float('inf') + + +class TestFillConstantBatchSizeLike5(TestFillConstantBatchSizeLike): + def init_value(self): + # str_value = '-inf' + self.value = 3.8 + self.str_value = '-inf' + self.output_value = -float('inf') + + +class TestFillConstantBatchSizeLike6(TestFillConstantBatchSizeLike): + def init_dtype(self): + self.dtype = core.VarDesc.VarType.FP16 + self.output_dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-2) + + +class TestFillConstantBatchSizeLike7(TestFillConstantBatchSizeLike): + def init_dtype(self): + self.dtype = core.VarDesc.VarType.INT32 + self.output_dtype = np.int32 + + +class TestFillConstantBatchSizeLike8(TestFillConstantBatchSizeLike): + def init_force_cpu(self): + self.force_cpu = True + + +class TestFillConstantBatchSizeLike9(TestFillConstantBatchSizeLike): + def init_shape(self): + self.input_shape = [4, 5] + self.shape = [123, 92] + self.output_shape = (123, 4) + + def init_dim_idx(self): + self.input_dim_idx = 0 + self.output_dim_idx = 1 + + +if __name__ == '__main__': + unittest.main()