From 35aebd5e1f7e05f8aca9fbb5350af82bf74d9c56 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 7 Jun 2021 03:12:23 +0000 Subject: [PATCH 01/12] new api trunc, test=develop --- paddle/fluid/operators/trunc_op.cc | 87 +++++++++++++++++++ paddle/fluid/operators/trunc_op.cu | 74 ++++++++++++++++ paddle/fluid/operators/trunc_op.h | 53 +++++++++++ python/paddle/__init__.py | 4 +- .../fluid/tests/unittests/test_trunc_op.py | 86 ++++++++++++++++++ python/paddle/tensor/__init__.py | 4 +- python/paddle/tensor/math.py | 39 +++++++++ 7 files changed, 345 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/trunc_op.cc create mode 100644 paddle/fluid/operators/trunc_op.cu create mode 100644 paddle/fluid/operators/trunc_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_trunc_op.py diff --git a/paddle/fluid/operators/trunc_op.cc b/paddle/fluid/operators/trunc_op.cc new file mode 100644 index 00000000000000..00f254d6735c79 --- /dev/null +++ b/paddle/fluid/operators/trunc_op.cc @@ -0,0 +1,87 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/trunc_op.h" + +namespace paddle { +namespace operators { + +class TruncOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "trunc"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "trunc"); + auto input_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", input_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class TruncOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of trunc op."); + AddOutput("Out", "(Tensor), The output tensor of trunc op."); + AddComment(R"DOC( +Trunc Operator. +Returns a new tensor with the truncated integer values of input. +$$out = trunc(x)$$ +)DOC"); + } +}; + +class TruncGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "TruncGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + framework::GradVarName("X"), "TruncGrad"); + + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); + } +}; + +template +class TruncGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("trunc_grad"); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(trunc, ops::TruncOp, ops::TruncOpMaker, + ops::TruncGradOpMaker, + ops::TruncGradOpMaker); + +REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp); + +REGISTER_OP_CPU_KERNEL(trunc, ops::TruncKernel, ops::TruncKernel, + ops::TruncKernel); + +REGISTER_OP_CPU_KERNEL(trunc_grad, ops::TruncGradKernel, + ops::TruncGradKernel, ops::TruncGradKernel); + diff --git a/paddle/fluid/operators/trunc_op.cu b/paddle/fluid/operators/trunc_op.cu new file mode 100644 index 00000000000000..01fef7903990b4 --- /dev/null +++ b/paddle/fluid/operators/trunc_op.cu @@ -0,0 +1,74 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/trunc_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void Trunc(const T* x, T* out, int N) { + CUDA_KERNEL_LOOP(index, N) { out[index] = trunc(x[index]); } +} + +template +__global__ void TruncGrad(const T* dout, T* dx, int N) { + CUDA_KERNEL_LOOP(index, N) { dx[index] = 0.0; } +} + +template +class TruncCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + + const T* x_data = x->data(); + T* out_data = out->mutable_data(context.GetPlace()); + + int numel = x->numel(); + + dim3 blockSize(256); + dim3 gridSize((numel + blockSize.x - 1) / blockSize.x); + Trunc<<>>(x_data, out_data, numel); + } +}; + +template +class TruncCUDAGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dout = context.Input(framework::GradVarName("Out")); + auto* dx = context.Output(framework::GradVarName("X")); + + const T* dout_data = dout->data(); + T* dx_data = dx->mutable_data(context.GetPlace()); + + int numel = dout->numel(); + + dim3 blockSize(256); + dim3 gridSize((numel + blockSize.x - 1) / blockSize.x); + TruncGrad<<>>(dout_data, dx_data, numel); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(trunc, ops::TruncCUDAKernel, + ops::TruncCUDAKernel, + ops::TruncCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(trunc_grad, ops::TruncCUDAGradKernel, + ops::TruncCUDAGradKernel, + ops::TruncCUDAGradKernel); + diff --git a/paddle/fluid/operators/trunc_op.h b/paddle/fluid/operators/trunc_op.h new file mode 100644 index 00000000000000..c2ef228afd3967 --- /dev/null +++ b/paddle/fluid/operators/trunc_op.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class TruncKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + + size_t numel = x->numel(); + const T* x_data = x->data(); + T* out_data = out->mutable_data(context.GetPlace()); + + for (size_t i = 0; i < numel; i++) { + out_data[i] = trunc(x_data[i]); + } + } +}; + +template +class TruncGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dx = context.Output(framework::GradVarName("X")); + T* dx_data = dx->mutable_data(context.GetPlace()); + + int numel = dx->numel(); + memset(dx_data, 0.0, numel * sizeof(T)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 7bac330376c44f..dc4c8c55280abe 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -206,6 +206,7 @@ from .tensor.math import prod # noqa: F401 from .tensor.math import broadcast_shape # noqa: F401 from .tensor.math import conj # noqa: F401 +from .tensor.math import trunc # noqa: F401 from .tensor.random import multinomial # noqa: F401 from .tensor.random import standard_normal # noqa: F401 @@ -493,5 +494,6 @@ 'log2', 'log10', 'concat', - 'check_shape' + 'check_shape', + 'trunc' ] diff --git a/python/paddle/fluid/tests/unittests/test_trunc_op.py b/python/paddle/fluid/tests/unittests/test_trunc_op.py new file mode 100644 index 00000000000000..4723916c8dfeb6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_trunc_op.py @@ -0,0 +1,86 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + +paddle.enable_static() + +class TestTruncOp(OpTest): + def setUp(self): + self.op_type = "trunc" + self.dtype = np.float64 + np.random.seed(2021) + self.inputs = {'X': np.random.random((20, 20)).astype(self.dtype)} + self.outputs = {'Out': (np.trunc(self.inputs['X']))} + + def init_dtype_type(self): + pass + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + print(self.inputs) + + self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5) + +class TestFloatTruncOp(TestTruncOp): + def init_dtype_type(self): + self.dtype = np.float32 + +class TestIntTruncOp(TestTruncOp): + def init_dtype_type(self): + self.dtype = np.int32 + +class TestTruncAPI(unittest.TestCase): + def setUp(self): + self.shape = [20, 20] + self.x = np.random.random((20, 20)).astype(np.float32) + self.place = paddle.CPUPlace() + + def test_api_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', self.shape) + out = paddle.trunc(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x}, fetch_list=[out]) + out_ref = np.trunc(self.x) + for out in res: + self.assertEqual(np.allclose(out, out_ref, rtol=1e-08), True) + + def test_api_dygraph(self): + paddle.disable_static(self.place) + x_tensor = paddle.to_tensor(self.x) + out = paddle.trunc(x_tensor) + out_ref = np.trunc(self.x) + self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True) + paddle.enable_static() + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', [20, 20], 'bool') + self.assertRaises(TypeError, paddle.trunc, x) + +if __name__ == "__main__": + unittest.main() + diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index c8d80fc9bc68cb..fa62740930bb78 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -163,6 +163,7 @@ from .math import any # noqa: F401 from .math import broadcast_shape # noqa: F401 from .math import conj # noqa: F401 +from .math import trunc # noqa: F401 from .random import multinomial # noqa: F401 from .random import standard_normal # noqa: F401 @@ -345,5 +346,6 @@ 'rank', 'shape', 'real', - 'imag' + 'imag', + 'trunc' ] diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 2f69946c52139b..1226eeb2b259fd 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -855,6 +855,45 @@ def add_n(inputs, name=None): return out +def trunc(x, name=None): + ''' + Returns a new tensor with the truncated integer values of input. + Args: + x (Tensor): The input tensor, it's data type should be int, float, double. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: The product Tensor. + + Examples: + .. code-block:: python + import paddle + import numpy as np + paddle.set_device('cpu') + x_data = np.random.random([2,2]).astype(np.float32) + x = paddle.to_tensor(x_data) + out = paddle.trunc(x, y) + print(out) + # [[0., 1.], + # [0., 0.]]) + ''' + if in_dygraph_mode(): + out = _varbase_creator(dtype=x.dtype) + return core.ops.trunc(x) + else: + inputs = {"X": x} + attrs = {} + + helper = LayerHelper("trunc", **locals()) + check_variable_and_dtype(x, 'X', ['float16', 'float32', 'float64'], 'trunc') + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type="trunc", inputs=inputs, attrs=attrs, outputs={"Out": out}) + return out + + + def mm(input, mat2, name=None): """ From cd716f25b6675c98ef77acb0e6f488aca40a7883 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 7 Jun 2021 05:54:04 +0000 Subject: [PATCH 02/12] new api trunc, test=develop --- paddle/fluid/operators/trunc_op.cc | 16 ++++++++++++++-- .../fluid/tests/unittests/test_trunc_op.py | 8 ++++++-- python/paddle/tensor/math.py | 8 +++----- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/trunc_op.cc b/paddle/fluid/operators/trunc_op.cc index 00f254d6735c79..7a5e927bfef257 100644 --- a/paddle/fluid/operators/trunc_op.cc +++ b/paddle/fluid/operators/trunc_op.cc @@ -3,7 +3,9 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -54,6 +56,14 @@ class TruncGradOp : public framework::OperatorWithKernel { auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + // Note: don't get data type from ctx.Input("Input"); + auto dtype = + ctx.Input(framework::GradVarName("Out"))->type(); + return framework::OpKernelType(dtype, ctx.GetPlace()); + } }; template @@ -69,6 +79,8 @@ class TruncGradOpMaker : public framework::SingleGradOpMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERER(SliceOpGradNoNeedBufferVarsInference, "X"); + } // namespace operators } // namespace paddle @@ -77,11 +89,11 @@ REGISTER_OPERATOR(trunc, ops::TruncOp, ops::TruncOpMaker, ops::TruncGradOpMaker, ops::TruncGradOpMaker); -REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp); +REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp, + ops::SliceOpGradNoNeedBufferVarsInference); REGISTER_OP_CPU_KERNEL(trunc, ops::TruncKernel, ops::TruncKernel, ops::TruncKernel); REGISTER_OP_CPU_KERNEL(trunc_grad, ops::TruncGradKernel, ops::TruncGradKernel, ops::TruncGradKernel); - diff --git a/python/paddle/fluid/tests/unittests/test_trunc_op.py b/python/paddle/fluid/tests/unittests/test_trunc_op.py index 4723916c8dfeb6..b6a7b81bb46548 100644 --- a/python/paddle/fluid/tests/unittests/test_trunc_op.py +++ b/python/paddle/fluid/tests/unittests/test_trunc_op.py @@ -24,6 +24,7 @@ paddle.enable_static() + class TestTruncOp(OpTest): def setUp(self): self.op_type = "trunc" @@ -33,7 +34,7 @@ def setUp(self): self.outputs = {'Out': (np.trunc(self.inputs['X']))} def init_dtype_type(self): - pass + self.dtype = np.float64 def test_check_output(self): self.check_output() @@ -43,14 +44,17 @@ def test_check_grad(self): self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5) + class TestFloatTruncOp(TestTruncOp): def init_dtype_type(self): self.dtype = np.float32 + class TestIntTruncOp(TestTruncOp): def init_dtype_type(self): self.dtype = np.int32 + class TestTruncAPI(unittest.TestCase): def setUp(self): self.shape = [20, 20] @@ -81,6 +85,6 @@ def test_errors(self): x = paddle.fluid.data('X', [20, 20], 'bool') self.assertRaises(TypeError, paddle.trunc, x) + if __name__ == "__main__": unittest.main() - diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 1226eeb2b259fd..c54ad7bfe52cfe 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -870,22 +870,20 @@ def trunc(x, name=None): import paddle import numpy as np paddle.set_device('cpu') - x_data = np.random.random([2,2]).astype(np.float32) - x = paddle.to_tensor(x_data) - out = paddle.trunc(x, y) + x = paddle.rand([2,2],'float32') + out = paddle.trunc(x) print(out) # [[0., 1.], # [0., 0.]]) ''' if in_dygraph_mode(): - out = _varbase_creator(dtype=x.dtype) return core.ops.trunc(x) else: inputs = {"X": x} attrs = {} helper = LayerHelper("trunc", **locals()) - check_variable_and_dtype(x, 'X', ['float16', 'float32', 'float64'], 'trunc') + check_variable_and_dtype(x, 'X', ['int32', 'float32', 'float64'], 'trunc') out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( From 33cc89b0a2b5bc324f21c4e40d8e1bb8292b009a Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 7 Jun 2021 06:57:03 +0000 Subject: [PATCH 03/12] new api paddle.trunc, test=develop --- paddle/fluid/operators/trunc_op.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/operators/trunc_op.cu b/paddle/fluid/operators/trunc_op.cu index 01fef7903990b4..7937674d9bd0da 100644 --- a/paddle/fluid/operators/trunc_op.cu +++ b/paddle/fluid/operators/trunc_op.cu @@ -71,4 +71,3 @@ REGISTER_OP_CUDA_KERNEL(trunc, ops::TruncCUDAKernel, REGISTER_OP_CUDA_KERNEL(trunc_grad, ops::TruncCUDAGradKernel, ops::TruncCUDAGradKernel, ops::TruncCUDAGradKernel); - From b5772966c388fc6fa05b4230c85c219d550e9d57 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 8 Jun 2021 03:52:09 +0000 Subject: [PATCH 04/12] new api paddle.trunc, test=develop --- paddle/fluid/framework/unused_var_check.cc | 1 + paddle/fluid/operators/trunc_op.cc | 18 ++----- paddle/fluid/operators/trunc_op.cu | 54 ++++++++++++++----- .../fluid/tests/unittests/test_trunc_op.py | 2 - python/paddle/tensor/math.py | 2 +- 5 files changed, 47 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/framework/unused_var_check.cc b/paddle/fluid/framework/unused_var_check.cc index 0f8465ab8948e4..f8ace3e85a643e 100644 --- a/paddle/fluid/framework/unused_var_check.cc +++ b/paddle/fluid/framework/unused_var_check.cc @@ -75,6 +75,7 @@ static const std::unordered_set &GetOpWithUnusedVarAllowSet() { "data_norm_grad", // 0 "update_loss_scaling", // 0 "fused_embedding_eltwise_layernorm", // 0 + "trunc_grad", // 1 }); return *allow_set; } diff --git a/paddle/fluid/operators/trunc_op.cc b/paddle/fluid/operators/trunc_op.cc index 7a5e927bfef257..2b79e2152b2f34 100644 --- a/paddle/fluid/operators/trunc_op.cc +++ b/paddle/fluid/operators/trunc_op.cc @@ -56,14 +56,6 @@ class TruncGradOp : public framework::OperatorWithKernel { auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); } - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - // Note: don't get data type from ctx.Input("Input"); - auto dtype = - ctx.Input(framework::GradVarName("Out"))->type(); - return framework::OpKernelType(dtype, ctx.GetPlace()); - } }; template @@ -79,8 +71,6 @@ class TruncGradOpMaker : public framework::SingleGradOpMaker { } }; -DECLARE_NO_NEED_BUFFER_VARS_INFERER(SliceOpGradNoNeedBufferVarsInference, "X"); - } // namespace operators } // namespace paddle @@ -89,11 +79,11 @@ REGISTER_OPERATOR(trunc, ops::TruncOp, ops::TruncOpMaker, ops::TruncGradOpMaker, ops::TruncGradOpMaker); -REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp, - ops::SliceOpGradNoNeedBufferVarsInference); +REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp); REGISTER_OP_CPU_KERNEL(trunc, ops::TruncKernel, ops::TruncKernel, - ops::TruncKernel); + ops::TruncKernel, ops::TruncKernel); REGISTER_OP_CPU_KERNEL(trunc_grad, ops::TruncGradKernel, - ops::TruncGradKernel, ops::TruncGradKernel); + ops::TruncGradKernel, ops::TruncGradKernel, + ops::TruncGradKernel); diff --git a/paddle/fluid/operators/trunc_op.cu b/paddle/fluid/operators/trunc_op.cu index 7937674d9bd0da..2c341695db5142 100644 --- a/paddle/fluid/operators/trunc_op.cu +++ b/paddle/fluid/operators/trunc_op.cu @@ -15,12 +15,39 @@ namespace paddle { namespace operators { template -__global__ void Trunc(const T* x, T* out, int N) { - CUDA_KERNEL_LOOP(index, N) { out[index] = trunc(x[index]); } +class truncFunctor { + public: + __device__ truncFunctor(const T x) : _x(x) {} + __device__ T operator()() { return trunc(_x); } + const T _x; +}; + +template <> +class truncFunctor { + public: + __device__ truncFunctor(const int x) : _x(x) {} + __device__ int operator()() { return _x; } + const int _x; +}; + +template <> +class truncFunctor { + public: + __device__ truncFunctor(const int64_t x) : _x(x) {} + __device__ int64_t operator()() { return _x; } + const int64_t _x; +}; + +template +__global__ void Trunc(const T* x, T* out, int64_t N) { + CUDA_KERNEL_LOOP(index, N) { + truncFunctor functor(x[index]); + out[index] = functor(); + } } template -__global__ void TruncGrad(const T* dout, T* dx, int N) { +__global__ void TruncGrad(T* dx, int64_t N) { CUDA_KERNEL_LOOP(index, N) { dx[index] = 0.0; } } @@ -31,10 +58,10 @@ class TruncCUDAKernel : public framework::OpKernel { auto* x = context.Input("X"); auto* out = context.Output("Out"); - const T* x_data = x->data(); - T* out_data = out->mutable_data(context.GetPlace()); + const auto* x_data = x->data(); + auto* out_data = out->mutable_data(context.GetPlace()); - int numel = x->numel(); + int64_t numel = x->numel(); dim3 blockSize(256); dim3 gridSize((numel + blockSize.x - 1) / blockSize.x); @@ -49,14 +76,14 @@ class TruncCUDAGradKernel : public framework::OpKernel { auto* dout = context.Input(framework::GradVarName("Out")); auto* dx = context.Output(framework::GradVarName("X")); - const T* dout_data = dout->data(); - T* dx_data = dx->mutable_data(context.GetPlace()); + const auto* dout_data = dout->data(); + auto* dx_data = dx->mutable_data(context.GetPlace()); - int numel = dout->numel(); + int64_t numel = dout->numel(); dim3 blockSize(256); dim3 gridSize((numel + blockSize.x - 1) / blockSize.x); - TruncGrad<<>>(dout_data, dx_data, numel); + TruncGrad<<>>(dx_data, numel); } }; @@ -65,9 +92,10 @@ class TruncCUDAGradKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(trunc, ops::TruncCUDAKernel, - ops::TruncCUDAKernel, - ops::TruncCUDAKernel); + ops::TruncCUDAKernel, ops::TruncCUDAKernel, + ops::TruncCUDAKernel); REGISTER_OP_CUDA_KERNEL(trunc_grad, ops::TruncCUDAGradKernel, ops::TruncCUDAGradKernel, - ops::TruncCUDAGradKernel); + ops::TruncCUDAGradKernel, + ops::TruncCUDAGradKernel); diff --git a/python/paddle/fluid/tests/unittests/test_trunc_op.py b/python/paddle/fluid/tests/unittests/test_trunc_op.py index b6a7b81bb46548..51844071138c70 100644 --- a/python/paddle/fluid/tests/unittests/test_trunc_op.py +++ b/python/paddle/fluid/tests/unittests/test_trunc_op.py @@ -40,8 +40,6 @@ def test_check_output(self): self.check_output() def test_check_grad(self): - print(self.inputs) - self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index c54ad7bfe52cfe..6851b8858b8b13 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -883,7 +883,7 @@ def trunc(x, name=None): attrs = {} helper = LayerHelper("trunc", **locals()) - check_variable_and_dtype(x, 'X', ['int32', 'float32', 'float64'], 'trunc') + check_variable_and_dtype(x, 'X', ['int32', 'int64', 'float32', 'float64'], 'trunc') out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( From 9a6272b2d259241a05db8dea7b869f2bac72f4e9 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 9 Jun 2021 05:16:03 +0000 Subject: [PATCH 05/12] new api paddle.trunc, test=develop --- paddle/fluid/operators/trunc_op.cu | 43 ++++++++++++++++++------------ paddle/fluid/operators/trunc_op.h | 2 ++ 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/trunc_op.cu b/paddle/fluid/operators/trunc_op.cu index 2c341695db5142..927d83d43829d4 100644 --- a/paddle/fluid/operators/trunc_op.cu +++ b/paddle/fluid/operators/trunc_op.cu @@ -2,7 +2,9 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -15,40 +17,46 @@ namespace paddle { namespace operators { template -class truncFunctor { +class TruncFunctor { + public: + __device__ TruncFunctor(const T x) : x_(x) {} + __device__ T operator()() { return trunc(x_); } + public: - __device__ truncFunctor(const T x) : _x(x) {} - __device__ T operator()() { return trunc(_x); } - const T _x; + const T x_; }; template <> -class truncFunctor { +class TruncFunctor { + public: + __device__ TruncFunctor(const int x) : x_(x) {} + __device__ int operator()() { return x_; } + public: - __device__ truncFunctor(const int x) : _x(x) {} - __device__ int operator()() { return _x; } - const int _x; + const int x_; }; template <> -class truncFunctor { +class TruncFunctor { public: - __device__ truncFunctor(const int64_t x) : _x(x) {} - __device__ int64_t operator()() { return _x; } - const int64_t _x; + __device__ TruncFunctor(const int64_t x) : x_(x) {} + __device__ int64_t operator()() { return x_; } + + public: + const int64_t x_; }; template __global__ void Trunc(const T* x, T* out, int64_t N) { CUDA_KERNEL_LOOP(index, N) { - truncFunctor functor(x[index]); + TruncFunctor functor(x[index]); out[index] = functor(); } } template __global__ void TruncGrad(T* dx, int64_t N) { - CUDA_KERNEL_LOOP(index, N) { dx[index] = 0.0; } + CUDA_KERNEL_LOOP(index, N) { dx[index] = static_cast(0.0); } } template @@ -63,9 +71,10 @@ class TruncCUDAKernel : public framework::OpKernel { int64_t numel = x->numel(); - dim3 blockSize(256); - dim3 gridSize((numel + blockSize.x - 1) / blockSize.x); - Trunc<<>>(x_data, out_data, numel); + int theads = platform::PADDLE_CUDA_NUM_THREADS; + int blocks = (numel + theads - 1) / theads; + + Trunc<<>>(x_data, out_data, numel); } }; diff --git a/paddle/fluid/operators/trunc_op.h b/paddle/fluid/operators/trunc_op.h index c2ef228afd3967..0f788eae5249c5 100644 --- a/paddle/fluid/operators/trunc_op.h +++ b/paddle/fluid/operators/trunc_op.h @@ -2,7 +2,9 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. From de107a392eebeb0481e7611e1dd4850d6a7f0662 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 9 Jun 2021 05:37:29 +0000 Subject: [PATCH 06/12] new api paddle.trunc, test=develop --- paddle/fluid/operators/trunc_op.cu | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/trunc_op.cu b/paddle/fluid/operators/trunc_op.cu index 927d83d43829d4..a284e0ea6e3939 100644 --- a/paddle/fluid/operators/trunc_op.cu +++ b/paddle/fluid/operators/trunc_op.cu @@ -12,10 +12,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/trunc_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" namespace paddle { namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; + template class TruncFunctor { public: @@ -71,7 +75,7 @@ class TruncCUDAKernel : public framework::OpKernel { int64_t numel = x->numel(); - int theads = platform::PADDLE_CUDA_NUM_THREADS; + int theads = PADDLE_CUDA_NUM_THREADS; int blocks = (numel + theads - 1) / theads; Trunc<<>>(x_data, out_data, numel); @@ -90,9 +94,10 @@ class TruncCUDAGradKernel : public framework::OpKernel { int64_t numel = dout->numel(); - dim3 blockSize(256); - dim3 gridSize((numel + blockSize.x - 1) / blockSize.x); - TruncGrad<<>>(dx_data, numel); + int theads = PADDLE_CUDA_NUM_THREADS; + int blocks = (numel + theads - 1) / theads; + + TruncGrad<<>>(dx_data, numel); } }; From 58b19cb1ffef68d6742d8318e069ca9ed187b8a8 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 9 Jun 2021 08:04:43 +0000 Subject: [PATCH 07/12] new api paddle.trunc, test=develop --- python/paddle/tensor/math.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 6851b8858b8b13..462e560a4d2df3 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -855,24 +855,25 @@ def add_n(inputs, name=None): return out -def trunc(x, name=None): +def trunc(input, name=None): ''' - Returns a new tensor with the truncated integer values of input. + This API is used to returns a new tensor with the truncated integer values of input. Args: - x (Tensor): The input tensor, it's data type should be int, float, double. + input (Tensor): The input tensor, it's data type should be int, float, double. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor: The product Tensor. + Tensor: The output Tensor of trunc. Examples: .. code-block:: python + import paddle - import numpy as np + paddle.set_device('cpu') - x = paddle.rand([2,2],'float32') - out = paddle.trunc(x) - print(out) + input = paddle.rand([2,2],'float32') + output = paddle.trunc(input) + print(output) # [[0., 1.], # [0., 0.]]) ''' From b7b5803409bc6d1d198d018b22a5eb580fdda1dd Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 9 Jun 2021 08:52:46 +0000 Subject: [PATCH 08/12] new api paddle.trunc, test=develop --- python/paddle/tensor/math.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 462e560a4d2df3..5c1f8b7de232fa 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -858,6 +858,7 @@ def add_n(inputs, name=None): def trunc(input, name=None): ''' This API is used to returns a new tensor with the truncated integer values of input. + Args: input (Tensor): The input tensor, it's data type should be int, float, double. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -869,7 +870,7 @@ def trunc(input, name=None): .. code-block:: python import paddle - + paddle.set_device('cpu') input = paddle.rand([2,2],'float32') output = paddle.trunc(input) @@ -878,14 +879,14 @@ def trunc(input, name=None): # [0., 0.]]) ''' if in_dygraph_mode(): - return core.ops.trunc(x) + return core.ops.trunc(input) else: - inputs = {"X": x} + inputs = {"X": input} attrs = {} helper = LayerHelper("trunc", **locals()) - check_variable_and_dtype(x, 'X', ['int32', 'int64', 'float32', 'float64'], 'trunc') - out = helper.create_variable_for_type_inference(dtype=x.dtype) + check_variable_and_dtype(input, 'X', ['int32', 'int64', 'float32', 'float64'], 'trunc') + out = helper.create_variable_for_type_inference(dtype=input.dtype) helper.append_op( type="trunc", inputs=inputs, attrs=attrs, outputs={"Out": out}) From f76b32e9d7e848e64dcbbeac4366790495382995 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 9 Jun 2021 09:21:55 +0000 Subject: [PATCH 09/12] new api paddle.trunc, test=develop --- python/paddle/tensor/math.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 5c1f8b7de232fa..342ccfaeea7b5d 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -871,12 +871,17 @@ def trunc(input, name=None): import paddle - paddle.set_device('cpu') input = paddle.rand([2,2],'float32') + print(input) + # Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[0.02331470, 0.42374918], + # [0.79647720, 0.74970269]]) + output = paddle.trunc(input) print(output) - # [[0., 1.], - # [0., 0.]]) + # Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[0., 0.], + # [0., 0.]]) ''' if in_dygraph_mode(): return core.ops.trunc(input) From 8158a58c1be27e00ca413c936aa8a5f3b2ef942a Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 10 Jun 2021 01:33:45 +0000 Subject: [PATCH 10/12] new api paddle.trunc, test=develop --- python/paddle/tensor/math.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 342ccfaeea7b5d..a7872d1c9c97d3 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -871,17 +871,12 @@ def trunc(input, name=None): import paddle - input = paddle.rand([2,2],'float32') - print(input) - # Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, - # [[0.02331470, 0.42374918], - # [0.79647720, 0.74970269]]) - + input = paddle.to_tensor([[1.45, 3.54], [0.23, -4.21]], dtype='float32') output = paddle.trunc(input) print(output) # Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, - # [[0., 0.], - # [0., 0.]]) + # [[1., 3.], + # [0., -4.]]) ''' if in_dygraph_mode(): return core.ops.trunc(input) From b7f18f9f0a790723b966949fc4f0ce447c41fa64 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Thu, 10 Jun 2021 02:36:15 +0000 Subject: [PATCH 11/12] new api paddle.trunc, test=develop --- python/paddle/tensor/math.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 85841362e292fa..b10413182c0154 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -872,12 +872,17 @@ def trunc(input, name=None): import paddle - input = paddle.to_tensor([[1.45, 3.54], [0.23, -4.21]], dtype='float32') + input = paddle.rand([2,2],'float32') + print(input) + # Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[0.02331470, 0.42374918], + # [0.79647720, 0.74970269]]) + output = paddle.trunc(input) print(output) # Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, - # [[1., 3.], - # [0., -4.]]) + # [[0., 0.], + # [0., 0.]])) ''' if in_dygraph_mode(): return core.ops.trunc(input) From 2783560defc699937cd062c73f4c0eea2cb18d49 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 11 Jun 2021 02:04:38 +0000 Subject: [PATCH 12/12] new api paddle.trunc, test=develop --- python/paddle/tensor/math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index b10413182c0154..66ad2741ece434 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -861,7 +861,7 @@ def trunc(input, name=None): This API is used to returns a new tensor with the truncated integer values of input. Args: - input (Tensor): The input tensor, it's data type should be int, float, double. + input (Tensor): The input tensor, it's data type should be int32, int64, float32, float64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: