diff --git a/Paddle b/Paddle index 0759e99d8a4..d5b4570dd98 160000 --- a/Paddle +++ b/Paddle @@ -1 +1 @@ -Subproject commit 0759e99d8a4ba233850dbffe87954a2b6a628776 +Subproject commit d5b4570dd98bf59a1ea81fb04fc5ef62f8826e59 diff --git a/backends/npu/CMakeLists.txt b/backends/npu/CMakeLists.txt index 19cdd447e33..c0cc808f2ce 100644 --- a/backends/npu/CMakeLists.txt +++ b/backends/npu/CMakeLists.txt @@ -54,6 +54,15 @@ include(third_party) add_dependencies(${CUSTOM_NPU_NAME} third_party) target_link_libraries(${CUSTOM_NPU_NAME} PRIVATE ${PADDLE_CORE_LIB}) +# testing +if (WITH_TESTING) + set(PYTHON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../python") + add_subdirectory(tests) + add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/tests/.timestamp + COMMAND cp -r ${CMAKE_SOURCE_DIR}/tests ${CMAKE_CURRENT_BINARY_DIR}) + add_custom_target(python_tests ALL DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/tests/.timestamp) +endif() + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup.py) diff --git a/backends/npu/kernels/sgd_kernel.cc b/backends/npu/kernels/sgd_kernel.cc index c40cbd61b46..56797e12d63 100644 --- a/backends/npu/kernels/sgd_kernel.cc +++ b/backends/npu/kernels/sgd_kernel.cc @@ -22,7 +22,10 @@ void SGDKernel(const Context& dev_ctx, const phi::DenseTensor& param_var, const phi::DenseTensor& learning_rate, const phi::DenseTensor& grad_var, - phi::DenseTensor* param_out) { + paddle::optional master_param, + bool multi_precision, + phi::DenseTensor* param_out, + phi::DenseTensor* master_param_out) { aclrtStream stream = static_cast(dev_ctx.stream()); dev_ctx.template Alloc(param_out); diff --git a/backends/npu/tests/CMakeLists.txt b/backends/npu/tests/CMakeLists.txt new file mode 100644 index 00000000000..29908b97b59 --- /dev/null +++ b/backends/npu/tests/CMakeLists.txt @@ -0,0 +1,33 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +function(py_test_modules TARGET_NAME) + set(options SERIAL) + set(oneValueArgs "") + set(multiValueArgs MODULES DEPS ENVS) + cmake_parse_arguments(py_test_modules "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + add_test(NAME ${TARGET_NAME} + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PYTHON_SOURCE_DIR}:$ENV{PYTHONPATH} ${py_test_modules_ENVS} + python ${PYTHON_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + + if (py_test_modules_SERIAL) + set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) + endif() +endfunction() + +py_test_modules(test_MNIST_model MODULES test_MNIST_model) + +add_subdirectory(unittests) diff --git a/backends/npu/tests/unittests/CMakeLists.txt b/backends/npu/tests/unittests/CMakeLists.txt new file mode 100644 index 00000000000..49715de4818 --- /dev/null +++ b/backends/npu/tests/unittests/CMakeLists.txt @@ -0,0 +1,20 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP}) +endforeach(TEST_OP) diff --git a/backends/npu/tests/unittests/test_softmax_op_npu.py b/backends/npu/tests/unittests/test_softmax_op_npu.py new file mode 100644 index 00000000000..7f4083b82e6 --- /dev/null +++ b/backends/npu/tests/unittests/test_softmax_op_npu.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +from tests.op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import core + +paddle.enable_static() +SEED = 2021 + + +class TestSoftmax(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.CustomPlace('ascend', 0) + self.op_type = "softmax" + self.init_dtype() + + x = np.random.random([3, 3]).astype(self.dtype) + np_out = np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True) + self.inputs = {'X': x} + + self.attrs = {} + self.outputs = {'Out': np_out} + + def set_npu(self): + self.__class__.use_custom_device = True + self.__class__.no_need_check_grad = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestSoftmaxNet(unittest.TestCase): + def _test(self, run_npu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(4, 32)).astype('float32') + b_np = np.random.random(size=(4, 32)).astype('float32') + label_np = np.random.randint(2, size=(4, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[4, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[4, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[4, 1], dtype='int64') + + c = paddle.multiply(a, b) + d = paddle.sqrt(c) + + # 4 x 128 + fc_1 = fluid.layers.fc(input=d, size=128) + # 4 x 2 + prediction = fluid.layers.fc(input=fc_1, size=2) + + # 4 x 2 + prob = fluid.layers.softmax(prediction, axis=1) + + cost = fluid.layers.cross_entropy(input=prob, label=label) + loss = fluid.layers.mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_npu: + place = paddle.CustomPlace('ascend', 0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_npu(self): + cpu_pred, cpu_loss = self._test(False) + npu_pred, npu_loss = self._test(True) + + self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-2)) + self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-2)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/__init__.py b/python/__init__.py new file mode 100644 index 00000000000..513558501a0 --- /dev/null +++ b/python/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 120000 index 00000000000..f4c1ea4faab --- /dev/null +++ b/python/tests/__init__.py @@ -0,0 +1 @@ +../../Paddle/python/paddle/fluid/tests/unittests/__init__.py \ No newline at end of file diff --git a/python/tests/op_test.py b/python/tests/op_test.py new file mode 120000 index 00000000000..3f65a6e3b4f --- /dev/null +++ b/python/tests/op_test.py @@ -0,0 +1 @@ +../../Paddle/python/paddle/fluid/tests/unittests/op_test.py \ No newline at end of file diff --git a/python/tests/testsuite.py b/python/tests/testsuite.py new file mode 120000 index 00000000000..c9eaa456157 --- /dev/null +++ b/python/tests/testsuite.py @@ -0,0 +1 @@ +../../Paddle/python/paddle/fluid/tests/unittests/testsuite.py \ No newline at end of file diff --git a/python/tests/white_list b/python/tests/white_list new file mode 120000 index 00000000000..a5b9f97d72a --- /dev/null +++ b/python/tests/white_list @@ -0,0 +1 @@ +../../Paddle/python/paddle/fluid/tests/unittests/white_list \ No newline at end of file diff --git a/python/tools/__init__.py b/python/tools/__init__.py new file mode 120000 index 00000000000..393955a4b42 --- /dev/null +++ b/python/tools/__init__.py @@ -0,0 +1 @@ +../../Paddle/tools/__init__.py \ No newline at end of file diff --git a/python/tools/static_mode_white_list.py b/python/tools/static_mode_white_list.py new file mode 120000 index 00000000000..908dbd07063 --- /dev/null +++ b/python/tools/static_mode_white_list.py @@ -0,0 +1 @@ +../../Paddle/tools/static_mode_white_list.py \ No newline at end of file diff --git a/python/tools/test_runner.py b/python/tools/test_runner.py new file mode 120000 index 00000000000..278cedb0304 --- /dev/null +++ b/python/tools/test_runner.py @@ -0,0 +1 @@ +../../Paddle/tools/test_runner.py \ No newline at end of file diff --git a/scripts/paddle_ci.sh b/scripts/paddle_ci.sh index d0c09af5b39..717df5b576b 100755 --- a/scripts/paddle_ci.sh +++ b/scripts/paddle_ci.sh @@ -42,10 +42,10 @@ function custom_npu_test() { pip install hypothesis pip install ${WORKSPACE_ROOT}/Paddle/build/python/dist/*whl - # custom_npu install + # custom_npu build and install cd ${WORKSPACE_ROOT}/PaddleCustomDevice/backends/npu mkdir build && cd build - cmake .. -DWITH_TESTING=ON -DWITH_KERNELS=ON + cmake .. -DWITH_TESTING=ON if [[ "$?" != "0" ]];then exit 7; fi @@ -55,10 +55,9 @@ function custom_npu_test() { fi pip install dist/*.whl - # simple test now + # run ut ut_total_startTime_s=`date +%s` - cd ${WORKSPACE_ROOT}/PaddleCustomDevice/backends/npu/tests - python test_MNIST_model.py + ctest --output-on-failure EXIT_CODE=$? ut_total_endTime_s=`date +%s` echo "TestCases Total Time: $[ $ut_total_endTime_s - $ut_total_startTime_s ]s"