From 4ac4a06724bfb887dafc01671f8784f093ccd13d Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Wed, 4 Aug 2021 14:37:53 +0800 Subject: [PATCH 1/6] Add NPU kernel for TopKV2 op --- paddle/fluid/operators/top_k_v2_op_npu.cc | 250 ++++++++++++++++ .../unittests/npu/test_top_k_v2_op_npu.py | 279 ++++++++++++++++++ .../static_mode_white_list.cpython-37.pyc | Bin 0 -> 20802 bytes 3 files changed, 529 insertions(+) create mode 100644 paddle/fluid/operators/top_k_v2_op_npu.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py create mode 100644 tools/__pycache__/static_mode_white_list.cpython-37.pyc diff --git a/paddle/fluid/operators/top_k_v2_op_npu.cc b/paddle/fluid/operators/top_k_v2_op_npu.cc new file mode 100644 index 00000000000000..62bd0c015b8f12 --- /dev/null +++ b/paddle/fluid/operators/top_k_v2_op_npu.cc @@ -0,0 +1,250 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/top_k_v2_op.h" +#include +#include +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class TopkV2NPUKernel : public framework::OpKernel { + public: + // Use CANN TopKV2 operator to implement paddle TopKV2Op + void Compute(const framework::ExecutionContext& context) const override { + using Tensor = framework::Tensor; + + // Read message from context + auto* input = context.Input("X"); + auto* k_tensor = context.Input("K"); + auto* output = context.Output("Out"); + auto* indices = context.Output("Indices"); + + int k = static_cast(context.Attr("k")); + int axis = static_cast(context.Attr("axis")); + const bool sorted = static_cast(context.Attr("sorted")); + const bool largest = static_cast(context.Attr("largest")); + + // Calculate the real value of axis and k + if (axis < 0) { + axis += input->dims().size(); + } + + if (k_tensor != nullptr) { + // seems complicated, but I really don't know how to assign a NPU value to + // a CPU variable by an elegant way + std::vector v_tmp(1); + TensorToVector( + *k_tensor, + context.template device_context(), + &v_tmp); + k = v_tmp[0]; + } + + // Allocate space for output tensors on NPU + framework::DDim output_dims = input->dims(); + output_dims[axis] = k; + + output->Resize(output_dims); + indices->Resize(output_dims); + + output->mutable_data(context.GetPlace()); + indices->mutable_data(context.GetPlace()); + + // Construct the input tensor x of CANN TopKV2 operator + // as CANN TopKV2 operator does not support setting 'axis'(defaults to the + // last dimension) and 'largest'(defaults to true) parameter yet, + // 1. when the 'axis' is not the last dimension, we use CANN Transpose + // operator to permutes the dimension 'axis' to the last dimension + // 2. when the 'largest' is false, we use CANN Neg operator to negate the + // input tensor element-wise, which convert descending to ascending order + // once the functino of the parameter 'dim' and 'largest' is further + // improved, these additional actions can be removed + Tensor* input_transpose = nullptr; + Tensor* input_neg = nullptr; + const Tensor* x_cann = + input; // the input tensor "x" of CANN TopKV2 operator + std::vector perm; + const int last_axis = static_cast( + input->dims().size() - + 1); // attention: there may be bugs when the input tensor is empty + + if (axis != + last_axis) { // in this case, the 'input' tensor should be transposed + // compute perm vector + perm.resize(last_axis + 1); + for (int i = 0; i <= last_axis; ++i) { + perm[i] = i; + } + std::swap(perm[axis], perm[last_axis]); + + // construct 'input_transpose' + input_transpose = new Tensor(input->type()); + + framework::DDim input_transpose_dims = input->dims(); + std::swap(input_transpose_dims[axis], input_transpose_dims[last_axis]); + + input_transpose->Resize(input_transpose_dims); + input_transpose->mutable_data(context.GetPlace()); + + // run CANN Transpose operator + NpuOpRunner npu_op_runner_transpose; + auto npu_stream_transpose = + context.template device_context() + .stream(); + npu_op_runner_transpose.SetType("Transpose") + .AddInput(*input) + .AddInput(std::move(perm)) + .AddOutput(*input_transpose) + .Run(npu_stream_transpose); + + x_cann = input_transpose; + } + + if (!largest) { // in this case, the 'input' tensor should be negated + // element-wise + // construct 'input_neg' + auto* input_tensor = + (input_transpose == nullptr ? input : input_transpose); + input_neg = new Tensor(input_tensor->type()); + input_neg->Resize(input_tensor->dims()); + input_neg->mutable_data(context.GetPlace()); + + // run CANN Neg operator + const auto& npu_op_runner_neg = + NpuOpRunner("Neg", {*input_tensor}, {*input_neg}); + auto npu_stream_neg = + context.template device_context() + .stream(); + npu_op_runner_neg.Run(npu_stream_neg); + + x_cann = input_neg; + } + + // Construct the input and output tensors of CANN TopKV2 operator (except x) + // input k: a 0D tensor of type int32, Number of top elements to look for + // along the last dimension (along each row for matrices) + Tensor* k_cann = new Tensor(framework::proto::VarType::INT32); + k_cann->mutable_data({1}, context.GetPlace()); + FillNpuTensorWithConstant(k_cann, static_cast(k)); + + // output values: a tensor specifying the sorted data, which has the same + // type as 'x' + Tensor* values_cann = nullptr; + if (axis == last_axis && largest) { // in this case, the CANN TopKV2 result + // will directly output to the 'output' + // tensor, which save an operation of + // tensor copy + values_cann = output; + } else { + values_cann = new Tensor(x_cann->type()); + framework::DDim values_cann_dims = x_cann->dims(); + values_cann_dims[last_axis] = k; + values_cann->Resize(values_cann_dims); + values_cann->mutable_data(context.GetPlace()); + } + + // output indices: a tensor of type int32 specifying the indices of sorted + // data + Tensor* indices_cann = new Tensor(framework::proto::VarType::INT32); + indices_cann->Resize(values_cann->dims()); + indices_cann->mutable_data(context.GetPlace()); + + // Run CANN TopKV2 operator + const auto& npu_op_runner_topkv2 = + NpuOpRunner("TopKV2", {*x_cann, *k_cann}, {*values_cann, *indices_cann}, + {{"sorted", sorted}}); + auto npu_stream_topkv2 = + context.template device_context() + .stream(); + npu_op_runner_topkv2.Run(npu_stream_topkv2); + + // Convert the computing result into paddle's output tensors + // 'values_cann' to 'output' and 'indices_cann' to 'indices_transpose' + Tensor* values_cann_neg = nullptr; + Tensor* indices_cann_transpose = nullptr; + + if (!largest) { + // run CANN Neg operator + if (axis == last_axis) { + values_cann_neg = output; // in this case, the CANN Neg result will + // directly output to the 'output' tensor + } else { + values_cann_neg = input_neg; // as the 'input_neg' tensor is no longer + // needed, we reuse its resources to + // 'values_cann_neg' tensor + values_cann_neg->Resize(values_cann->dims()); + } + const auto& npu_op_runner_neg = + NpuOpRunner("Neg", {*values_cann}, {*values_cann_neg}); + auto npu_stream_neg = + context.template device_context() + .stream(); + npu_op_runner_neg.Run(npu_stream_neg); + } + + if (axis != last_axis) { + // run CANN Transpose operator + // transpose values + Tensor* input_tensor = (largest ? values_cann : values_cann_neg); + NpuOpRunner npu_op_runner_transpose_values; + auto npu_stream_transpose_values = + context.template device_context() + .stream(); + npu_op_runner_transpose_values.SetType("Transpose") + .AddInput(*input_tensor) + .AddInput(std::move(perm)) + .AddOutput(*output) + .Run(npu_stream_transpose_values); + + // transpose indices + indices_cann_transpose = new Tensor(indices_cann->type()); + indices_cann_transpose->Resize(indices->dims()); + indices_cann_transpose->mutable_data(context.GetPlace()); + + NpuOpRunner npu_op_runner_transpose_indices; + auto npu_stream_transpose_indices = + context.template device_context() + .stream(); + npu_op_runner_transpose_indices.SetType("Transpose") + .AddInput(*indices_cann) + .AddInput(std::move(perm)) + .AddOutput(*indices_cann_transpose) + .Run(npu_stream_transpose_indices); + } else { + indices_cann_transpose = indices_cann; + } + + // 'indices_cann_transpose' to 'indices', from INT32 to INT64 + auto dst_dtype = ConvertToNpuDtype(indices->type()); + const auto& npu_op_runner_cast = + NpuOpRunner("Cast", {*indices_cann_transpose}, {*indices}, + {{"dst_type", static_cast(dst_dtype)}}); + auto npu_stream_cast = + context.template device_context() + .stream(); + npu_op_runner_cast.Run(npu_stream_cast); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_NPU_KERNEL( + top_k_v2, ops::TopkV2NPUKernel, + ops::TopkV2NPUKernel, + ops::TopkV2NPUKernel, + ops::TopkV2NPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py new file mode 100644 index 00000000000000..11d752cc6fbaa5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py @@ -0,0 +1,279 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid.core as core + + +def numpy_topk(x, k=1, axis=-1, largest=True): + if axis < 0: + axis = len(x.shape) + axis + if largest: + indices = np.argsort(-x, axis=axis) + else: + indices = np.argsort(x, axis=axis) + if largest: + value = -np.sort(-x, axis=axis) + else: + value = np.sort(x, axis=axis) + indices = indices.take(indices=range(0, k), axis=axis) + value = value.take(indices=range(0, k), axis=axis) + return value, indices + + +class TestTopkV2NPUOp(OpTest): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = True + + def setUp(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + self.op_type = "top_k_v2" + + self.dtype = np.float64 + self.input_data = np.random.rand(10, 20) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + def test_check_output(self): + paddle.enable_static() + self.check_output() + + def test_check_grad(self): + paddle.enable_static() + self.check_grad(set(['X']), 'Out') + + +class TestTopkOp1(TestTopkV2NPUOp): + def init_args(self): + self.k = 3 + self.axis = 0 + self.largest = False + + +class TestTopkOp2(TestTopkV2NPUOp): + def init_args(self): + self.k = 4 + self.axis = 0 + self.largest = False + + +class TestTopkOp3(OpTest): + def init_args(self): + self.k = 6 + self.axis = 1 + self.largest = True + + def setUp(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + self.op_type = "top_k_v2" + self.dtype = np.float64 + self.input_data = np.random.rand(16, 100) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp4(TestTopkV2NPUOp): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float64 + self.input_data = np.random.rand(10, 10, 5) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopkOp5(TestTopkV2NPUOp): + def init_args(self): + self.k = 3 + self.axis = 1 + self.largest = True + + def setUp(self): + self.op_type = "top_k_v2" + self.dtype = np.float64 + self.input_data = np.random.rand(10, 10, 5) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + output, indices = numpy_topk( + self.input_data, axis=self.axis, k=self.k, largest=self.largest) + self.outputs = {'Out': output, 'Indices': indices} + + +class TestTopKAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.input_data = np.random.rand(6, 7, 8) + #self.input_data = np.random.rand(2, 3, 4) + self.large_input_data = np.random.rand(2, 1030) + + def run_dygraph(self, place): + paddle.disable_static(place) + input_tensor = paddle.to_tensor(self.input_data) + large_input_tensor = paddle.to_tensor(self.large_input_data) + # test case for basic test case 1 + paddle_result = paddle.topk(input_tensor, k=2) + numpy_result = numpy_topk(self.input_data, k=2) + self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) + + # test case for basic test case 2 with axis + paddle_result = paddle.topk(input_tensor, k=2, axis=1) + numpy_result = numpy_topk(self.input_data, k=2, axis=1) + self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) + # test case for basic test case 3 with tensor K + k_tensor = paddle.to_tensor(np.array([2])) + paddle_result = paddle.topk(input_tensor, k=k_tensor, axis=1) + numpy_result = numpy_topk(self.input_data, k=2, axis=1) + self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) + + # test case for basic test case 4 with tensor largest + k_tensor = paddle.to_tensor(np.array([2])) + paddle_result = paddle.topk(input_tensor, k=2, axis=1, largest=False) + numpy_result = numpy_topk(self.input_data, k=2, axis=1, largest=False) + self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) + + # test case for basic test case 5 with axis -1 + k_tensor = paddle.to_tensor(np.array([2])) + paddle_result = paddle.topk(input_tensor, k=2, axis=-1, largest=False) + numpy_result = numpy_topk(self.input_data, k=2, axis=-1, largest=False) + self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) + + # test case for basic test case 6 for the partial sort + paddle_result = paddle.topk(large_input_tensor, k=1, axis=-1) + numpy_result = numpy_topk(self.large_input_data, k=1, axis=-1) + self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1])) + # test case for basic test case 7 for the unsorted + paddle_result = paddle.topk(input_tensor, k=2, axis=1, sorted=False) + sort_paddle = numpy_topk( + np.array(paddle_result[0].numpy()), axis=1, k=2) + numpy_result = numpy_topk(self.input_data, k=2, axis=1) + self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0])) + + def run_static(self, place): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + input_tensor = paddle.static.data( + name="x", shape=[6, 7, 8], dtype="float64") + #input_tensor = paddle.static.data(name="x", shape=[2, 3, 4], dtype="float64") + large_input_tensor = paddle.static.data( + name="large_x", shape=[2, 1030], dtype="float64") + k_tensor = paddle.static.data(name="k", shape=[1], dtype="int32") + result1 = paddle.topk(input_tensor, k=2) + result2 = paddle.topk(input_tensor, k=2, axis=-1) + result3 = paddle.topk(input_tensor, k=k_tensor, axis=1) + self.assertEqual(result3[0].shape, (6, -1, 8)) + self.assertEqual(result3[1].shape, (6, -1, 8)) + #self.assertEqual(result3[0].shape, (2, -1, 4)) + #self.assertEqual(result3[1].shape, (2, -1, 4)) + result4 = paddle.topk(input_tensor, k=2, axis=1, largest=False) + result5 = paddle.topk(input_tensor, k=2, axis=-1, largest=False) + result6 = paddle.topk(large_input_tensor, k=1, axis=-1) + result7 = paddle.topk(input_tensor, k=2, axis=1, sorted=False) + exe = paddle.static.Executor(place) + input_data = np.random.rand(10, 20).astype("float64") + large_input_data = np.random.rand(2, 100).astype("float64") + paddle_result = exe.run( + feed={ + "x": self.input_data, + "large_x": self.large_input_data, + "k": np.array([2]).astype("int32") + }, + fetch_list=[ + result1[0], result1[1], result2[0], result2[1], result3[0], + result3[1], result4[0], result4[1], result5[0], result5[1], + result6[0], result6[1], result7[0], result7[1] + ]) + numpy_result = numpy_topk(self.input_data, k=2) + self.assertTrue(np.allclose(paddle_result[0], numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[1], numpy_result[1])) + + numpy_result = numpy_topk(self.input_data, k=2, axis=-1) + self.assertTrue(np.allclose(paddle_result[2], numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[3], numpy_result[1])) + + numpy_result = numpy_topk(self.input_data, k=2, axis=1) + self.assertTrue(np.allclose(paddle_result[4], numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[5], numpy_result[1])) + + numpy_result = numpy_topk( + self.input_data, k=2, axis=1, largest=False) + self.assertTrue(np.allclose(paddle_result[6], numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[7], numpy_result[1])) + + numpy_result = numpy_topk( + self.input_data, k=2, axis=-1, largest=False) + self.assertTrue(np.allclose(paddle_result[8], numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[9], numpy_result[1])) + + numpy_result = numpy_topk(self.large_input_data, k=1, axis=-1) + self.assertTrue(np.allclose(paddle_result[10], numpy_result[0])) + self.assertTrue(np.allclose(paddle_result[11], numpy_result[1])) + sort_paddle = numpy_topk(paddle_result[12], axis=1, k=2) + numpy_result = numpy_topk(self.input_data, k=2, axis=1) + self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0])) + + def test_cases(self): + places = [core.NPUPlace(0)] + #if core.is_compiled_with_cuda(): + # places.append(core.CUDAPlace(0)) + for place in places: + self.run_dygraph(place) + self.run_static(place) + + def test_errors(self): + paddle.disable_static() + x = paddle.to_tensor([1, 2, 3]) + with self.assertRaises(BaseException): + paddle.topk(x, k=-1) + + with self.assertRaises(BaseException): + paddle.topk(x, k=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/__pycache__/static_mode_white_list.cpython-37.pyc b/tools/__pycache__/static_mode_white_list.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d284ff69a05a89055ba0838284530b93ed9e867b GIT binary patch literal 20802 zcmeI4cbp_gb;nPK07)Puvd9@EKoWN+M9v5hK?I4kgr;|1&up_Z)1&U0-CL1!29a~l zIcFRI8*n5z%h@>R9KbgIzEIsgGb=vW{~Gl9N!zcgtE*nUTs7SD#1l_A!2iDfy36F> zfB1m|-_?iozporTaNzDdGzX7Aa6k^oF>+8YBge{R<#KX)IZlq36XXhVMLAKfBqzzq za%H)STve_nSC?zZHRTk!mRwt|BiEJd$@S$1aznY1+*ocRHo&E*zys+=ablv~NI z<#ahi&Xn88Avr8Z~ zT)DqIKprR$lJn%j@(_8bJWS4)3*_PQ2zjJDN**m6@)&unJWd`jnLI&q*_45Z6f%^N zY)L6klx-QyM5a>7OfHmK8kx&Nc4SwM%AQ;#Pm(9gQ{<`gGeL6nqSP z9DD+N5_}4L8hi$P7JLqT9()0O5qt^!1o%ntQ{bn;m%&%SSHaJKp9Nn7UkBd+KL>su z`~vtz@JryA!8gITz_-D7z^{N`1-}M<9sCCPP4HXb67bvLcfjw0-vhr7{s8eL6nqSP z9DD+N5_}4L8hi$P7JTmFgQKr(9N@3dpMT)gFAxuHTw`7|^DG|@%W07fw({w;7-!XN z<4Q&!RMSB|*XL{6b1^O^#dN+~Hbs_;?BCEVHnrgtdt+FR$Js?iT{YRb+%67xEn~iz z*ZR$Jnholqwp`iX8CR0kMN`;^uNI_gdoj!A`R2ILj@MeJZVnHftxeZmZ`wJWrPdRz zb(RyGY3Gydv0;((qE>=`t|=~D6w?75F$Qd0#k=*WJFB!=RgHs7?c2rCSw594Z&ru( zhE*!F(kb|sw1X_BL3f6>^xmXb-e|@pQ)eV!Z2K79W;w0%DwAS1-_mZ^@zq_p$m>F~ z@!?F$W}fTvX*}5)^TA+I=Yu_^xG3a&pXk%|`cLz6QX(kXaCYP{|2q4u?(Io_G|OjY2qnWQf%koV`n;=;&LAFVObr3u z8D+XFhh()yUyqhF1%c3%wj4ITl5#rIA+H@XRDon!wKA=k3SIWvugk<1b5-x^nl!3Q zTW=7@XU6$nQE%MXdroEue=*q3(7kfbd=V8k)df7(G97GHbv7!dMV-(2iz!{f z->&oNsMv_{>rpmAi-Ik|^?KA)wdqg{R@eESts~p5^@9xNYMwVu;XT{uax~R;R|nel zc_$w)?5!{w=*-kOHg4v|f|%N3ps$9Q@X)|aRy9-UqTvR_HP-x>+EC z4F-t!rC`HtVb0Ee8}s$mZkB@uV6Iq-nXbca&)2@B99Orb-$62d_Zs2He_YI)waQJwiHHIvKaKt@@6^9>$z^}#2CU1d@D9xG+l=`tE1UqT-w}Y9!g>9 zCRf*Y?m=>X)FzOz^^6#U%xYt!uF|a{5<4*cFOh~zj<-<;c~gY0^QXRYRjlpkoIRN4 zBh=h^j0f2{gjF@ocS{?jIV#*Rtz;Iu{`e;E@9L=i-CcF>^L**=A6b7N5=-yi=Xv#g z7Xba(QlpBN33WHW*5Y^L^+rqXAa;}SvSzCsniRNc8W0UMrRi*J2G$kn9gV9^n7WhK ziKT=ECXHc^W1rbgU5#t&K%YW0y4@wQbA zXD7wxVx*$xwiZqyWyBkuzje#6EL_Y?sNDhZ%1umB_i4%Wo9>y?1dAaS)a8>O^YywQq1$gR{OAn=LsE^LLagb0gy!0 zA!R=)PgPS4R;u_hjp55Ws}8tUS@2N`ASz1*_*1T+;ENNDW8PzZ00D7 z24iZh)AsI7)vYFKY<)!Qf`3+*bsZUP~`}7Z1;hJ!6$`!7zt$J!8g;&J$~B?H}Y5 zY}P}IvRxBwU7g*=-{v!NyuDcwS~e^+P}amcJY}xQKzLAZTudiLJ%VRGeJ83bj#+3T zpHI6D?vYV`O&CqD+yYtJ&s1Y+%S8a?oC_q~G#n(7D|H03-X0bX37wg$xo%hHplC`y zFGq`N(PWkbb;6cxqq5jRDs*669=4G_m9$_9by_Gjxv7_F(CNiC{D|G~Z zE7a`B(l85>7}DgTAw&=9_ZJrPavaj%5t@iAhXmw_8BC+tt?F&00Kcw2BvfSBSiBRg zo!6$)E>*(?+mduHIw%PqxO*Hfrex~Xl( ziwLe~WD#kR+LRnUsf@OGK5Q@5#dyK;DNI<6NoVYRZsao0+mfjn^Ac27gJ8Yx7j+oA zOg72xF0^Pwc@b)o+QN{bLMe3(Az3t9&`>ADOCDb`pwd#ub|fX%R0Ax6VI-S6Z?&t@ z4$Q&Upal`?WvP#wr5Kq$T8xNVlYY2eYz*Pe2rgqqJ6QmaT?wNL zX+twd=;dsz(Vni|qYI^lr5e`Ev0={#8YK-klh~Nf01j9{~KB&^`tX=ceg6NZvyM@#6nuF3@UAKr4Ev@Wz zTWm%~tTf-Za&ow>^Fef(D<_0pM6{Zc4@>00#M1IgpL98yX_-iCUsf#%J?@c(3}P@0 zM)l~!WO~}~wV0c#ULUO3lX#`o$7raL#&cjTRr$Jr$le& zZbJJ?W|N~GDyd01s)d#!sMHbVSrIDax?ayphuA~2qi!zC^)BYl+&Yw$A1ok-?}-wH zUMOxMw#S+vXfaH8$8E@VjoaYKmoPrmuGY8Ita8>&s>*Bk{?a(d-?hV9?Wlb(j!M){ zIBmjbiT>&CPNdh}PD#3UH!Hn%l1I3^Hlq4mchk<9F-4B4mTWW$rl!3`jcts~#CTr& z>>A4Qo*LIIt2?Z_O-e^n)vG9oD|r(vRXtO6s4WIP!TDP4#o>A?Ut@+Y7rxQ`Ns`$ zB9V$@rPzi!11CkZm5x<>>pAi$h}vtyy`$#ut#?eLoRT24e3_K)ncYBP*6^ShzpW+T z8{et)h|{Zj&eTbTmnW!?ky(hsH z#R_8{I@%xvS(-`^F?!YT=zySDUvs;lXCNq6MS|yfI(9LnQcqR&Om|6a^9bW|3j*k*}TE#Rv3=hKwLqZ5BLL zAe_m`s5@8nDIbR#`$L=#adxF@TZakda1=Y7q&Yr@R+->Buc!cKF+CUCMCfPh(ENlt z6v>_uA2(U3hH;$un}ltUa!T0m$8uV047KT^7PqeLGaF*LD_LbmacDOjkRI+bEUYDe z)0~ydFAcLe(Zr;Kek;j`eoaPFB|FDX^Zm ztEL{dm}b2*K~lsD{+gacAu}4}hgn;HXEJ5vjngk>M5e}m!WA1jReCl&dRQOKM)A5G zfkrvqII=u$)8ra)Q7QdlQ}w7yN1qf_B-*|4)pTiBI}2OoN`0{xwnccvfYvb%3EZ}6 zkU-c~jsh?Hjy%Dj17R}p{L;d~m+Vsn1TnQWrb%*3ZK!Upt zWa`TuAke|hp5>TPcnrz0198HoX+a*TS%v8uZ@5v@@_lK!vDWdu=HLoQ3u!zZJ!bL8 zF04qMuuz?QSvny|Jt_)BMSY1|h;$$qT=e5PbVpK=D-{K2h>)g0aenNkIW>Ajs`Whg z?9k^GK`>Q^W{|2B5RZ~>QSj=HQrhm`;>G#czB-+J=z6?7XOl?V+n$VQdu+&_*;pmG zH=CuS8J0BJET@*=MTpk5X&^E#c1&8Z^I?m-_L*J-Fn=0xL$7Ip51j4uy(oK*RVzJ3 z=w6DHqD|i;QR+HS9U;(5_9yLlucX$oggZ5hB9^bUhuDALY$aGA<9v?VgdOfNWtUMU zVwkLC;arA7^5zo3-(6f2B=??v+LlPYLjYI%)g$0|?gZ91D99@@vIoO;G1MxiSs>zU zzx)vkD@HK3-b-sJMhi_gaFD^Qx}%tWy05MaI|sxRrj$8xjyOoGyKu$Ti`5w+U2*a) zWK)qLJU)Uwr*@*n0uGS$0LKNeMv5Xk3c{p8?k!#nR^g!s1&Xa!Bi0EiLb|PD%wFg; zqBng#YXJXTcu-xT`;1i^ma|kDVC#gq^w~Mb0m4;%+~&6!W$_9$mcXriLy=kYsXZx}9T<=J72@ zn_{E^EQEB#=Cum^pC;5U10B0*G*cbJ!^TZ5M)H+ZeU5X;SYnbrXEBmT`0dpPX5gl_ z*{VxaQ4bZv6a`D&zhRm6FoNzb>EZVDCSbaKc$#e$yE)NILQB;FFBo*&;<}l-4l6uS z8x970LMYs8^^$!$_j80skC&NMaDtU_9;0XB&|Wr+_kHGXure_26Radvh`b35ldt)Seg!tLxn)n%YMf^PIZf2(Zu`H`8 zX9cc}OPUp~A$BWj&0G^qla8-*`mx@_DJMOl*kq(s7bFsFi3+5N3_ zf5nYJ{|rrQfBqcyi_;zxBj<{qw;NaDkxP`;>Ym5^g~r&isb(PUN>IDcR=ZaV*lj(x z492*c()8A)T1l&Am6+>li)TG0^}5%M`UoM(l&f%2PUy8&1CI7_U{(Uj%#W^j@-lXotD@6x z3))_jb!WOhGIxG`oAk7ZgQeMG<2K8=bBgLYsYzX}bQF=WzHx^i?G2k$??hN%byHx~ z<>pfB8Ga2xYYVXj?YUSnrOaam=2iu|wpv#NJ$!z;xIZjwDr3#OXk(Am!l_y`q(CTs z0ib80$y=w+5pl1REuvoiXbJ?FjwjabG;t2w?KO2xN!@bkRP54#9~1k{XJYWK4B)w; z&oou){HVmMft0T!4YixcF(WLZnLr=M($A>EP(tgfe4JHVt;u%AWUH(1dt^a(4!k?6 zTVtrrHrI>0_j;%M=D}8eX3`DdDU0>lEORdS`4L6e!A+lc(7GeLYlEL^;%D^!9Go5y^Yd>#Fo^i6)z9{*!Gi!Tl!vQTCvIE8y%l$vz3cs5XzNzqu9|m5cT6!C zyy)$dSn=JZwv}Mmc@gjJGo0SVUU{V*PTN*K)efg^-Ef`AL+*B{>)?T^E`F)Q%9p#* zDdu%AJFj}T$+{-r(-m#$z3TG6+P?DH^+EI;r2EMFAlkk%NUKEMd#$#tbahL(68dzn z`3;qoW&+)=hfjDlsOs819$szVXNVAXRq~aMZk0uqLEBf>;Cs@0zx%q&Pc+$0qpr;b zrQf#l-SrMDANGl@M)=(q+>huvVn=hi3dQWmKG7hre8zokv@LPYGe?zxm8r)oc7<6v z)_rKTh4n8TG20=TaT&AIGh{x&sD7_}ZK-W#^1j}9ubYO?9Q#VZN_C)X@Xf5dd+A-1 z%W(4g9o-FKqOW{Eq7IH7yB8wihPKk#wV^W_E9+csUU}PD+r8a=&?EkMHnZFNm0!05 zr<>F5cMD@_W&fae-8~Hna%bwJ5Y_cNqd#o)>t-1-S8uD``NX(hRg1T}F`X6BC*4cw znmwO$mx7DNF*wip;Q+}2(yC<`FGqHU7IVx9R!>d+H4N9^tLlIC?26iM7&p6N%9&DK zrfV$vbsPu*x_iOxy3NYFZFuIKj`+QdW>Ao|dSEvh*WPdo^;l^=mS=tC9COU+4V%a4DKfk;^1GyzZYm5 z5Ry{ygUyDTP8ZYoXdx28=#HWBYsd)6o!(wF=BRBq(sq{zc+&U8U1-fbokjFhQM}Yu z(yuvvO8wH7R5`N+tH|q9w7mwDEAd@alJ|<5dd3RyYlL zh)l^>Yd3Ue6-FzBOVbHoMUTBMh`NDs5*~#+2&)BS4_jRv-1zEiT!eLbwcYv=li<<+ zJ*CxG|9APbz7aN_>O+*pQ!;;+TSL1Z5_ds5>ynk60od7mvOHc6iCT(T5jNCrko4Y) zPHy$$8dY|bSvc2v#-5s#(!FP@QYpp#RD06Wd%G8yFTEW<0JWa}c6f{E<FLQHphQ|B#OI`lR^vzLDI2?Q^l(783gtxBu(~pRONJ8Rw}O zssE3zF#)4^J(o*i{KA20MyU*g)Bez>nw@vrF_)Zj{ss5A;M{v<4|?#u&&e)0=ll!K zJ@3BR1J6DGg7+V&FU$X*mi_B{htAZiMRrf`%!jap Date: Wed, 4 Aug 2021 15:26:33 +0800 Subject: [PATCH 2/6] deleted unnecessary cache file static_mode_white_list.cpython-37.pyc --- .../static_mode_white_list.cpython-37.pyc | Bin 20802 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tools/__pycache__/static_mode_white_list.cpython-37.pyc diff --git a/tools/__pycache__/static_mode_white_list.cpython-37.pyc b/tools/__pycache__/static_mode_white_list.cpython-37.pyc deleted file mode 100644 index d284ff69a05a89055ba0838284530b93ed9e867b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20802 zcmeI4cbp_gb;nPK07)Puvd9@EKoWN+M9v5hK?I4kgr;|1&up_Z)1&U0-CL1!29a~l zIcFRI8*n5z%h@>R9KbgIzEIsgGb=vW{~Gl9N!zcgtE*nUTs7SD#1l_A!2iDfy36F> zfB1m|-_?iozporTaNzDdGzX7Aa6k^oF>+8YBge{R<#KX)IZlq36XXhVMLAKfBqzzq za%H)STve_nSC?zZHRTk!mRwt|BiEJd$@S$1aznY1+*ocRHo&E*zys+=ablv~NI z<#ahi&Xn88Avr8Z~ zT)DqIKprR$lJn%j@(_8bJWS4)3*_PQ2zjJDN**m6@)&unJWd`jnLI&q*_45Z6f%^N zY)L6klx-QyM5a>7OfHmK8kx&Nc4SwM%AQ;#Pm(9gQ{<`gGeL6nqSP z9DD+N5_}4L8hi$P7JLqT9()0O5qt^!1o%ntQ{bn;m%&%SSHaJKp9Nn7UkBd+KL>su z`~vtz@JryA!8gITz_-D7z^{N`1-}M<9sCCPP4HXb67bvLcfjw0-vhr7{s8eL6nqSP z9DD+N5_}4L8hi$P7JTmFgQKr(9N@3dpMT)gFAxuHTw`7|^DG|@%W07fw({w;7-!XN z<4Q&!RMSB|*XL{6b1^O^#dN+~Hbs_;?BCEVHnrgtdt+FR$Js?iT{YRb+%67xEn~iz z*ZR$Jnholqwp`iX8CR0kMN`;^uNI_gdoj!A`R2ILj@MeJZVnHftxeZmZ`wJWrPdRz zb(RyGY3Gydv0;((qE>=`t|=~D6w?75F$Qd0#k=*WJFB!=RgHs7?c2rCSw594Z&ru( zhE*!F(kb|sw1X_BL3f6>^xmXb-e|@pQ)eV!Z2K79W;w0%DwAS1-_mZ^@zq_p$m>F~ z@!?F$W}fTvX*}5)^TA+I=Yu_^xG3a&pXk%|`cLz6QX(kXaCYP{|2q4u?(Io_G|OjY2qnWQf%koV`n;=;&LAFVObr3u z8D+XFhh()yUyqhF1%c3%wj4ITl5#rIA+H@XRDon!wKA=k3SIWvugk<1b5-x^nl!3Q zTW=7@XU6$nQE%MXdroEue=*q3(7kfbd=V8k)df7(G97GHbv7!dMV-(2iz!{f z->&oNsMv_{>rpmAi-Ik|^?KA)wdqg{R@eESts~p5^@9xNYMwVu;XT{uax~R;R|nel zc_$w)?5!{w=*-kOHg4v|f|%N3ps$9Q@X)|aRy9-UqTvR_HP-x>+EC z4F-t!rC`HtVb0Ee8}s$mZkB@uV6Iq-nXbca&)2@B99Orb-$62d_Zs2He_YI)waQJwiHHIvKaKt@@6^9>$z^}#2CU1d@D9xG+l=`tE1UqT-w}Y9!g>9 zCRf*Y?m=>X)FzOz^^6#U%xYt!uF|a{5<4*cFOh~zj<-<;c~gY0^QXRYRjlpkoIRN4 zBh=h^j0f2{gjF@ocS{?jIV#*Rtz;Iu{`e;E@9L=i-CcF>^L**=A6b7N5=-yi=Xv#g z7Xba(QlpBN33WHW*5Y^L^+rqXAa;}SvSzCsniRNc8W0UMrRi*J2G$kn9gV9^n7WhK ziKT=ECXHc^W1rbgU5#t&K%YW0y4@wQbA zXD7wxVx*$xwiZqyWyBkuzje#6EL_Y?sNDhZ%1umB_i4%Wo9>y?1dAaS)a8>O^YywQq1$gR{OAn=LsE^LLagb0gy!0 zA!R=)PgPS4R;u_hjp55Ws}8tUS@2N`ASz1*_*1T+;ENNDW8PzZ00D7 z24iZh)AsI7)vYFKY<)!Qf`3+*bsZUP~`}7Z1;hJ!6$`!7zt$J!8g;&J$~B?H}Y5 zY}P}IvRxBwU7g*=-{v!NyuDcwS~e^+P}amcJY}xQKzLAZTudiLJ%VRGeJ83bj#+3T zpHI6D?vYV`O&CqD+yYtJ&s1Y+%S8a?oC_q~G#n(7D|H03-X0bX37wg$xo%hHplC`y zFGq`N(PWkbb;6cxqq5jRDs*669=4G_m9$_9by_Gjxv7_F(CNiC{D|G~Z zE7a`B(l85>7}DgTAw&=9_ZJrPavaj%5t@iAhXmw_8BC+tt?F&00Kcw2BvfSBSiBRg zo!6$)E>*(?+mduHIw%PqxO*Hfrex~Xl( ziwLe~WD#kR+LRnUsf@OGK5Q@5#dyK;DNI<6NoVYRZsao0+mfjn^Ac27gJ8Yx7j+oA zOg72xF0^Pwc@b)o+QN{bLMe3(Az3t9&`>ADOCDb`pwd#ub|fX%R0Ax6VI-S6Z?&t@ z4$Q&Upal`?WvP#wr5Kq$T8xNVlYY2eYz*Pe2rgqqJ6QmaT?wNL zX+twd=;dsz(Vni|qYI^lr5e`Ev0={#8YK-klh~Nf01j9{~KB&^`tX=ceg6NZvyM@#6nuF3@UAKr4Ev@Wz zTWm%~tTf-Za&ow>^Fef(D<_0pM6{Zc4@>00#M1IgpL98yX_-iCUsf#%J?@c(3}P@0 zM)l~!WO~}~wV0c#ULUO3lX#`o$7raL#&cjTRr$Jr$le& zZbJJ?W|N~GDyd01s)d#!sMHbVSrIDax?ayphuA~2qi!zC^)BYl+&Yw$A1ok-?}-wH zUMOxMw#S+vXfaH8$8E@VjoaYKmoPrmuGY8Ita8>&s>*Bk{?a(d-?hV9?Wlb(j!M){ zIBmjbiT>&CPNdh}PD#3UH!Hn%l1I3^Hlq4mchk<9F-4B4mTWW$rl!3`jcts~#CTr& z>>A4Qo*LIIt2?Z_O-e^n)vG9oD|r(vRXtO6s4WIP!TDP4#o>A?Ut@+Y7rxQ`Ns`$ zB9V$@rPzi!11CkZm5x<>>pAi$h}vtyy`$#ut#?eLoRT24e3_K)ncYBP*6^ShzpW+T z8{et)h|{Zj&eTbTmnW!?ky(hsH z#R_8{I@%xvS(-`^F?!YT=zySDUvs;lXCNq6MS|yfI(9LnQcqR&Om|6a^9bW|3j*k*}TE#Rv3=hKwLqZ5BLL zAe_m`s5@8nDIbR#`$L=#adxF@TZakda1=Y7q&Yr@R+->Buc!cKF+CUCMCfPh(ENlt z6v>_uA2(U3hH;$un}ltUa!T0m$8uV047KT^7PqeLGaF*LD_LbmacDOjkRI+bEUYDe z)0~ydFAcLe(Zr;Kek;j`eoaPFB|FDX^Zm ztEL{dm}b2*K~lsD{+gacAu}4}hgn;HXEJ5vjngk>M5e}m!WA1jReCl&dRQOKM)A5G zfkrvqII=u$)8ra)Q7QdlQ}w7yN1qf_B-*|4)pTiBI}2OoN`0{xwnccvfYvb%3EZ}6 zkU-c~jsh?Hjy%Dj17R}p{L;d~m+Vsn1TnQWrb%*3ZK!Upt zWa`TuAke|hp5>TPcnrz0198HoX+a*TS%v8uZ@5v@@_lK!vDWdu=HLoQ3u!zZJ!bL8 zF04qMuuz?QSvny|Jt_)BMSY1|h;$$qT=e5PbVpK=D-{K2h>)g0aenNkIW>Ajs`Whg z?9k^GK`>Q^W{|2B5RZ~>QSj=HQrhm`;>G#czB-+J=z6?7XOl?V+n$VQdu+&_*;pmG zH=CuS8J0BJET@*=MTpk5X&^E#c1&8Z^I?m-_L*J-Fn=0xL$7Ip51j4uy(oK*RVzJ3 z=w6DHqD|i;QR+HS9U;(5_9yLlucX$oggZ5hB9^bUhuDALY$aGA<9v?VgdOfNWtUMU zVwkLC;arA7^5zo3-(6f2B=??v+LlPYLjYI%)g$0|?gZ91D99@@vIoO;G1MxiSs>zU zzx)vkD@HK3-b-sJMhi_gaFD^Qx}%tWy05MaI|sxRrj$8xjyOoGyKu$Ti`5w+U2*a) zWK)qLJU)Uwr*@*n0uGS$0LKNeMv5Xk3c{p8?k!#nR^g!s1&Xa!Bi0EiLb|PD%wFg; zqBng#YXJXTcu-xT`;1i^ma|kDVC#gq^w~Mb0m4;%+~&6!W$_9$mcXriLy=kYsXZx}9T<=J72@ zn_{E^EQEB#=Cum^pC;5U10B0*G*cbJ!^TZ5M)H+ZeU5X;SYnbrXEBmT`0dpPX5gl_ z*{VxaQ4bZv6a`D&zhRm6FoNzb>EZVDCSbaKc$#e$yE)NILQB;FFBo*&;<}l-4l6uS z8x970LMYs8^^$!$_j80skC&NMaDtU_9;0XB&|Wr+_kHGXure_26Radvh`b35ldt)Seg!tLxn)n%YMf^PIZf2(Zu`H`8 zX9cc}OPUp~A$BWj&0G^qla8-*`mx@_DJMOl*kq(s7bFsFi3+5N3_ zf5nYJ{|rrQfBqcyi_;zxBj<{qw;NaDkxP`;>Ym5^g~r&isb(PUN>IDcR=ZaV*lj(x z492*c()8A)T1l&Am6+>li)TG0^}5%M`UoM(l&f%2PUy8&1CI7_U{(Uj%#W^j@-lXotD@6x z3))_jb!WOhGIxG`oAk7ZgQeMG<2K8=bBgLYsYzX}bQF=WzHx^i?G2k$??hN%byHx~ z<>pfB8Ga2xYYVXj?YUSnrOaam=2iu|wpv#NJ$!z;xIZjwDr3#OXk(Am!l_y`q(CTs z0ib80$y=w+5pl1REuvoiXbJ?FjwjabG;t2w?KO2xN!@bkRP54#9~1k{XJYWK4B)w; z&oou){HVmMft0T!4YixcF(WLZnLr=M($A>EP(tgfe4JHVt;u%AWUH(1dt^a(4!k?6 zTVtrrHrI>0_j;%M=D}8eX3`DdDU0>lEORdS`4L6e!A+lc(7GeLYlEL^;%D^!9Go5y^Yd>#Fo^i6)z9{*!Gi!Tl!vQTCvIE8y%l$vz3cs5XzNzqu9|m5cT6!C zyy)$dSn=JZwv}Mmc@gjJGo0SVUU{V*PTN*K)efg^-Ef`AL+*B{>)?T^E`F)Q%9p#* zDdu%AJFj}T$+{-r(-m#$z3TG6+P?DH^+EI;r2EMFAlkk%NUKEMd#$#tbahL(68dzn z`3;qoW&+)=hfjDlsOs819$szVXNVAXRq~aMZk0uqLEBf>;Cs@0zx%q&Pc+$0qpr;b zrQf#l-SrMDANGl@M)=(q+>huvVn=hi3dQWmKG7hre8zokv@LPYGe?zxm8r)oc7<6v z)_rKTh4n8TG20=TaT&AIGh{x&sD7_}ZK-W#^1j}9ubYO?9Q#VZN_C)X@Xf5dd+A-1 z%W(4g9o-FKqOW{Eq7IH7yB8wihPKk#wV^W_E9+csUU}PD+r8a=&?EkMHnZFNm0!05 zr<>F5cMD@_W&fae-8~Hna%bwJ5Y_cNqd#o)>t-1-S8uD``NX(hRg1T}F`X6BC*4cw znmwO$mx7DNF*wip;Q+}2(yC<`FGqHU7IVx9R!>d+H4N9^tLlIC?26iM7&p6N%9&DK zrfV$vbsPu*x_iOxy3NYFZFuIKj`+QdW>Ao|dSEvh*WPdo^;l^=mS=tC9COU+4V%a4DKfk;^1GyzZYm5 z5Ry{ygUyDTP8ZYoXdx28=#HWBYsd)6o!(wF=BRBq(sq{zc+&U8U1-fbokjFhQM}Yu z(yuvvO8wH7R5`N+tH|q9w7mwDEAd@alJ|<5dd3RyYlL zh)l^>Yd3Ue6-FzBOVbHoMUTBMh`NDs5*~#+2&)BS4_jRv-1zEiT!eLbwcYv=li<<+ zJ*CxG|9APbz7aN_>O+*pQ!;;+TSL1Z5_ds5>ynk60od7mvOHc6iCT(T5jNCrko4Y) zPHy$$8dY|bSvc2v#-5s#(!FP@QYpp#RD06Wd%G8yFTEW<0JWa}c6f{E<FLQHphQ|B#OI`lR^vzLDI2?Q^l(783gtxBu(~pRONJ8Rw}O zssE3zF#)4^J(o*i{KA20MyU*g)Bez>nw@vrF_)Zj{ss5A;M{v<4|?#u&&e)0=ll!K zJ@3BR1J6DGg7+V&FU$X*mi_B{htAZiMRrf`%!jap Date: Fri, 6 Aug 2021 16:38:48 +0800 Subject: [PATCH 3/6] A draft for error checking --- paddle/fluid/operators/top_k_v2_op_npu.cc | 208 ++++-------------- .../unittests/npu/test_top_k_v2_op_npu.py | 153 +++++++------ 2 files changed, 125 insertions(+), 236 deletions(-) mode change 100644 => 100755 python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py diff --git a/paddle/fluid/operators/top_k_v2_op_npu.cc b/paddle/fluid/operators/top_k_v2_op_npu.cc index 62bd0c015b8f12..3590a8854ebe70 100644 --- a/paddle/fluid/operators/top_k_v2_op_npu.cc +++ b/paddle/fluid/operators/top_k_v2_op_npu.cc @@ -23,17 +23,15 @@ namespace operators { template class TopkV2NPUKernel : public framework::OpKernel { public: - // Use CANN TopKV2 operator to implement paddle TopKV2Op + // Use Ascend TopKV2 operator to implement paddle TopKV2Op void Compute(const framework::ExecutionContext& context) const override { - using Tensor = framework::Tensor; - // Read message from context auto* input = context.Input("X"); auto* k_tensor = context.Input("K"); - auto* output = context.Output("Out"); - auto* indices = context.Output("Indices"); + auto* out = context.Output("Out"); + auto* indices = context.Output("Indices"); // type:INT64 - int k = static_cast(context.Attr("k")); + int32_t k = static_cast(context.Attr("k")); int axis = static_cast(context.Attr("axis")); const bool sorted = static_cast(context.Attr("sorted")); const bool largest = static_cast(context.Attr("largest")); @@ -44,199 +42,67 @@ class TopkV2NPUKernel : public framework::OpKernel { } if (k_tensor != nullptr) { - // seems complicated, but I really don't know how to assign a NPU value to - // a CPU variable by an elegant way + // seems complicated, but I really don't know + // how to assign a NPU value to a CPU variable by an elegant way std::vector v_tmp(1); TensorToVector( *k_tensor, context.template device_context(), &v_tmp); - k = v_tmp[0]; + k = static_cast(v_tmp[0]); } - // Allocate space for output tensors on NPU + // Allocate space for output tensors of Paddle topKV2 operator framework::DDim output_dims = input->dims(); output_dims[axis] = k; - output->Resize(output_dims); + out->Resize(output_dims); indices->Resize(output_dims); - output->mutable_data(context.GetPlace()); + out->mutable_data(context.GetPlace()); indices->mutable_data(context.GetPlace()); - // Construct the input tensor x of CANN TopKV2 operator - // as CANN TopKV2 operator does not support setting 'axis'(defaults to the - // last dimension) and 'largest'(defaults to true) parameter yet, - // 1. when the 'axis' is not the last dimension, we use CANN Transpose - // operator to permutes the dimension 'axis' to the last dimension - // 2. when the 'largest' is false, we use CANN Neg operator to negate the - // input tensor element-wise, which convert descending to ascending order - // once the functino of the parameter 'dim' and 'largest' is further - // improved, these additional actions can be removed - Tensor* input_transpose = nullptr; - Tensor* input_neg = nullptr; - const Tensor* x_cann = - input; // the input tensor "x" of CANN TopKV2 operator - std::vector perm; - const int last_axis = static_cast( - input->dims().size() - - 1); // attention: there may be bugs when the input tensor is empty - - if (axis != - last_axis) { // in this case, the 'input' tensor should be transposed - // compute perm vector - perm.resize(last_axis + 1); - for (int i = 0; i <= last_axis; ++i) { - perm[i] = i; - } - std::swap(perm[axis], perm[last_axis]); - - // construct 'input_transpose' - input_transpose = new Tensor(input->type()); - - framework::DDim input_transpose_dims = input->dims(); - std::swap(input_transpose_dims[axis], input_transpose_dims[last_axis]); - - input_transpose->Resize(input_transpose_dims); - input_transpose->mutable_data(context.GetPlace()); - - // run CANN Transpose operator - NpuOpRunner npu_op_runner_transpose; - auto npu_stream_transpose = - context.template device_context() - .stream(); - npu_op_runner_transpose.SetType("Transpose") - .AddInput(*input) - .AddInput(std::move(perm)) - .AddOutput(*input_transpose) - .Run(npu_stream_transpose); - - x_cann = input_transpose; - } - - if (!largest) { // in this case, the 'input' tensor should be negated - // element-wise - // construct 'input_neg' - auto* input_tensor = - (input_transpose == nullptr ? input : input_transpose); - input_neg = new Tensor(input_tensor->type()); - input_neg->Resize(input_tensor->dims()); - input_neg->mutable_data(context.GetPlace()); - - // run CANN Neg operator - const auto& npu_op_runner_neg = - NpuOpRunner("Neg", {*input_tensor}, {*input_neg}); - auto npu_stream_neg = - context.template device_context() - .stream(); - npu_op_runner_neg.Run(npu_stream_neg); - - x_cann = input_neg; - } - - // Construct the input and output tensors of CANN TopKV2 operator (except x) - // input k: a 0D tensor of type int32, Number of top elements to look for - // along the last dimension (along each row for matrices) - Tensor* k_cann = new Tensor(framework::proto::VarType::INT32); - k_cann->mutable_data({1}, context.GetPlace()); - FillNpuTensorWithConstant(k_cann, static_cast(k)); - - // output values: a tensor specifying the sorted data, which has the same - // type as 'x' - Tensor* values_cann = nullptr; - if (axis == last_axis && largest) { // in this case, the CANN TopKV2 result - // will directly output to the 'output' - // tensor, which save an operation of - // tensor copy - values_cann = output; - } else { - values_cann = new Tensor(x_cann->type()); - framework::DDim values_cann_dims = x_cann->dims(); - values_cann_dims[last_axis] = k; - values_cann->Resize(values_cann_dims); - values_cann->mutable_data(context.GetPlace()); - } - - // output indices: a tensor of type int32 specifying the indices of sorted - // data - Tensor* indices_cann = new Tensor(framework::proto::VarType::INT32); - indices_cann->Resize(values_cann->dims()); - indices_cann->mutable_data(context.GetPlace()); - - // Run CANN TopKV2 operator + // Allocate space for input k and output indices of Ascend topkV2 operator + framework::Tensor* k_Ascend = new Tensor(framework::proto::VarType::INT32); + k_Ascend->mutable_data({1}, context.GetPlace()); + FillNpuTensorWithConstant(k_Ascend, static_cast(k)); + + framework::Tensor* indices_int32 = + new Tensor(framework::proto::VarType::INT32); + indices_int32->Resize(output_dims); + indices_int32->mutable_data(context.GetPlace()); + + VLOG(4) << "input: " << *input; + VLOG(4) << "k: " << *k_Ascend; + VLOG(4) << "sorted: " << sorted; + VLOG(4) << "dim: " << axis; + VLOG(4) << "largest: " << largest; + VLOG(4) << "output: " << *out; + VLOG(4) << "indices_int32: " << *indices_int32; + + // Run CANN TopKV2 operator, error occurred when the dtype is 'float16' const auto& npu_op_runner_topkv2 = - NpuOpRunner("TopKV2", {*x_cann, *k_cann}, {*values_cann, *indices_cann}, - {{"sorted", sorted}}); + NpuOpRunner("TopKV2", {*input, *k_Ascend}, {*out, *indices_int32}, + {{"sorted", sorted}, {"dim", axis}, {"largest", largest}}); auto npu_stream_topkv2 = context.template device_context() .stream(); npu_op_runner_topkv2.Run(npu_stream_topkv2); - // Convert the computing result into paddle's output tensors - // 'values_cann' to 'output' and 'indices_cann' to 'indices_transpose' - Tensor* values_cann_neg = nullptr; - Tensor* indices_cann_transpose = nullptr; - - if (!largest) { - // run CANN Neg operator - if (axis == last_axis) { - values_cann_neg = output; // in this case, the CANN Neg result will - // directly output to the 'output' tensor - } else { - values_cann_neg = input_neg; // as the 'input_neg' tensor is no longer - // needed, we reuse its resources to - // 'values_cann_neg' tensor - values_cann_neg->Resize(values_cann->dims()); - } - const auto& npu_op_runner_neg = - NpuOpRunner("Neg", {*values_cann}, {*values_cann_neg}); - auto npu_stream_neg = - context.template device_context() - .stream(); - npu_op_runner_neg.Run(npu_stream_neg); - } - - if (axis != last_axis) { - // run CANN Transpose operator - // transpose values - Tensor* input_tensor = (largest ? values_cann : values_cann_neg); - NpuOpRunner npu_op_runner_transpose_values; - auto npu_stream_transpose_values = - context.template device_context() - .stream(); - npu_op_runner_transpose_values.SetType("Transpose") - .AddInput(*input_tensor) - .AddInput(std::move(perm)) - .AddOutput(*output) - .Run(npu_stream_transpose_values); - - // transpose indices - indices_cann_transpose = new Tensor(indices_cann->type()); - indices_cann_transpose->Resize(indices->dims()); - indices_cann_transpose->mutable_data(context.GetPlace()); - - NpuOpRunner npu_op_runner_transpose_indices; - auto npu_stream_transpose_indices = - context.template device_context() - .stream(); - npu_op_runner_transpose_indices.SetType("Transpose") - .AddInput(*indices_cann) - .AddInput(std::move(perm)) - .AddOutput(*indices_cann_transpose) - .Run(npu_stream_transpose_indices); - } else { - indices_cann_transpose = indices_cann; - } + VLOG(4) << "output: " << *out; + VLOG(4) << "indices_int32: " << *indices_int32; - // 'indices_cann_transpose' to 'indices', from INT32 to INT64 + // Cast 'indices_int32' to 'indices', from INT32 to INT64 auto dst_dtype = ConvertToNpuDtype(indices->type()); const auto& npu_op_runner_cast = - NpuOpRunner("Cast", {*indices_cann_transpose}, {*indices}, + NpuOpRunner("Cast", {*indices_int32}, {*indices}, {{"dst_type", static_cast(dst_dtype)}}); auto npu_stream_cast = context.template device_context() .stream(); npu_op_runner_cast.Run(npu_stream_cast); + + VLOG(4) << "indices: " << *indices; } }; } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py old mode 100644 new mode 100755 index 11d752cc6fbaa5..f43fd7bc26a7f4 --- a/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py @@ -40,106 +40,126 @@ def numpy_topk(x, k=1, axis=-1, largest=True): class TestTopkV2NPUOp(OpTest): - def init_args(self): - self.k = 3 - self.axis = 1 - self.largest = True - def setUp(self): - self.__class__.use_npu = True - self.place = paddle.NPUPlace(0) + paddle.enable_static() self.op_type = "top_k_v2" - self.dtype = np.float64 - self.input_data = np.random.rand(10, 20) - self.init_args() - self.inputs = {'X': self.input_data} - self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} + self.set_npu() + self.set_dtype() + self.set_input_data() + self.set_attrs() output, indices = numpy_topk( self.input_data, axis=self.axis, k=self.k, largest=self.largest) + + self.inputs = {'X': self.input_data} + self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} self.outputs = {'Out': output, 'Indices': indices} + def set_dtype(self): + self.dtype = np.int32 + + def set_attrs(self): + self.k = 3 + self.axis = 1 + self.largest = True + + def set_input_data(self): + self.input_data = np.random.choice( + 10000, size=(10, 20), replace=False).astype(self.dtype) + def test_check_output(self): - paddle.enable_static() - self.check_output() + self.__class__.no_need_check_grad = True + self.check_output_with_place(self.place) - def test_check_grad(self): - paddle.enable_static() - self.check_grad(set(['X']), 'Out') + def set_npu(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) -class TestTopkOp1(TestTopkV2NPUOp): - def init_args(self): +class TestTopkV2OP1Int32(TestTopkV2NPUOp): + def set_attrs(self): self.k = 3 self.axis = 0 self.largest = False -class TestTopkOp2(TestTopkV2NPUOp): - def init_args(self): +''' +class TestTopkV2OP2Int32(TestTopkV2NPUOp): + def set_attrs(self): self.k = 4 self.axis = 0 self.largest = False - -class TestTopkOp3(OpTest): - def init_args(self): +class TestTopkV2OP3Int32(TestTopkV2NPUOp): + def set_attrs(self): self.k = 6 self.axis = 1 self.largest = True - def setUp(self): - self.__class__.use_npu = True - self.place = paddle.NPUPlace(0) - self.op_type = "top_k_v2" - self.dtype = np.float64 - self.input_data = np.random.rand(16, 100) - self.init_args() - self.inputs = {'X': self.input_data} - self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} - output, indices = numpy_topk( - self.input_data, axis=self.axis, k=self.k, largest=self.largest) - self.outputs = {'Out': output, 'Indices': indices} - - -class TestTopkOp4(TestTopkV2NPUOp): - def init_args(self): +class TestTopkV2OP4Int32(TestTopkV2NPUOp): + def set_attrs(self): self.k = 3 self.axis = 1 self.largest = True - def setUp(self): - self.op_type = "top_k_v2" + +class TestTopkV2Op1Int64(TestTopkV2OP1Int32): + def set_dtype(self): + self.dtype = np.int64 + +class TestTopkV2Op2Int64(TestTopkV2OP2Int32): + def set_dtype(self): + self.dtype = np.int64 + +class TestTopkV2Op3Int64(TestTopkV2OP3Int32): + def set_dtype(self): + self.dtype = np.int64 + +class TestTopkV2Op4Int64(TestTopkV2OP4Int32): + def set_dtype(self): + self.dtype = np.int64 +''' + + +# Error occurred in this test case +class TestTopkOp1Float32(TestTopkV2OP1Int32): + def set_dtype(self): + self.dtype = np.float32 + + def set_input_data(self): + self.input_data = np.random.rand(10, 20).astype(self.dtype) + + +''' +class TestTopkOp1Float64(TestTopkV2OP1Int32): + def set_dtype(self): self.dtype = np.float64 - self.input_data = np.random.rand(10, 10, 5) - self.init_args() - self.inputs = {'X': self.input_data} - self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} - output, indices = numpy_topk( - self.input_data, axis=self.axis, k=self.k, largest=self.largest) - self.outputs = {'Out': output, 'Indices': indices} + def set_input_data(self): + self.input_data = np.random.rand(10, 20).astype(self.dtype) +class TestTopkOp2Float64(TestTopkV2OP2Int32): + def set_dtype(self): + self.dtype = np.float64 + def set_input_data(self): + self.input_data = np.random.rand(10, 20).astype(self.dtype) -class TestTopkOp5(TestTopkV2NPUOp): - def init_args(self): - self.k = 3 - self.axis = 1 - self.largest = True +class TestTopkOp3Float64(TestTopkV2OP3Int32): + def set_dtype(self): + self.dtype = np.float64 + def set_input_data(self): + self.input_data = np.random.rand(10, 20).astype(self.dtype) - def setUp(self): - self.op_type = "top_k_v2" +class TestTopkOp4Float64(TestTopkV2OP4Int32): + def set_dtype(self): self.dtype = np.float64 - self.input_data = np.random.rand(10, 10, 5) - self.init_args() - self.inputs = {'X': self.input_data} - self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest} - output, indices = numpy_topk( - self.input_data, axis=self.axis, k=self.k, largest=self.largest) - self.outputs = {'Out': output, 'Indices': indices} + def set_input_data(self): + self.input_data = np.random.rand(10, 20).astype(self.dtype) class TestTopKAPI(unittest.TestCase): def setUp(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) np.random.seed(123) self.input_data = np.random.rand(6, 7, 8) #self.input_data = np.random.rand(2, 3, 4) @@ -193,6 +213,7 @@ def run_dygraph(self, place): numpy_result = numpy_topk(self.input_data, k=2, axis=1) self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0])) + def run_static(self, place): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program(), @@ -257,7 +278,7 @@ def run_static(self, place): numpy_result = numpy_topk(self.input_data, k=2, axis=1) self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0])) - def test_cases(self): + def test_cases(self): places = [core.NPUPlace(0)] #if core.is_compiled_with_cuda(): # places.append(core.CUDAPlace(0)) @@ -266,6 +287,8 @@ def test_cases(self): self.run_static(place) def test_errors(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) paddle.disable_static() x = paddle.to_tensor([1, 2, 3]) with self.assertRaises(BaseException): @@ -273,7 +296,7 @@ def test_errors(self): with self.assertRaises(BaseException): paddle.topk(x, k=0) - +''' if __name__ == "__main__": unittest.main() From afcd6efff90572c56c6d88d00229b7c14f96f85e Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 10 Aug 2021 11:48:55 +0800 Subject: [PATCH 4/6] A commit with accuracy error for float32 data --- paddle/fluid/operators/top_k_v2_op_npu.cc | 50 +++++++++-------------- 1 file changed, 19 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/top_k_v2_op_npu.cc b/paddle/fluid/operators/top_k_v2_op_npu.cc index 3590a8854ebe70..1b5d7920674f65 100644 --- a/paddle/fluid/operators/top_k_v2_op_npu.cc +++ b/paddle/fluid/operators/top_k_v2_op_npu.cc @@ -20,7 +20,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class TopkV2NPUKernel : public framework::OpKernel { public: // Use Ascend TopKV2 operator to implement paddle TopKV2Op @@ -42,8 +42,6 @@ class TopkV2NPUKernel : public framework::OpKernel { } if (k_tensor != nullptr) { - // seems complicated, but I really don't know - // how to assign a NPU value to a CPU variable by an elegant way std::vector v_tmp(1); TensorToVector( *k_tensor, @@ -62,36 +60,27 @@ class TopkV2NPUKernel : public framework::OpKernel { out->mutable_data(context.GetPlace()); indices->mutable_data(context.GetPlace()); - // Allocate space for input k and output indices of Ascend topkV2 operator - framework::Tensor* k_Ascend = new Tensor(framework::proto::VarType::INT32); - k_Ascend->mutable_data({1}, context.GetPlace()); - FillNpuTensorWithConstant(k_Ascend, static_cast(k)); - + // Allocate space for output indices of Ascend topkV2 operator framework::Tensor* indices_int32 = new Tensor(framework::proto::VarType::INT32); indices_int32->Resize(output_dims); indices_int32->mutable_data(context.GetPlace()); - - VLOG(4) << "input: " << *input; - VLOG(4) << "k: " << *k_Ascend; - VLOG(4) << "sorted: " << sorted; - VLOG(4) << "dim: " << axis; - VLOG(4) << "largest: " << largest; - VLOG(4) << "output: " << *out; - VLOG(4) << "indices_int32: " << *indices_int32; - - // Run CANN TopKV2 operator, error occurred when the dtype is 'float16' - const auto& npu_op_runner_topkv2 = - NpuOpRunner("TopKV2", {*input, *k_Ascend}, {*out, *indices_int32}, - {{"sorted", sorted}, {"dim", axis}, {"largest", largest}}); + VLOG(4) << "input:" << *input; + // Run Ascend TopKV2 operator + NpuOpRunner npu_op_runner_topkv2; auto npu_stream_topkv2 = context.template device_context() .stream(); - npu_op_runner_topkv2.Run(npu_stream_topkv2); - - VLOG(4) << "output: " << *out; - VLOG(4) << "indices_int32: " << *indices_int32; - + npu_op_runner_topkv2.SetType("TopKV2") + .AddInput(*input) + .AddInput(std::vector{k}) + .AddOutput(*out) + .AddOutput(*indices_int32) + .AddAttr("sorted", sorted) + .AddAttr("dim", axis) + .AddAttr("largest", largest) + .Run(npu_stream_topkv2); + VLOG(4) << "output:" << *out; // Cast 'indices_int32' to 'indices', from INT32 to INT64 auto dst_dtype = ConvertToNpuDtype(indices->type()); const auto& npu_op_runner_cast = @@ -109,8 +98,7 @@ class TopkV2NPUKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_NPU_KERNEL( - top_k_v2, ops::TopkV2NPUKernel, - ops::TopkV2NPUKernel, - ops::TopkV2NPUKernel, - ops::TopkV2NPUKernel); +REGISTER_OP_NPU_KERNEL(top_k_v2, ops::TopkV2NPUKernel, + ops::TopkV2NPUKernel, + ops::TopkV2NPUKernel, + ops::TopkV2NPUKernel); From 469f63b822ca3d646feabe6789f06bfe9b73eb26 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 10 Aug 2021 18:35:56 +0800 Subject: [PATCH 5/6] Modify codes according to the review comments --- paddle/fluid/operators/top_k_v2_op_npu.cc | 42 ++++------- .../unittests/npu/test_top_k_v2_op_npu.py | 73 +++++++++++++++---- 2 files changed, 76 insertions(+), 39 deletions(-) diff --git a/paddle/fluid/operators/top_k_v2_op_npu.cc b/paddle/fluid/operators/top_k_v2_op_npu.cc index 1b5d7920674f65..dc80173a191093 100644 --- a/paddle/fluid/operators/top_k_v2_op_npu.cc +++ b/paddle/fluid/operators/top_k_v2_op_npu.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,24 +19,22 @@ limitations under the License. */ namespace paddle { namespace operators { - +// Attention: the Ascend TopKV2 operator used in this kernel +// may lead to large accuracy error for float32 data template class TopkV2NPUKernel : public framework::OpKernel { public: - // Use Ascend TopKV2 operator to implement paddle TopKV2Op void Compute(const framework::ExecutionContext& context) const override { - // Read message from context auto* input = context.Input("X"); auto* k_tensor = context.Input("K"); auto* out = context.Output("Out"); - auto* indices = context.Output("Indices"); // type:INT64 + auto* indices = context.Output("Indices"); // type: INT64 int32_t k = static_cast(context.Attr("k")); int axis = static_cast(context.Attr("axis")); const bool sorted = static_cast(context.Attr("sorted")); const bool largest = static_cast(context.Attr("largest")); - // Calculate the real value of axis and k if (axis < 0) { axis += input->dims().size(); } @@ -50,7 +48,6 @@ class TopkV2NPUKernel : public framework::OpKernel { k = static_cast(v_tmp[0]); } - // Allocate space for output tensors of Paddle topKV2 operator framework::DDim output_dims = input->dims(); output_dims[axis] = k; @@ -60,38 +57,31 @@ class TopkV2NPUKernel : public framework::OpKernel { out->mutable_data(context.GetPlace()); indices->mutable_data(context.GetPlace()); - // Allocate space for output indices of Ascend topkV2 operator - framework::Tensor* indices_int32 = - new Tensor(framework::proto::VarType::INT32); - indices_int32->Resize(output_dims); - indices_int32->mutable_data(context.GetPlace()); - VLOG(4) << "input:" << *input; - // Run Ascend TopKV2 operator - NpuOpRunner npu_op_runner_topkv2; - auto npu_stream_topkv2 = + framework::Tensor indices_int32(framework::proto::VarType::INT32); + indices_int32.Resize(output_dims); + indices_int32.mutable_data(context.GetPlace()); + + auto npu_stream = context.template device_context() .stream(); + + NpuOpRunner npu_op_runner_topkv2; npu_op_runner_topkv2.SetType("TopKV2") .AddInput(*input) .AddInput(std::vector{k}) .AddOutput(*out) - .AddOutput(*indices_int32) + .AddOutput(indices_int32) .AddAttr("sorted", sorted) .AddAttr("dim", axis) .AddAttr("largest", largest) - .Run(npu_stream_topkv2); - VLOG(4) << "output:" << *out; + .Run(npu_stream); + // Cast 'indices_int32' to 'indices', from INT32 to INT64 auto dst_dtype = ConvertToNpuDtype(indices->type()); const auto& npu_op_runner_cast = - NpuOpRunner("Cast", {*indices_int32}, {*indices}, + NpuOpRunner("Cast", {indices_int32}, {*indices}, {{"dst_type", static_cast(dst_dtype)}}); - auto npu_stream_cast = - context.template device_context() - .stream(); - npu_op_runner_cast.Run(npu_stream_cast); - - VLOG(4) << "indices: " << *indices; + npu_op_runner_cast.Run(npu_stream); } }; } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py index f43fd7bc26a7f4..471fa5d2e5089e 100755 --- a/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py @@ -69,13 +69,29 @@ def set_input_data(self): def test_check_output(self): self.__class__.no_need_check_grad = True - self.check_output_with_place(self.place) + if self.dtype == np.float32: + self.check_output_with_place(self.place, atol=1e-3) + else: + self.check_output_with_place(self.place) def set_npu(self): self.__class__.use_npu = True self.place = paddle.NPUPlace(0) +class TestTopkV2OpFloat16(TestTopkV2NPUOp): + def set_attrs(self): + self.k = 3 + self.axis = 1 + self.largest = True + + def set_dtype(self): + self.dtype = np.float32 + + def set_input_data(self): + self.input_data = np.random.rand(3, 4).astype(self.dtype) + + class TestTopkV2OP1Int32(TestTopkV2NPUOp): def set_attrs(self): self.k = 3 @@ -83,19 +99,20 @@ def set_attrs(self): self.largest = False -''' class TestTopkV2OP2Int32(TestTopkV2NPUOp): def set_attrs(self): self.k = 4 self.axis = 0 self.largest = False + class TestTopkV2OP3Int32(TestTopkV2NPUOp): def set_attrs(self): self.k = 6 self.axis = 1 self.largest = True + class TestTopkV2OP4Int32(TestTopkV2NPUOp): def set_attrs(self): self.k = 3 @@ -107,22 +124,31 @@ class TestTopkV2Op1Int64(TestTopkV2OP1Int32): def set_dtype(self): self.dtype = np.int64 + class TestTopkV2Op2Int64(TestTopkV2OP2Int32): def set_dtype(self): self.dtype = np.int64 + class TestTopkV2Op3Int64(TestTopkV2OP3Int32): def set_dtype(self): self.dtype = np.int64 + class TestTopkV2Op4Int64(TestTopkV2OP4Int32): def set_dtype(self): self.dtype = np.int64 -''' -# Error occurred in this test case -class TestTopkOp1Float32(TestTopkV2OP1Int32): +class TestTopkV2Op1Float32(TestTopkV2OP1Int32): + def set_dtype(self): + self.dtype = np.float32 + + def set_input_data(self): + self.input_data = np.random.rand(10, 20).astype(self.dtype) + + +class TestTopkV2Op2Float32(TestTopkV2OP2Int32): def set_dtype(self): self.dtype = np.float32 @@ -130,28 +156,50 @@ def set_input_data(self): self.input_data = np.random.rand(10, 20).astype(self.dtype) -''' -class TestTopkOp1Float64(TestTopkV2OP1Int32): +class TestTopkV2Op3Float32(TestTopkV2OP3Int32): + def set_dtype(self): + self.dtype = np.float32 + + def set_input_data(self): + self.input_data = np.random.rand(10, 20).astype(self.dtype) + + +class TestTopkV2Op4Float32(TestTopkV2OP4Int32): + def set_dtype(self): + self.dtype = np.float32 + + def set_input_data(self): + self.input_data = np.random.rand(10, 20).astype(self.dtype) + + +class TestTopkV2Op1Float64(TestTopkV2OP1Int32): def set_dtype(self): self.dtype = np.float64 + def set_input_data(self): self.input_data = np.random.rand(10, 20).astype(self.dtype) -class TestTopkOp2Float64(TestTopkV2OP2Int32): + +class TestTopkV2Op2Float64(TestTopkV2OP2Int32): def set_dtype(self): self.dtype = np.float64 + def set_input_data(self): self.input_data = np.random.rand(10, 20).astype(self.dtype) -class TestTopkOp3Float64(TestTopkV2OP3Int32): + +class TestTopkV2Op3Float64(TestTopkV2OP3Int32): def set_dtype(self): self.dtype = np.float64 + def set_input_data(self): self.input_data = np.random.rand(10, 20).astype(self.dtype) -class TestTopkOp4Float64(TestTopkV2OP4Int32): + +class TestTopkV2Op4Float64(TestTopkV2OP4Int32): def set_dtype(self): self.dtype = np.float64 + def set_input_data(self): self.input_data = np.random.rand(10, 20).astype(self.dtype) @@ -213,7 +261,6 @@ def run_dygraph(self, place): numpy_result = numpy_topk(self.input_data, k=2, axis=1) self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0])) - def run_static(self, place): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program(), @@ -278,7 +325,7 @@ def run_static(self, place): numpy_result = numpy_topk(self.input_data, k=2, axis=1) self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0])) - def test_cases(self): + def test_cases(self): places = [core.NPUPlace(0)] #if core.is_compiled_with_cuda(): # places.append(core.CUDAPlace(0)) @@ -296,7 +343,7 @@ def test_errors(self): with self.assertRaises(BaseException): paddle.topk(x, k=0) -''' + if __name__ == "__main__": unittest.main() From 41d1f4c48d1430ca3855af12d52339401c4b2642 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 10 Aug 2021 18:35:56 +0800 Subject: [PATCH 6/6] Modify codes according to the review comments --- paddle/fluid/operators/top_k_v2_op_npu.cc | 42 ++++------ .../unittests/npu/test_top_k_v2_op_npu.py | 81 ++++++++++++++----- 2 files changed, 77 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/operators/top_k_v2_op_npu.cc b/paddle/fluid/operators/top_k_v2_op_npu.cc index 1b5d7920674f65..dc80173a191093 100644 --- a/paddle/fluid/operators/top_k_v2_op_npu.cc +++ b/paddle/fluid/operators/top_k_v2_op_npu.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,24 +19,22 @@ limitations under the License. */ namespace paddle { namespace operators { - +// Attention: the Ascend TopKV2 operator used in this kernel +// may lead to large accuracy error for float32 data template class TopkV2NPUKernel : public framework::OpKernel { public: - // Use Ascend TopKV2 operator to implement paddle TopKV2Op void Compute(const framework::ExecutionContext& context) const override { - // Read message from context auto* input = context.Input("X"); auto* k_tensor = context.Input("K"); auto* out = context.Output("Out"); - auto* indices = context.Output("Indices"); // type:INT64 + auto* indices = context.Output("Indices"); // type: INT64 int32_t k = static_cast(context.Attr("k")); int axis = static_cast(context.Attr("axis")); const bool sorted = static_cast(context.Attr("sorted")); const bool largest = static_cast(context.Attr("largest")); - // Calculate the real value of axis and k if (axis < 0) { axis += input->dims().size(); } @@ -50,7 +48,6 @@ class TopkV2NPUKernel : public framework::OpKernel { k = static_cast(v_tmp[0]); } - // Allocate space for output tensors of Paddle topKV2 operator framework::DDim output_dims = input->dims(); output_dims[axis] = k; @@ -60,38 +57,31 @@ class TopkV2NPUKernel : public framework::OpKernel { out->mutable_data(context.GetPlace()); indices->mutable_data(context.GetPlace()); - // Allocate space for output indices of Ascend topkV2 operator - framework::Tensor* indices_int32 = - new Tensor(framework::proto::VarType::INT32); - indices_int32->Resize(output_dims); - indices_int32->mutable_data(context.GetPlace()); - VLOG(4) << "input:" << *input; - // Run Ascend TopKV2 operator - NpuOpRunner npu_op_runner_topkv2; - auto npu_stream_topkv2 = + framework::Tensor indices_int32(framework::proto::VarType::INT32); + indices_int32.Resize(output_dims); + indices_int32.mutable_data(context.GetPlace()); + + auto npu_stream = context.template device_context() .stream(); + + NpuOpRunner npu_op_runner_topkv2; npu_op_runner_topkv2.SetType("TopKV2") .AddInput(*input) .AddInput(std::vector{k}) .AddOutput(*out) - .AddOutput(*indices_int32) + .AddOutput(indices_int32) .AddAttr("sorted", sorted) .AddAttr("dim", axis) .AddAttr("largest", largest) - .Run(npu_stream_topkv2); - VLOG(4) << "output:" << *out; + .Run(npu_stream); + // Cast 'indices_int32' to 'indices', from INT32 to INT64 auto dst_dtype = ConvertToNpuDtype(indices->type()); const auto& npu_op_runner_cast = - NpuOpRunner("Cast", {*indices_int32}, {*indices}, + NpuOpRunner("Cast", {indices_int32}, {*indices}, {{"dst_type", static_cast(dst_dtype)}}); - auto npu_stream_cast = - context.template device_context() - .stream(); - npu_op_runner_cast.Run(npu_stream_cast); - - VLOG(4) << "indices: " << *indices; + npu_op_runner_cast.Run(npu_stream); } }; } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py index f43fd7bc26a7f4..a8242be855c80a 100755 --- a/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_top_k_v2_op_npu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -69,13 +69,29 @@ def set_input_data(self): def test_check_output(self): self.__class__.no_need_check_grad = True - self.check_output_with_place(self.place) + if self.dtype == np.float32: + self.check_output_with_place(self.place, atol=1e-3) + else: + self.check_output_with_place(self.place) def set_npu(self): self.__class__.use_npu = True self.place = paddle.NPUPlace(0) +class TestTopkV2OpFloat16(TestTopkV2NPUOp): + def set_attrs(self): + self.k = 3 + self.axis = 1 + self.largest = True + + def set_dtype(self): + self.dtype = np.float32 + + def set_input_data(self): + self.input_data = np.random.rand(3, 4).astype(self.dtype) + + class TestTopkV2OP1Int32(TestTopkV2NPUOp): def set_attrs(self): self.k = 3 @@ -83,19 +99,20 @@ def set_attrs(self): self.largest = False -''' class TestTopkV2OP2Int32(TestTopkV2NPUOp): def set_attrs(self): self.k = 4 self.axis = 0 self.largest = False + class TestTopkV2OP3Int32(TestTopkV2NPUOp): def set_attrs(self): self.k = 6 self.axis = 1 self.largest = True + class TestTopkV2OP4Int32(TestTopkV2NPUOp): def set_attrs(self): self.k = 3 @@ -107,22 +124,31 @@ class TestTopkV2Op1Int64(TestTopkV2OP1Int32): def set_dtype(self): self.dtype = np.int64 + class TestTopkV2Op2Int64(TestTopkV2OP2Int32): def set_dtype(self): self.dtype = np.int64 + class TestTopkV2Op3Int64(TestTopkV2OP3Int32): def set_dtype(self): self.dtype = np.int64 + class TestTopkV2Op4Int64(TestTopkV2OP4Int32): def set_dtype(self): self.dtype = np.int64 -''' -# Error occurred in this test case -class TestTopkOp1Float32(TestTopkV2OP1Int32): +class TestTopkV2Op1Float32(TestTopkV2OP1Int32): + def set_dtype(self): + self.dtype = np.float32 + + def set_input_data(self): + self.input_data = np.random.rand(10, 20).astype(self.dtype) + + +class TestTopkV2Op2Float32(TestTopkV2OP2Int32): def set_dtype(self): self.dtype = np.float32 @@ -130,28 +156,50 @@ def set_input_data(self): self.input_data = np.random.rand(10, 20).astype(self.dtype) -''' -class TestTopkOp1Float64(TestTopkV2OP1Int32): +class TestTopkV2Op3Float32(TestTopkV2OP3Int32): + def set_dtype(self): + self.dtype = np.float32 + + def set_input_data(self): + self.input_data = np.random.rand(10, 20).astype(self.dtype) + + +class TestTopkV2Op4Float32(TestTopkV2OP4Int32): + def set_dtype(self): + self.dtype = np.float32 + + def set_input_data(self): + self.input_data = np.random.rand(10, 20).astype(self.dtype) + + +class TestTopkV2Op1Float64(TestTopkV2OP1Int32): def set_dtype(self): self.dtype = np.float64 + def set_input_data(self): self.input_data = np.random.rand(10, 20).astype(self.dtype) -class TestTopkOp2Float64(TestTopkV2OP2Int32): + +class TestTopkV2Op2Float64(TestTopkV2OP2Int32): def set_dtype(self): self.dtype = np.float64 + def set_input_data(self): self.input_data = np.random.rand(10, 20).astype(self.dtype) -class TestTopkOp3Float64(TestTopkV2OP3Int32): + +class TestTopkV2Op3Float64(TestTopkV2OP3Int32): def set_dtype(self): self.dtype = np.float64 + def set_input_data(self): self.input_data = np.random.rand(10, 20).astype(self.dtype) -class TestTopkOp4Float64(TestTopkV2OP4Int32): + +class TestTopkV2Op4Float64(TestTopkV2OP4Int32): def set_dtype(self): self.dtype = np.float64 + def set_input_data(self): self.input_data = np.random.rand(10, 20).astype(self.dtype) @@ -162,7 +210,6 @@ def setUp(self): self.place = paddle.NPUPlace(0) np.random.seed(123) self.input_data = np.random.rand(6, 7, 8) - #self.input_data = np.random.rand(2, 3, 4) self.large_input_data = np.random.rand(2, 1030) def run_dygraph(self, place): @@ -213,14 +260,12 @@ def run_dygraph(self, place): numpy_result = numpy_topk(self.input_data, k=2, axis=1) self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0])) - def run_static(self, place): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): input_tensor = paddle.static.data( name="x", shape=[6, 7, 8], dtype="float64") - #input_tensor = paddle.static.data(name="x", shape=[2, 3, 4], dtype="float64") large_input_tensor = paddle.static.data( name="large_x", shape=[2, 1030], dtype="float64") k_tensor = paddle.static.data(name="k", shape=[1], dtype="int32") @@ -229,8 +274,6 @@ def run_static(self, place): result3 = paddle.topk(input_tensor, k=k_tensor, axis=1) self.assertEqual(result3[0].shape, (6, -1, 8)) self.assertEqual(result3[1].shape, (6, -1, 8)) - #self.assertEqual(result3[0].shape, (2, -1, 4)) - #self.assertEqual(result3[1].shape, (2, -1, 4)) result4 = paddle.topk(input_tensor, k=2, axis=1, largest=False) result5 = paddle.topk(input_tensor, k=2, axis=-1, largest=False) result6 = paddle.topk(large_input_tensor, k=1, axis=-1) @@ -278,10 +321,8 @@ def run_static(self, place): numpy_result = numpy_topk(self.input_data, k=2, axis=1) self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0])) - def test_cases(self): + def test_cases(self): places = [core.NPUPlace(0)] - #if core.is_compiled_with_cuda(): - # places.append(core.CUDAPlace(0)) for place in places: self.run_dygraph(place) self.run_static(place) @@ -296,7 +337,7 @@ def test_errors(self): with self.assertRaises(BaseException): paddle.topk(x, k=0) -''' + if __name__ == "__main__": unittest.main()