diff --git a/CMakeLists.txt b/CMakeLists.txt index 1a0b0b36861605..036a5faf24f24a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,36 @@ set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) include(system) +if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + cmake_minimum_required(VERSION 3.10) + # TODO(TJ): make as function check_default + if(NOT DEFINED ARM_TARGET_OS) + set(ARM_TARGET_OS "android" CACHE STRING "Choose ARM Target OS") + endif() + set(ARM_TARGET_OS_LIST "android" "armlinux") # TODO: "ios" + set_property(CACHE ARM_TARGET_OS PROPERTY STRINGS ${ARM_TARGET_OS_LIST}) + if (NOT ARM_TARGET_OS IN_LIST ARM_TARGET_OS_LIST) + message(FATAL_ERROR "ARM_TARGET_OS must be in one of ${ARM_TARGET_OS_LIST}") + endif() + + if(NOT DEFINED ARM_TARGET_ARCH_ABI) + set(ARM_TARGET_ARCH_ABI "arm64-v8a" CACHE STRING "Choose ARM Target ARCH ABI") + endif() + set(ARM_TARGET_ARCH_ABI_LIST "arm64-v8a" "armeabi-v7a" "armeabi-v7a-softfp" "armeabi-v7a-hf") + set_property(CACHE ARM_TARGET_ARCH_ABI PROPERTY STRINGS ${ARM_TARGET_ARCH_ABI_LIST}) + if (NOT ARM_TARGET_ARCH_ABI IN_LIST ARM_TARGET_ARCH_ABI_LIST) + message(FATAL_ERROR "ARM_TARGET_ARCH_ABI must be in one of ${ARM_TARGET_ARCH_ABI_LIST}") + endif() + + if(NOT DEFINED TARGET_ARCH_ABI) + set(ARCH_ABI "arm64-v8a" CACHE STRING "Choose android platform") + endif() + + include(cross_compiling/host) + include(cross_compiling/armlinux) + include(cross_compiling/android) +endif() + project(paddle CXX C) message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: " "${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}") @@ -41,7 +71,9 @@ if(WIN32) set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${PADDLE_LINK_FLAGS}") endif(WIN32) -find_package(CUDA QUIET) +if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + find_package(CUDA QUIET) +endif() find_package(Git REQUIRED) find_package(Threads REQUIRED) @@ -79,18 +111,42 @@ option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VER option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON) option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ON) +if(ANDROID OR IOS OR ARMLINUX) + set(WITH_GPU OFF CACHE STRING + "Disable GPU when cross-compiling for Android and iOS" FORCE) + set(WITH_DSO OFF CACHE STRING + "Disable DSO when cross-compiling for Android and iOS" FORCE) + set(WITH_AVX OFF CACHE STRING + "Disable AVX when cross-compiling for Android and iOS" FORCE) + set(WITH_PYTHON OFF CACHE STRING + "Disable PYTHON when cross-compiling for Android and iOS" FORCE) + set(WITH_RDMA OFF CACHE STRING + "Disable RDMA when cross-compiling for Android and iOS" FORCE) + set(WITH_MKL OFF CACHE STRING + "Disable MKL when cross-compiling for Android and iOS" FORCE) + + if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRING + "Default use Release in android" FORCE) + endif() + if(NOT THIRD_PARTY_BUILD_TYPE) + set(THIRD_PARTY_BUILD_TYPE "MinSizeRel" CACHE STRING + "Default use MinSizeRel in android" FORCE) + endif() +endif() + # for lite, both server and mobile framework. option(WITH_LITE "Enable lite framework" OFF) option(LITE_WITH_CUDA "Enable CUDA in lite mode" OFF) -option(LITE_WITH_X86 "Enable X86 in lite mode" ON) +option(LITE_WITH_X86 "Enable X86 in lite mode" ON) +option(LITE_WITH_ARM "Enable ARM in lite mode" OFF) option(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "Enable light-weight framework" OFF) +option(LITE_WITH_PROFILE "Enable profile mode in lite framework" OFF) set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING "A path setting third party libraries download & build directories.") -set(THIRD_PARTY_BUILD_TYPE Release) - # CMAKE_BUILD_TYPE if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING @@ -107,7 +163,7 @@ if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) include(external/gflags) # download, build, install gflags include(external/glog) # download, build, install glog include(external/gtest) # download, build, install gtest - include(external/zlib) # download, build, install gtest + #include(external/zlib) # download, build, install gtest include(external/protobuf) # download, build, install protobuf include(external/eigen) # download eigen3 @@ -115,7 +171,7 @@ if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) include(configure) # add paddle env configuration add_definitions(-std=c++11) - + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") add_subdirectory(paddle) return() diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 6c9c3fd488901f..385a9572f58d52 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -172,6 +172,14 @@ if (LITE_WITH_X86) add_definitions("-DLITE_WITH_X86") endif() +if (LITE_WITH_ARM) + add_definitions("-DLITE_WITH_ARM") +endif() + +if (LITE_WITH_PROFILE) + add_definitions("-DLITE_WITH_PROFILE") +endif() + if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) add_definitions("-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK") endif() diff --git a/cmake/cross_compiling/android.cmake b/cmake/cross_compiling/android.cmake new file mode 100644 index 00000000000000..e57f32aae7c1d5 --- /dev/null +++ b/cmake/cross_compiling/android.cmake @@ -0,0 +1,79 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT ARM_TARGET_OS STREQUAL "android") + return() +endif() + +set(ANDROID TRUE) +add_definitions(-DLITE_WITH_LINUX) + +if(NOT DEFINED ANDROID_NDK) + set(ANDROID_NDK $ENV{NDK_ROOT}) + if(NOT ANDROID_NDK) + message(FATAL_ERROR "Must set ANDROID_NDK or env NDK_ROOT") + endif() +endif() + + +if(NOT DEFINED ANDROID_API_LEVEL) + set(ANDROID_API_LEVEL "22") +endif() + +if(NOT DEFINED ANDROID_STL_TYPE) + set(ANDROID_STL_TYPE "c++_static" CACHE STRING "stl type") +endif() + +# TODO(TJ): enable me +if(ARM_TARGET_ARCH_ABI STREQUAL "armeabi-v7a-hf") + message(FATAL_ERROR "Not supported building android armeabi-v7a-hf yet") +endif() + +set(ANDROID_ARCH_ABI ${ARM_TARGET_ARCH_ABI} CACHE STRING "Choose Android Arch ABI") + +if(ANDROID_ARCH_ABI STREQUAL "armeabi-v7a-softfp") + set(ANDROID_ARCH_ABI "armeabi-v7a") +endif() + +set(ANDROID_ARCH_ABI_LIST "arm64-v8a" "armeabi-v7a" "armeabi-v6" "armeabi" + "mips" "mips64" "x86" "x86_64" "armeabi-v7a-hf") +set_property(CACHE ANDROID_ARCH_ABI PROPERTY STRINGS ${ANDROID_ARCH_ABI_LIST}) +if(NOT ANDROID_ARCH_ABI IN_LIST ANDROID_ARCH_ABI_LIST) + message(FATAL_ERROR "ANDROID_ARCH_ABI must be in one of ${ANDROID_ARCH_ABI_LIST}") +endif() + +if(ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") + message(STATUS "armeabi-v7a default use softfp") + set(CMAKE_ANDROID_ARM_NEON ON) + message(STATUS "NEON is enabled on arm-v7a with softfp") +endif() + +if(ANDROID_ARCH_ABI STREQUAL "armeabi-v7a-hf") + set(ANDROID_ARCH_ABI "armeabi-v7a") + set(CMAKE_CXX_FLAGS "-std=c++11 -march=armv7-a -mfloat-abi=hard -mfpu=neon-vfpv4 ${CMAKE_CXX_FLAGS}" ) + set(CMAKE_C_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon-vfpv4 ${CMAKE_C_FLAGS}" ) + message(STATUS "NEON is enabled on arm-v7a with hard float") +endif() + +set(ANDROID_STL_TYPE_LITS "gnustl_static" "c++_static") +set_property(CACHE ANDROID_STL_TYPE PROPERTY STRINGS ${ANDROID_STL_TYPE_LITS}) +if (NOT ANDROID_STL_TYPE IN_LIST ANDROID_STL_TYPE_LITS) + message(FATAL_ERROR "ANDROID_STL_TYPE must be in one of ${ANDROID_STL_TYPE_LITS}") +endif() + +set(CMAKE_SYSTEM_NAME Android) +set(CMAKE_SYSTEM_VERSION ${ANDROID_API_LEVEL}) +set(CMAKE_ANDROID_ARCH_ABI ${ANDROID_ARCH_ABI}) +set(CMAKE_ANDROID_NDK ${ANDROID_NDK}) +set(CMAKE_ANDROID_STL_TYPE ${ANDROID_STL_TYPE}) diff --git a/cmake/cross_compiling/armlinux.cmake b/cmake/cross_compiling/armlinux.cmake new file mode 100644 index 00000000000000..1d752075cca2d4 --- /dev/null +++ b/cmake/cross_compiling/armlinux.cmake @@ -0,0 +1,57 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT ARM_TARGET_OS STREQUAL "armlinux") + return() +endif() + +set(ARMLINUX TRUE) +add_definitions(-DLITE_WITH_LINUX) +set(CMAKE_SYSTEM_NAME Linux) + +if(ARM_TARGET_ARCH_ABI STREQUAL "arm64-v8a") + set(CMAKE_SYSTEM_PROCESSOR aarch64) + set(CMAKE_C_COMPILER "aarch64-linux-gnu-gcc") + set(CMAKE_CXX_COMPILER "aarch64-linux-gnu-g++") + + set(CMAKE_CXX_FLAGS "-march=armv8-a ${CMAKE_CXX_FLAGS}") + set(CMAKE_C_FLAGS "-march=armv8-a ${CMAKE_C_FLAGS}") + message(STATUS "NEON is enabled on arm64-v8a") +endif() + +if(ARM_TARGET_ARCH_ABI STREQUAL "armeabi-v7a" + OR ARM_TARGET_ARCH_ABI STREQUAL "armeabi-v7a-hf") + message(FATAL_ERROR "Not supported building arm linux arm-v7 yet") +endif() + +# TODO(TJ): make sure v7 works +if(ARM_TARGET_ARCH_ABI STREQUAL "armeabi-v7a") + set(CMAKE_SYSTEM_PROCESSOR arm) + set(CMAKE_C_COMPILER "arm-linux-gnueabi-gcc") + set(CMAKE_CXX_COMPILER "arm-linux-gnueabi-g++") + + set(CMAKE_CXX_FLAGS "-march=armv7-a -mfloat-abi=softfp -mfpu=neon-vfpv4 ${CMAKE_CXX_FLAGS}") + set(CMAKE_C_FLAGS "-march=armv7-a -mfloat-abi=softfp -mfpu=neon-vfpv4 ${CMAKE_C_FLAGS}") + message(STATUS "NEON is enabled on arm-v7a with softfp") +endif() + +if(ARM_TARGET_ARCH_ABI STREQUAL "armeabi-v7a-hf") + set(CMAKE_SYSTEM_PROCESSOR arm) + set(CMAKE_C_COMPILER "arm-linux-gnueabihf-gcc") + set(CMAKE_CXX_COMPILER "arm-linux-gnueabihf-g++") + + set(CMAKE_CXX_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon-vfpv4 ${CMAKE_CXX_FLAGS}") + set(CMAKE_C_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon-vfpv4 ${CMAKE_C_FLAGS}" ) + message(STATUS "NEON is enabled on arm-v7a with hard float") +endif() diff --git a/cmake/cross_compiling/host.cmake b/cmake/cross_compiling/host.cmake new file mode 100644 index 00000000000000..b65e45208d8602 --- /dev/null +++ b/cmake/cross_compiling/host.cmake @@ -0,0 +1,40 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(HOST_C_COMPILER $ENV{CC}) +set(HOST_CXX_COMPILER $ENV{CXX}) + +if(NOT HOST_C_COMPILER) + find_program(HOST_C_COMPILER NAMES gcc PATH + /usr/bin + /usr/local/bin) +endif() + +if(NOT HOST_CXX_COMPILER) + find_program(HOST_CXX_COMPILER NAMES g++ PATH + /usr/bin + /usr/local/bin) +endif() + +if(NOT HOST_C_COMPILER OR NOT EXISTS ${HOST_C_COMPILER}) + MESSAGE(FATAL_ERROR "Cannot find host C compiler. export CC=/path/to/cc") +ENDIF() + +if(NOT HOST_CXX_COMPILER OR NOT EXISTS ${HOST_CXX_COMPILER}) + MESSAGE(FATAL_ERROR "Cannot find host C compiler. export CC=/path/to/cc") +ENDIF() + +MESSAGE(STATUS "Found host C compiler: " ${HOST_C_COMPILER}) +MESSAGE(STATUS "Found host CXX compiler: " ${HOST_CXX_COMPILER}) + diff --git a/cmake/external/gflags.cmake b/cmake/external/gflags.cmake index 911920ed6212b8..42ce7c644f3e8e 100644 --- a/cmake/external/gflags.cmake +++ b/cmake/external/gflags.cmake @@ -25,6 +25,24 @@ ENDIF(WIN32) INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR}) +SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" + "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" + "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}" + "-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}" + "-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}" + "-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}") + +if(ANDROID) + SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} + "-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}" + "-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}" + "-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}" + "-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}" + "-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}") +endif() + ExternalProject_Add( extern_gflags ${EXTERNAL_PROJECT_LOG_ARGS} @@ -32,19 +50,12 @@ ExternalProject_Add( GIT_TAG 77592648e3f3be87d6c7123eb81cbad75f9aef5a PREFIX ${GFLAGS_SOURCES_DIR} UPDATE_COMMAND "" - CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} - -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} - -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} - -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} - -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG} - -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} - -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG} - -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE} - -DBUILD_STATIC_LIBS=ON + CMAKE_ARGS -DBUILD_STATIC_LIBS=ON -DCMAKE_INSTALL_PREFIX=${GFLAGS_INSTALL_DIR} -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_TESTING=OFF -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} + ${OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GFLAGS_INSTALL_DIR} -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON diff --git a/cmake/external/glog.cmake b/cmake/external/glog.cmake index 7fa17ce6b7b106..9ac9b8326431ad 100644 --- a/cmake/external/glog.cmake +++ b/cmake/external/glog.cmake @@ -31,6 +31,24 @@ INCLUDE_DIRECTORIES(${GLOG_INCLUDE_DIR}) SET(GLOG_REPOSITORY "https://github.com/google/glog.git") SET(GLOG_TAG "v0.3.5") +SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" + "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" + "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}" + "-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}" + "-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}" + "-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}") + +if(ANDROID) + SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} + "-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}" + "-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}" + "-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}" + "-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}" + "-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}") +endif() + ExternalProject_Add( extern_glog ${EXTERNAL_PROJECT_LOG_ARGS} @@ -39,14 +57,7 @@ ExternalProject_Add( GIT_TAG ${GLOG_TAG} PREFIX ${GLOG_SOURCES_DIR} UPDATE_COMMAND "" - CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} - -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} - -DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS} - -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} - -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG} - -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} - -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG} - -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE} + CMAKE_ARGS ${OPTIONAL_ARGS} -DCMAKE_INSTALL_PREFIX=${GLOG_INSTALL_DIR} -DCMAKE_INSTALL_LIBDIR=${GLOG_INSTALL_DIR}/lib -DCMAKE_POSITION_INDEPENDENT_CODE=ON diff --git a/cmake/external/gtest.cmake b/cmake/external/gtest.cmake index e459526583bd5e..de44719803fc4f 100644 --- a/cmake/external/gtest.cmake +++ b/cmake/external/gtest.cmake @@ -43,6 +43,24 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC)) SET(GTEST_DEPENDS ${MKLML_PROJECT}) ENDIF() + SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" + "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" + "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}" + "-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}" + "-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}" + "-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}") + + if(ANDROID) + SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} + "-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}" + "-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}" + "-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}" + "-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}" + "-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}") + endif() + ExternalProject_Add( extern_gtest ${EXTERNAL_PROJECT_LOG_ARGS} @@ -51,14 +69,7 @@ IF(WITH_TESTING OR (WITH_DISTRIBUTE AND NOT WITH_GRPC)) GIT_TAG "release-1.8.0" PREFIX ${GTEST_SOURCES_DIR} UPDATE_COMMAND "" - CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} - -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} - -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} - -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} - -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG} - -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} - -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG} - -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE} + CMAKE_ARGS ${OPTIONAL_ARGS} -DCMAKE_INSTALL_PREFIX=${GTEST_INSTALL_DIR} -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_GMOCK=ON diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake index 09eb437aede436..41cd1ebaf33a6e 100644 --- a/cmake/external/protobuf.cmake +++ b/cmake/external/protobuf.cmake @@ -142,7 +142,6 @@ IF (WIN32) ENDIF(WIN32) if (NOT "${PROTOBUF_ROOT}" STREQUAL "") - find_path(PROTOBUF_INCLUDE_DIR google/protobuf/message.h PATHS ${PROTOBUF_ROOT}/include NO_DEFAULT_PATH) find_library(PROTOBUF_LIBRARY protobuf libprotobuf.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) find_library(PROTOBUF_LITE_LIBRARY protobuf-lite libprotobuf-lite.lib PATHS ${PROTOBUF_ROOT}/lib NO_DEFAULT_PATH) @@ -178,12 +177,28 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) "${PROTOBUF_INSTALL_DIR}/bin/protoc${CMAKE_EXECUTABLE_SUFFIX}" PARENT_SCOPE) + SET(PROTOBUF_REPO "https://github.com/protocolbuffers/protobuf.git") + SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546") SET(OPTIONAL_CACHE_ARGS "") SET(OPTIONAL_ARGS "") + IF(BUILD_FOR_HOST) - SET(OPTIONAL_ARGS "-Dprotobuf_WITH_ZLIB=OFF") - ELSE() SET(OPTIONAL_ARGS + "-DCMAKE_C_COMPILER=${HOST_C_COMPILER}" + "-DCMAKE_CXX_COMPILER=${HOST_CXX_COMPILER}" + "-Dprotobuf_WITH_ZLIB=OFF" + "-DZLIB_ROOT:FILEPATH=${ZLIB_ROOT}") + SET(OPTIONAL_CACHE_ARGS "-DZLIB_ROOT:STRING=${ZLIB_ROOT}") + ELSE() + # protobuf have compile issue when use android stl c++_static + SET(PROTOBUF_REPO "https://github.com/tensor-tang/protobuf.git") + SET(PROTOBUF_TAG "mobile") + SET(OPTIONAL_ARGS "-Dprotobuf_WITH_ZLIB=OFF" + "-DCMAKE_SYSTEM_NAME=${CMAKE_SYSTEM_NAME}" + "-DCMAKE_SYSTEM_VERSION=${CMAKE_SYSTEM_VERSION}" + "-DCMAKE_ANDROID_ARCH_ABI=${CMAKE_ANDROID_ARCH_ABI}" + "-DCMAKE_ANDROID_NDK=${CMAKE_ANDROID_NDK}" + "-DCMAKE_ANDROID_STL_TYPE=${CMAKE_ANDROID_STL_TYPE}" "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" "-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}" @@ -191,25 +206,18 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) "-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}" "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}" "-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}" - "-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}" - "-Dprotobuf_WITH_ZLIB=ON" - "-DZLIB_ROOT:FILEPATH=${ZLIB_ROOT}" - ${EXTERNAL_OPTIONAL_ARGS}) - SET(OPTIONAL_CACHE_ARGS "-DZLIB_ROOT:STRING=${ZLIB_ROOT}") + "-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}") ENDIF() IF(WIN32) SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} "-DCMAKE_GENERATOR_PLATFORM=x64") ENDIF() - SET(PROTOBUF_REPO "https://github.com/protocolbuffers/protobuf.git") - SET(PROTOBUF_TAG "9f75c5aa851cd877fb0d93ccc31b8567a6706546") - ExternalProject_Add( ${TARGET_NAME} ${EXTERNAL_PROJECT_LOG_ARGS} PREFIX ${PROTOBUF_SOURCES_DIR} UPDATE_COMMAND "" - DEPENDS zlib + #DEPENDS zlib GIT_REPOSITORY ${PROTOBUF_REPO} GIT_TAG ${PROTOBUF_TAG} CONFIGURE_COMMAND @@ -233,6 +241,13 @@ ENDFUNCTION() SET(PROTOBUF_VERSION 3.1.0) +IF(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + build_protobuf(protobuf_host TRUE) + LIST(APPEND external_project_dependencies protobuf_host) + SET(PROTOBUF_PROTOC_EXECUTABLE ${protobuf_host_PROTOC_EXECUTABLE} + CACHE FILEPATH "protobuf executable." FORCE) +ENDIF() + IF(NOT PROTOBUF_FOUND) build_protobuf(extern_protobuf FALSE) @@ -245,7 +260,12 @@ IF(NOT PROTOBUF_FOUND) SET(PROTOBUF_PROTOC_LIBRARY ${extern_protobuf_PROTOC_LIBRARY} CACHE FILEPATH "protoc library." FORCE) - SET(PROTOBUF_PROTOC_EXECUTABLE ${extern_protobuf_PROTOC_EXECUTABLE} - CACHE FILEPATH "protobuf executable." FORCE) - PROMPT_PROTOBUF_LIB(extern_protobuf) + IF(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + PROMPT_PROTOBUF_LIB(protobuf_host extern_protobuf) + ELSE() + SET(PROTOBUF_PROTOC_EXECUTABLE ${extern_protobuf_PROTOC_EXECUTABLE} + CACHE FILEPATH "protobuf executable." FORCE) + PROMPT_PROTOBUF_LIB(extern_protobuf) + ENDIF() + ENDIF(NOT PROTOBUF_FOUND) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 99c078cf7db625..a028dcbd6be80d 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -93,7 +93,10 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}) if(NOT APPLE) find_package(Threads REQUIRED) link_libraries(${CMAKE_THREAD_LIBS_INIT}) - set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl -lrt") + set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -ldl") + if (NOT ANDROID) + set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -lrt") + endif() endif(NOT APPLE) set_property(GLOBAL PROPERTY FLUID_MODULES "") @@ -424,7 +427,7 @@ function(raw_cc_test TARGET_NAME) endif() endfunction(raw_cc_test) -function(lite_cc_test args) +function(_lite_cc_test args) if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) message(STATUS "building lite raw test: ${args}") raw_cc_test(${args} ${ARGN}) diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index efdabffb9b33dd..6c60a041a191f1 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ syntax = "proto2"; -option optimize_for = LITE_RUNTIME; +// option optimize_for = LITE_RUNTIME; package paddle.framework.proto; // Any incompatible changes to ProgramDesc and its dependencies should diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 1ea93b7638a85e..8d3864c6b3da55 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_desc.h" +#include #include #include #include // NOLINT #include #include -#include "glog/logging.h" +#include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/operator.h" diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index fae33f55b054b1..dab35bae4d5241 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -39,6 +39,10 @@ DEFINE_int32(inner_op_parallelism, 0, "number of threads for inner op"); namespace paddle { namespace framework { +OpDuppy op_duppy; +Scope scope_duppy; +RuntimeContext runtime_context_duppy({}, {}); + std::vector> kKernelPriority = { std::make_tuple(platform::CUDAPlace(0), LibraryType::kCUDNN), std::make_tuple(platform::CUDAPlace(0), LibraryType::kPlain), diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 5b4bfc1eb47339..8f301c6ebce124 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -239,9 +239,10 @@ class OpDuppy : public OperatorBase { void RunImpl(const Scope& scope, const platform::Place& place) const override {} }; -OpDuppy op_duppy; -Scope scope_duppy; -RuntimeContext runtime_context_duppy({}, {}); + +extern OpDuppy op_duppy; +extern Scope scope_duppy; +extern RuntimeContext runtime_context_duppy; class ExecutionContext { public: @@ -255,7 +256,7 @@ class ExecutionContext { ctx_(ctx), kernel_configs_(configs) {} - ExecutionContext(const platform::DeviceContext& device_context) + explicit ExecutionContext(const platform::DeviceContext& device_context) : op_(op_duppy), scope_(scope_duppy), device_context_(device_context), diff --git a/paddle/fluid/incubate/CMakeLists.txt b/paddle/fluid/incubate/CMakeLists.txt index a6ded5204921be..552134ba6640b5 100644 --- a/paddle/fluid/incubate/CMakeLists.txt +++ b/paddle/fluid/incubate/CMakeLists.txt @@ -1 +1 @@ -include_directories(lite) \ No newline at end of file +include_directories(lite) diff --git a/paddle/fluid/inference/analysis/dot.h b/paddle/fluid/inference/analysis/dot.h index d1eef603be48ea..1cb790f1822900 100644 --- a/paddle/fluid/inference/analysis/dot.h +++ b/paddle/fluid/inference/analysis/dot.h @@ -23,10 +23,10 @@ #include #include #include -#include "paddle/fluid/lite/utils/logging.h" -#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +// #include "paddle/fluid/lite/utils/logging.h" +// #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #include -#endif +// #endif namespace paddle { namespace inference { diff --git a/paddle/fluid/lite/CMakeLists.txt b/paddle/fluid/lite/CMakeLists.txt index 3b9a6756953aca..269cc95b6587c8 100644 --- a/paddle/fluid/lite/CMakeLists.txt +++ b/paddle/fluid/lite/CMakeLists.txt @@ -3,11 +3,14 @@ if (NOT WITH_LITE) endif() message(WARNING "Lite enabled!") -message(STATUS "LIGHT_FRAMEWORK: ${LITE_WITH_LIGHT_WEIGHT_FRAMEWORK}") -message(STATUS "LITE_WITH_CUDA: ${LITE_WITH_CUDA}") -message(STATUS "LITE_WITH_X86: ${LITE_WITH_X86}") +message(STATUS "LIGHT_FRAMEWORK:\t${LITE_WITH_LIGHT_WEIGHT_FRAMEWORK}") +message(STATUS "LITE_WITH_CUDA:\t${LITE_WITH_CUDA}") +message(STATUS "LITE_WITH_X86:\t${LITE_WITH_X86}") +message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}") +message(STATUS "LITE_WITH_PROFILE:\t${LITE_WITH_PROFILE}") set(LITE_MODEL_DIR "${THIRD_PARTY_PATH}/install") +set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url") function(lite_download_and_uncompress INSTALL_DIR URL FILENAME) message(STATUS "Download inference test stuff from ${URL}/${FILENAME}") @@ -29,13 +32,144 @@ function(lite_download_and_uncompress INSTALL_DIR URL FILENAME) ) endfunction() +function (lite_deps TARGET) + set(options "") + set(oneValueArgs "") + set(multiValueArgs DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS HVY_DEPS ARGS) + cmake_parse_arguments(lite_deps "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(deps ${lite_deps_DEPS}) + + if(LITE_WITH_X86) + foreach(var ${lite_deps_X86_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + if(LITE_WITH_CUDA) + foreach(var ${lite_deps_CUDA_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + if(LITE_WITH_ARM) + foreach(var ${lite_deps_ARM_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + if(LITE_WITH_PROFILE) + foreach(var ${lite_deps_PROFILE_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + foreach(var ${lite_deps_LIGHT_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + foreach(var ${lite_deps_HVY_DEPS}) + set(deps ${deps} ${var}) + endforeach(var) + endif() + + set(${TARGET} ${deps} PARENT_SCOPE) + +endfunction() + +# Add names for lite libraries for latter compile. We use this name list to avoid compiling +# the whole fluid project to accelerate the compile speed. +set(offline_lib_registry_file "${CMAKE_BINARY_DIR}/lite_libs.txt") +file(WRITE ${offline_lib_registry_file} "") # clean +# cc_library with branch support. +# The branches: +# X86_DEPS: works only when LITE_WITH_X86 is ON. +# CUDA_DEPS: LITE_WITH_CUDA +# ARM_DEPS: LITE_WITH_ARM +# PROFILE_DEPS: LITE_WITH_PROFILE +# LIGHT_DEPS: LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +# HVY_DEPS: NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +function(lite_cc_library TARGET) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS LIGHT_DEPS + HVY_DEPS ARGS) + cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(deps "") + lite_deps(deps + DEPS ${args_DEPS} + X86_DEPS ${args_X86_DEPS} + CUDA_DEPS ${args_CUDA_DEPS} + ARM_DEPS ${args_ARM_DEPS} + PROFILE_DEPS ${args_PROFILE_DEPS} + LIGHT_DEPS ${args_LIGHT_DEPS} + HVY_DEPS ${args_HVY_DEPS} + ) + + cc_library(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS}) + + # register a library name. + file(APPEND ${offline_lib_registry_file} "${TARGET}\n") +endfunction() + +function(lite_cc_binary TARGET) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS + LIGHT_DEPS HVY_DEPS ARGS) + cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(deps "") + lite_deps(deps + DEPS ${args_DEPS} + X86_DEPS ${args_X86_DEPS} + CUDA_DEPS ${args_CUDA_DEPS} + ARM_DEPS ${args_ARM_DEPS} + PROFILE_DEPS ${args_PROFILE_DEPS} + LIGHT_DEPS ${args_LIGHT_DEPS} + HVY_DEPS ${args_HVY_DEPS} + ) + cc_binary(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ${args_DEPS}) +endfunction() + +# Add a unit-test name to file for latter offline manual test. +set(offline_test_registry_file "${CMAKE_BINARY_DIR}/lite_tests.txt") +file(WRITE ${offline_test_registry_file} "") # clean +# Test lite modules. +function(lite_cc_test TARGET) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS X86_DEPS CUDA_DEPS ARM_DEPS PROFILE_DEPS + LIGHT_DEPS HVY_DEPS + ARGS) + cmake_parse_arguments(args "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(deps "") + lite_deps(deps + DEPS ${args_DEPS} + X86_DEPS ${args_X86_DEPS} + CUDA_DEPS ${args_CUDA_DEPS} + ARM_DEPS ${args_ARM_DEPS} + PROFILE_DEPS ${args_PROFILE_DEPS} + LIGHT_DEPS ${args_LIGHT_DEPS} + HVY_DEPS ${args_HVY_DEPS} + ) + _lite_cc_test(${TARGET} SRCS ${args_SRCS} DEPS ${deps} ARGS ${args_ARGS}) + file(APPEND ${offline_test_registry_file} "${TARGET}\n") +endfunction() + +add_subdirectory(operators) +add_subdirectory(kernels) add_subdirectory(core) add_subdirectory(x86) +add_subdirectory(arm) add_subdirectory(host) add_subdirectory(cuda) -add_subdirectory(operators) -add_subdirectory(kernels) add_subdirectory(model_parser) add_subdirectory(utils) add_subdirectory(api) +add_subdirectory(gen_code) diff --git a/paddle/fluid/lite/api/CMakeLists.txt b/paddle/fluid/lite/api/CMakeLists.txt index 1de5d09394d55c..78f85a8caebc2b 100644 --- a/paddle/fluid/lite/api/CMakeLists.txt +++ b/paddle/fluid/lite/api/CMakeLists.txt @@ -1,11 +1,11 @@ -set(cxx_api_lite_deps scope_lite optimizer_lite target_wrapper_host optimizer_lite model_parser_lite) +set(cxx_api_lite_deps scope_lite optimizer_lite target_wrapper_host model_parser_lite) if(LITE_WITH_CUDA) set(cxx_api_lite_deps ${cxx_api_lite_deps} kernels_cuda) cc_library(cxx_api_lite_cuda SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} target_wrapper_cuda) nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda) endif() -cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} ${ops_lite}) +cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} ${ops_lite} program_lite) set(light_api_deps scope_lite target_wrapper_host model_parser_lite) @@ -17,23 +17,38 @@ endif() cc_library(light_api_lite SRCS light_api.cc DEPS ${light_api_deps} ${ops_lite} ${host_kernels}) message(STATUS "get ops ${ops_lite}") -message(STATUS "get kernels ${host_kernels}") +message(STATUS "get Host kernels ${host_kernels}") +message(STATUS "get ARM kernels ${arm_kernels}") include(ExternalProject) -set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url") set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING "A path setting inference demo download directories.") +if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING) + lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc + DEPS cxx_api_lite mir_passes + ${ops_lite} ${host_kernels} ${x86_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model + --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) -lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc - DEPS cxx_api_lite model_parser_lite target_wrapper_host - ${ops_lite} ${host_kernels} ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model - --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) -if(WITH_TESTING) -lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_naive_model.tar.gz") -add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz) -endif(WITH_TESTING) + lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_naive_model.tar.gz") + add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz) +endif() + +if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING) + add_dependencies(test_cxx_api_lite extern_lite_download_lite_naive_model_tar_gz) +endif() + +# if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) +# lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) +# endif() -lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) -cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host ${ops_lite} ${host_kernels}) +lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc + DEPS + cxx_api_lite + model_parser_lite + target_wrapper_host + mir_passes + ${ops_lite} ${host_kernels} + ARM_DEPS ${arm_kernels}) diff --git a/paddle/fluid/lite/api/cxx_api_bin.cc b/paddle/fluid/lite/api/cxx_api_bin.cc index f7e7426a45ae60..f53f6105d1bf8a 100644 --- a/paddle/fluid/lite/api/cxx_api_bin.cc +++ b/paddle/fluid/lite/api/cxx_api_bin.cc @@ -13,28 +13,22 @@ // limitations under the License. #include "paddle/fluid/lite/api/cxx_api.h" + +#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #include "paddle/fluid/lite/core/mir/passes.h" +#endif + #include "paddle/fluid/lite/core/op_registry.h" namespace paddle { namespace lite { void Run(const char* model_dir) { - lite::Executor predictor; -#ifndef LITE_WITH_CUDA - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}}); -#else - std::vector valid_places({ - Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}, - Place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)}, - Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kNCHW)}, - Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kNCHW)}, - Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)}, - Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}, - }); -#endif + lite::ExecutorLite predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}}); - predictor.Build(model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)}, + predictor.Build(model_dir, Place{TARGET(kARM), PRECISION(kFloat)}, valid_places); auto* input_tensor = predictor.GetInput(0); @@ -44,8 +38,6 @@ void Run(const char* model_dir) { data[i] = i; } - LOG(INFO) << "input " << *input_tensor; - predictor.Run(); auto* out = predictor.GetOutput(0); @@ -53,7 +45,7 @@ void Run(const char* model_dir) { LOG(INFO) << "out " << out->data()[0]; LOG(INFO) << "out " << out->data()[1]; LOG(INFO) << "dims " << out->dims(); - LOG(INFO) << "out " << *out; + LOG(INFO) << "out data size: " << out->data_size(); } } // namespace lite @@ -72,12 +64,19 @@ USE_LITE_OP(scale); USE_LITE_OP(feed); USE_LITE_OP(fetch); USE_LITE_OP(io_copy); -USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def); -USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def); -USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def); + USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); +#ifdef LITE_WITH_ARM +USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); +// USE_LITE_KERNEL(feed, kARM, kAny, kAny, def); +// USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def); +#endif // LITE_WITH_ARM + #ifdef LITE_WITH_CUDA USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def); USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device); diff --git a/paddle/fluid/lite/api/cxx_api_test.cc b/paddle/fluid/lite/api/cxx_api_test.cc index ae78a0c177fe3f..430bd9b58f80e5 100644 --- a/paddle/fluid/lite/api/cxx_api_test.cc +++ b/paddle/fluid/lite/api/cxx_api_test.cc @@ -32,7 +32,8 @@ namespace lite { TEST(CXXApi, test) { lite::ExecutorLite predictor; #ifndef LITE_WITH_CUDA - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}}); + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); #else std::vector valid_places({ Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}, @@ -44,7 +45,8 @@ TEST(CXXApi, test) { }); #endif - predictor.Build(FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)}, + predictor.Build(FLAGS_model_dir, + Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda valid_places); auto* input_tensor = predictor.GetInput(0); @@ -69,16 +71,18 @@ TEST(CXXApi, test) { #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK TEST(CXXApi, save_model) { lite::ExecutorLite predictor; - std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}}); + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); predictor.Build(FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)}, valid_places); + LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; predictor.SaveModel(FLAGS_optimized_model); } #endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK -TEST(CXXTrainer, train) { +/*TEST(CXXTrainer, train) { Place prefer_place({TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}); std::vector valid_places({prefer_place}); auto scope = std::make_shared(); @@ -92,7 +96,7 @@ TEST(CXXTrainer, train) { main_program_desc.ParseFromString(main_program_pb); startup_program_desc.ParseFromString(startup_program_pb); - LOG(INFO) << main_program_desc.DebugString(); + // LOG(INFO) << main_program_desc.DebugString(); for (const auto& op : main_program_desc.blocks(0).ops()) { LOG(INFO) << "get op " << op.type(); @@ -108,7 +112,7 @@ TEST(CXXTrainer, train) { data0[0] = 0; exe.Run(); -} +}*/ #endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK } // namespace lite @@ -116,16 +120,39 @@ TEST(CXXTrainer, train) { USE_LITE_OP(mul); USE_LITE_OP(fc); +USE_LITE_OP(relu); USE_LITE_OP(scale); USE_LITE_OP(feed); USE_LITE_OP(fetch); USE_LITE_OP(io_copy); -USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def); -USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def); -USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def); +USE_LITE_OP(elementwise_add) +USE_LITE_OP(elementwise_sub) +USE_LITE_OP(square) +USE_LITE_OP(softmax) +USE_LITE_OP(dropout) +USE_LITE_OP(concat) +USE_LITE_OP(conv2d) +USE_LITE_OP(depthwise_conv2d) +USE_LITE_OP(pool2d) USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); +#ifdef LITE_WITH_X86 +USE_LITE_KERNEL(relu, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(square, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(concat, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(pool2d, kX86, kFloat, kNCHW, def); +#endif + #ifdef LITE_WITH_CUDA USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def); USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device); diff --git a/paddle/fluid/lite/api/light_api.h b/paddle/fluid/lite/api/light_api.h index 9cd9f62a0b0950..a43755c87387e6 100644 --- a/paddle/fluid/lite/api/light_api.h +++ b/paddle/fluid/lite/api/light_api.h @@ -22,6 +22,7 @@ #include #include #include +#include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/program.h" #include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/model_parser/model_parser.h" @@ -63,7 +64,7 @@ class LightPredictor { private: void BuildRuntimeProgram(const framework::proto::ProgramDesc& prog) { - std::vector insts; + std::vector insts; // 1. Create op first Program program(prog, scope_, {}); @@ -71,9 +72,8 @@ class LightPredictor { // Create the kernels of the target places, and filter out the specific // kernel with the target alias. - for (auto& op : program.ops) { - lite::pb::OpDesc desc(op->op_info()->desc()); - auto kernel_type = desc.GetAttr(kKernelTypeAttr).get(); + for (auto& op : program.ops()) { + auto kernel_type = op->op_info()->GetAttr(kKernelTypeAttr); std::string op_type, alias; Place place; KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); @@ -84,11 +84,12 @@ class LightPredictor { return it->alias() == alias; }); CHECK(it != kernels.end()); + (*it)->SetContext(ContextScheduler::Global().NewContext((*it)->target())); insts.emplace_back(op, std::move(*it)); } program_.reset(new RuntimeProgram(std::move(insts))); - CHECK(program.exec_scope); - program_->set_exec_scope(program.exec_scope); + CHECK(program.exec_scope()); + program_->set_exec_scope(program.exec_scope()); } private: diff --git a/paddle/fluid/lite/api/light_api_test.cc b/paddle/fluid/lite/api/light_api_test.cc index ad0d87cf00bd98..b1e6741e09ebd0 100644 --- a/paddle/fluid/lite/api/light_api_test.cc +++ b/paddle/fluid/lite/api/light_api_test.cc @@ -44,8 +44,18 @@ USE_LITE_OP(scale); USE_LITE_OP(feed); USE_LITE_OP(fetch); USE_LITE_OP(io_copy); -USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def); -USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def); -USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def); + USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); + +#ifdef LITE_WITH_X86 +USE_LITE_KERNEL(relu, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(square, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def); +#endif diff --git a/paddle/fluid/lite/arm/CMakeLists.txt b/paddle/fluid/lite/arm/CMakeLists.txt new file mode 100644 index 00000000000000..8abd04b5233829 --- /dev/null +++ b/paddle/fluid/lite/arm/CMakeLists.txt @@ -0,0 +1,2 @@ + +add_subdirectory(math) diff --git a/paddle/fluid/lite/arm/math/CMakeLists.txt b/paddle/fluid/lite/arm/math/CMakeLists.txt new file mode 100644 index 00000000000000..8af2c33943f7e2 --- /dev/null +++ b/paddle/fluid/lite/arm/math/CMakeLists.txt @@ -0,0 +1,9 @@ +if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) + return() +endif() + +if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) + return() +endif() + +cc_library(math_arm SRCS funcs.cc packed_sgemm.cc softmax.cc scale.cc elementwise.cc DEPS ${lite_kernel_deps} eigen3) diff --git a/paddle/fluid/lite/arm/math/elementwise.cc b/paddle/fluid/lite/arm/math/elementwise.cc new file mode 100644 index 00000000000000..68140a5d7dbccc --- /dev/null +++ b/paddle/fluid/lite/arm/math/elementwise.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/elementwise.h" +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void elementwise_add(const float* dinx, const float* diny, float* dout, + int num) { + int cnt = num >> 4; + int remain = num % 16; +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* dinx_ptr = dinx + (i << 4); + const float* diny_ptr = diny + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t dinx0 = vld1q_f32(dinx_ptr); + float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4); + float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8); + float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12); + + float32x4_t diny0 = vld1q_f32(diny_ptr); + float32x4_t diny1 = vld1q_f32(diny_ptr + 4); + float32x4_t diny2 = vld1q_f32(diny_ptr + 8); + float32x4_t diny3 = vld1q_f32(diny_ptr + 12); + + float32x4_t vsum0 = vaddq_f32(dinx0, diny0); + float32x4_t vsum1 = vaddq_f32(dinx1, diny1); + float32x4_t vsum2 = vaddq_f32(dinx2, diny2); + float32x4_t vsum3 = vaddq_f32(dinx3, diny3); + + vst1q_f32(dout_ptr, vsum0); + vst1q_f32(dout_ptr + 4, vsum1); + vst1q_f32(dout_ptr + 8, vsum2); + vst1q_f32(dout_ptr + 12, vsum3); + } + if (remain > 0) { + const float* dinx_ptr = dinx + (cnt << 4); + const float* diny_ptr = diny + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *dinx_ptr + *diny_ptr; + dout_ptr++; + dinx_ptr++; + diny_ptr++; + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/elementwise.h b/paddle/fluid/lite/arm/math/elementwise.h new file mode 100644 index 00000000000000..cf4c8e46b0703a --- /dev/null +++ b/paddle/fluid/lite/arm/math/elementwise.h @@ -0,0 +1,28 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void elementwise_add(const T* dinx, const T* diny, T* dout, int num); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/funcs.cc b/paddle/fluid/lite/arm/math/funcs.cc new file mode 100644 index 00000000000000..4013ac31bfd1c5 --- /dev/null +++ b/paddle/fluid/lite/arm/math/funcs.cc @@ -0,0 +1,155 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/funcs.h" +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void fill_bias_fc(float *out, const float *bias, const int num, + const int channel) { + int cnt = channel >> 4; + int remain = channel & 15; + + for (int j = 0; j < num; ++j) { + const float *ptr_bias = bias; + float *ptr_out = out + j * channel; + + float32x4_t vout1; + float32x4_t vout2; + float32x4_t vout3; + float32x4_t vout4; + + for (int i = 0; i < cnt; ++i) { + float32x4_t vin1 = vld1q_f32(ptr_out); + float32x4_t vb1 = vld1q_f32(ptr_bias); + + float32x4_t vin2 = vld1q_f32(ptr_out + 4); + float32x4_t vb2 = vld1q_f32(ptr_bias + 4); + + float32x4_t vin3 = vld1q_f32(ptr_out + 8); + float32x4_t vb3 = vld1q_f32(ptr_bias + 8); + + float32x4_t vin4 = vld1q_f32(ptr_out + 12); + float32x4_t vb4 = vld1q_f32(ptr_bias + 12); + + vout1 = vaddq_f32(vin1, vb1); + vout2 = vaddq_f32(vin2, vb2); + vout3 = vaddq_f32(vin3, vb3); + vout4 = vaddq_f32(vin4, vb4); + + vst1q_f32(ptr_out, vout1); + vst1q_f32(ptr_out + 4, vout2); + vst1q_f32(ptr_out + 8, vout3); + vst1q_f32(ptr_out + 12, vout4); + + ptr_out += 16; + ptr_bias += 16; + } +#if 0 + if (cnt > 0) { + asm( + "1: \n" + "vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n" + "vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n" + "vadd.f32 q2, q0, q1 @ add bias\n" + "vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 1b @ jump to main loop\n" + :[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \ + [cnt] "+r"(cnt) + : + :"q0", "q1", "q2" + ); + } +#endif + for (int i = 0; i < remain; ++i) { + *(ptr_out++) += *(ptr_bias++); + } + } +} + +template <> +void fill_bias_fc(int *out, const int *bias, const int num, + const int channel) { + int cnt = channel >> 4; + int remain = channel & 15; + + for (int j = 0; j < num; ++j) { + const int *ptr_bias = bias; + int *ptr_out = out + j * channel; + + int32x4_t vout1; + int32x4_t vout2; + int32x4_t vout3; + int32x4_t vout4; + + for (int i = 0; i < cnt; ++i) { + int32x4_t vin1 = vld1q_s32(ptr_out); + int32x4_t vb1 = vld1q_s32(ptr_bias); + + int32x4_t vin2 = vld1q_s32(ptr_out + 4); + int32x4_t vb2 = vld1q_s32(ptr_bias + 4); + + int32x4_t vin3 = vld1q_s32(ptr_out + 8); + int32x4_t vb3 = vld1q_s32(ptr_bias + 8); + + int32x4_t vin4 = vld1q_s32(ptr_out + 12); + int32x4_t vb4 = vld1q_s32(ptr_bias + 12); + + vout1 = vaddq_s32(vin1, vb1); + vout2 = vaddq_s32(vin2, vb2); + vout3 = vaddq_s32(vin3, vb3); + vout4 = vaddq_s32(vin4, vb4); + + vst1q_s32(ptr_out, vout1); + vst1q_s32(ptr_out + 4, vout2); + vst1q_s32(ptr_out + 8, vout3); + vst1q_s32(ptr_out + 12, vout4); + + ptr_out += 16; + ptr_bias += 16; + } + +#if 0 + if (cnt > 0) { + asm( + "1: \n" + "vld1.32 {d0-d1}, [%[ptr_out]] @ load data\n" + "vld1.32 {d2-d3}, [%[ptr_bias]]! @ load data\n" + "vadd.s32 q2, q0, q1 @ add bias\n" + "vst1.32 {d4-d5}, [%[ptr_out]]! @ store result\n" + "subs %[cnt], #1 @ loop count -1\n" + "bne 1b @ jump to main loop\n" + :[ptr_out] "+r"(ptr_out), [ptr_bias] "+r"(ptr_bias), \ + [cnt] "+r"(cnt) + : + :"q0", "q1", "q2" + ); + } +#endif + for (int i = 0; i < remain; ++i) { + *(ptr_out++) += *(ptr_bias++); + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/funcs.h b/paddle/fluid/lite/arm/math/funcs.h new file mode 100644 index 00000000000000..e95506c1a968f2 --- /dev/null +++ b/paddle/fluid/lite/arm/math/funcs.h @@ -0,0 +1,336 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/lite/arm/math/elementwise.h" +#include "paddle/fluid/lite/arm/math/packed_sgemm.h" +#include "paddle/fluid/lite/arm/math/scale.h" +#include "paddle/fluid/lite/arm/math/softmax.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#define c_inv_mant_mask ~0x7f800000u +#define c_cephes_SQRTHF 0.707106781186547524 +#define c_cephes_log_p0 7.0376836292E-2 +#define c_cephes_log_p1 -1.1514610310E-1 +#define c_cephes_log_p2 1.1676998740E-1 +#define c_cephes_log_p3 -1.2420140846E-1 +#define c_cephes_log_p4 +1.4249322787E-1 +#define c_cephes_log_p5 -1.6668057665E-1 +#define c_cephes_log_p6 +2.0000714765E-1 +#define c_cephes_log_p7 -2.4999993993E-1 +#define c_cephes_log_p8 +3.3333331174E-1 +#define c_cephes_log_q1 -2.12194440e-4 +#define c_cephes_log_q2 0.693359375 + +// natural logarithm computed for 4 simultaneous float +// return NaN for x <= 0 +inline float32x4_t log_ps(float32x4_t x) { + float32x4_t one = vdupq_n_f32(1); + + x = vmaxq_f32(x, vdupq_n_f32(0)); // force flush to zero on denormal values + uint32x4_t invalid_mask = vcleq_f32(x, vdupq_n_f32(0)); + + int32x4_t ux = vreinterpretq_s32_f32(x); + + int32x4_t emm0 = vshrq_n_s32(ux, 23); + + // keep only the fractional part + ux = vandq_s32(ux, vdupq_n_s32(c_inv_mant_mask)); + ux = vorrq_s32(ux, vreinterpretq_s32_f32(vdupq_n_f32(0.5f))); + x = vreinterpretq_f32_s32(ux); + + emm0 = vsubq_s32(emm0, vdupq_n_s32(0x7f)); + float32x4_t e = vcvtq_f32_s32(emm0); + + e = vaddq_f32(e, one); + + // part2: + // if( x < SQRTHF ) { + // e -= 1; + // x = x + x - 1.0; + // } else { + // x = x - 1.0; + // } + // + uint32x4_t mask = vcltq_f32(x, vdupq_n_f32(c_cephes_SQRTHF)); + float32x4_t tmp = + vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); + x = vsubq_f32(x, one); + e = vsubq_f32( + e, vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(one), mask))); + x = vaddq_f32(x, tmp); + + float32x4_t z = vmulq_f32(x, x); + + float32x4_t y = vdupq_n_f32(c_cephes_log_p0); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p1)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p2)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p3)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p4)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p5)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p6)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p7)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p8)); + y = vmulq_f32(y, x); + + y = vmulq_f32(y, z); + + tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q1)); + y = vaddq_f32(y, tmp); + + tmp = vmulq_f32(z, vdupq_n_f32(0.5f)); + y = vsubq_f32(y, tmp); + + tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2)); + x = vaddq_f32(x, y); + x = vaddq_f32(x, tmp); + x = vreinterpretq_f32_u32(vorrq_u32( + vreinterpretq_u32_f32(x), invalid_mask)); // negative arg will be NAN + return x; +} + +#define c_exp_hi 88.3762626647949f +#define c_exp_lo -88.3762626647949f + +#define c_cephes_LOG2EF 1.44269504088896341 +#define c_cephes_exp_C1 0.693359375 +#define c_cephes_exp_C2 -2.12194440e-4 + +#define c_cephes_exp_p0 1.9875691500E-4 +#define c_cephes_exp_p1 1.3981999507E-3 +#define c_cephes_exp_p2 8.3334519073E-3 +#define c_cephes_exp_p3 4.1665795894E-2 +#define c_cephes_exp_p4 1.6666665459E-1 +#define c_cephes_exp_p5 5.0000001201E-1 + +// exp() computed for 4 float at once +inline float32x4_t exp_ps(float32x4_t x) { + float32x4_t tmp, fx; + + float32x4_t one = vdupq_n_f32(1); + x = vminq_f32(x, vdupq_n_f32(c_exp_hi)); + x = vmaxq_f32(x, vdupq_n_f32(c_exp_lo)); + + // express exp(x) as exp(g + n*log(2)) + fx = vmlaq_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF)); + + // perform a floorf + tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx)); + + // if greater, substract 1 + uint32x4_t mask = vcgtq_f32(tmp, fx); + mask = vandq_u32(mask, vreinterpretq_u32_f32(one)); + + fx = vsubq_f32(tmp, vreinterpretq_f32_u32(mask)); + + tmp = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C1)); + float32x4_t z = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C2)); + x = vsubq_f32(x, tmp); + x = vsubq_f32(x, z); + + static const float cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1, + c_cephes_exp_p2, c_cephes_exp_p3, + c_cephes_exp_p4, c_cephes_exp_p5}; + float32x4_t y = vld1q_dup_f32(cephes_exp_p + 0); + float32x4_t c1 = vld1q_dup_f32(cephes_exp_p + 1); + float32x4_t c2 = vld1q_dup_f32(cephes_exp_p + 2); + float32x4_t c3 = vld1q_dup_f32(cephes_exp_p + 3); + float32x4_t c4 = vld1q_dup_f32(cephes_exp_p + 4); + float32x4_t c5 = vld1q_dup_f32(cephes_exp_p + 5); + + y = vmulq_f32(y, x); + z = vmulq_f32(x, x); + + y = vaddq_f32(y, c1); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c2); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c3); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c4); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c5); + + y = vmulq_f32(y, z); + y = vaddq_f32(y, x); + y = vaddq_f32(y, one); + + // build 2^n + int32x4_t mm; + mm = vcvtq_s32_f32(fx); + mm = vaddq_s32(mm, vdupq_n_s32(0x7f)); + mm = vshlq_n_s32(mm, 23); + float32x4_t pow2n = vreinterpretq_f32_s32(mm); + + y = vmulq_f32(y, pow2n); + return y; +} + +#define c_minus_cephes_DP1 -0.78515625 +#define c_minus_cephes_DP2 -2.4187564849853515625e-4 +#define c_minus_cephes_DP3 -3.77489497744594108e-8 +#define c_sincof_p0 -1.9515295891E-4 +#define c_sincof_p1 8.3321608736E-3 +#define c_sincof_p2 -1.6666654611E-1 +#define c_coscof_p0 2.443315711809948E-005 +#define c_coscof_p1 -1.388731625493765E-003 +#define c_coscof_p2 4.166664568298827E-002 +#define c_cephes_FOPI 1.27323954473516 // 4 / M_PI + +// evaluation of 4 sines & cosines at once. +// +// The code is the exact rewriting of the cephes sinf function. +// Precision is excellent as long as x < 8192 (I did not bother to +// take into account the special handling they have for greater values +// -- it does not return garbage for arguments over 8192, though, but +// the extra precision is missing). +// +// Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the +// surprising but correct result. +// +// Note also that when you compute sin(x), cos(x) is available at +// almost no extra price so both sin_ps and cos_ps make use of +// sincos_ps.. +// +inline void sincos_ps(float32x4_t x, float32x4_t* ysin, float32x4_t* ycos) { + // any x + float32x4_t xmm1, xmm2, xmm3, y; + + uint32x4_t emm2; + + uint32x4_t sign_mask_sin, sign_mask_cos; + sign_mask_sin = vcltq_f32(x, vdupq_n_f32(0)); + x = vabsq_f32(x); + + // scale by 4/Pi + y = vmulq_f32(x, vdupq_n_f32(c_cephes_FOPI)); + + // store the integer part of y in mm0 + emm2 = vcvtq_u32_f32(y); + // j=(j+1) & (~1) (see the cephes sources) + emm2 = vaddq_u32(emm2, vdupq_n_u32(1)); + emm2 = vandq_u32(emm2, vdupq_n_u32(~1)); + y = vcvtq_f32_u32(emm2); + + // get the polynom selection mask + // there is one polynom for 0 <= x <= Pi/4 + // and another one for Pi/4 +void fill_bias_fc(T* tensor, const T* bias, const int num, const int channel); + +template +void fc_compute_eigen(const T* x, int x_h, int x_w, // + const T* w, int w_h, int w_w, // + const T* b, // + T* out) { + using matrix_t = + Eigen::Matrix; + + Eigen::Map X(x, x_h, x_w); + Eigen::Map W(w, w_h, w_w); + Eigen::Map Out(out, x_h, w_w); + + Out = X * W; + + if (b) { + Eigen::Map> B(b, w_w); + Out = Out.array().rowwise() + B.array(); + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/packed_sgemm.cc b/paddle/fluid/lite/arm/math/packed_sgemm.cc new file mode 100644 index 00000000000000..1028d371d3c720 --- /dev/null +++ b/paddle/fluid/lite/arm/math/packed_sgemm.cc @@ -0,0 +1,3049 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/packed_sgemm.h" +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#ifdef __aarch64__ +void prepackA_8x12(float *out, const float *in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax); +void prepackA_trans_8x12(float *out, const float *in, const int ldin, + const int m0, const int mmax, const int k0, + const int kmax); +void sgemm_conv_8x12(const float *A_packed, const float *B, const float *bias, + float *C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB, ARMContext *ctx); +#else +// for kA72 +void prepackA_6x8(float *out, const float *in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax); +void prepackA_trans_6x8(float *out, const float *in, const int ldin, + const int m0, const int mmax, const int k0, + const int kmax); +// for kA73 +void prepackA_4x8(float *out, const float *in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax); +void prepackA_trans_4x8(float *out, const float *in, const int ldin, + const int m0, const int mmax, const int k0, + const int kmax); +// for kA72, 6x8 +void sgemm_conv_6x8(const float *A_packed, const float *B, const float *bias, + float *C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB, ARMContext *ctx); +// for kA73, 4x8 +void sgemm_conv_4x8(const float *A_packed, const float *B, const float *bias, + float *C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB, ARMContext *ctx); +#endif // __aarch64__ + +/** + * \brief input data is not transpose + * for arm-v7a, transform data to block x k x 6 layout + * for arm-v8a, transform data to block x k x 8 layout + */ +void prepackA(float *out, const float *in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax, bool is_trans, + ARMContext *ctx) { +#ifdef __aarch64__ + if (is_trans) { + prepackA_trans_8x12(out, in, ldin, m0, mmax, k0, kmax); + } else { + prepackA_8x12(out, in, ldin, m0, mmax, k0, kmax); + } +#else + if (ctx->arch() == kA73) { + if (is_trans) { + prepackA_trans_4x8(out, in, ldin, m0, mmax, k0, kmax); + } else { + prepackA_4x8(out, in, ldin, m0, mmax, k0, kmax); + } + } else { + if (is_trans) { + prepackA_trans_6x8(out, in, ldin, m0, mmax, k0, kmax); + } else { + prepackA_6x8(out, in, ldin, m0, mmax, k0, kmax); + } + } +#endif +} + +void prepackA(TensorLite *tout, const TensorLite &tin, int m, int k, int group, + bool is_trans, ARMContext *ctx) { + int hblock = get_hblock(ctx->arch()); + int m_roundup = hblock * ((m + hblock - 1) / hblock); + int group_size_round_up = ((m_roundup * k + 15) / 16) * 16; + if (tout->numel() < group_size_round_up * group) { + tout->Resize({group_size_round_up * group}); + } + int lda = k; + if (is_trans) { + lda = m; + } + for (int g = 0; g < group; ++g) { + const float *weights_group = tin.data() + g * m * k; + float *weights_trans_ptr = + tout->mutable_data() + g * group_size_round_up; + prepackA(weights_trans_ptr, weights_group, lda, 0, m, 0, k, is_trans, ctx); + } +} + +/// a: m*k b: k*n c: m*n +void sgemm_prepack(const float *A_packed, const float *B, const float *bias, + float *C, int M, int N, int K, bool is_bias, bool is_relu, + bool is_transB, ARMContext *ctx) { +#ifdef __aarch64__ + sgemm_conv_8x12(A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, + ctx); +#else // armv7 + if (ctx->arch() == kA73) { + sgemm_conv_4x8(A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, + ctx); + } else { + sgemm_conv_6x8(A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB, + ctx); + } +#endif // arm64 +} + +#ifdef __aarch64__ +void prepackA_8x12(float *out, const float *in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + + uint32_t *dout = reinterpret_cast(out); + const uint32_t *inptr = reinterpret_cast(in); + + int stride = x_len * 8; +#pragma omp parallel for + for (int y = m0; y < mmax; y += 8) { + uint32_t *outptr = dout + stride * (y - m0) / 8; + + const uint32_t *inptr0 = inptr + y * ldin + k0; + const uint32_t *inptr1 = inptr0 + ldin; + const uint32_t *inptr2 = inptr1 + ldin; + const uint32_t *inptr3 = inptr2 + ldin; + const uint32_t *inptr4 = inptr3 + ldin; + const uint32_t *inptr5 = inptr4 + ldin; + const uint32_t *inptr6 = inptr5 + ldin; + const uint32_t *inptr7 = inptr6 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + "prfm pldl1keep, [%[ptr4]] \n" + "prfm pldl1keep, [%[ptr4], #64] \n" + "prfm pldl1keep, [%[ptr5]] \n" + "prfm pldl1keep, [%[ptr5], #64] \n" + "prfm pldl1keep, [%[ptr6]] \n" + "prfm pldl1keep, [%[ptr6], #64] \n" + "prfm pldl1keep, [%[ptr7]] \n" + "prfm pldl1keep, [%[ptr7], #64] \n" + : + : [ptr0] "r"(inptr0), [ptr1] "r"(inptr1), [ptr2] "r"(inptr2), + [ptr3] "r"(inptr3), [ptr4] "r"(inptr4), [ptr5] "r"(inptr5), + [ptr6] "r"(inptr6), [ptr7] "r"(inptr7) + : "memory"); + + int x = x_len; + //! cope with row index exceed real size, set to zero buffer + if ((y + 7) >= mmax) { + switch ((y + 7) - mmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + default: + break; + } + } + for (; x > 7; x -= 8) { + asm volatile( + // Load up 8 elements (2 vectors) from each of 8 sources. + "LDP q0, q1, [%[inptr0]], #32\n" // q0=A0A1A2A3 + "LDP q2, q3, [%[inptr1]], #32\n" // q2=B0B1B2B3 + "LDP q4, q5, [%[inptr2]], #32\n" // q4=C0C1C2C3 + "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1 + "prfm pldl1keep, [%[inptr0], #128] \n" + "LDP q6, q7, [%[inptr3]], #32\n" // q6=D0D1D2D3 + "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1 + "LDP q8, q9, [%[inptr4]], #32\n" + "LDP q10, q11, [%[inptr5]], #32\n" + "LDP q12, q13, [%[inptr6]], #32\n" + "ZIP1 v18.4s, v8.4s, v12.4s\n" + "prfm pldl1keep, [%[inptr1], #128]\n" + "LDP q14, q15, [%[inptr7]], #32\n" + "ZIP1 v19.4s, v10.4s, v14.4s\n" + + "ZIP1 v20.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0 + "prfm pldl1keep, [%[inptr2], #128]\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP2 v16.4s, v0.4s, v4.4s\n" + "prfm pldl1keep, [%[inptr3], #128]\n" + "ZIP2 v17.4s, v2.4s, v6.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Write back the first + // element of each source + + "ZIP2 v18.4s, v8.4s, v12.4s\n" + "ZIP2 v19.4s, v10.4s, v14.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Write back the second + // element of each source + + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "prfm pldl1keep, [%[inptr4], #128]\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP1 v16.4s, v1.4s, v5.4s\n" + "prfm pldl1keep, [%[inptr5], #128]\n" + "ZIP1 v17.4s, v3.4s, v7.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Third element + + "ZIP1 v18.4s, v9.4s, v13.4s\n" + "ZIP1 v19.4s, v11.4s, v15.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Fourth element + + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "prfm pldl1keep, [%[inptr6], #128]\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP2 v16.4s, v1.4s, v5.4s\n" + "ZIP2 v17.4s, v3.4s, v7.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Fifth element + + "ZIP2 v18.4s, v9.4s, v13.4s\n" + "prfm pldl1keep, [%[inptr7], #128]\n" + "ZIP2 v19.4s, v11.4s, v15.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Sixth element + + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Seventh element + + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Eighth element + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "cc", "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + } + } +} + +void prepackA_trans_8x12(float *out, const float *in, const int ldin, + const int m0, const int mmax, const int k0, + const int kmax) { + uint32_t *outptr = reinterpret_cast(out); + const uint32_t *inptr = + reinterpret_cast(in) + k0 * ldin + m0; + + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = mmax - m0; + int y_len = kmax - k0; + int right_remain = x_len - 8 * (x_len / 8); + int right_pad = 8 - right_remain; + if (right_remain == 0) { + right_pad = 0; + } + + uint32_t *outptr_row = outptr; + int stride_out = 8 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t *ptr0 = inptr + y * ldin; + const uint32_t *ptr1 = ptr0 + ldin; + const uint32_t *ptr2 = ptr1 + ldin; + const uint32_t *ptr3 = ptr2 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + : [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3) + : "memory"); + + uint32_t *outptr_row_col = outptr_row + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + uint32x4_t vr00 = vld1q_u32(ptr0); + uint32x4_t vr01 = vld1q_u32(ptr0 + 4); + + uint32x4_t vr10 = vld1q_u32(ptr1); + uint32x4_t vr11 = vld1q_u32(ptr1 + 4); + + vst1q_u32(outptr_row_col, vr00); + vst1q_u32(outptr_row_col + 4, vr01); + + uint32x4_t vr20 = vld1q_u32(ptr2); + uint32x4_t vr21 = vld1q_u32(ptr2 + 4); + + vst1q_u32(outptr_row_col + 8, vr10); + vst1q_u32(outptr_row_col + 12, vr11); + + uint32x4_t vr30 = vld1q_u32(ptr3); + uint32x4_t vr31 = vld1q_u32(ptr3 + 4); + + vst1q_u32(outptr_row_col + 16, vr20); + vst1q_u32(outptr_row_col + 20, vr21); + + vst1q_u32(outptr_row_col + 24, vr30); + vst1q_u32(outptr_row_col + 28, vr31); + + ptr0 += 8; + ptr1 += 8; + ptr2 += 8; + ptr3 += 8; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32x4_t vr00 = vld1q_u32(ptr0); + uint32x4_t vr01 = vld1q_u32(ptr0 + 4); + + uint32x4_t vr10 = vld1q_u32(ptr1); + uint32x4_t vr11 = vld1q_u32(ptr1 + 4); + + uint32x4_t vr00_1 = vbslq_u32(vmask1, vr00, vzero); + uint32x4_t vr01_1 = vbslq_u32(vmask2, vr01, vzero); + + uint32x4_t vr20 = vld1q_u32(ptr2); + uint32x4_t vr21 = vld1q_u32(ptr2 + 4); + + vst1q_u32(outptr_row_col, vr00_1); + vst1q_u32(outptr_row_col + 4, vr01_1); + + uint32x4_t vr10_1 = vbslq_u32(vmask1, vr10, vzero); + uint32x4_t vr11_1 = vbslq_u32(vmask2, vr11, vzero); + + uint32x4_t vr30 = vld1q_u32(ptr3); + uint32x4_t vr31 = vld1q_u32(ptr3 + 4); + + vst1q_u32(outptr_row_col + 8, vr10_1); + vst1q_u32(outptr_row_col + 12, vr11_1); + + uint32x4_t vr20_1 = vbslq_u32(vmask1, vr20, vzero); + uint32x4_t vr21_1 = vbslq_u32(vmask2, vr21, vzero); + + uint32x4_t vr30_1 = vbslq_u32(vmask1, vr30, vzero); + uint32x4_t vr31_1 = vbslq_u32(vmask2, vr31, vzero); + + vst1q_u32(outptr_row_col + 16, vr20_1); + vst1q_u32(outptr_row_col + 20, vr21_1); + vst1q_u32(outptr_row_col + 24, vr30_1); + vst1q_u32(outptr_row_col + 28, vr31_1); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t *ptr0 = inptr + y * ldin; + uint32_t *outptr_row_col = outptr_row + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + uint32x4_t vr0 = vld1q_u32(ptr0); + uint32x4_t vr1 = vld1q_u32(ptr0 + 4); + vst1q_u32(outptr_row_col, vr0); + vst1q_u32(outptr_row_col + 4, vr1); + + ptr0 += 8; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32x4_t vr0 = vld1q_u32(ptr0); + uint32x4_t vr1 = vld1q_u32(ptr0 + 4); + + uint32x4_t vr0_1 = vbslq_u32(vmask1, vr0, vzero); + uint32x4_t vr1_1 = vbslq_u32(vmask2, vr1, vzero); + + vst1q_u32(outptr_row_col, vr0_1); + vst1q_u32(outptr_row_col + 4, vr1_1); + } + } +} + +#else // __aarch64__ +void prepackA_6x8(float* out, const float* in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + + uint32_t* dout = reinterpret_cast(out); + const uint32_t* inptr = reinterpret_cast(in); + + uint32_t* outptr = dout; + + //! data A is not transposed, transpose A to k * 6 + for (int y = m0; y < mmax; y += 6) { + const uint32_t* inptr0 = inptr + y * ldin + k0; + const uint32_t* inptr1 = inptr0 + ldin; + const uint32_t* inptr2 = inptr1 + ldin; + const uint32_t* inptr3 = inptr2 + ldin; + const uint32_t* inptr4 = inptr3 + ldin; + const uint32_t* inptr5 = inptr4 + ldin; + + int x = x_len; + //! cope with row index exceed real size, set to zero buffer + if ((y + 5) >= mmax) { + switch ((y + 5) - mmax) { + case 4: + inptr1 = zerobuff; + case 3: + inptr2 = zerobuff; + case 2: + inptr3 = zerobuff; + case 1: + inptr4 = zerobuff; + case 0: + inptr5 = zerobuff; + default: + break; + } + } + + for (; x > 7; x -= 8) { + //! zip load 8 elements (2 neon Q registers) from each of 6 rows + asm volatile( + "vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, " + "q0,q1=r00,r04,r01,r05,r02,r06,r03,r07\n" + "vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, " + "q2,q3=r10,r14,r11,r15,r12,r16,r13,r17\n" + "vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, " + "q4,q5=r20,r24,r21,r25,r22,r26,r23,r27\n" + "vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, " + "q6,q7=r30,r34,r31,r35,r32,r36,r33,r37\n" + "vld4.32 {d16-d19}, [%[inptr4]]! @ zip load r4, " + "q8,q9=r40,r44,r41,r45,r42,r46,r43,r47\n" + "vld4.32 {d20-d23}, [%[inptr5]]! @ zip load r5, " + "q10,q11=r50,r54,r51,r55,r52,r56,r53,r57\n" + + "vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; " + "q2=r04,r14,r05,r15\n" + "vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; " + "q6=r24,r34,r25,r35\n" + "vtrn.32 q8, q10 @ trans data: q8=r40,r50,r41,r51; " + "q10=r44,r54,r45,r55\n" + + "vswp d1, d8 @ swap d1, d8, q0=r00,r10,r20,r30; " + "q4=r01,r11,r21,r31\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write q0:r00,r10,r20,r30\n" + "vst1.32 {d16}, [%[outptr]]! @ write d16(q8,low),r40,r50\n" + "vst1.32 {d8-d9}, [%[outptr]]! @ write q4:r01,r11,r21,r31\n" + "vst1.32 {d17}, [%[outptr]]! @ write d16(q8,high),r41,r51\n" + + "vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; " + "q3=r06,r16,r07,r17\n" + "vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; " + "q7=r26,r36,r27,r37\n" + "vtrn.32 q9, q11 @ trans data: q9=r42,r52,r43,r53; " + "q11=r46,r56,r47,r57\n" + + "vswp d3, d10 @ swap d3, d10, " + "q1=r02,r12,r22,r32; q5=r03,r13,r23,r33\n" + "vst1.32 {d2-d3}, [%[outptr]]! @ write q1:r02,r12,r22,r32\n" + "vst1.32 {d18}, [%[outptr]]! @ write d18(q9,low),r42,r52\n" + "vst1.32 {d10-d11},[%[outptr]]! @ write q5:r03,r13,r23,r33\n" + "vst1.32 {d19}, [%[outptr]]! @ write d19(q9,high),r43,r53\n" + + "vswp d5, d12 @ swap d5, d12,q2=r04,r14,r24,r34; " + "q6=r05,r15,r25,r35\n" + "vst1.32 {d4-d5}, [%[outptr]]! @ write q2:r04,r14,r24,r34\n" + "vst1.32 {d20}, [%[outptr]]! @ write d20(q10,low),r44,r54\n" + "vst1.32 {d12-d13},[%[outptr]]! @ write q6:r05,r15,r25,r35\n" + "vst1.32 {d21}, [%[outptr]]! @ write d21(q10,high),r45,r55\n" + + "vswp d7, d14 @ swap d7, d14, " + "q3=r06,r16,r26,r36; q7=r07,r17,r27,r37\n" + "vst1.32 {d6-d7}, [%[outptr]]! @ write q3:r06,r16,r26,r36\n" + "vst1.32 {d22}, [%[outptr]]! @ write d22(q11,low),r46,r56\n" + "vst1.32 {d14-d15},[%[outptr]]! @ write q7:r07,r17,r27,r37\n" + "vst1.32 {d23}, [%[outptr]]! @ write d23(q11,high),r47,r57\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "cc", "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + } + } +} + +void prepackA_trans_6x8(float* out, const float* in, const int ldin, + const int m0, const int mmax, const int k0, + const int kmax) { + uint32_t* outptr = reinterpret_cast(out); + const uint32_t* inptr = + reinterpret_cast(in) + k0 * ldin + m0; + + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = mmax - m0; + int y_len = kmax - k0; + int right_remain = x_len - 6 * (x_len / 6); + int right_pad = 6 - right_remain; + if (right_remain == 0) { + right_pad = 0; + } + + uint32_t* outptr_row = outptr; + int stride_out = 6 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t* ptr0 = inptr + y * ldin; + const uint32_t* ptr1 = ptr0 + ldin; + const uint32_t* ptr2 = ptr1 + ldin; + const uint32_t* ptr3 = ptr2 + ldin; + + uint32_t* outptr_row_col = outptr_row + y * 6; + int i = 0; + for (; i < x_len - 5; i += 6) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n" + "vld1.32 {d4-d6}, [%[ptr1]]! @ load r1, 6 elements\n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d6}, [%[outptr]]! @ write to output ptr\n" + + "vld1.32 {d0-d2}, [%[ptr2]]! @ load r2, 6 elements\n" + "vld1.32 {d4-d6}, [%[ptr3]]! @ load r3, 6 elements\n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d6}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : + : "q0", "q1", "q2", "q3", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_pad > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n" + "vld1.32 {d4-d6}, [%[ptr1]]! @ load r1, 6 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d2, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d6, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d6}, [%[outptr]]! @ write to output ptr\n" + + "vld1.32 {d0-d2}, [%[ptr2]]! @ load r2, 8 elements\n" + "vld1.32 {d4-d6}, [%[ptr3]]! @ load r3, 8 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d2, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d6, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d6}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero) + : "q0", "q1", "q2", "q3", "cc", "memory"); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t* ptr0 = inptr + y * ldin; + uint32_t* outptr_row_col = outptr_row + y * 6; + int i = 0; + for (; i < x_len - 5; i += 6) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : + : "q0", "q1", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_pad > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d2}, [%[ptr0]]! @ load r0, 6 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif d2, %e[vzero], %e[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d2}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero) + : "q0", "q1", "cc", "memory"); + } + } +} + +void prepackA_4x8(float* out, const float* in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + + uint32_t* dout = reinterpret_cast(out); + const uint32_t* inptr = reinterpret_cast(in); + + uint32_t* outptr = dout; + //! data A is not transposed, transpose A to k * 4 + for (int y = m0; y < mmax; y += 4) { + const uint32_t* inptr0 = inptr + y * ldin + k0; + const uint32_t* inptr1 = inptr0 + ldin; + const uint32_t* inptr2 = inptr1 + ldin; + const uint32_t* inptr3 = inptr2 + ldin; + + int x = x_len; + //! cope with row index exceed real size, set to zero buffer + if ((y + 3) >= mmax) { + switch ((y + 3) - mmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + default: + break; + } + } + + for (; x > 7; x -= 8) { + //! zip load 8 elements (2 neon Q registers) from each of 4 rows + asm volatile( + "vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, " + "q0,q1=r00,r04,r01,r05,r02,r06,r03,r07\n" + "vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, " + "q2,q3=r10,r14,r11,r15,r12,r16,r13,r17\n" + "vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, " + "q4,q5=r20,r24,r21,r25,r22,r26,r23,r27\n" + "vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, " + "q6,q7=r30,r34,r31,r35,r32,r36,r33,r37\n" + + "vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; " + "q2=r04,r14,r05,r15\n" + "vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; " + "q6=r24,r34,r25,r35\n" + + "vswp d1, d8 @ swap d1, d8, q0=r00,r10,r20,r30; " + "q4=r01,r11,r21,r31\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write q0:r00,r10,r20,r30\n" + "vst1.32 {d8-d9}, [%[outptr]]! @ write q4:r01,r11,r21,r31\n" + + "vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; " + "q3=r06,r16,r07,r17\n" + "vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; " + "q7=r26,r36,r27,r37\n" + + "vswp d3, d10 @ swap d3, d10, " + "q1=r02,r12,r22,r32; q5=r03,r13,r23,r33\n" + "vst1.32 {d2-d3}, [%[outptr]]! @ write q1:r02,r12,r22,r32\n" + "vst1.32 {d10-d11},[%[outptr]]! @ write q5:r03,r13,r23,r33\n" + + "vswp d5, d12 @ swap d5, d12,q2=r04,r14,r24,r34; " + "q6=r05,r15,r25,r35\n" + "vst1.32 {d4-d5}, [%[outptr]]! @ write q2:r04,r14,r24,r34\n" + "vst1.32 {d12-d13},[%[outptr]]! @ write q6:r05,r15,r25,r35\n" + + "vswp d7, d14 @ swap d7, d14, " + "q3=r06,r16,r26,r36; q7=r07,r17,r27,r37\n" + "vst1.32 {d6-d7}, [%[outptr]]! @ write q3:r06,r16,r26,r36\n" + "vst1.32 {d14-d15},[%[outptr]]! @ write q7:r07,r17,r27,r37\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "cc", "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + } + } +} + +void prepackA_trans_4x8(float* out, const float* in, const int ldin, + const int m0, const int mmax, const int k0, + const int kmax) { + uint32_t* outptr = reinterpret_cast(out); + const uint32_t* inptr = + reinterpret_cast(in) + k0 * ldin + m0; + + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = mmax - m0; + int y_len = kmax - k0; + int right_remain = x_len - 4 * (x_len / 4); + int right_pad = 4 - right_remain; + if (right_remain == 0) { + right_pad = 0; + } + + uint32_t* outptr_row = outptr; + int stride_out = 4 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); +// uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask_buffer + 4), +// vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t* ptr0 = inptr + y * ldin; + const uint32_t* ptr1 = ptr0 + ldin; + const uint32_t* ptr2 = ptr1 + ldin; + const uint32_t* ptr3 = ptr2 + ldin; + + uint32_t* outptr_row_col = outptr_row + y * 4; + int i = 0; + for (; i < x_len - 3; i += 4) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n" + "vld1.32 {d2-d3}, [%[ptr1]]! @ load r1, 4 elements\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d2-d3}, [%[outptr]]! @ write to output ptr\n" + + "vld1.32 {d4-d5}, [%[ptr2]]! @ load r2, 4 elements\n" + "vld1.32 {d6-d7}, [%[ptr3]]! @ load r3, 4 elements\n" + "vst1.32 {d4-d5}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d6-d7}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : + : "q0", "q1", "q2", "q3", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_pad > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n" + "vld1.32 {d2-d3}, [%[ptr1]]! @ load r1, 4 elements\n" + "vld1.32 {d4-d5}, [%[ptr2]]! @ load r2, 4 elements\n" + "vld1.32 {d6-d7}, [%[ptr3]]! @ load r3, 4 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q3, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d2-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d5}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d6-d7}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : [vmask1] "w"(vmask1), [vzero] "w"(vzero) + : "q0", "q1", "q2", "q3", "cc", "memory"); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t* ptr0 = inptr + y * ldin; + uint32_t* outptr_row_col = outptr_row + y * 4; + int i = 0; + for (; i < x_len - 3; i += 4) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : + : "q0", "q1", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_pad > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d1}, [%[ptr0]]! @ load r0, 4 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [vmask1] "w"(vmask1), [vzero] "w"(vzero) + : "q0", "q1", "cc", "memory"); + } + } +} + +#endif // __aarch64__ + +/** +* \brief input data is transpose +* for arm-v7a, transform data to block x k x 8 layout +* for arm-v8a, transform data to block x k x 12 layout +*/ +#ifdef __aarch64__ +void loadb(float *out, const float *in, const int ldin, const int k0, + const int kmax, const int n0, const int nmax) { + uint32_t *outptr = reinterpret_cast(out); + const uint32_t *inptr = + reinterpret_cast(in) + k0 * ldin + n0; + uint32_t mask_buffer[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + int x_len = nmax - n0; + int y_len = kmax - k0; + int right_remain = x_len - 12 * (x_len / 12); + int right_pad = 12 - right_remain; + const size_t copy_len_remain = sizeof(float) * right_remain; + const size_t copy_len_pad = sizeof(float) * right_pad; + const size_t size_ldin = sizeof(float) * ldin; + + uint32_t *outptr_row = outptr; + int stride_out = 12 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + uint32x4_t vmask3 = + vcltq_u32(vld1q_u32(mask_buffer + 8), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t *ptr0 = inptr + y * ldin; + const uint32_t *ptr1 = ptr0 + ldin; + const uint32_t *ptr2 = ptr1 + ldin; + const uint32_t *ptr3 = ptr2 + ldin; + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + : [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3) + : "memory"); + + uint32_t *outptr_row_col = outptr_row + y * 12; + + int i = 0; + for (; i < x_len - 11; i += 12) { + uint32x4_t vr00 = vld1q_u32(ptr0); + uint32x4_t vr01 = vld1q_u32(ptr0 + 4); + uint32x4_t vr02 = vld1q_u32(ptr0 + 8); + + uint32x4_t vr10 = vld1q_u32(ptr1); + uint32x4_t vr11 = vld1q_u32(ptr1 + 4); + uint32x4_t vr12 = vld1q_u32(ptr1 + 8); + + vst1q_u32(outptr_row_col, vr00); + vst1q_u32(outptr_row_col + 4, vr01); + vst1q_u32(outptr_row_col + 8, vr02); + + uint32x4_t vr20 = vld1q_u32(ptr2); + uint32x4_t vr21 = vld1q_u32(ptr2 + 4); + uint32x4_t vr22 = vld1q_u32(ptr2 + 8); + + vst1q_u32(outptr_row_col + 12, vr10); + vst1q_u32(outptr_row_col + 16, vr11); + vst1q_u32(outptr_row_col + 20, vr12); + + uint32x4_t vr30 = vld1q_u32(ptr3); + uint32x4_t vr31 = vld1q_u32(ptr3 + 4); + uint32x4_t vr32 = vld1q_u32(ptr3 + 8); + + vst1q_u32(outptr_row_col + 24, vr20); + vst1q_u32(outptr_row_col + 28, vr21); + vst1q_u32(outptr_row_col + 32, vr22); + + vst1q_u32(outptr_row_col + 36, vr30); + vst1q_u32(outptr_row_col + 40, vr31); + vst1q_u32(outptr_row_col + 44, vr32); + + ptr0 += 12; + ptr1 += 12; + ptr2 += 12; + ptr3 += 12; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32x4_t vr00 = vld1q_u32(ptr0); + uint32x4_t vr01 = vld1q_u32(ptr0 + 4); + uint32x4_t vr02 = vld1q_u32(ptr0 + 8); + + uint32x4_t vr10 = vld1q_u32(ptr1); + uint32x4_t vr11 = vld1q_u32(ptr1 + 4); + uint32x4_t vr12 = vld1q_u32(ptr1 + 8); + + uint32x4_t vr00_1 = vbslq_u32(vmask1, vr00, vzero); + uint32x4_t vr01_1 = vbslq_u32(vmask2, vr01, vzero); + uint32x4_t vr02_1 = vbslq_u32(vmask3, vr02, vzero); + + uint32x4_t vr20 = vld1q_u32(ptr2); + uint32x4_t vr21 = vld1q_u32(ptr2 + 4); + uint32x4_t vr22 = vld1q_u32(ptr2 + 8); + + vst1q_u32(outptr_row_col, vr00_1); + vst1q_u32(outptr_row_col + 4, vr01_1); + vst1q_u32(outptr_row_col + 8, vr02_1); + + uint32x4_t vr10_1 = vbslq_u32(vmask1, vr10, vzero); + uint32x4_t vr11_1 = vbslq_u32(vmask2, vr11, vzero); + uint32x4_t vr12_1 = vbslq_u32(vmask3, vr12, vzero); + + uint32x4_t vr30 = vld1q_u32(ptr3); + uint32x4_t vr31 = vld1q_u32(ptr3 + 4); + uint32x4_t vr32 = vld1q_u32(ptr3 + 8); + + vst1q_u32(outptr_row_col + 12, vr10_1); + vst1q_u32(outptr_row_col + 16, vr11_1); + vst1q_u32(outptr_row_col + 20, vr12_1); + + uint32x4_t vr20_1 = vbslq_u32(vmask1, vr20, vzero); + uint32x4_t vr21_1 = vbslq_u32(vmask2, vr21, vzero); + uint32x4_t vr22_1 = vbslq_u32(vmask3, vr22, vzero); + + uint32x4_t vr30_1 = vbslq_u32(vmask1, vr30, vzero); + uint32x4_t vr31_1 = vbslq_u32(vmask2, vr31, vzero); + uint32x4_t vr32_1 = vbslq_u32(vmask3, vr32, vzero); + + vst1q_u32(outptr_row_col + 24, vr20_1); + vst1q_u32(outptr_row_col + 28, vr21_1); + vst1q_u32(outptr_row_col + 32, vr22_1); + + vst1q_u32(outptr_row_col + 36, vr30_1); + vst1q_u32(outptr_row_col + 40, vr31_1); + vst1q_u32(outptr_row_col + 44, vr32_1); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t *ptr0 = inptr + y * ldin; + uint32_t *outptr_row_col = outptr_row + y * 12; + + int i = 0; + for (; i < x_len - 11; i += 12) { + uint32x4_t vr0 = vld1q_u32(ptr0); + uint32x4_t vr1 = vld1q_u32(ptr0 + 4); + uint32x4_t vr2 = vld1q_u32(ptr0 + 8); + vst1q_u32(outptr_row_col, vr0); + vst1q_u32(outptr_row_col + 4, vr1); + vst1q_u32(outptr_row_col + 8, vr2); + + ptr0 += 12; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32x4_t vr0 = vld1q_u32(ptr0); + uint32x4_t vr1 = vld1q_u32(ptr0 + 4); + uint32x4_t vr2 = vld1q_u32(ptr0 + 8); + + uint32x4_t vr0_1 = vbslq_u32(vmask1, vr0, vzero); + uint32x4_t vr1_1 = vbslq_u32(vmask2, vr1, vzero); + uint32x4_t vr2_1 = vbslq_u32(vmask3, vr2, vzero); + + vst1q_u32(outptr_row_col, vr0_1); + vst1q_u32(outptr_row_col + 4, vr1_1); + vst1q_u32(outptr_row_col + 8, vr2_1); + } + } +} + +void loadb_trans(float *out, const float *in, const int ldin, const int k0, + const int kmax, const int n0, const int nmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + uint32_t *outptr = reinterpret_cast(out); + const uint32_t *inptr = reinterpret_cast(in); + + //! data B is not transposed, transpose B to k * 12 + for (int y = n0; y < nmax; y += 12) { + const uint32_t *inptr0 = inptr + y * ldin + k0; + const uint32_t *inptr1 = inptr0 + ldin; + const uint32_t *inptr2 = inptr1 + ldin; + const uint32_t *inptr3 = inptr2 + ldin; + const uint32_t *inptr4 = inptr3 + ldin; + const uint32_t *inptr5 = inptr4 + ldin; + const uint32_t *inptr6 = inptr5 + ldin; + const uint32_t *inptr7 = inptr6 + ldin; + const uint32_t *inptr8 = inptr7 + ldin; + const uint32_t *inptr9 = inptr8 + ldin; + const uint32_t *inptr10 = inptr9 + ldin; + const uint32_t *inptr11 = inptr10 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + "prfm pldl1keep, [%[ptr4]] \n" + "prfm pldl1keep, [%[ptr4], #64] \n" + "prfm pldl1keep, [%[ptr5]] \n" + "prfm pldl1keep, [%[ptr5], #64] \n" + "prfm pldl1keep, [%[ptr6]] \n" + "prfm pldl1keep, [%[ptr6], #64] \n" + "prfm pldl1keep, [%[ptr7]] \n" + "prfm pldl1keep, [%[ptr7], #64] \n" + "prfm pldl1keep, [%[ptr8]] \n" + "prfm pldl1keep, [%[ptr8], #64] \n" + "prfm pldl1keep, [%[ptr9]] \n" + "prfm pldl1keep, [%[ptr9], #64] \n" + "prfm pldl1keep, [%[ptr10]] \n" + "prfm pldl1keep, [%[ptr10], #64] \n" + "prfm pldl1keep, [%[ptr11]] \n" + "prfm pldl1keep, [%[ptr11], #64] \n" + : + : [ptr0] "r"(inptr0), [ptr1] "r"(inptr1), [ptr2] "r"(inptr2), + [ptr3] "r"(inptr3), [ptr4] "r"(inptr4), [ptr5] "r"(inptr5), + [ptr6] "r"(inptr6), [ptr7] "r"(inptr7), [ptr8] "r"(inptr8), + [ptr9] "r"(inptr9), [ptr10] "r"(inptr10), [ptr11] "r"(inptr11) + : "memory"); + + int x = x_len; + + //! cope with row index exceed real size, set to zero buffer + if ((y + 11) >= nmax) { + switch ((y + 11) - nmax) { + case 10: + inptr1 = zerobuff; + case 9: + inptr2 = zerobuff; + case 8: + inptr3 = zerobuff; + case 7: + inptr4 = zerobuff; + case 6: + inptr5 = zerobuff; + case 5: + inptr6 = zerobuff; + case 4: + inptr7 = zerobuff; + case 3: + inptr8 = zerobuff; + case 2: + inptr9 = zerobuff; + case 1: + inptr10 = zerobuff; + case 0: + inptr11 = zerobuff; + default: + break; + } + } + for (; x > 7; x -= 8) { + asm volatile( + // Load up 12 elements (3 vectors) from each of 8 sources. + "LDP q0, q1, [%[inptr0]], #32\n" // q0=A0A1A2A3 + "LDP q2, q3, [%[inptr1]], #32\n" // q2=B0B1B2B3 + "LDP q4, q5, [%[inptr2]], #32\n" // q4=C0C1C2C3 + "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1 + "prfm pldl1keep, [%[inptr0], #128] \n" + "LDP q6, q7, [%[inptr3]], #32\n" // q6=D0D1D2D3 + "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1 + "LDP q8, q9, [%[inptr4]], #32\n" + "LDP q10, q11, [%[inptr5]], #32\n" + "LDP q12, q13, [%[inptr6]], #32\n" + "ZIP1 v18.4s, v8.4s, v12.4s\n" + "prfm pldl1keep, [%[inptr1], #128]\n" + "LDP q14, q15, [%[inptr7]], #32\n" + "ZIP1 v19.4s, v10.4s, v14.4s\n" + + "ZIP1 v20.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0 + "prfm pldl1keep, [%[inptr2], #128]\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "LDP q24, q25, [%[inptr8]], #32\n" // q24=A0A1A2A3 + "LDP q26, q27, [%[inptr9]], #32\n" // q26=B0B1B2B3 + "LDP q28, q29, [%[inptr10]], #32\n" // q28=C0C1C2C3 + "LDP q30, q31, [%[inptr11]], #32\n" // q30=D0D1D2D3 + "prfm pldl1keep, [%[inptr3], #128]\n" + "prfm pldl1keep, [%[inptr4], #128]\n" + "ZIP1 v16.4s, v24.4s, v28.4s\n" // q16=A0C0A1C1 + "ZIP1 v17.4s, v26.4s, v30.4s\n" // q17=B0D0B1D1 + "STP q20, q21, [%[outptr]], #32\n" // Write back the first + // element of each source + "ZIP1 v18.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0 + "ZIP2 v19.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0 + + "ZIP2 v16.4s, v0.4s, v4.4s\n" + "prfm pldl1keep, [%[inptr5], #128]\n" + "ZIP2 v17.4s, v2.4s, v6.4s\n" + "STR q18, [%[outptr]], #16\n" // Write back the second element + // of each source + + "STP q22, q23, [%[outptr]], #32\n" // Write back the second + // element of each source + "ZIP2 v18.4s, v8.4s, v12.4s\n" + "prfm pldl1keep, [%[inptr6], #128]\n" + "STR q19, [%[outptr]], #16\n" // Write back the second element + // of each source + "ZIP2 v19.4s, v10.4s, v14.4s\n" + + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "prfm pldl1keep, [%[inptr7], #128]\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP2 v16.4s, v24.4s, v28.4s\n" // q16=A0C0A1C1 + "ZIP2 v17.4s, v26.4s, v30.4s\n" // q17=B0D0B1D1 + "prfm pldl1keep, [%[inptr8], #128]\n" + "STP q20, q21, [%[outptr]], #32\n" // Third element + "ZIP1 v18.4s, v16.4s, v17.4s\n" + "ZIP2 v19.4s, v16.4s, v17.4s\n" + + "ZIP1 v16.4s, v1.4s, v5.4s\n" + "prfm pldl1keep, [%[inptr9], #128]\n" + "ZIP1 v17.4s, v3.4s, v7.4s\n" + "STR q18, [%[outptr]], #16\n" // Write back the second element + // of each source + + "STP q22, q23, [%[outptr]], #32\n" // Fourth element + "ZIP1 v18.4s, v9.4s, v13.4s\n" + "prfm pldl1keep, [%[inptr10], #128]\n" + "STR q19, [%[outptr]], #16\n" // Write back the second element + // of each source + "ZIP1 v19.4s, v11.4s, v15.4s\n" + + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "prfm pldl1keep, [%[inptr11], #128]\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP1 v16.4s, v25.4s, v29.4s\n" + "ZIP1 v17.4s, v27.4s, v31.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Fifth element + "ZIP1 v18.4s, v16.4s, v17.4s\n" + "ZIP2 v19.4s, v16.4s, v17.4s\n" + + "ZIP2 v16.4s, v1.4s, v5.4s\n" + "ZIP2 v17.4s, v3.4s, v7.4s\n" + "STR q18, [%[outptr]], #16\n" + + "STP q22, q23, [%[outptr]], #32\n" // Sixth element + "ZIP2 v18.4s, v9.4s, v13.4s\n" + "STR q19, [%[outptr]], #16\n" // Sixth element + + "ZIP2 v19.4s, v11.4s, v15.4s\n" + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP2 v16.4s, v25.4s, v29.4s\n" + "ZIP2 v17.4s, v27.4s, v31.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Seventh element + + "ZIP1 v18.4s, v16.4s, v17.4s\n" + "ZIP2 v19.4s, v16.4s, v17.4s\n" + "STR q18, [%[outptr]], #16\n" + "STP q22, q23, [%[outptr]], #32\n" // Eighth element + "STR q19, [%[outptr]], #16\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), + [inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10), + [inptr11] "+r"(inptr11), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "cc", "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + *outptr++ = *inptr8++; + *outptr++ = *inptr9++; + *outptr++ = *inptr10++; + *outptr++ = *inptr11++; + } + } +} + +#else // __aarch64__ +void loadb(float* out, const float* in, const int ldin, const int k0, + const int kmax, const int n0, const int nmax) { + uint32_t* outptr = reinterpret_cast(out); + const uint32_t* inptr = + reinterpret_cast(in) + k0 * ldin + n0; + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = nmax - n0; + int y_len = kmax - k0; + int right_remain = x_len - 8 * (x_len / 8); + int right_pad = 8 - right_remain; + const size_t copy_len_remain = sizeof(float) * right_remain; + const size_t copy_len_pad = sizeof(float) * right_pad; + const size_t size_ldin = sizeof(float) * ldin; + + uint32_t* outptr_row = outptr; + int stride_out = 8 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t* ptr0 = inptr + y * ldin; + const uint32_t* ptr1 = ptr0 + ldin; + const uint32_t* ptr2 = ptr1 + ldin; + const uint32_t* ptr3 = ptr2 + ldin; + uint32_t* outptr_row_col = outptr_row + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr1]]! @ load r1, 8 elements\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + + "vld1.32 {d0-d3}, [%[ptr2]]! @ load r2, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr3]]! @ load r3, 8 elements\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : + : "q0", "q1", "q2", "q3", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr1]]! @ load r1, 8 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + //"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q3, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + + "vld1.32 {d0-d3}, [%[ptr2]]! @ load r2, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr3]]! @ load r3, 8 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + //"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q3, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero) + : "q0", "q1", "q2", "q3", "cc", "memory"); + } + } +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t* ptr0 = inptr + y * ldin; + uint32_t* outptr_row_col = outptr_row + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : + : "q0", "q1", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero) + : "q0", "q1", "cc", "memory"); + } + } +} + +void loadb_trans(float* out, const float* in, const int ldin, const int k0, + const int kmax, const int n0, const int nmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; // NOLINT + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + + uint32_t* outptr = reinterpret_cast(out); + const uint32_t* inptr = reinterpret_cast(in); + //! data B is not transposed, transpose B to k * 8 + for (int y = n0; y < nmax; y += 8) { + const uint32_t* inptr0 = inptr + y * ldin + k0; + const uint32_t* inptr1 = inptr0 + ldin; + const uint32_t* inptr2 = inptr1 + ldin; + const uint32_t* inptr3 = inptr2 + ldin; + const uint32_t* inptr4 = inptr3 + ldin; + const uint32_t* inptr5 = inptr4 + ldin; + const uint32_t* inptr6 = inptr5 + ldin; + const uint32_t* inptr7 = inptr6 + ldin; + + int x = x_len; + + //! cope with row index exceed real size, set to zero buffer + if ((y + 7) >= nmax) { + switch ((y + 7) - nmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + default: + break; + } + } + + for (; x > 7; x -= 8) { + //! zip load 8 elements (2 neon Q registers) from each of 8 rows + asm volatile( + "vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, " + "q0,q1=r00,r04,r01,r05,r02,r06,r03,r07\n" + "vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, " + "q2,q3=r10,r14,r11,r15,r12,r16,r13,r17\n" + "vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; " + "q2=r04,r14,r05,r15\n" + "vst1.32 {d0}, [%[outptr]]! @ write d0(q0,low),r00,r10\n" + + "vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, " + "q4,q5=r20,r24,r21,r25,r22,r26,r23,r27\n" + "vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, " + "q6,q7=r30,r34,r31,r35,r32,r36,r33,r37\n" + "vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; " + "q6=r24,r34,r25,r35\n" + "vst1.32 {d8}, [%[outptr]]! @ write d8(q4,low),r20,r30\n" + + "vld4.32 {d16-d19}, [%[inptr4]]! @ zip load r4, " + "q8,q9=r40,r44,r41,r45,r42,r46,r43,r47\n" + "vld4.32 {d20-d23}, [%[inptr5]]! @ zip load r5, " + "q10,q11=r50,r54,r51,r55,r52,r56,r53,r57\n" + "vtrn.32 q8, q10 @ trans data: q8=r40,r50,r41,r51; " + "q10=r44,r54,r45,r55\n" + "vst1.32 {d16}, [%[outptr]]! @ write d16(q8,low),r40,r50\n" + + "vld4.32 {d24-d27}, [%[inptr6]]! @ zip load r6, " + "q12,q13=r60,r64,r61,r65,r62,r66,r63,r67\n" + "vld4.32 {d28-d31}, [%[inptr7]]! @ zip load r7, " + "q14,q15=r70,r74,r71,r75,r72,r76,r73,r77\n" + "vtrn.32 q12, q14 @ trans data:q12=r60,r70,r61,r71; " + "q14=r64,r74,r65,r75\n" + "vst1.32 {d24}, [%[outptr]]! @ write d24(q8,low),r60,r70\n" + + //"pld [%[inptr0], #128] @ preload r0 data to cache, fill + // pipeline\n" + "vst1.32 {d1}, [%[outptr]]! @ write d1(q0,high),r01,r11\n" + "vst1.32 {d9}, [%[outptr]]! @ write d9(q4,high),r21,r31\n" + "vst1.32 {d17}, [%[outptr]]! @ write d17(q8,high),r41,r51\n" + "vst1.32 {d25}, [%[outptr]]! @ write d25(q12,high),r61,r71\n" + + "vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; " + "q3=r06,r16,r07,r17\n" + "vst1.32 {d2}, [%[outptr]]! @ write d2(q1,low),r02,r12\n" + "vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; " + "q7=r26,r36,r27,r37\n" + "vst1.32 {d10}, [%[outptr]]! @ write d10(q5,low),r22,r32\n" + "vtrn.32 q9, q11 @ trans data: q9=r42,r52,r43,r53; " + "q11=r46,r56,r47,r57\n" + "vst1.32 {d18}, [%[outptr]]! @ write d18(q9,low),r42,r52\n" + "vtrn.32 q13, q15 @ trans data:q13=r62,r72,r63,r73; " + "q15=r66,r76,r67,r77\n" + "vst1.32 {d26}, [%[outptr]]! @ write d18(q9,low),r62,r72\n" + + //"pld [%[inptr1], #128] @ preload r1 data to cache, fill + // pipeline\n" + "vst1.32 {d3}, [%[outptr]]! @ write d3(q1,high),r03,r13\n" + "vst1.32 {d11}, [%[outptr]]! @ write d11(q5,high),r23,r33\n" + "vst1.32 {d19}, [%[outptr]]! @ write d19(q9,high),r43,r53\n" + "vst1.32 {d27}, [%[outptr]]! @ write d27(q13,high),r63,r73\n" + + //"pld [%[inptr2], #128] @ preload r2 data to cache, fill + // pipeline\n" + "vst1.32 {d4}, [%[outptr]]! @ write d4(q2,low),r04,r14\n" + "vst1.32 {d12}, [%[outptr]]! @ write d12(q6,low),r24,r34\n" + "vst1.32 {d20}, [%[outptr]]! @ write d20(q10,low),r44,r54\n" + "vst1.32 {d28}, [%[outptr]]! @ write d28(q14,low),r64,r74\n" + + //"pld [%[inptr3], #128] @ preload r3 data to cache, fill + // pipeline\n" + "vst1.32 {d5}, [%[outptr]]! @ write d5(q2,high),r05,r15\n" + "vst1.32 {d13}, [%[outptr]]! @ write d13(q6,high),r25,r35\n" + "vst1.32 {d21}, [%[outptr]]! @ write d21(q10,high),r45,r55\n" + "vst1.32 {d29}, [%[outptr]]! @ write d29(q14,high),r65,r75\n" + + //"pld [%[inptr4], #128] @ preload r4 data to cache, fill + // pipeline\n" + "vst1.32 {d6}, [%[outptr]]! @ write d6(q3,low),r06,r16\n" + "vst1.32 {d14}, [%[outptr]]! @ write d14(q7,low),r26,r36\n" + "vst1.32 {d22}, [%[outptr]]! @ write d22(q11,low),r46,r56\n" + "vst1.32 {d30}, [%[outptr]]! @ write d30(q15,low),r66,r76\n" + + //"pld [%[inptr5], #128] @ preload r5 data to cache, fill + // pipeline\n" + "vst1.32 {d7}, [%[outptr]]! @ write d7(q3,high),r07,r17\n" + "vst1.32 {d15}, [%[outptr]]! @ write d15(q7,high),r27,r37\n" + "vst1.32 {d23}, [%[outptr]]! @ write d23(q11,high),r47,r57\n" + "vst1.32 {d31}, [%[outptr]]! @ write d31(q15,high),r67,r77\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15", "cc", "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + } + } +} + +#endif // __aarch64__ + +#ifdef __aarch64__ +void sgemm_conv_8x12(const float *A_packed, const float *B, const float *bias, + float *C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB, ARMContext *ctx) { + size_t l2_cache = + ctx->l2_cache_size() > 0 ? ctx->l2_cache_size() : 512 * 1024; + float *workspace = ctx->workspace_data(); + int threads = ctx->threads(); + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = (l2_cache - (MBLOCK * K)) / (sizeof(float) * (K + MBLOCK)); + x_block /= NBLOCK; + x_block *= NBLOCK; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK - 1) / NBLOCK; + x_block *= NBLOCK; + x_block = x_block < NBLOCK ? NBLOCK : x_block; + + // unroll 2 loop + int tail_pre = (K & (KBLOCK - 1)); + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + + bool flag_p_remain = false; + int remain = 0; + + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK; + remain = xmax - x0 - (bblocks - 1) * NBLOCK; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + float *b_pannel = workspace; + if (transB) { + loadb_trans(b_pannel, B, K, 0, K, x0, xmax); + } else { + loadb(b_pannel, B, N, 0, K, x0, xmax); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += MBLOCK) { + unsigned int ymax = y + MBLOCK; + if (ymax > M) { + ymax = M; + } + + float bias_local[8] = {0}; + if (is_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + bias_local[4] = bias[y + 4]; + bias_local[5] = bias[y + 5]; + bias_local[6] = bias[y + 6]; + bias_local[7] = bias[y + 7]; + } + + float cout0[NBLOCK]; + float cout1[NBLOCK]; + float cout2[NBLOCK]; + float cout3[NBLOCK]; + float cout4[NBLOCK]; + float cout5[NBLOCK]; + float cout6[NBLOCK]; + float cout7[NBLOCK]; + + float *c_ptr0 = C + y * N + x0; + float *c_ptr1 = c_ptr0 + N; + float *c_ptr2 = c_ptr1 + N; + float *c_ptr3 = c_ptr2 + N; + float *c_ptr4 = c_ptr3 + N; + float *c_ptr5 = c_ptr4 + N; + float *c_ptr6 = c_ptr5 + N; + float *c_ptr7 = c_ptr6 + N; + + float *pout0 = c_ptr0; + float *pout1 = c_ptr1; + float *pout2 = c_ptr2; + float *pout3 = c_ptr3; + float *pout4 = c_ptr4; + float *pout5 = c_ptr5; + float *pout6 = c_ptr6; + float *pout7 = c_ptr7; + + const float *a_ptr_l = A_packed + y * K; + const float *b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 7) >= ymax) { + switch ((y + 7) - ymax) { + case 6: + c_ptr1 = cout1; + case 5: + c_ptr2 = cout2; + case 4: + c_ptr3 = cout3; + case 3: + c_ptr4 = cout4; + case 2: + c_ptr5 = cout5; + case 1: + c_ptr6 = cout6; + case 0: + c_ptr7 = cout7; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + pout4 = c_ptr4; + pout5 = c_ptr5; + pout6 = c_ptr6; + pout7 = c_ptr7; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + c_ptr4 = cout4; + c_ptr5 = cout5; + c_ptr6 = cout6; + c_ptr7 = cout7; + } + const float *a_ptr = a_ptr_l; + int tail = tail_pre; + int k = k_pre; + + asm volatile( + // Initialize result registers, load initial operands, prime + // prefetches. + "ldp q2, q3, [%[bias_ptr]]\n" /* load bias to q2, q3*/ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00,a01 to q0, q1*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ + "dup v8.4s, v2.s[0]\n" /* out0 = 0 */ + "dup v9.4s, v2.s[0]\n" /* out1 = 0*/ + "dup v10.4s, v2.s[0]\n" /* out2 = 0*/ + "dup v11.4s, v2.s[1]\n" /* out3 = 0*/ + "dup v12.4s, v2.s[1]\n" /* out4 = 0*/ + "prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/ + "dup v13.4s, v2.s[1]\n" /* out5 = 0*/ + "prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/ + "dup v14.4s, v2.s[2]\n" /* out6 = 0*/ + "prfm pldl1keep, [%[b_ptr], #128]\n" /* preload b*/ + "dup v15.4s, v2.s[2]\n" /* out7 = 0*/ + "prfm pldl1keep, [%[a_ptr], #128]\n" /* preload a*/ + "dup v16.4s, v2.s[2]\n" /* out8 = 0*/ + "prfm pldl1keep, [%[b_ptr], #192]\n" /* preload b*/ + "dup v17.4s, v2.s[3]\n" /* out9 = 0*/ + "prfm pldl1keep, [%[b_ptr], #256]\n" /* preload b*/ + "dup v18.4s, v2.s[3]\n" /* out10 = 0*/ + "prfm pldl1keep, [%[a_ptr], #192]\n" /* preload a*/ + "dup v19.4s, v2.s[3]\n" /* out11 = 0*/ + "prfm pldl1keep, [%[b_ptr], #320]\n" /* preload b*/ + "dup v20.4s, v3.s[0]\n" /* out12 = 0*/ + "prfm pldl1keep, [%[a_ptr], #256]\n" /* preload a*/ + "dup v21.4s, v3.s[0]\n" /* out13 = 0*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" /* preload b*/ + "dup v22.4s, v3.s[0]\n" /* out14 = 0*/ + "dup v23.4s, v3.s[1]\n" /* out15 = 0*/ + "dup v24.4s, v3.s[1]\n" /* out16 = 0*/ + "dup v25.4s, v3.s[1]\n" /* out17 = 0*/ + "dup v26.4s, v3.s[2]\n" /* out18 = 0*/ + "dup v27.4s, v3.s[2]\n" /* out19 = 0*/ + "dup v28.4s, v3.s[2]\n" /* out20 = 0*/ + "dup v29.4s, v3.s[3]\n" /* out21 = 0*/ + "dup v30.4s, v3.s[3]\n" /* out22 = 0*/ + "dup v31.4s, v3.s[3]\n" /* out23 = 0*/ + "cbz %w[k], 2f\n" /* check loop count > 0 */ + /* main loop */ + /* unrool 0*/ + "1:\n" /* main loop */ + "fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 = + q4 */ + "fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 = + q4 */ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7 */ + "fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 = + q4 */ + "fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 = + q4 */ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4 */ + "fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 = + q4 */ + "fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 = + q4 */ + "fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 = + q4 */ + "fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 = + q4 */ + + "fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 = + q5 */ + "fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 = + q5 */ + "fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 = + q5*/ + "fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 = + q5*/ + "fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 = + q5*/ + "fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 = + q5*/ + "fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 = + q5*/ + "fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 = + q5*/ + + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5 */ + + "fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 = + q6*/ + "fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 = + q6*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" + "fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 = + q6*/ + "fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 = + q6*/ + "fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 = + q6*/ + "fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 = + q6*/ + "fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 = + q6*/ + "fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 = + q6*/ + + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1 */ + + /* unrool 1 */ + "fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = + q7 */ + "fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = + q7 */ + "fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = + q7 */ + "prfm pldl1keep, [%[a_ptr], #256]\n" + "fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = + q7 */ + "fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = + q7 */ + "fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q7 + */ + "fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = + q7 */ + "fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = + q7 */ + + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7 */ + + "fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = + q4 */ + "fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 = + q4 */ + "fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q4*/ + "fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q4*/ + "fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q4*/ + "fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q4*/ + "fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q4*/ + "fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q4*/ + + "fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q5*/ + "fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q5*/ + "fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q5*/ + "fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q5*/ + "fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q5*/ + "fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q5*/ + "fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q5*/ + "fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q5*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5 */ + /* unrool 2*/ + "fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 = + q6 */ + "fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 = + q6 */ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ + "fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 = + q6*/ + "fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 = + q6*/ + "fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 = + q6*/ + "fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 = + q6*/ + "fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 = + q6*/ + "fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 = + q6*/ + "fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 = + q7*/ + "fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 = + q7*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" + "fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 = + q7*/ + "fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 = + q7*/ + "fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 = + q7*/ + "fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 = + q7*/ + "fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 = + q7*/ + "fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 = + q7*/ + + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/ + + "fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 = + q4*/ + "fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 = + q4*/ + "fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 = + q4*/ + "fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 = + q4*/ + "fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 = + q4*/ + "fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 = + q4*/ + "fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 = + q4*/ + "fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 = + q4*/ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ + /* unrool 3*/ + "fmla v8.4s , v5.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = + q5*/ + "fmla v11.4s , v5.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v5.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = + q5*/ + "fmla v17.4s, v5.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = + q5*/ + "fmla v20.4s, v5.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = + q5*/ + "fmla v23.4s, v5.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v5.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = + q5*/ + "fmla v29.4s, v5.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = + q5*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ + "fmla v9.4s, v6.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = + q6*/ + "fmla v12.4s, v6.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 = + q6*/ + "prfm pldl1keep, [%[a_ptr], #256]\n" + "fmla v15.4s, v6.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v6.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v6.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v6.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v6.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" + "fmla v30.4s, v6.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v7.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v7.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v7.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v7.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v7.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v7.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "subs %w[k], %w[k], #1\n" /* loop count - 1*/ + "fmla v28.4s, v7.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v7.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "bne 1b\n" + /* Target to use when K is 1 or 2 (i.e. zero iterations of main + loop)*/ + "2:\n" /* process tail*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "beq 3f\n" /*jump to tail = 1*/ + /* final unrool 0*/ + /* unrool 0, tail > 1*/ + "fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 = + q4*/ + "fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 = + q4*/ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7*/ + "fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 = + q4*/ + "fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 = + q4*/ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q2, q3*/ + "fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 = + q4*/ + "fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 = + q4*/ + "fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 = + q4*/ + "fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 = + q4*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 = + q5*/ + "fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 = + q5*/ + "fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 = + q5*/ + "fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 = + q5*/ + "fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 = + q5*/ + "fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 = + q5*/ + "fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 = + q5*/ + "fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 = + q5*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5*/ + "fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 = + q6*/ + "fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 = + q6*/ + "fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 = + q6*/ + "fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 = + q6*/ + "fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 = + q6*/ + "fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 = + q6*/ + "fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 = + q6*/ + "fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 = + q6*/ + "beq 4f\n" /*jump to tail = 2*/ + /* unrool 1, tail > 2*/ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ + "fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = + q7*/ + "fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = + q7*/ + "fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = + q7*/ + "fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = + q7*/ + "fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = + q7*/ + "fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q7*/ + "fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = + q7*/ + "fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = + q7*/ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7*/ + "fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = + q4*/ + "fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 = + q4*/ + "fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q4*/ + "fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q4*/ + "fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q4*/ + "fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q4*/ + "fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q4*/ + "fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q4*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q5*/ + "fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q5*/ + "fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q5*/ + "fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q5*/ + "fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q5*/ + "fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q5*/ + "fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q5*/ + "fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q5*/ + "beq 5f\n" /*jump to tail = 3*/ + /* unrool 2, tail = 4*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5*/ + "fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 = + q6*/ + "fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 = + q6*/ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ + "fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 = + q6*/ + "fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 = + q6*/ + "fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 = + q6*/ + "fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 = + q6*/ + "fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 = + q6*/ + "fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 = + q6*/ + "fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 = + q7*/ + "fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 = + q7*/ + "fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 = + q7*/ + "fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 = + q7*/ + "fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 = + q7*/ + "fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 = + q7*/ + "fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 = + q7*/ + "fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 = + q7*/ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/ + "fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 = + q4*/ + "fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 = + q4*/ + "fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 = + q4*/ + "fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 = + q4*/ + "fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 = + q4*/ + "fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 = + q4*/ + "fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 = + q4*/ + "fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 = + q4*/ + /* unrool 3, tail = 4*/ + "fmla v8.4s , v5.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = + q5*/ + "fmla v11.4s , v5.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v5.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = + q5*/ + "fmla v17.4s, v5.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = + q5*/ + "fmla v20.4s, v5.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = + q5*/ + "fmla v23.4s, v5.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v5.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = + q5*/ + "fmla v29.4s, v5.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = + q5*/ + "fmla v9.4s, v6.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = + q6*/ + "fmla v12.4s, v6.4s, v2.s[1]\n" /* out9 = b1 * a10[1], b1 = + q6*/ + "fmla v15.4s, v6.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v6.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v6.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v6.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v6.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "fmla v30.4s, v6.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v7.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v7.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v7.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v7.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v7.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v7.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "fmla v28.4s, v7.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v7.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "b 11f\n" + /* tails==1 final tail*/ + "3: \n" /* tail=1*/ + "ldr q6, [%[b_ptr]], #16\n" /* load b2 to q6*/ + "fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a10[0], b0 = + q5*/ + "fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a10[2], b0 = + q5*/ + "fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a10[3], b0 = + q5*/ + "fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a11[0], b0 = + q5*/ + "fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a11[2], b0 = + q5*/ + "fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a11[3], b0 = + q5*/ + "fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b0 * a10[0], b1 = + q6*/ + "fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a10[1], b1 = + q6*/ + "fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "b 11f\n" + /* tails==2 final tail*/ + "4:\n" /* tail = 2*/ + "fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = + q5*/ + "fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = + q5*/ + "fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = + q5*/ + "fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = + q5*/ + "fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = + q5*/ + "fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = + q5*/ + "fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = + q6*/ + "fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b1 * a10[1], b1 = + q6*/ + "fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "b 11f\n" + /* tails==3 final tail*/ + "5:\n" /* tail = 3*/ + "ldr q4, [%[b_ptr]], #16\n" /* load b2, b0 to q4*/ + "fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a10[0], b0 = + q5*/ + "fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a10[2], b0 = + q5*/ + "fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a10[3], b0 = + q5*/ + "fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a11[0], b0 = + q5*/ + "fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a11[2], b0 = + q5*/ + "fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a11[3], b0 = + q5*/ + "fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b0 * a10[0], b1 = + q6*/ + "fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a10[1], b1 = + q6*/ + "fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "11: \n" /* check if relu */ + "cbz %w[relu], 12f\n" /* skip relu */ + "movi v2.4s, #0\n" /* for relu*/ + "fmax v8.4s, v8.4s, v2.4s\n" /* relu*/ + "fmax v9.4s, v9.4s, v2.4s\n" /* relu*/ + "fmax v10.4s, v10.4s, v2.4s\n" /* relu*/ + "fmax v11.4s, v11.4s, v2.4s\n" /* relu*/ + "fmax v12.4s, v12.4s, v2.4s\n" /* relu*/ + "fmax v13.4s, v13.4s, v2.4s\n" /* relu*/ + "fmax v14.4s, v14.4s, v2.4s\n" /* relu*/ + "fmax v15.4s, v15.4s, v2.4s\n" /* relu*/ + "fmax v16.4s,v16.4s,v2.4s\n" /* relu*/ + "fmax v17.4s,v17.4s,v2.4s\n" /* relu*/ + "fmax v18.4s, v18.4s, v2.4s\n" /* relu*/ + "fmax v19.4s, v19.4s, v2.4s\n" /* relu*/ + "fmax v20.4s, v20.4s, v2.4s\n" /* relu*/ + "fmax v21.4s, v21.4s, v2.4s\n" /* relu*/ + "fmax v22.4s, v22.4s, v2.4s\n" /* relu*/ + "fmax v23.4s, v23.4s, v2.4s\n" /* relu*/ + "fmax v24.4s,v24.4s,v2.4s\n" /* relu*/ + "fmax v25.4s,v25.4s,v2.4s\n" /* relu*/ + "fmax v26.4s, v26.4s, v2.4s\n" /* relu*/ + "fmax v27.4s, v27.4s, v2.4s\n" /* relu*/ + "fmax v28.4s, v28.4s, v2.4s\n" /* relu*/ + "fmax v29.4s, v29.4s, v2.4s\n" /* relu*/ + "fmax v30.4s, v30.4s, v2.4s\n" /* relu*/ + "fmax v31.4s, v31.4s, v2.4s\n" /* relu*/ + "12: \n" + "st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */ + "st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */ + "st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */ + "st1 {v17.4s, v18.4s, v19.4s},[%[c_ptr3]], #48\n" /* store r3 */ + "st1 {v20.4s, v21.4s, v22.4s},[%[c_ptr4]], #48\n" /* store r4 */ + "st1 {v23.4s, v24.4s, v25.4s},[%[c_ptr5]], #48\n" /* store r5 */ + "st1 {v26.4s, v27.4s, v28.4s},[%[c_ptr6]], #48\n" /* store r6 */ + "st1 {v29.4s, v30.4s, v31.4s},[%[c_ptr7]], #48\n" /* store r7 */ + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), + [tail] "+r"(tail), [c_ptr0] "+r"(c_ptr0), [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), [c_ptr3] "+r"(c_ptr3), + [c_ptr4] "+r"(c_ptr4), [c_ptr5] "+r"(c_ptr5), + [c_ptr6] "+r"(c_ptr6), [c_ptr7] "+r"(c_ptr7) + : [bias_ptr] "r"(bias_local), [relu] "r"(is_relu) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31"); + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + *pout4++ = cout4[i]; + *pout5++ = cout5[i]; + *pout6++ = cout6[i]; + *pout7++ = cout7[i]; + } + } + } + } + } +} +#else // __aarch64__ +/** + * \brief gemm with ablock = 6, bblock = 8, output 6x8 + * @param A + * @param B + * @param C + * @param M + * @param N + * @param K + * @param threads + * @param workspace + */ +void sgemm_conv_6x8(const float* A_packed, const float* B, const float* bias, + float* C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB, ARMContext* ctx) { + size_t l2_cache = + ctx->l2_cache_size() > 0 ? ctx->l2_cache_size() : 512 * 1024; + auto* workspace = ctx->workspace_data(); + int threads = ctx->threads(); + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = + (l2_cache - (MBLOCK_OTH * K)) / (sizeof(float) * (K + MBLOCK_OTH)); + x_block /= NBLOCK; + x_block *= NBLOCK; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK - 1) / NBLOCK; + x_block *= NBLOCK; + x_block = x_block < NBLOCK ? NBLOCK : x_block; + + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + int tail_pre = (K & (KBLOCK - 1)); + if (tail_pre == 0) { + tail_pre = KBLOCK; + } + + bool flag_p_remain = false; + int remain = 0; + + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK; + remain = xmax - x0 - (bblocks - 1) * NBLOCK; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + float* b_pannel = workspace; + if (transB) { + loadb_trans(b_pannel, B, K, 0, K, x0, xmax); + } else { + loadb(b_pannel, B, N, 0, K, x0, xmax); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += MBLOCK_OTH) { + unsigned int ymax = y + MBLOCK_OTH; + if (ymax > M) { + ymax = M; + } + float* c_ptr0 = C + y * N + x0; + float* c_ptr1 = c_ptr0 + N; + float* c_ptr2 = c_ptr1 + N; + float* c_ptr3 = c_ptr2 + N; + float* c_ptr4 = c_ptr3 + N; + float* c_ptr5 = c_ptr4 + N; + + float* pout0 = c_ptr0; + float* pout1 = c_ptr1; + float* pout2 = c_ptr2; + float* pout3 = c_ptr3; + float* pout4 = c_ptr4; + float* pout5 = c_ptr5; + + float bias_local[6] = {0}; + if (is_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + bias_local[4] = bias[y + 4]; + bias_local[5] = bias[y + 5]; + } + + float cout0[NBLOCK]; + float cout1[NBLOCK]; + float cout2[NBLOCK]; + float cout3[NBLOCK]; + float cout4[NBLOCK]; + float cout5[NBLOCK]; + + const float* a_ptr_l = A_packed + y * K; + const float* b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 5) >= ymax) { + switch ((y + 5) - ymax) { + case 4: + c_ptr1 = cout1; + case 3: + c_ptr2 = cout2; + case 2: + c_ptr3 = cout3; + case 1: + c_ptr4 = cout4; + case 0: + c_ptr5 = cout5; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + pout4 = c_ptr4; + pout5 = c_ptr5; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + c_ptr4 = cout4; + c_ptr5 = cout5; + } + const float* a_ptr = a_ptr_l; + int tails = tail_pre; + int k = k_pre; + asm volatile( + // sgemm 6x8 + "vld1.32 {d2-d4}, [%[bias_ptr]] @ load bias 6 elements\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a0~a3\n" + "pld [%[a_ptr]] @ preload a\n" + "vdup.i32 q12,d4[0] @ out40=0\n" + "pld [%[b_ptr]] @ preload b\n" + "vdup.i32 q13,d4[0] @ out41=0\n" + "pld [%[a_ptr], #64] @ preload a\n" + "vdup.i32 q14,d4[1] @ out50=0\n" + "pld [%[b_ptr], #64] @ preload b\n" + "vdup.i32 q15,d4[1] @ out51=0\n" + "pld [%[a_ptr], #128] @ preload a\n" + "vdup.i32 q4, d2[0] @ out00=0\n" + "pld [%[b_ptr], #128] @ preload b\n" + "vdup.i32 q5, d2[0] @ out01=0\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vdup.i32 q6, d2[1] @ out10=0\n" + "pld [%[a_ptr], #192] @ preload a\n" + "vdup.i32 q7, d2[1] @ out11=0\n" + "pld [%[b_ptr], #192] @ preload a\n" + "vdup.i32 q8, d3[0] @ out20=0\n" + "pld [%[a_ptr], #256] @ preload a\n" + "vdup.i32 q9, d3[0] @ out21=0\n" + "pld [%[b_ptr], #256] @ preload a\n" + "vdup.i32 q10,d3[1] @ out30=0\n" + "pld [%[b_ptr], #320] @ preload b\n" + "vdup.i32 q11,d3[1] @ out31=0\n" + "pld [%[b_ptr], #384] @ preload b\n" + "cmp %[k], #0 @ check weather k is " + "bigger than 0\n" + "beq 0f @ jump to tail\n" + "1: @ main loop for k\n" + /* Unroll 0*/ + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a4, a5, and next " + "a0, " + "a1\n" + "vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + /* Unroll 1 */ + "vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n" + /*"pld [%[a_ptr], #64] @ preload a\n"*/ + "vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n" + /*"pld [%[b_ptr], #192]\n"*/ + "vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a0~a3\n" + "vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a4, a5, a0, a1\n" + /* Unroll 2 */ + "vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n" + /*"pld [%[a_ptr], #240] @ preload\n"*/ + "vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n" + /*"pld [%[b_ptr], #208]\n"*/ + "vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + /* Unroll 3 */ + "vmla.f32 q4, q2, d1[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d1[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d2[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d2[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d3[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d3[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d1[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d1[1] @ out7 += b2 * a1\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a0~a3\n" + "vmla.f32 q9, q3, d2[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d2[1] @ out9 += b2 * a3\n" + "subs %[k], %[k], #1 @ k--\n" + "vmla.f32 q13, q3, d3[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d3[1] @ out11 += b2 * a5\n" + "bne 1b @ jump to main " + "loop\n" + "0: @ process tail\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "beq 3f @ jump to tail = " + "1\n" + /* Unroll 0*/ + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a4,5, a0, a1\n" + "vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "beq 4f @ jump to tail==2\n" + /* Unroll 1*/ + "vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a0~a3\n" + "vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "beq 5f @ jump to tail==3\n" + /* Unroll 2 */ + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a4,a5, a0,a1\n" + "vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + /* Unroll 3*/ + "vmla.f32 q4, q2, d1[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d1[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d2[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d2[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d3[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d3[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d1[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d1[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d2[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d2[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d3[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d3[1] @ out11 += b2 * a5\n" + "b 2f\n" + /* tails==1 final tail*/ + "3: @ tail=1\n" + "vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d2}, [%[a_ptr] :64]! @ load a4,a5\n" + "vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n" + "b 2f @ jump to end\n" + /* tails==2 final tail*/ + "4: @ tail == 2\n" + "vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n" + "b 2f @ jump to end\n" + /* tails==3 final tail*/ + "5: @ tail=3\n" + "vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n" + "vld1.32 {d0}, [%[a_ptr] :64]! @ load a4,a5\n" + "vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n" + "2: @ check relu\n" + "cmp %[relu], #0 @ check if has relu\n" + "ble 6f @ skip relu if relu <= 0\n" + "vmov.u32 q0, #0 @ for relu\n" + "vmax.f32 q4, q4, q0 @ for relu\n" + "vmax.f32 q5, q5, q0 @ for relu\n" + "vmax.f32 q6, q6, q0 @ for relu\n" + "vmax.f32 q7, q7, q0 @ for relu\n" + "vmax.f32 q8, q8, q0 @ for relu\n" + "vmax.f32 q9, q9, q0 @ for relu\n" + "vmax.f32 q10, q10, q0 @ for relu\n" + "vmax.f32 q11, q11, q0 @ for relu\n" + "vmax.f32 q12, q12, q0 @ for relu\n" + "vmax.f32 q13, q13, q0 @ for relu\n" + "vmax.f32 q14, q14, q0 @ for relu\n" + "vmax.f32 q15, q15, q0 @ for relu\n" + "6: @ store result\n" + "vst1.32 {d8-d11}, [%[c_ptr0]]! @ store r0\n" + "vst1.32 {d12-d15}, [%[c_ptr1]]! @ store r1\n" + "vst1.32 {d16-d19}, [%[c_ptr2]]! @ store r2\n" + "vst1.32 {d20-d23}, [%[c_ptr3]]! @ store r3\n" + "vst1.32 {d24-d27}, [%[c_ptr4]]! @ store r4\n" + "vst1.32 {d28-d31}, [%[c_ptr5]]! @ store r5\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), [c_ptr4] "+r"(c_ptr4), + [c_ptr5] "+r"(c_ptr5), [k] "+r"(k), [tails] "+r"(tails) + : [bias_ptr] "r"(bias_local), [relu] "r"(is_relu) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15", "cc", "memory"); + + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + *pout4++ = cout4[i]; + *pout5++ = cout5[i]; + } + } + } + } + } +} + +void sgemm_conv_4x8(const float* A_packed, const float* B, const float* bias, + float* C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB, ARMContext* ctx) { + size_t l2_cache = + ctx->l2_cache_size() > 0 ? ctx->l2_cache_size() : 512 * 1024; + auto* workspace = ctx->workspace_data(); + int threads = ctx->threads(); + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = + (l2_cache - (MBLOCK_A73 * K)) / (sizeof(float) * (K + MBLOCK_A73)); + x_block /= NBLOCK; + x_block *= NBLOCK; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK - 1) / NBLOCK; + x_block *= NBLOCK; + x_block = x_block < NBLOCK ? NBLOCK : x_block; + + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + int tail_pre = (K & (KBLOCK - 1)); + if (tail_pre == 0) { + tail_pre = KBLOCK; + } + + bool flag_p_remain = false; + int remain = 0; + + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK; + remain = xmax - x0 - (bblocks - 1) * NBLOCK; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + float* b_pannel = workspace; + if (transB) { + loadb_trans(b_pannel, B, K, 0, K, x0, xmax); + } else { + loadb(b_pannel, B, N, 0, K, x0, xmax); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += MBLOCK_A73) { + unsigned int ymax = y + MBLOCK_A73; + if (ymax > M) { + ymax = M; + } + + float cout0[NBLOCK]; + float cout1[NBLOCK]; + float cout2[NBLOCK]; + float cout3[NBLOCK]; + + float bias_local[4] = {0}; + if (is_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + } + + float* c_ptr0 = C + y * N + x0; + float* c_ptr1 = c_ptr0 + N; + float* c_ptr2 = c_ptr1 + N; + float* c_ptr3 = c_ptr2 + N; + + float* pout0 = c_ptr0; + float* pout1 = c_ptr1; + float* pout2 = c_ptr2; + float* pout3 = c_ptr3; + + const float* a_ptr_l = A_packed + y * K; + const float* b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + case 2: + c_ptr1 = cout1; + case 1: + c_ptr2 = cout1; + case 0: + c_ptr3 = cout1; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + } + const float* a_ptr = a_ptr_l; + int tails = tail_pre; + int k = k_pre; + asm volatile( + "vld1.32 {d4-d5}, [%[bias_ptr]] @ load bias\n" + "vld1.32 {d0-d3}, [%[a_ptr] :128]! @ load a0~a3\n" + "vdup.32 q8, d4[0] @ add bias to out00\n" + "pld [%[a_ptr]] @ preload a, 64byte\n" + "vdup.32 q9, d4[0] @ add bias to out01\n" + "pld [%[b_ptr]] @ preload b\n" + "vdup.32 q10, d4[1] @ add bias to out10\n" + "pld [%[a_ptr], #64] @ preload a\n" + "vdup.32 q11, d4[1] @ add bias to out11\n" + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load b1\n" + "vdup.32 q12, d5[0] @ add bias to out20\n" + "pld [%[b_ptr], #64] @ preload b\n" + "vdup.32 q13, d5[0] @ add bias to out21\n" + "pld [%[a_ptr], #128] @ preload a\n" + "vdup.32 q14, d5[1] @ add bias to out30\n" + "pld [%[b_ptr], #128] @ preload b\n" + "vdup.32 q15, d5[1] @ add bias to out31\n" + "pld [%[b_ptr], #192] @ preload b\n" + "cmp %[k], #0 @ check weather k is " + "bigger than 0\n" + "beq 0f @ jump to tail\n" + "1: @ main loop for k\n" + /* Unroll 0*/ + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1, b2\n" + "vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d4-d7}, [%[a_ptr] :128]! @ load next 2xa0~a3\n" + "vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n" + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n" + /* Unroll 1 */ + "vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n" + "pld [%[b_ptr], #64] @ preload b\n" + "vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q11, q7, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q13, q7, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q15, q7, d3[1] @ out9 += b2 * a3\n" + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1,b2\n" + /* Unroll 2 */ + "vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n" + "vld1.32 {d0-d3}, [%[a_ptr] :128]! @ load next a0~a3\n" + "vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n" + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n" + /* Unroll 3 */ + "vmla.f32 q8, q6, d6[0] @ out0 += b1 * a0\n" + "pld [%[a_ptr], #64] @ preload a\n" + "vmla.f32 q10, q6, d6[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d7[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d7[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d6[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q7, d6[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q7, d7[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q7, d7[1] @ out7 += b2 * a3\n" + "subs %[k], %[k], #1 @ k--\n" + "bne 1b @ jump to main " + "loop\n" + "0: @ process tail\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "beq 3f @ jump to tail = " + "1\n" + /* Unroll 0*/ + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1, b2\n" + "vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n" + "beq 4f @ jump to tail==2\n" + /* Unroll 1 */ + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n" + "vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n" + "vld1.32 {d4-d7}, [%[a_ptr] :128]! @ load next 2xa0~a3\n" + "vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q11, q7, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q13, q7, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q15, q7, d3[1] @ out9 += b2 * a3\n" + "beq 5f @ jump to tail==3\n" + /* Unroll 2 */ + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1,b2\n" + "vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n" + /* Unroll 3 */ + "vmla.f32 q8, q6, d6[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q6, d6[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d7[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d7[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d6[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q7, d6[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q7, d7[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q7, d7[1] @ out7 += b2 * a3\n" + "b 2f\n" + /* tails==1 final tail */ + "3: @ tail=1\n" + "vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n" + /*aptr - 16 */ + "sub %[a_ptr], %[a_ptr], #16 @ tail--\n" + "b 2f @ jump to end\n" + /* tails==2 final tail*/ + "4: @ tail == 2\n" + "vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d2[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q7, d2[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q7, d3[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q7, d3[1] @ out7 += b2 * a3\n" + "b 2f @ jump to end\n" + /* tails==3 final tail*/ + "5: @ tail=3\n" + "vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n" + /*aptr - 16*/ + "sub %[a_ptr], %[a_ptr], #16 @ tail--\n" + "2: @ check relu\n" + "cmp %[relu], #0 @ check if has relu\n" + "ble 6f @ skip relu if relu <= 0\n" + "vmov.u32 q0, #0 @ for relu\n" + "vmax.f32 q8, q8, q0 @ for relu\n" + "vmax.f32 q9, q9, q0 @ for relu\n" + "vmax.f32 q10, q10, q0 @ for relu\n" + "vmax.f32 q11, q11, q0 @ for relu\n" + "vmax.f32 q12, q12, q0 @ for relu\n" + "vmax.f32 q13, q13, q0 @ for relu\n" + "vmax.f32 q14, q14, q0 @ for relu\n" + "vmax.f32 q15, q15, q0 @ for relu\n" + "6: @ store result\n" + "vst1.32 {d16-d19}, [%[c_ptr0]]! @ store r0\n" + "vst1.32 {d20-d23}, [%[c_ptr1]]! @ store r1\n" + "vst1.32 {d24-d27}, [%[c_ptr2]]! @ store r2\n" + "vst1.32 {d28-d31}, [%[c_ptr3]]! @ store r3\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), [k] "+r"(k), [tails] "+r"(tails) + : [bias_ptr] "r"(bias_local), [relu] "r"(is_relu) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15", "cc", "memory"); + + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + } + } + } + } + } +} +#endif // __aarch64__ + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/packed_sgemm.h b/paddle/fluid/lite/arm/math/packed_sgemm.h new file mode 100644 index 00000000000000..160b432c8d80fe --- /dev/null +++ b/paddle/fluid/lite/arm/math/packed_sgemm.h @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/lite/core/context.h" +#include "paddle/fluid/lite/core/cpu_info.h" +#include "paddle/fluid/lite/core/lite_tensor.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +#ifdef __aarch64__ +constexpr int MBLOCK = 8; +constexpr int NBLOCK = 12; +constexpr int KBLOCK = 4; +inline int get_hblock(ARMArch arch) { return MBLOCK; } +#else +constexpr int MBLOCK_A73 = 4; +constexpr int MBLOCK_OTH = 6; +constexpr int NBLOCK = 8; +constexpr int KBLOCK = 4; +inline int get_hblock(ARMArch arch) { + if (arch == kA73) { + return MBLOCK_A73; + } else { + return MBLOCK_OTH; + } +} +#endif // __aarch64__ + +void prepackA(float* out, const float* in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax, bool is_trans, + ARMContext* ctx); + +void prepackA(TensorLite* tout, const TensorLite& tin, int m, int k, int group, + bool is_trans, ARMContext* ctx); + +void sgemm_prepack(const float* A_packed, const float* B, const float* bias, + float* C, int M, int N, int K, bool is_bias, bool is_relu, + bool is_transB, ARMContext* ctx); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/scale.cc b/paddle/fluid/lite/arm/math/scale.cc new file mode 100644 index 00000000000000..40b91e6979f6f3 --- /dev/null +++ b/paddle/fluid/lite/arm/math/scale.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/scale.h" +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void scale(const float* din, float* dout, int num, float scale, + float bias) { + int cnt = num >> 4; + int remain = num % 16; + float32x4_t vscale = vdupq_n_f32(scale); + float32x4_t vbias = vdupq_n_f32(bias); +#pragma omp parallel for + for (int i = 0; i < cnt; i++) { + const float* din_ptr = din + (i << 4); + float* dout_ptr = dout + (i << 4); + + float32x4_t din0 = vld1q_f32(din_ptr); + float32x4_t din1 = vld1q_f32(din_ptr + 4); + float32x4_t din2 = vld1q_f32(din_ptr + 8); + float32x4_t din3 = vld1q_f32(din_ptr + 12); + + float32x4_t vsum1 = vmlaq_f32(vbias, din0, vscale); + float32x4_t vsum2 = vmlaq_f32(vbias, din1, vscale); + float32x4_t vsum3 = vmlaq_f32(vbias, din2, vscale); + float32x4_t vsum4 = vmlaq_f32(vbias, din3, vscale); + + vst1q_f32(dout_ptr, vsum1); + vst1q_f32(dout_ptr + 4, vsum2); + vst1q_f32(dout_ptr + 8, vsum3); + vst1q_f32(dout_ptr + 12, vsum4); + } + if (remain > 0) { + const float* din_ptr = din + (cnt << 4); + float* dout_ptr = dout + (cnt << 4); + for (int i = 0; i < remain; i++) { + *dout_ptr = *din_ptr * scale + bias; + dout_ptr++; + din_ptr++; + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/scale.h b/paddle/fluid/lite/arm/math/scale.h new file mode 100644 index 00000000000000..97a5f79fc6bfab --- /dev/null +++ b/paddle/fluid/lite/arm/math/scale.h @@ -0,0 +1,28 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void scale(const T* din, T* dout, int num, float scale, float bias); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/softmax.cc b/paddle/fluid/lite/arm/math/softmax.cc new file mode 100644 index 00000000000000..2a081eaf489966 --- /dev/null +++ b/paddle/fluid/lite/arm/math/softmax.cc @@ -0,0 +1,601 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/arm/math/softmax.h" +#include +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template <> +void softmax_basic(const float* din, float* dout, const int axis_size, + const int inner_num, const int outer_num) { + int compute_size = inner_num * outer_num; +#pragma omp parallel for + for (int i = 0; i < compute_size; ++i) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner8_axis4(const float* din, float* dout, + const int axis_size, const int inner_num, + const int outer_num) { + int compute_size = inner_num * outer_num; + int cmp_cnt = compute_size >> 3; + int remain = compute_size % 8; + float32x4_t vone = vdupq_n_f32(1.0f); + +#pragma omp parallel for + for (int c = 0; c < cmp_cnt; ++c) { + int i = c * 8; + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + // get max axis_size == 4 + const float* din_ptr = din + real_index; + const float* din_ptr1 = din_ptr + inner_num; + const float* din_ptr2 = din_ptr1 + inner_num; + const float* din_ptr3 = din_ptr2 + inner_num; + float32x4_t vdata0 = vld1q_f32(din_ptr); + float32x4_t vdata1 = vld1q_f32(din_ptr1); + float32x4_t vdata2 = vld1q_f32(din_ptr2); + float32x4_t vdata3 = vld1q_f32(din_ptr3); + + float32x4_t vdata01 = vld1q_f32(din_ptr + 4); + float32x4_t vdata11 = vld1q_f32(din_ptr1 + 4); + float32x4_t vdata21 = vld1q_f32(din_ptr2 + 4); + float32x4_t vdata31 = vld1q_f32(din_ptr3 + 4); + + float* dout_ptr0 = dout + real_index; + float* dout_ptr1 = dout_ptr0 + inner_num; + float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1); + float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3); + float32x4_t vmax11 = vmaxq_f32(vdata01, vdata11); + float32x4_t vmax21 = vmaxq_f32(vdata21, vdata31); + float* dout_ptr2 = dout_ptr1 + inner_num; + float* dout_ptr3 = dout_ptr2 + inner_num; + float32x4_t vmax = vmaxq_f32(vmax1, vmax2); + float32x4_t vmax_1 = vmaxq_f32(vmax11, vmax21); + + // sub, exp and sum + float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax)); + float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax)); + float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax)); + float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax)); + + float32x4_t vsum01 = exp_ps(vsubq_f32(vdata01, vmax_1)); + float32x4_t vsum11 = exp_ps(vsubq_f32(vdata11, vmax_1)); + float32x4_t vsum21 = exp_ps(vsubq_f32(vdata21, vmax_1)); + float32x4_t vsum31 = exp_ps(vsubq_f32(vdata31, vmax_1)); + + float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1); + float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3); + float32x4_t vsum_11 = vaddq_f32(vsum01, vsum11); + float32x4_t vsum_21 = vaddq_f32(vsum21, vsum31); + + float32x4_t vsum = vaddq_f32(vsum_1, vsum_2); + float32x4_t vsum111 = vaddq_f32(vsum_11, vsum_21); + + float32x4_t vinf = div_ps(vone, vsum); + float32x4_t vinf1 = div_ps(vone, vsum111); + + vsum0 = vmulq_f32(vsum0, vinf); + vsum1 = vmulq_f32(vsum1, vinf); + vsum2 = vmulq_f32(vsum2, vinf); + vsum3 = vmulq_f32(vsum3, vinf); + + vsum01 = vmulq_f32(vsum01, vinf1); + vsum11 = vmulq_f32(vsum11, vinf1); + vsum21 = vmulq_f32(vsum21, vinf1); + vsum31 = vmulq_f32(vsum31, vinf1); + + vst1q_f32(dout_ptr0, vsum0); + vst1q_f32(dout_ptr1, vsum1); + vst1q_f32(dout_ptr2, vsum2); + vst1q_f32(dout_ptr3, vsum3); + + vst1q_f32(dout_ptr0 + 4, vsum01); + vst1q_f32(dout_ptr1 + 4, vsum11); + vst1q_f32(dout_ptr2 + 4, vsum21); + vst1q_f32(dout_ptr3 + 4, vsum31); + } + + int i = cmp_cnt * 8; + + if (remain > 4) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + // get max axis_size == 4 + const float* din_ptr = din + real_index; + const float* din_ptr1 = din_ptr + inner_num; + const float* din_ptr2 = din_ptr1 + inner_num; + const float* din_ptr3 = din_ptr2 + inner_num; + float32x4_t vdata0 = vld1q_f32(din_ptr); + float32x4_t vdata1 = vld1q_f32(din_ptr1); + float32x4_t vdata2 = vld1q_f32(din_ptr2); + float32x4_t vdata3 = vld1q_f32(din_ptr3); + + float* dout_ptr0 = dout + real_index; + float* dout_ptr1 = dout_ptr0 + inner_num; + float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1); + float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3); + float* dout_ptr2 = dout_ptr1 + inner_num; + float* dout_ptr3 = dout_ptr2 + inner_num; + float32x4_t vmax = vmaxq_f32(vmax1, vmax2); + + // sub, exp and sum + float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax)); + float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax)); + float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax)); + float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax)); + + float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1); + float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3); + + float32x4_t vsum = vaddq_f32(vsum_1, vsum_2); + + float32x4_t vone = vdupq_n_f32(1.0f); + float32x4_t vinf = div_ps(vone, vsum); + + vsum0 = vmulq_f32(vsum0, vinf); + vsum1 = vmulq_f32(vsum1, vinf); + vsum2 = vmulq_f32(vsum2, vinf); + vsum3 = vmulq_f32(vsum3, vinf); + + vst1q_f32(dout_ptr0, vsum0); + vst1q_f32(dout_ptr1, vsum1); + vst1q_f32(dout_ptr2, vsum2); + vst1q_f32(dout_ptr3, vsum3); + + i += 4; + } + for (; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner4_axis4(const float* din, float* dout, + const int axis_size, const int inner_num, + const int outer_num) { + int compute_size = inner_num * outer_num; + int cmp_cnt = compute_size >> 2; + int remain = compute_size % 4; + float32x4_t vone = vdupq_n_f32(1.0f); + +#pragma omp parallel for + for (int c = 0; c < cmp_cnt; ++c) { + int i = c * 4; + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + // get max axis_size == 4 + const float* din_ptr = din + real_index; + const float* din_ptr1 = din_ptr + inner_num; + const float* din_ptr2 = din_ptr1 + inner_num; + const float* din_ptr3 = din_ptr2 + inner_num; + float32x4_t vdata0 = vld1q_f32(din_ptr); + float32x4_t vdata1 = vld1q_f32(din_ptr1); + float32x4_t vdata2 = vld1q_f32(din_ptr2); + float32x4_t vdata3 = vld1q_f32(din_ptr3); + + float* dout_ptr0 = dout + real_index; + float* dout_ptr1 = dout_ptr0 + inner_num; + float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1); + float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3); + float* dout_ptr2 = dout_ptr1 + inner_num; + float* dout_ptr3 = dout_ptr2 + inner_num; + float32x4_t vmax = vmaxq_f32(vmax1, vmax2); + + // sub, exp and sum + float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax)); + float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax)); + float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax)); + float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax)); + + float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1); + float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3); + + float32x4_t vsum = vaddq_f32(vsum_1, vsum_2); + + float32x4_t vinf = div_ps(vone, vsum); + + vsum0 = vmulq_f32(vsum0, vinf); + vsum1 = vmulq_f32(vsum1, vinf); + vsum2 = vmulq_f32(vsum2, vinf); + vsum3 = vmulq_f32(vsum3, vinf); + + vst1q_f32(dout_ptr0, vsum0); + vst1q_f32(dout_ptr1, vsum1); + vst1q_f32(dout_ptr2, vsum2); + vst1q_f32(dout_ptr3, vsum3); + } + + int i = cmp_cnt * 8; + for (; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner8(const float* din, float* dout, const int axis_size, + const int inner_num, const int outer_num) { + int compute_size = inner_num * outer_num; + int cmp_cnt = compute_size >> 3; +#pragma omp parallel for + for (int c = 0; c < cmp_cnt; ++c) { + int i = c * 8; + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + const float* din_ptr = din + real_index; + float32x4_t vmax = vld1q_f32(din_ptr); + float32x4_t vmax2 = vld1q_f32(din_ptr + 4); + // get max + for (int j = 1; j < axis_size; ++j) { + din_ptr += inner_num; + float32x4_t vdata = vld1q_f32(din_ptr); + float32x4_t vdata2 = vld1q_f32(din_ptr + 4); + vmax = vmaxq_f32(vmax, vdata); + vmax2 = vmaxq_f32(vmax2, vdata2); + } + + // sub, exp and sum + din_ptr = din + real_index; + float* dout_ptr = dout + real_index; + float32x4_t vdata = vld1q_f32(din_ptr); + float32x4_t vdata2 = vld1q_f32(din_ptr + 4); + float32x4_t vsum = exp_ps(vsubq_f32(vdata, vmax)); + float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax2)); + din_ptr += inner_num; + vst1q_f32(dout_ptr, vsum); + vst1q_f32(dout_ptr + 4, vsum2); + dout_ptr += inner_num; + for (int j = 1; j < axis_size; ++j) { + float32x4_t vdata0 = vld1q_f32(din_ptr); + float32x4_t vdata1 = vld1q_f32(din_ptr + 4); + vdata0 = exp_ps(vsubq_f32(vdata0, vmax)); + vdata1 = exp_ps(vsubq_f32(vdata1, vmax2)); + din_ptr += inner_num; + vsum = vaddq_f32(vsum, vdata0); + vsum2 = vaddq_f32(vsum2, vdata1); + vst1q_f32(dout_ptr, vdata0); + vst1q_f32(dout_ptr + 4, vdata1); + dout_ptr += inner_num; + } + + float32x4_t vone = vdupq_n_f32(1.0f); + float32x4_t vinf = div_ps(vone, vsum); + float32x4_t vinf2 = div_ps(vone, vsum2); + dout_ptr = dout + real_index; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + float32x4_t vdata0 = vld1q_f32(dout_ptr); + float32x4_t vdata1 = vld1q_f32(dout_ptr + 4); + vdata0 = vmulq_f32(vdata0, vinf); + vdata1 = vmulq_f32(vdata1, vinf2); + vst1q_f32(dout_ptr, vdata0); + vst1q_f32(dout_ptr + 4, vdata1); + dout_ptr += inner_num; + } + } + + for (int i = cmp_cnt * 8; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner4(const float* din, float* dout, const int axis_size, + const int inner_num, const int outer_num) { + int compute_size = inner_num * outer_num; + int cmp_cnt = compute_size >> 2; +#pragma omp parallel for + for (int c = 0; c < cmp_cnt; ++c) { + int i = c * 4; + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + // float max_data = din[real_index]; + const float* din_ptr = din + real_index; + float32x4_t vmax = vld1q_f32(din_ptr); + // get max + for (int j = 1; j < axis_size; ++j) { + din_ptr += inner_num; + float32x4_t vdata = vld1q_f32(din_ptr); + vmax = vmaxq_f32(vmax, vdata); + } + // sub, exp and sum + din_ptr = din + real_index; + float* dout_ptr = dout + real_index; + float32x4_t vdata = vld1q_f32(din_ptr); + float32x4_t vsum = exp_ps(vsubq_f32(vdata, vmax)); + din_ptr += inner_num; + vst1q_f32(dout_ptr, vsum); + dout_ptr += inner_num; + for (int j = 1; j < axis_size; ++j) { + // real_index += inner_num; + float32x4_t vdata0 = vld1q_f32(din_ptr); + vdata0 = exp_ps(vsubq_f32(vdata0, vmax)); + din_ptr += inner_num; + vsum = vaddq_f32(vsum, vdata0); + vst1q_f32(dout_ptr, vdata0); + dout_ptr += inner_num; + } + + float32x4_t vone = vdupq_n_f32(1.0f); + float32x4_t vinf = div_ps(vone, vsum); + dout_ptr = dout + real_index; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + float32x4_t vdata0 = vld1q_f32(dout_ptr); + vdata0 = vmulq_f32(vdata0, vinf); + vst1q_f32(dout_ptr, vdata0); + dout_ptr += inner_num; + } + } + + for (int i = cmp_cnt * 4; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + + float max_data = din[real_index]; + // get max + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + max_data = din[real_index] > max_data ? din[real_index] : max_data; + } + + real_index = idx_outer * inner_num + idx_inner; + // sub, exp and sum + dout[real_index] = expf(din[real_index] - max_data); + float sum_data = dout[real_index]; + for (int j = 1; j < axis_size; ++j) { + real_index += inner_num; + dout[real_index] = expf(din[real_index] - max_data); + sum_data += dout[real_index]; + } + + float sum_inv = 1.f / sum_data; + real_index = idx_outer * inner_num + idx_inner; + // get softmax result + for (int j = 0; j < axis_size; ++j) { + dout[real_index] *= sum_inv; + real_index += inner_num; + } + } +} + +template <> +void softmax_inner1_large_axis(const float* din, float* dout, + const int outer_size, + const int axis_size) { +#pragma omp parallel for + for (int i = 0; i < outer_size; ++i) { + const float* din_ptr = din + i * axis_size; + float* dout_ptr = dout + i * axis_size; + + const float* din_max_ptr = din_ptr; + int nn = axis_size >> 2; + + // get max + float32x4_t vmax = vld1q_f32(din_max_ptr); + din_max_ptr += 4; + int j = 1; + for (; j < nn; ++j) { + vmax = vmaxq_f32(vmax, vld1q_f32(din_max_ptr)); + din_max_ptr += 4; + } + float32x2_t vhmax = vmax_f32(vget_high_f32(vmax), vget_low_f32(vmax)); + float max_data = std::max(vget_lane_f32(vhmax, 0), vget_lane_f32(vhmax, 1)); + for (j = 4 * j; j < axis_size; ++j) { + max_data = std::max(max_data, din_max_ptr[0]); + din_max_ptr++; + } + + // sub, exp and sum + const float* din_sum_ptr = din_ptr; + float* dout_sum_ptr = dout_ptr; + vmax = vdupq_n_f32(max_data); + float32x4_t vsub_exp = exp_ps(vsubq_f32(vld1q_f32(din_sum_ptr), vmax)); + float32x4_t vsum = vsub_exp; + vst1q_f32(dout_sum_ptr, vsub_exp); + din_sum_ptr += 4; + dout_sum_ptr += 4; + + j = 1; + for (; j < nn; ++j) { + vsub_exp = exp_ps(vsubq_f32(vld1q_f32(din_sum_ptr), vmax)); + vst1q_f32(dout_sum_ptr, vsub_exp); + vsum = vaddq_f32(vsum, vsub_exp); + din_sum_ptr += 4; + dout_sum_ptr += 4; + } + float32x2_t vhsum = vadd_f32(vget_high_f32(vsum), vget_low_f32(vsum)); + float sum_data = vget_lane_f32(vhsum, 0) + vget_lane_f32(vhsum, 1); + + for (j = 4 * j; j < axis_size; ++j) { + dout_sum_ptr[0] = expf(din_sum_ptr[0] - max_data); + sum_data += dout_sum_ptr[0]; + din_sum_ptr++; + dout_sum_ptr++; + } + + float sum_inv = 1.f / sum_data; + float* dout_res_ptr = dout_ptr; + float32x4_t vinv = vdupq_n_f32(sum_inv); + // get softmax result + j = 0; + for (; j < nn; ++j) { + float32x4_t vout = vld1q_f32(dout_res_ptr); + float32x4_t vres = vmulq_f32(vout, vinv); + vst1q_f32(dout_res_ptr, vres); + dout_res_ptr += 4; + } + for (j = nn * 4; j < axis_size; ++j) { + dout_ptr[j] *= sum_inv; + } + } +} + +template <> +void softmax_inner1_small_axis(const float* din, float* dout, + const int outer_size, + const int axis_size) { +#pragma omp parallel for + for (int i = 0; i < outer_size; ++i) { + const float* din_ptr = din + i * axis_size; + float* dout_ptr = dout + i * axis_size; + // get max + float max_data = din_ptr[0]; + for (int j = 1; j < axis_size; ++j) { + max_data = std::max(max_data, din_ptr[j]); + } + + // sub, exp and sum + float sum_data = 0.f; + for (int j = 0; j < axis_size; ++j) { + dout_ptr[j] = expf(din_ptr[j] - max_data); + sum_data += dout_ptr[j]; + } + + float sum_inv = 1.f / sum_data; + for (int j = 0; j < axis_size; ++j) { + dout_ptr[j] *= sum_inv; + } + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/arm/math/softmax.h b/paddle/fluid/lite/arm/math/softmax.h new file mode 100644 index 00000000000000..c0109ffd12f60a --- /dev/null +++ b/paddle/fluid/lite/arm/math/softmax.h @@ -0,0 +1,52 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace lite { +namespace arm { +namespace math { + +template +void softmax_basic(const T* din, T* dout, const int axis_size, + const int inner_num, const int outer_num); + +template +void softmax_inner8_axis4(const T* din, T* dout, const int axis_size, + const int inner_num, const int outer_num); + +template +void softmax_inner4_axis4(const T* din, T* dout, const int axis_size, + const int inner_num, const int outer_num); +template +void softmax_inner8(const T* din, T* dout, const int axis_size, + const int inner_num, const int outer_num); + +template +void softmax_inner4(const T* din, T* dout, const int axis_size, + const int inner_num, const int outer_num); + +template +void softmax_inner1_large_axis(const T* din, T* dout, const int outer_size, + const int axis_size); + +template +void softmax_inner1_small_axis(const T* din, T* dout, const int outer_size, + const int axis_size); + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index bf40d198dee47f..3edd5db08fdca4 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -1,9 +1,14 @@ -cc_library(lite_gtest_main SRCS lite_gtest_main.cc DEPS gtest) -cc_library(memory_lite SRCS memory.cc DEPS target_wrapper_lite target_wrapper_host) -cc_library(target_wrapper_lite SRCS target_wrapper.cc) -cc_library(lite_tensor SRCS lite_tensor.cc DEPS memory_lite target_wrapper_lite) +if (WITH_TESTING) + cc_library(lite_gtest_main SRCS lite_gtest_main.cc DEPS gtest) +endif() +lite_cc_library(target_wrapper_lite SRCS target_wrapper.cc + DEPS target_wrapper_host + X86_DEPS target_wrapper_x86 + CUDA_DEPS target_wrapper_cuda) +lite_cc_library(memory_lite SRCS memory.cc DEPS target_wrapper_lite) +lite_cc_library(lite_tensor SRCS lite_tensor.cc DEPS memory_lite target_wrapper_lite) if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) - cc_library(hvy_tensor SRCS hvy_tensor.cc DEPS lod_tensor) + lite_cc_library(hvy_tensor SRCS hvy_tensor.cc DEPS lod_tensor HVY_DEPS framework_proto) endif() if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) @@ -14,19 +19,25 @@ endif() proto_library(framework_proto_lite SRCS framework.proto) -cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite) +cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite any_lite op_params_lite framework_proto_lite ${tensor_lite}) cc_library(variable_lite SRCS variable.cc) -cc_library(op_registry_lite SRCS op_registry.cc) -cc_library(scope_lite SRCS scope.cc) -cc_library(context_lite SRCS context.cc DEPS any_lite) -cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite) +cc_library(op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite) +cc_library(scope_lite SRCS scope.cc DEPS ${tensor_lite}) +cc_library(cpu_info_lite SRCS cpu_info.cc) +cc_library(context_lite SRCS context.cc DEPS ${tensor_lite} any_lite cpu_info_lite) +cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapper_lite + cpp_op_desc_lite ${tensor_lite}) cc_library(types_lite SRCS types.cc) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) -cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite) +lite_cc_library(program_lite SRCS program.cc + DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite + HVY_DEPS framework_proto + PROFILE_DEPS basic_profiler_lite) cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite) add_subdirectory(mir) +add_subdirectory(profile) # for mobile, unnecessary to compile the following testings. if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) @@ -40,9 +51,11 @@ cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph ) lite_cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite) -lite_cc_test(test_kernel_lite SRCS kernel_test.cc DEPS kernel_lite target_wrapper_x86) +lite_cc_test(test_kernel_lite SRCS kernel_test.cc DEPS kernel_lite target_wrapper_lite any_lite) lite_cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite) lite_cc_test(test_tensor_lite SRCS lite_tensor_test.cc DEPS lite_tensor) lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_lite) #lite_cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes optimizer_lite fc_op_lite) lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite) +lite_cc_test(test_memory_lite SRCS memory_test.cc DEPS memory_lite) +lite_cc_test(test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator) diff --git a/paddle/fluid/lite/core/context.cc b/paddle/fluid/lite/core/context.cc index fa01f1d3e19dc9..cd7006f4724cca 100644 --- a/paddle/fluid/lite/core/context.cc +++ b/paddle/fluid/lite/core/context.cc @@ -12,8 +12,318 @@ // See the License for the specific language governing permissions and // limitations under the License. -// -// Created by chunwei on 19-2-22. -// - #include "paddle/fluid/lite/core/context.h" +#include "paddle/fluid/lite/core/cpu_info.h" + +#ifdef LITE_WITH_LINUX +#include +#include +#endif +#if __APPLE__ +#include "TargetConditionals.h" +#if TARGET_OS_IPHONE +#include +#include +#include +#endif // TARGET_OS_IPHONE +#endif // __APPLE__ + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM + +void Context::SetCache(int l1size, int l2size, int l3size) { + DeviceInfo& dev = DeviceInfo::Global(); + int cpu_count = arm_get_cpucount(); + dev.L1_cache_.resize(cpu_count); + dev.L2_cache_.resize(cpu_count); + dev.L3_cache_.resize(cpu_count); + for (int i = 0; i < cpu_count; ++i) { + dev.L1_cache_[i] = l1size; + dev.L2_cache_[i] = l2size; + dev.L3_cache_[i] = l3size; + } + workspace_.Resize({2 * (l1size + l2size)}); +} + +Context::Context() { + active_ids_ = {0}; + mode_ = LITE_POWER_HIGH; + DeviceInfo& dev = DeviceInfo::Global(); + workspace_.Resize( + {static_cast(dev.L2_cache_[active_ids_[0]] / sizeof(float))}); +#ifdef TARGET_IOS + arch_ = APPLE; // use 6x8 +#else + if (dev.big_core_ids_.size() > 0) { + arch_ = dev.archs_[dev.big_core_ids_[0]]; + } +#endif +} + +PowerMode Context::mode() const { return mode_; } + +int Context::threads() const { return active_ids_.size(); } + +Context::Context(const ARMContext& ctx) { + mode_ = ctx.mode_; + active_ids_ = ctx.active_ids_; + workspace_ = ctx.workspace_; + arch_ = ctx.arch_; + count_ = ctx.count_; +} + +ARMContext& Context::operator=(const ARMContext& ctx) { + mode_ = ctx.mode_; + active_ids_ = ctx.active_ids_; + workspace_ = ctx.workspace_; + arch_ = ctx.arch_; + count_ = ctx.count_; + return *this; +} + +void Context::BindDev() { +#ifdef USE_OPENMP + int num_threads = active_ids_.size(); + omp_set_num_threads(num_threads); +#ifdef LITE_WITH_LINUX + std::vector ssarets; + for (int j = 0; j < num_threads; ++j) { + ssarets.push_back(0); + } +#pragma omp parallel for + for (int i = 0; i < num_threads; i++) { + ssarets[i] = set_sched_affinity(active_ids_); + } + for (int i = 0; i < num_threads; i++) { + if (ssarets[i] != 0) { + LOGE("set cpu affinity failed, cpuID: %d\n", active_ids_[i]); + return; + } + } +#endif // LITE_WITH_LINUX +#else // USE_OPENMP +#ifdef LITE_WITH_LINUX + std::vector cpuid1; + cpuid1.push_back(active_ids_[0]); + int ssaret = set_sched_affinity(cpuid1); + if (ssaret != 0) { + printf("set cpu affinity failed, cpuID: %d\n", active_ids_[0]); + return; + } +#endif // LITE_WITH_LINUX +#endif // USE_OPENMP +} + +void Context::SetRunMode(PowerMode mode, int threads) { + DeviceInfo& dev = DeviceInfo::Global(); + int big_core_size = dev.big_core_ids_.size(); + int small_core_size = dev.little_core_ids_.size(); + if (threads > big_core_size + small_core_size) { + threads = big_core_size + small_core_size; + } +#ifdef USE_OPENMP + count_++; + int shift_num = (count_ / 10) % big_core_size; + switch (mode) { + case LITE_POWER_FULL: + mode_ = mode; + active_ids_.clear(); + for (int i = 0; i < threads; ++i) { + if (i < big_core_size) { + active_ids_.push_back(dev.big_core_ids_[i]); + } else { + active_ids_.push_back(dev.little_core_ids_[i - big_core_size]); + } + } + if (active_ids_.size() == 0) { + active_ids_.push_back(0); + } + break; + case LITE_POWER_HIGH: + active_ids_.clear(); + if (big_core_size > 0) { + mode_ = LITE_POWER_HIGH; + if (threads > big_core_size) { + LOGE("threads: %d, exceed the big cores size: %d\n", threads, + big_core_size); + active_ids_ = dev.big_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back(dev.big_core_ids_[i]); + } + } + } else { + mode_ = LITE_POWER_LOW; + LOGE("HIGH POWER MODE is not support, switch to little cores\n"); + if (threads > small_core_size) { + active_ids_ = dev.little_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back(dev.little_core_ids_[i]); + } + } + } + if (active_ids_.size() == 0) { + active_ids_.push_back(0); + } + break; + case LITE_POWER_LOW: + active_ids_.clear(); + if (small_core_size > 0) { + mode_ = LITE_POWER_LOW; + if (threads > small_core_size) { + LOGW("threads: %d, exceed the little cores size: %d\n", threads, + small_core_size); + active_ids_ = dev.little_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back(dev.little_core_ids_[i]); + } + } + } else { + mode_ = LITE_POWER_HIGH; + LOGW("LOW POWER MODE is not support, switch to big cores\n"); + if (threads > big_core_size) { + active_ids_ = dev.big_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back(dev.big_core_ids_[i]); + } + } + } + if (active_ids_.size() == 0) { + active_ids_.push_back(0); + } + break; + case LITE_POWER_NO_BIND: + mode_ = LITE_POWER_NO_BIND; + active_ids_.clear(); + if (threads > dev.core_ids_.size()) { + active_ids_.resize(dev.core_ids_.size()); + } else { + active_ids_.resize(threads); + } + break; + case LITE_POWER_RAND_HIGH: + active_ids_.clear(); + if (big_core_size > 0) { + mode_ = LITE_POWER_RAND_HIGH; + if (threads > big_core_size) { + LOGW("threads: %d, exceed the big cores size: %d\n", threads, + big_core_size); + active_ids_ = dev.big_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back( + dev.big_core_ids_[(i + shift_num) % big_core_size]); + } + } + } else { + mode_ = LITE_POWER_LOW; + LOGW("HIGH POWER MODE is not support, switch to little cores\n"); + if (threads > small_core_size) { + active_ids_ = dev.little_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back(dev.little_core_ids_[i]); + } + } + } + if (active_ids_.size() == 0) { + active_ids_.push_back(0); + } + break; + case LITE_POWER_RAND_LOW: + active_ids_.clear(); + if (small_core_size > 0) { + mode_ = LITE_POWER_RAND_LOW; + if (threads > small_core_size) { + LOGW("threads: %d, exceed the little cores size: %d\n", threads, + small_core_size); + active_ids_ = dev.little_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back( + dev.little_core_ids_[(i + shift_num) % small_core_size]); + } + } + } else { + mode_ = LITE_POWER_HIGH; + LOGW("LOW POWER MODE is not support, switch to big cores\n"); + if (threads > big_core_size) { + active_ids_ = dev.big_core_ids_; + } else { + for (int i = 0; i < threads; ++i) { + active_ids_.push_back(dev.big_core_ids_[i]); + } + } + } + if (active_ids_.size() == 0) { + active_ids_.push_back(0); + } + break; + } + //! fix multi-threads LITE_POWER_HIGH mode + if (mode_ == LITE_POWER_NO_BIND || threads > 1) { + int threads = active_ids_.size(); + omp_set_num_threads(threads); + } else { + if (check_online(active_ids_)) { + BindDev(); + } else { + LOG(ERROR) << "core id " << active_ids_[0] + << " is offline, switch to NO BIND MODE"; + int threads = active_ids_.size(); + omp_set_num_threads(threads); + } + } +#else + if (big_core_size > 0) { + active_ids_ = {dev.big_core_ids_[0]}; + } else { + active_ids_ = {0}; + } +#endif + //! alloc memory for sgemm in this context + int temp_mem_size = + DeviceInfo::Global().L2_cache_[active_ids_[0]] / sizeof(float); + workspace_.Resize({temp_mem_size}); + arch_ = DeviceInfo::Global().archs_[active_ids_[0]]; +} + +ARMArch Context::arch() const { return arch_; } + +void Context::SetArch(ARMArch arch) { arch_ = arch; } + +int Context::l1_cache_size() const { + DeviceInfo& dev = DeviceInfo::Global(); + return dev.L1_cache_[active_ids_[0]]; +} + +int Context::l2_cache_size() const { + DeviceInfo& dev = DeviceInfo::Global(); + return dev.L2_cache_[active_ids_[0]]; +} + +int Context::l3_cache_size() const { + DeviceInfo& dev = DeviceInfo::Global(); + return dev.L3_cache_[active_ids_[0]]; +} + +bool Context::ExtendWorkspace(DDimLite dims) { + auto count = dims.product(); + auto old = workspace_.dims(); + if (count == old.product()) { + return false; + } + + workspace_.Resize( + {static_cast(count + l2_cache_size() / sizeof(float))}); + return true; +} +#endif // LITE_WITH_ARM + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/context.h b/paddle/fluid/lite/core/context.h index 01253e0de19527..483f51541440fe 100644 --- a/paddle/fluid/lite/core/context.h +++ b/paddle/fluid/lite/core/context.h @@ -23,45 +23,188 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/device_context.h" #endif +#include #include #include +#include +#include #include +#include "paddle/fluid/lite/core/cpu_info.h" +#include "paddle/fluid/lite/core/lite_tensor.h" #include "paddle/fluid/lite/core/target_wrapper.h" +#include "paddle/fluid/lite/utils/all.h" namespace paddle { namespace lite { -struct HostContext {}; +template +class Context; + +using HostContext = Context; +using X86Context = Context; +using CUDAContext = Context; +using ARMContext = Context; + +template <> +class Context { + public: + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() {} + + void CopyShared(const HostContext* ctx) {} + + std::string name() const { return "HostContext"; } +}; #ifdef LITE_WITH_ARM -struct ARMContext {}; + +template <> +class Context { + public: + Context(); + Context(PowerMode mode, int threads); + explicit Context(const ARMContext& ctx); + + ARMContext& operator=(const ARMContext& ctx); + + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() { DeviceInfo::Init(); } + + void CopyShared(const ARMContext* ctx) {} + + void SetRunMode(PowerMode mode, int threads); + void SetCache(int l1size, int l2size, int l3size); + void SetArch(ARMArch arch); + void BindDev(); + + PowerMode mode() const; + int threads() const; + ARMArch arch() const; + + template + T* workspace_data() { + return workspace_.mutable_data(); + } + + int l1_cache_size() const; + int l2_cache_size() const; + int l3_cache_size() const; + bool ExtendWorkspace(DDimLite dims); + + std::string name() const { return "ARMContext"; } + + private: + // LITE_POWER_HIGH stands for using big cores, + // LITE_POWER_LOW stands for using small core, + // LITE_POWER_FULL stands for using all cores + ARMArch arch_; + PowerMode mode_; + std::vector active_ids_; + TensorLite workspace_; + int64_t count_{0}; +}; #endif #ifdef LITE_WITH_CUDA // Only works with CUDA kernels. -struct CUDAContext { +template <> +class Context { + public: + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() { + cublas_fp32_ = std::make_shared>(); + } + + void CopyShared(const CUDAContext* ctx) { + CHECK(ctx); + CHECK(cublas_fp32_) << "cublas_fp32 should be set first"; + ctx->cublas_fp32_ = cublas_fp32_; + } + + const cudaStream_t exec_stream() { return exec_stream_; } + void SetExecStream(cudaStream_t stream) { exec_stream_ = stream; } + + const cudaStream_t io_stream() { return io_stream_; } + void SetIoStream(cudaStream_t stream) { io_stream_ = stream; } + + std::shared_ptr> cublas_fp32() { return cublas_fp32_; } + void SetCuBlasFP32(std::shared_ptr> cublas_fp32) { + cublas_fp32_ = cublas_fp32; + } + + const std::vector& input_events() { return input_events_; } + void SetInputEvents(const std::vector& input_events) { + input_events_.clear(); + input_events_.assign(input_events.begin(), input_events.end()); + } + + const std::vector& output_events() { return output_events_; } + void SetOutputEvents(const std::vector& output_events) { + output_events_.clear(); + output_events_.assign(output_events.begin(), output_events.end()); + } + + std::string name() const { return "CUDAContext"; } + + private: // overall information - cudaStream_t exec_stream; - cudaStream_t io_stream; + cudaStream_t exec_stream_; + cudaStream_t io_stream_; // not thread-safe, should allocate for each thread. - std::shared_ptr> blas_fp32; + std::shared_ptr> cublas_fp32_; // kernel information - std::vector input_events; - std::vector output_events; + std::vector input_events_; + std::vector output_events_; }; #endif #ifdef LITE_WITH_X86 -struct X86Context { - // overall information +template <> +class Context { + public: + using device_ctx_t = ::paddle::platform::CPUDeviceContext; + using execution_ctx_t = ::paddle::framework::ExecutionContext; + + Context() { + x86_device_context_.reset(new ::paddle::platform::CPUDeviceContext); + x86_execution_context_.reset( + new ::paddle::framework::ExecutionContext(*x86_device_context_)); + } + Context(Context&& ctx) { + x86_device_context_ = std::move(ctx.x86_device_context_); + x86_execution_context_ = std::move(ctx.x86_execution_context_); + } + + // NOTE: InitOnce should only be used by ContextScheduler + void InitOnce() {} + + void CopyShared(const X86Context* ctx) {} + + const device_ctx_t* x86_device_context() { return x86_device_context_.get(); } + void SetX86DeviceContext(std::unique_ptr&& ctx) { + x86_device_context_ = std::move(ctx); + } + + const execution_ctx_t* x86_execution_context() { + return x86_execution_context_.get(); + } + void SetX86ExecutionContext(std::unique_ptr&& ctx) { + x86_execution_context_ = std::move(ctx); + } + + std::string name() const { return "X86Context"; } + + private: + // overall information + // // kernel information // legacy info. - std::unique_ptr<::paddle::platform::CPUDeviceContext> x86_device_context; - std::unique_ptr<::paddle::framework::ExecutionContext> x86_execution_context; + std::unique_ptr x86_device_context_; + std::unique_ptr x86_execution_context_; }; #endif @@ -81,5 +224,67 @@ class KernelContext { Any ctx_; }; +// The ContextScheduler helps to assign different context for each kernel. +class ContextScheduler { + public: + static ContextScheduler& Global() { + static auto* x = new ContextScheduler; + return *x; + } + + std::unique_ptr NewContext(TargetType target) { + std::unique_ptr ctx(new KernelContext); + switch (target) { + case TARGET(kHost): + kernel_contexts_[TargetType::kHost].As().CopyShared( + &ctx->As()); + break; +#ifdef LITE_WITH_X86 + case TARGET(kX86): + kernel_contexts_[TargetType::kX86].As().CopyShared( + &ctx->As()); + break; +#endif +#ifdef LITE_WITH_CUDA + case TARGET(kCUDA): + kernel_contexts_[TargetType::kCUDA].As().CopyShared( + &ctx->As()); + break; +#endif +#ifdef LITE_WITH_ARM + case TARGET(kARM): + kernel_contexts_[TargetType::kARM].As().CopyShared( + &ctx->As()); + break; +#endif + default: + LOG(FATAL) << "unsupported target " << TargetToStr(target); + } + return ctx; + } + + private: + template + void InitContext() { + kernel_contexts_[Type].As().InitOnce(); + } + + ContextScheduler() { + InitContext(); +#ifdef LITE_WITH_X86 + InitContext(); +#endif +#ifdef LITE_WITH_CUDA + InitContext(); +#endif +#ifdef LITE_WITH_ARM + InitContext(); +#endif + } + + private: + std::map kernel_contexts_; +}; + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/context_test.cc b/paddle/fluid/lite/core/context_test.cc new file mode 100644 index 00000000000000..0952aec33f37e4 --- /dev/null +++ b/paddle/fluid/lite/core/context_test.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/context.h" +#include + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_X86 +TEST(ContextScheduler, NewContext) { + auto ctx1_p = ContextScheduler::Global().NewContext(TargetType::kX86); + auto ctx2_p = ContextScheduler::Global().NewContext(TargetType::kX86); + ASSERT_FALSE(ctx1_p.get() == ctx2_p.get()); + + auto& ctx1 = ctx1_p->As(); + auto& ctx2 = ctx2_p->As(); + + ASSERT_EQ(ctx1.name(), "X86Context"); + ASSERT_EQ(ctx2.name(), "X86Context"); + + ASSERT_FALSE(ctx1.x86_device_context() == nullptr || + ctx2.x86_device_context() == nullptr); + ASSERT_FALSE(ctx1.x86_execution_context() == nullptr || + ctx2.x86_execution_context() == nullptr); + + ASSERT_TRUE(ctx1.x86_device_context() != ctx2.x86_device_context()); + ASSERT_TRUE(ctx1.x86_execution_context() != ctx2.x86_execution_context()); + + using device_ctx_t = ::paddle::platform::CPUDeviceContext; + using exec_ctx_t = ::paddle::framework::ExecutionContext; + auto* device_ctx = new device_ctx_t; + ctx1.SetX86DeviceContext(std::unique_ptr(device_ctx)); + ctx1.SetX86ExecutionContext( + std::unique_ptr(new exec_ctx_t(*device_ctx))); +} +#endif + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/cpu_info.cc b/paddle/fluid/lite/core/cpu_info.cc new file mode 100644 index 00000000000000..df80f1c857688f --- /dev/null +++ b/paddle/fluid/lite/core/cpu_info.cc @@ -0,0 +1,629 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/cpu_info.h" +#include + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM + +void DeviceInfo::InitInternal(DeviceInfo* dev) { + set_default_cache(dev); + dev->compute_core_num_ = arm_get_cpucount(); + dev->max_memory_ = arm_get_meminfo(); + +// get max freq +#ifdef LITE_WITH_LINUX + std::vector max_freq(dev->compute_core_num_); + for (int i = 0; i < dev->compute_core_num_; ++i) { + max_freq[i] = get_max_freq_khz(i) / 1000; + } + std::string cpu_name = arm_get_cpu_name(); + if (get_cpu_info_from_name(dev, cpu_name) != true) { + arm_sort_cpuid_by_max_frequency(dev->compute_core_num_, &dev->core_ids_, + max_freq, &dev->cluster_ids_); + dev->big_core_ids_.clear(); + dev->little_core_ids_.clear(); + for (int i = 0; i < dev->cluster_ids_.size(); ++i) { + if (dev->cluster_ids_[i] == 0) { + dev->big_core_ids_.push_back(dev->core_ids_[i]); + } else { + dev->little_core_ids_.push_back(dev->core_ids_[i]); + } + } + arm_get_cpu_arch(&dev->archs_); + } + + LOG(INFO) << "ARM multiprocessors number: " << dev->compute_core_num_; + for (int i = 0; i < dev->compute_core_num_; ++i) { + LOG(INFO) << "ARM multiprocessors ID: " << dev->core_ids_[i] + << ", frequence: " << max_freq[i] + << ", cluster ID: " << dev->cluster_ids_[dev->core_ids_[i]] + << ", CPU ARCH: A" << dev->archs_[i]; + } + LOG(INFO) << "L1 DataCache size is: "; + for (int i = 0; i < dev->compute_core_num_; ++i) { + LOG(INFO) << dev->L1_cache_[i] / 1024 << " KB"; + } + LOG(INFO) << "L2 Cache size is: "; + for (int i = 0; i < dev->compute_core_num_; ++i) { + LOG(INFO) << dev->L2_cache_[i] / 1024 << " KB"; + } + LOG(INFO) << "Total memory: " << dev->max_memory_ << "KB"; + + dev->max_freq_ = max_freq[0]; + for (int j = 1; j < dev->compute_core_num_; ++j) { + if (dev->max_freq_ < max_freq[j]) { + dev->max_freq_ = max_freq[j]; + } + } +#elif defined(TARGET_IOS) + arm_get_cpu_arch(&dev->archs_); +#endif +} + +// cache_id : 0 -> L1, 1 -> L2, 2 -> L3 +void set_cache_info(DeviceInfo* cpu_info, int cache_id, int argc, ...) { + va_list arg_ptr; + va_start(arg_ptr, argc); + std::vector* cache; + switch (cache_id) { + case 0: + cache = &cpu_info->L1_cache_; + break; + case 1: + cache = &cpu_info->L2_cache_; + break; + case 2: + cache = &cpu_info->L3_cache_; + break; + default: + break; + } + int core_num = cpu_info->compute_core_num_; + cache->resize(core_num); + if (argc == 1) { + int cache_size = va_arg(arg_ptr, int); + for (int i = 0; i < core_num; ++i) { + (*cache)[i] = cache_size; + } + } else { + int big_core_num = cpu_info->big_core_ids_.size(); + int little_core_num = cpu_info->little_core_ids_.size(); + int big_core_cache_size = va_arg(arg_ptr, int); + int little_core_cache_size = va_arg(arg_ptr, int); + for (int i = 0; i < big_core_num; ++i) { + (*cache)[cpu_info->big_core_ids_[i]] = big_core_cache_size; + } + for (int i = 0; i < little_core_num; ++i) { + (*cache)[cpu_info->little_core_ids_[i]] = little_core_cache_size; + } + } + va_end(arg_ptr); +} + +void set_arch_info(DeviceInfo* cpu_info, int argc, ...) { + va_list arg_ptr; + va_start(arg_ptr, argc); + int core_num = cpu_info->compute_core_num_; + cpu_info->archs_.resize(core_num); + if (argc == 1) { + ARMArch arch = (ARMArch)va_arg(arg_ptr, int); + for (int i = 0; i < core_num; ++i) { + cpu_info->archs_[i] = arch; + } + } else { + ARMArch big_core_arch = (ARMArch)va_arg(arg_ptr, int); + ARMArch little_core_arch = (ARMArch)va_arg(arg_ptr, int); + int big_core_num = cpu_info->big_core_ids_.size(); + int little_core_num = cpu_info->little_core_ids_.size(); + for (int i = 0; i < big_core_num; ++i) { + cpu_info->archs_[cpu_info->big_core_ids_[i]] = big_core_arch; + } + for (int i = 0; i < little_core_num; ++i) { + cpu_info->archs_[cpu_info->little_core_ids_[i]] = little_core_arch; + } + } + va_end(arg_ptr); +} + +bool get_cpu_info_from_name(DeviceInfo* cpu_info, std::string hardware_name) { + /* Snapdragon */ + if (hardware_name.find("SDM845") != std::string::npos) { // 845 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {4, 5, 6, 7}; + cpu_info->little_core_ids_ = {0, 1, 2, 3}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + set_arch_info(cpu_info, 2, kA75, kA55); + set_cache_info(cpu_info, 0, 1, 32 * 1024); + set_cache_info(cpu_info, 1, 2, 256 * 1024, 128 * 1024); + set_cache_info(cpu_info, 2, 1, 2048 * 1024); + return true; + + } else if (hardware_name.find("SDM710") != std::string::npos) { // 710 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {6, 7}; + cpu_info->little_core_ids_ = {0, 1, 2, 3, 4, 5}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 1, 1, 0, 0}; + set_arch_info(cpu_info, 2, kA75, kA55); + return true; + } else if (hardware_name.find("MSM8998") != std::string::npos) { // 835 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {4, 5, 6, 7}; + cpu_info->little_core_ids_ = {0, 1, 2, 3}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + set_arch_info(cpu_info, 2, kA73, kA53); + set_cache_info(cpu_info, 0, 2, 64 * 1024); + set_cache_info(cpu_info, 1, 2, 1024 * 1024, + /*real cache size is 2M, while that will get bad performace + on conv3x3s1 or gemm, set to 1M or 512K*/ + 1024 * 1024); + return true; + + } else if (hardware_name.find("MSM8996") != std::string::npos) { // 820 + cpu_info->compute_core_num_ = 4; + cpu_info->core_ids_ = {0, 1, 2, 3}; + cpu_info->big_core_ids_ = {2, 3}; + cpu_info->little_core_ids_ = {0, 1}; + cpu_info->cluster_ids_ = {1, 1, 0, 0}; + set_arch_info(cpu_info, 1, kA72); + set_cache_info(cpu_info, 0, 1, 24 * 1024); + set_cache_info(cpu_info, 1, 2, 1024 * 1024, 512 * 1024); + return true; + + } else if (hardware_name.find("SDM660") != std::string::npos || + hardware_name.find("SDM636") != std::string::npos) { // 660, 636 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {4, 5, 6, 7}; + cpu_info->little_core_ids_ = {0, 1, 2, 3}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + set_arch_info(cpu_info, 1, kA73); + set_cache_info(cpu_info, 0, 2, 64 * 1024, 32 * 1024); + set_cache_info(cpu_info, 1, 1, 1024 * 1024); + return true; + + } else if (hardware_name.find("MSM8976") != std::string::npos) { // 652,653 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {4, 5, 6, 7}; + cpu_info->little_core_ids_ = {0, 1, 2, 3}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + set_arch_info(cpu_info, 2, kA72, kA53); + set_cache_info(cpu_info, 0, 1, 32 * 1024); + set_cache_info(cpu_info, 1, 2, 1024 * 1024, 512 * 1024); + return true; + + } else if (hardware_name.find("MSM8953") != std::string::npos) { // 625 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->little_core_ids_ = {}; + cpu_info->cluster_ids_ = {0, 0, 0, 0, 0, 0, 0, 0}; + set_arch_info(cpu_info, 1, kA53); + set_cache_info(cpu_info, 0, 1, 32 * 1024); + set_cache_info(cpu_info, 1, 1, 1024 * 1024); + return true; + + } else if (hardware_name.find("MSM8939") != std::string::npos) { // 615 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {0, 1, 2, 3}; + cpu_info->little_core_ids_ = {4, 5, 6, 7}; + cpu_info->cluster_ids_ = {0, 0, 0, 0, 1, 1, 1, 1}; + set_arch_info(cpu_info, 1, kA53); + set_cache_info(cpu_info, 0, 1, 32 * 1024); + set_cache_info(cpu_info, 1, 2, 512 * 1024, 256 * 1024); + return true; + + /* MediaTek */ + + } else if (hardware_name.find("MT6797") != + std::string::npos) { // X20/X23/X25/X27 + cpu_info->compute_core_num_ = 10; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + cpu_info->big_core_ids_ = {8, 9}; + cpu_info->little_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 1, 1, 1, 1, 0, 0}; + set_arch_info(cpu_info, 2, kA72, kA53); + set_cache_info(cpu_info, 0, 1, 32 * 1024); + set_cache_info(cpu_info, 1, 2, 1024 * 1024, 512 * 1024); + return true; + + } else if (hardware_name.find("MT6799") != std::string::npos) { // X30 + cpu_info->compute_core_num_ = 10; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + cpu_info->big_core_ids_ = {8, 9}; + cpu_info->little_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 1, 1, 1, 1, 0, 0}; + set_arch_info(cpu_info, 2, kA73, kA53); + return true; + + } else if (hardware_name.find("MT6795") != std::string::npos || + hardware_name.find("MT6762") != std::string::npos || + hardware_name.find("MT6755T") != std::string::npos || + hardware_name.find("MT6755S") != std::string::npos || + hardware_name.find("MT6753") != std::string::npos || + hardware_name.find("MT6752") != std::string::npos || + hardware_name.find("MT6750") != std::string::npos) { + // X10, P22, P15/P18, MT6753, MT6752/MT6752M, MT6750 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->little_core_ids_ = {}; + cpu_info->cluster_ids_ = {0, 0, 0, 0, 0, 0, 0, 0}; + set_arch_info(cpu_info, 1, kA53); + return true; + + } else if (hardware_name.find("MT6758") != std::string::npos || + hardware_name.find("MT6757") != std::string::npos || + hardware_name.find("MT6763") != std::string::npos || + hardware_name.find("MT6755M") != std::string::npos || + hardware_name.find("MT6755") != + std::string::npos) { // P30, P20/P25, P23, P10 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {4, 5, 6, 7}; + cpu_info->little_core_ids_ = {0, 1, 2, 3}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + set_arch_info(cpu_info, 1, kA53); + return true; + + } else if (hardware_name.find("MT6771") != std::string::npos) { // P60 + cpu_info->compute_core_num_ = 8; + cpu_info->core_ids_ = {0, 1, 2, 3, 4, 5, 6, 7}; + cpu_info->big_core_ids_ = {4, 5, 6, 7}; + cpu_info->little_core_ids_ = {0, 1, 2, 3}; + cpu_info->cluster_ids_ = {1, 1, 1, 1, 0, 0, 0, 0}; + set_arch_info(cpu_info, 2, kA73, kA53); + return true; + + } else if (hardware_name.find("MT6765") != std::string::npos || + hardware_name.find("MT6739") != std::string::npos || + hardware_name.find("MT6738") != std::string::npos || + hardware_name.find("MT6737") != + std::string::npos) { // A22, MT6739, MT6738, MT6767 + cpu_info->compute_core_num_ = 4; + cpu_info->core_ids_ = {0, 1, 2, 3}; + cpu_info->big_core_ids_ = {0, 0, 0, 0}; + cpu_info->little_core_ids_ = {}; + cpu_info->cluster_ids_ = {0, 0, 0, 0}; + set_arch_info(cpu_info, 1, kA53); + return true; + } + return false; +} + +size_t arm_get_meminfo() { +#ifdef LITE_WITH_LINUX + // get cpu count from /proc/cpuinfo + FILE* fp = fopen("/proc/meminfo", "rb"); + if (!fp) { + return 1; + } + + size_t memsize = 0; + char line[1024]; + while (!feof(fp)) { + char* s = fgets(line, 1024, fp); + if (!s) { + break; + } + sscanf(s, "MemTotal: %d kB", &memsize); + } + + fclose(fp); + + return memsize; +#elif defined(TARGET_IOS) + // to be implemented + printf("not implemented\n"); + return 0; +#endif +} + +int arm_get_cpucount() { +#ifdef LITE_WITH_LINUX + // get cpu count from /sys/devices/system/cpu/cpunum/uevent + int max_cpu_count = 20; + int count = 0; + for (int i = 0; i < max_cpu_count; ++i) { + char path[256]; + snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%d/uevent", i); + FILE* fp = fopen(path, "rb"); + if (!fp) { + break; + } + count++; + fclose(fp); + } + if (count < 1) { + count = 1; + } + return count; +#elif defined(TARGET_IOS) + int count = 0; + size_t len = sizeof(count); + sysctlbyname("hw.ncpu", &count, &len, NULL, 0); + if (count < 1) { + count = 1; + } + return count; +#else + return 1; +#endif +} + +void arm_get_cpu_arch(std::vector* archs) { +#ifdef LITE_WITH_LINUX + archs->clear(); + //! get CPU ARCH + FILE* fp = fopen("/proc/cpuinfo", "rb"); + if (!fp) { + return; + } + char line[1024]; + while (!feof(fp)) { + char* s = fgets(line, 1024, fp); + if (!s) { + break; + } + if (strstr(line, "part") != NULL) { + int arch_id = 0; + sscanf(s, "CPU part\t: %x", &arch_id); + switch (arch_id) { + case 0xd03: + archs->push_back(kA53); + break; + case 0xd05: + archs->push_back(kA55); + break; + case 0xd07: + archs->push_back(kA57); + break; + case 0xd08: + archs->push_back(kA72); + break; + case 0xd09: + archs->push_back(kA73); + break; + case 0xd0a: + archs->push_back(kA75); + break; + case 0x800: + // 835 + archs->push_back(kA73); + break; + case 0x205: + // 820 + archs->push_back(kA72); + break; + default: + LOG(ERROR) << "unknow type"; + archs->push_back(kARMArch_UNKOWN); + } + } + } + fclose(fp); + int cpu_count = arm_get_cpucount(); + if (archs->size() < cpu_count) { + for (int i = archs->size(); i < cpu_count; ++i) { + archs->push_back(archs->at(i - 1)); + } + } +#endif +#ifdef TARGET_IOS + int cpu_count = arm_get_cpucount(); + for (int i = 0; i < cpu_count; ++i) { + archs->push_back(APPLE); + } +#endif +} + +#ifdef LITE_WITH_LINUX + +void set_default_cache(DeviceInfo* dev) { + int cpu_count = arm_get_cpucount(); + dev->L1_cache_.resize(cpu_count); + dev->L2_cache_.resize(cpu_count); + dev->L3_cache_.resize(cpu_count); +#ifdef TARGET_IOS + for (int i = 0; i < cpu_count; ++i) { + dev->L1_cache_[i] = 64 * 1024; + dev->L2_cache_[i] = 2048 * 1024; + dev->L3_cache_[i] = 0; + } +#else + for (int i = 0; i < cpu_count; ++i) { + dev->L1_cache_[i] = 32 * 1024; + dev->L2_cache_[i] = 512 * 1024; + dev->L3_cache_[i] = 0; + } +#endif +} +std::string arm_get_cpu_name() { + FILE* fp = fopen("/proc/cpuinfo", "rb"); + if (!fp) { + return ""; + } + char line[1024]; + while (!feof(fp)) { + char* s = fgets(line, 1024, fp); + if (!s) { + break; + } + if (strstr(line, "Hardware") != NULL) { + fclose(fp); + return std::string(line); + } + } + fclose(fp); + return ""; +} + +int get_max_freq_khz(int cpuid) { + // first try, for all possible cpu + char path[256]; + snprintf(path, sizeof(path), + "/sys/devices/system/cpu/cpufreq/stats/cpu%d/time_in_state", cpuid); + + FILE* fp = fopen(path, "rb"); + + if (!fp) { + // second try, for online cpu + snprintf(path, sizeof(path), + "/sys/devices/system/cpu/cpu%d/cpufreq/stats/time_in_state", + cpuid); + fp = fopen(path, "rb"); + + if (!fp) { + // third try, for online cpu + snprintf(path, sizeof(path), + "/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_max_freq", cpuid); + fp = fopen(path, "rb"); + + if (!fp) { + return -1; + } + + int max_freq_khz = -1; + fscanf(fp, "%d", &max_freq_khz); + + fclose(fp); + + return max_freq_khz; + } + } + + int max_freq_khz = 0; + while (!feof(fp)) { + int freq_khz = 0; + int nscan = fscanf(fp, "%d %*d", &freq_khz); + if (nscan != 1) { + break; + } + + if (freq_khz > max_freq_khz) { + max_freq_khz = freq_khz; + } + } + + fclose(fp); + + return max_freq_khz; +} + +int arm_sort_cpuid_by_max_frequency(int cpu_count, std::vector* cpuids, + const std::vector& cpu_freq, + std::vector* cluster_ids) { + if (cpu_count == 0) { + return 0; + } + + cpuids->resize(cpu_count); + cluster_ids->resize(cpu_count); + + for (int i = 0; i < cpu_count; i++) { + cpuids->at(i) = i; + } + + // sort cpuid as big core first + // simple bubble sort + + for (int i = 0; i < cpu_count; i++) { + for (int j = i + 1; j < cpu_count; j++) { + if (cpu_freq[i] < cpu_freq[j]) { + // swap + int tmp = cpuids->at(i); + cpuids->at(i) = cpuids->at(j); + cpuids->at(j) = tmp; + } + } + } + // SMP + int mid_max_freq_khz = + (cpu_freq[cpuids->at(0)] + cpu_freq[cpuids->at(cpu_count - 1)]) / 2; + + for (int i = 0; i < cpu_count; i++) { + cpuids->at(i) = i; + if (cpu_freq[i] >= mid_max_freq_khz) { + cluster_ids->at(i) = 0; + } else { + cluster_ids->at(i) = 1; + } + } + return 0; +} + +int check_online(const std::vector& core_ids) { + if (core_ids.size() == 0) { + return 0; + } + char path[256]; + int online = 1; + for (int i = 0; i < core_ids.size(); ++i) { + snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%d/online", + core_ids[i]); + FILE* fp = fopen(path, "rb"); + if (!fp) { + return 0; + } + int cur_online = 0; + fscanf(fp, "%d", &cur_online); + online &= cur_online; + fclose(fp); + } + return online; +} + +int set_sched_affinity(const std::vector& cpuids) { +// #define CPU_SETSIZE 1024 +// #define __NCPUBITS (8 * sizeof (unsigned long)) +// typedef struct +// { +// unsigned long __bits[CPU_SETSIZE / __NCPUBITS]; +// } cpu_set_t; + +// set affinity for thread +#ifdef __GLIBC__ + pid_t pid = syscall(SYS_gettid); +#else + pid_t pid = gettid(); +#endif + cpu_set_t mask; + CPU_ZERO(&mask); + for (int i = 0; i < cpuids.size(); i++) { + CPU_SET(cpuids[i], &mask); + } + + int syscallret = syscall(__NR_sched_setaffinity, pid, sizeof(mask), &mask); + if (syscallret) { + LOG(ERROR) << "syscall error " << syscallret; + return -1; + } + + return 0; +} + +#endif // LITE_WITH_LINUX + +#endif // LITE_WITH_ARM + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/cpu_info.h b/paddle/fluid/lite/core/cpu_info.h new file mode 100644 index 00000000000000..385954e6d8e480 --- /dev/null +++ b/paddle/fluid/lite/core/cpu_info.h @@ -0,0 +1,125 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/utils/cp_logging.h" + +#ifdef LITE_WITH_LINUX +#include +#include +#endif + +#if __APPLE__ +#include "TargetConditionals.h" +#if TARGET_OS_IPHONE +#include +#include +#include +#endif // TARGET_OS_IPHONE +#endif // __APPLE__ + +namespace paddle { +namespace lite { + +#ifdef LITE_WITH_ARM + +typedef enum { + LITE_POWER_HIGH = 0, + LITE_POWER_LOW = 1, + LITE_POWER_FULL = 2, + LITE_POWER_NO_BIND = 3, + LITE_POWER_RAND_HIGH = 4, + LITE_POWER_RAND_LOW = 5 +} PowerMode; + +typedef enum { + kAPPLE = 0, + kA53 = 53, + kA55 = 55, + kA57 = 57, + kA72 = 72, + kA73 = 73, + kA75 = 75, + kA76 = 76, + kARMArch_UNKOWN = -1 +} ARMArch; + +class DeviceInfo { + public: + int idx_; + int max_freq_; + int min_freq_; + int generate_arch_; + int compute_core_num_; + int max_memory_; + int sharemem_size_; + + std::string device_name_; + std::string compute_ability_; + + std::vector L1_cache_; + std::vector L2_cache_; + std::vector L3_cache_; + std::vector core_ids_; + std::vector big_core_ids_; + std::vector little_core_ids_; + std::vector cluster_ids_; + std::vector archs_; + + static DeviceInfo& Global() { + static auto* x = new DeviceInfo; + return *x; + } + + static void Init() { + auto& info = Global(); + InitInternal(&info); + } + + private: + DeviceInfo() = default; + static void InitInternal(DeviceInfo* dev); +}; + +size_t arm_get_meminfo(); + +int arm_get_cpucount(); + +void arm_get_cpu_arch(std::vector* archs); + +bool get_cpu_info_from_name(DeviceInfo* cpu_info, std::string hardware_name); + +#ifdef LITE_WITH_LINUX + +void set_default_cache(DeviceInfo* dev); + +std::string arm_get_cpu_name(); + +int get_max_freq_khz(int cpuid); + +int arm_sort_cpuid_by_max_frequency(int cpu_count, std::vector* cpuids, + const std::vector& cpu_freq, + std::vector* cluster_ids); +int check_online(const std::vector& core_ids); +int set_sched_affinity(const std::vector& cpuids); + +#endif // LITE_WITH_LINUX + +#endif // LITE_WITH_ARM + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/hvy_tensor.h b/paddle/fluid/lite/core/hvy_tensor.h index 1fa8dbbee33a87..16172a80035e65 100644 --- a/paddle/fluid/lite/core/hvy_tensor.h +++ b/paddle/fluid/lite/core/hvy_tensor.h @@ -21,6 +21,7 @@ #pragma once #include #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/core/tensor.h" namespace paddle { @@ -39,6 +40,7 @@ class DDimHvy : public DDimBase { } value_type operator[](int offset) const { return data_[offset]; } + value_type& operator[](int offset) { return data_[offset]; } std::vector Vectorize() const { return framework::vectorize(data_); } @@ -64,6 +66,14 @@ class TensorHvy : public TensorBase { using DDimT = DDimHvy; using LoDT = framework::LoD; + template + void Assign(DType* data, const DimT& dim) { + Resize(dim); + auto* dst = mutable_data(Target); + CopySync(dst, data, dim.production() * sizeof(DType), + IoDirection::HtoD); + } + TargetType target() const { if (platform::is_gpu_place(data_.place())) { return TARGET(kCUDA); @@ -94,15 +104,18 @@ class TensorHvy : public TensorBase { const void* raw_data() const { return data_.raw_data(); } void Resize(const DDimHvy& dims) { - LOG(INFO) << "dims.size " << dims.size(); data_.Resize(framework::make_ddim(dims.Vectorize())); } void ShareDataWith(const TensorHvy& other) { data_.ShareDataWith(other.data_); } + void ShareDataWith(const framework::Tensor& other) { + data_.ShareDataWith(other); + } void CopyDataFrom(const TensorHvy& other) { - data_.ShareDataWith(other.data_); + data_.mutable_data(other.data_.place(), other.data_.type()); + TensorCopySync(other.data_, data_.place(), &data_); } DDimT dims() const { return DDimT(framework::vectorize(data_.dims())); } diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index 6846dbb920d2b8..d7b296eec12a27 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -41,10 +41,26 @@ class KernelBase { const std::map& input_types, const std::string& out_arg)>; + protected: + /// Run some initialization before `Run`, it will invoke after `SetParam` and + /// `SetContext`, that is both the param_ and context_ are valid. + virtual void PrepareForRun() {} + + /// Run the kernel. Before Run, both the param_ and context_ should be valid. virtual void Run() = 0; + public: + void Launch() { + if (is_first_epoch_) { + PrepareForRun(); + is_first_epoch_ = false; + } + + Run(); + } + void SetContext(std::unique_ptr&& ctx) { - context_ = std::move(ctx); + ctx_ = std::move(ctx); } template void SetParam(T param) { @@ -86,7 +102,7 @@ class KernelBase { virtual TargetType target() const = 0; virtual PrecisionType precision() const = 0; virtual DataLayoutType layout() const = 0; - const KernelContext* context() const { return context_.get(); } + const KernelContext* context() const { return ctx_.get(); } virtual std::string name() const = 0; // Short human-readable document. @@ -134,13 +150,14 @@ class KernelBase { void Torch() {} protected: - std::unique_ptr context_; + std::unique_ptr ctx_{nullptr}; mutable operators::param_t param_; // The corresponding op type. std::string op_type_{}; // The extra identity to help defficiate a specific kernel, op_type_ + alias_ // is the unique ID for the kernel. std::string alias_{}; + bool is_first_epoch_{true}; }; // Light-weight kernel implementation. @@ -152,9 +169,6 @@ template class KernelLite : public KernelBase { public: - // Set runtime context. - void SetContext(std::unique_ptr&& ctx) { ctx_ = ctx; } - // Run the kernel. virtual void Run() { CHECK(false) << "Not Implemented"; } @@ -168,9 +182,6 @@ class KernelLite : public KernelBase { KernelLite() = default; virtual ~KernelLite() = default; - - protected: - std::unique_ptr ctx_; }; template diff --git a/paddle/fluid/lite/core/lite_tensor.h b/paddle/fluid/lite/core/lite_tensor.h index 3fe29cc33313e0..6cccdc0dd03527 100644 --- a/paddle/fluid/lite/core/lite_tensor.h +++ b/paddle/fluid/lite/core/lite_tensor.h @@ -14,6 +14,7 @@ #pragma once #include +#include // for multiplies #include #include #include @@ -36,10 +37,15 @@ class DDimLite : public DDimBase { void ConstructFrom(const std::vector &x) { data_ = x; } value_type operator[](int offset) const { return data_[offset]; } - std::vector Vectorize() { return data_; } + value_type &operator[](int offset) { return data_[offset]; } + std::vector Vectorize() const { return data_; } size_t size() const { return data_.size(); } bool empty() const { return data_.empty(); } + value_type product() const { + return std::accumulate(std::begin(data_), std::end(data_), 1, + std::multiplies()); + } const std::vector &data() const { return data_; } private: @@ -55,14 +61,24 @@ class TensorLite : public TensorBase { TensorLite() : buffer_(std::make_shared()) {} + template + void Assign(DType *data, const DimT &dim) { + Resize(dim); + auto *dst = mutable_data(Target); + CopySync(dst, data, dim.product() * sizeof(DType), + IoDirection::HtoD); + } + template const T *data() const { return static_cast(buffer_->data()); } void Resize(const DDimLite &ddim) { dims_ = ddim; } + void Resize(const std::vector &x) { dims_ = DDimLite(x); } const DDimLite &dims() const { return dims_; } + int64_t numel() const { return dims_.product(); } const LoD &lod() const { return lod_; } LoD *mutable_lod() { return &lod_; } diff --git a/paddle/fluid/lite/core/memory.cc b/paddle/fluid/lite/core/memory.cc index 39f312be8d4bf3..0224ff1422ab76 100644 --- a/paddle/fluid/lite/core/memory.cc +++ b/paddle/fluid/lite/core/memory.cc @@ -15,5 +15,65 @@ #include "paddle/fluid/lite/core/memory.h" namespace paddle { -namespace lite {} // namespace lite +namespace lite { + +void* TargetMalloc(TargetType target, size_t size) { + void* data{nullptr}; + switch (target) { + case TargetType::kHost: + case TargetType::kX86: + case TargetType::kARM: + data = TargetWrapper::Malloc(size); + break; +#ifdef LITE_WITH_CUDA + case TargetType::kCUDA: + data = + TargetWrapper::Malloc(size); + break; +#endif // LITE_WITH_CUDA + default: + LOG(FATAL) << "Unknown supported target " << TargetToStr(target); + } + return data; +} + +void TargetFree(TargetType target, void* data) { + switch (target) { + case TargetType::kHost: + case TargetType::kX86: + case TargetType::kARM: + TargetWrapper::Free(data); + break; + +#ifdef LITE_WITH_CUDA + case TargetType::kCUDA: + TargetWrapper::Free(data); + break; +#endif // LITE_WITH_CUDA + default: + LOG(FATAL) << "Unknown type"; + } +} + +void TargetCopy(TargetType target, void* dst, const void* src, size_t size) { + switch (target) { + case TargetType::kHost: + case TargetType::kX86: + case TargetType::kARM: + TargetWrapper::MemcpySync(dst, src, size, + IoDirection::DtoD); + break; + +#ifdef LITE_WITH_CUDA + case TargetType::kCUDA: + TargetWrapper::MemcpySync(dst, src, size, + IoDirection::DtoD); + break; +#endif + default: + LOG(FATAL) << "unsupported type"; + } +} + +} // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/memory.h b/paddle/fluid/lite/core/memory.h index 5b332f7e3ac14f..5948f6c4a854d9 100644 --- a/paddle/fluid/lite/core/memory.h +++ b/paddle/fluid/lite/core/memory.h @@ -18,57 +18,16 @@ namespace paddle { namespace lite { -static void* TargetMalloc(TargetType target, size_t size) { - void* data{nullptr}; - switch (target) { - case TargetType::kHost: -#ifdef LITE_WITH_X86 - case TargetType::kX86: -#endif - data = TargetWrapper::Malloc(size); - break; -#ifdef LITE_WITH_CUDA - case TargetType::kCUDA: - data = - TargetWrapper::Malloc(size); - break; -#endif // LITE_WITH_CUDA - default: - LOG(FATAL) << "Unknown supported target " << TargetToStr(target); - } - return data; -} - -static void TargetFree(TargetType target, void* data) { - switch (static_cast(target)) { - case static_cast(TargetType::kX86): - TargetWrapper::Free(data); - break; - case static_cast(TargetType::kCUDA): - TargetWrapper::Free(data); - break; - default: - LOG(FATAL) << "Unknown type"; - } -} +// Malloc memory for a specific Target. All the targets should be an element in +// the `switch` here. +void* TargetMalloc(TargetType target, size_t size); -static void TargetCopy(TargetType target, void* dst, const void* src, - size_t size) { - switch (target) { - case TargetType::kX86: - case TargetType::kHost: - TargetWrapper::MemcpySync(dst, src, size, - IoDirection::DtoD); - break; +// Free memory for a specific Target. All the targets should be an element in +// the `switch` here. +void TargetFree(TargetType target, void* data); - case TargetType::kCUDA: - TargetWrapper::MemcpySync(dst, src, size, - IoDirection::DtoD); - break; - default: - LOG(FATAL) << "unsupported type"; - } -} +// Copy a buffer from host to another target. +void TargetCopy(TargetType target, void* dst, const void* src, size_t size); // Memory buffer manager. class Buffer { diff --git a/paddle/fluid/lite/core/memory_test.cc b/paddle/fluid/lite/core/memory_test.cc new file mode 100644 index 00000000000000..191fb3931c177d --- /dev/null +++ b/paddle/fluid/lite/core/memory_test.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/memory.h" +#include + +namespace paddle { +namespace lite { + +TEST(memory, test) { + auto* buf = TargetMalloc(TARGET(kX86), 10); + ASSERT_TRUE(buf); + TargetFree(TARGET(kX86), buf); + +#ifdef LITE_WITH_CUDA + auto* buf_cuda = TargetMalloc(TARGET(kCUDA), 10); + ASSERT_TRUE(buf_cuda); + TargetFree(Target(kCUDA), buf_cuda); +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index db31aeb58e16b1..ba6c363b86287b 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -1,10 +1,13 @@ -cc_library(mir_node SRCS node.cc) +cc_library(mir_node SRCS node.cc DEPS framework_proto_lite) cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node) cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes) cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) +add_subdirectory(fusion) cc_library(mir_passes - SRCS static_kernel_pick_pass.cc + SRCS fc_fuse_pass.cc + conv_elementwise_add_relu_fuse_pass.cc + static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_transform_pass.cc io_copy_kernel_pick_pass.cc @@ -13,7 +16,7 @@ cc_library(mir_passes argument_type_display_pass.cc demo_pass.cc runtime_context_assign_pass.cc - DEPS mir_pass types_lite context_lite) + DEPS mir_pass types_lite context_lite ${mir_fusers}) # for mobile, unnecessary to compile the following testings. if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) @@ -28,23 +31,52 @@ cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS mir_pass_manager program_fake_utils ) -set(test_variable_place_infrence_pass_DEPS - mul_op_lite - feed_op_lite - fetch_op_lite - io_copy_op_lite - ${host_kernels} - mir_passes - mir_pass_manager - optimizer_lite - program_fake_utils - target_wrapper_host - ) -if (LITE_WITH_CUDA) - set(test_variable_place_infrence_pass_DEPS - ${test_variable_place_infrence_pass_DEPS} target_wrapper_cuda - kernels_cuda - ) +# lite_cc_test(test_variable_place_infrence_pass SRCS variable_place_inference_pass_test.cc +# DEPS +# mul_op_lite +# feed_op_lite +# fetch_op_lite +# io_copy_op_lite +# ${host_kernels} +# mir_passes +# mir_pass_manager +# optimizer_lite +# program_fake_utils +# target_wrapper_host +# PROFILE_DEPS basic_profiler_lite +# CUDA_DEPS target_wrapper_cuda kernels_cuda +# ARM_DEPS mul_compute_arm +# X86_DEPS mul_compute_x86 +# ) + + +lite_cc_library(pattern_matcher_lite SRCS pattern_matcher.cc DEPS mir_node mir_ssa_graph op_lite) +lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern_matcher_lite) + +lite_cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite) + +# TODO(wz) replace framework/proto to lite proto. +if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + # it depends on the fluid/framework/proto, that is too heavy for mobile execution. + lite_cc_test(test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS + pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite + mir_passes compatible_pb_lite program_lite ${ops_lite}) endif() -cc_test(test_variable_place_infrence_pass SRCS variable_place_inference_pass_test.cc DEPS - ${test_variable_place_infrence_pass_DEPS}) + +message(STATUS "----> Ops lite: ${ops_lite}") +message(STATUS "----> Host kernels: ${host_kernels}") +message(STATUS "----> X86 kernels: ${x86_kernels}") + +lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc + DEPS cxx_api_lite mir_passes + ${ops_lite} ${host_kernels} ${x86_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/lite_fc_model + --optimized_model=${LITE_MODEL_DIR}/lite_fc_model_opt SERIAL) + +lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz") +add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz) + +lite_cc_test(test_lite_conv_elementwise_add_relu_fuse + SRCS conv_elementwise_add_relu_fuse_pass_test.cc + DEPS cxx_api_lite mir_passes + ${ops_lite} ${host_kernels} ${x86_kernels}) diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc new file mode 100644 index 00000000000000..065e8ceca3f5bf --- /dev/null +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h" +#include +#include +#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ConvElementwiseAddReLUFusePass::Apply( + const std::unique_ptr& graph) { + fusion::ConvElementwiseAddReLUFuser fuser; + fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass, + paddle::lite::mir::ConvElementwiseAddReLUFusePass); diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h new file mode 100644 index 00000000000000..4276f1ffc8c258 --- /dev/null +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ConvElementwiseAddReLUFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc new file mode 100644 index 00000000000000..5ada0a2c60dabf --- /dev/null +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/lite/api/cxx_api.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" +#include "paddle/fluid/lite/core/mir/passes.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/program.h" + +DEFINE_string(model_dir, "", ""); +DEFINE_string(optimized_model, "", ""); + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + auto* main_block = program_desc->MutableBlock(0); + + auto* conv2d_1 = main_block->AppendOp(); + auto* conv2d_2 = main_block->AppendOp(); + auto* add_1 = main_block->AppendOp(); + auto* add_2 = main_block->AppendOp(); + auto* relu_1 = main_block->AppendOp(); + auto* relu_2 = main_block->AppendOp(); + + main_block->Var("input_1"); + main_block->Var("input_2"); + main_block->Var("filter_1"); + main_block->Var("filter_2"); + main_block->Var("conv2d_1_out"); + main_block->Var("conv2d_2_out"); + main_block->Var("bias_1"); + main_block->Var("add_1_out"); + main_block->Var("add_2_out"); + main_block->Var("relu_1_out"); + main_block->Var("out"); + + scope->Var("input_1")->GetMutable(); + scope->Var("input_2")->GetMutable(); + scope->Var("filter_1")->GetMutable(); + scope->Var("filter_2")->GetMutable(); + scope->Var("conv2d_1_out")->GetMutable(); + scope->Var("conv2d_2_out")->GetMutable(); + scope->Var("bias_1")->GetMutable(); + scope->Var("add_1_out")->GetMutable(); + scope->Var("add_2_out")->GetMutable(); + scope->Var("relu_1_out")->GetMutable(); + scope->Var("out")->GetMutable(); + + conv2d_1->SetType("conv2d"); + conv2d_1->SetInput("Input", {"input_1"}); + conv2d_1->SetInput("Filter", {"filter_1"}); + conv2d_1->SetOutput("Output", {"conv2d_1_out"}); + conv2d_1->SetAttr("strides", std::vector({1, 1})); + conv2d_1->SetAttr("paddings", std::vector({0, 0})); + conv2d_1->SetAttr("groups", 1); + conv2d_1->SetAttr("dilations", std::vector({1, 1})); + conv2d_1->SetAttr("fuse_relu", false); + + add_1->SetType("elementwise_add"); + add_1->SetInput("X", {"conv2d_1_out"}); + add_1->SetInput("Y", {"bias_1"}); + add_1->SetOutput("Out", {"add_1_out"}); + add_1->SetAttr("axis", 1); + + relu_1->SetType("relu"); + relu_1->SetInput("Input", {"add_1_out"}); + relu_1->SetOutput("Out", {"relu_1_out"}); + + conv2d_2->SetType("conv2d"); + conv2d_2->SetInput("Input", {"input_2"}); + conv2d_2->SetInput("Filter", {"filter_2"}); + conv2d_2->SetOutput("Output", {"conv2d_2_out"}); + conv2d_2->SetAttr("strides", std::vector({1, 1})); + conv2d_2->SetAttr("paddings", std::vector({0, 0})); + conv2d_2->SetAttr("groups", 1); + conv2d_2->SetAttr("dilations", std::vector({1, 1})); + conv2d_2->SetAttr("fuse_relu", false); + + add_2->SetType("elementwise_add"); + add_2->SetInput("X", {"conv2d_2_out"}); + add_2->SetInput("Y", {"relu_1_out"}); + add_2->SetOutput("Out", {"add_2_out"}); + add_2->SetAttr("axis", 1); + + relu_2->SetType("relu"); + relu_2->SetInput("Input", {"add_2_out"}); + relu_2->SetOutput("Out", {"out"}); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(conv_elementwise_add_relu_fuse_pass, graph_test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + + ASSERT_EQ(graph->nodes().size(), + 11UL /*vars*/ + 6UL /*ops*/ + 2UL /*feed op + fetch op*/); + Visualize(graph.get()); +} + +TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + Visualize(graph.get()); + const int num_nodes = graph->nodes().size(); + auto* fuser = new ConvElementwiseAddReLUFusePass; + fuser->Apply(graph); + Visualize(graph.get()); + ASSERT_EQ(graph->nodes().size(), num_nodes - 5UL * 2 /*nodes removed */ + + 1UL * 2 /* fused fc node*/); +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(elementwise_add); +USE_LITE_OP(conv2d); +USE_LITE_OP(depthwise_conv2d); +USE_LITE_OP(relu); diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass.cc b/paddle/fluid/lite/core/mir/fc_fuse_pass.cc new file mode 100644 index 00000000000000..008f05ce5cbd5f --- /dev/null +++ b/paddle/fluid/lite/core/mir/fc_fuse_pass.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h" +#include +#include +#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void FcFusePass::Apply(const std::unique_ptr& graph) { + fusion::FcFuser fuser; + fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass); diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass.h b/paddle/fluid/lite/core/mir/fc_fuse_pass.h new file mode 100644 index 00000000000000..f1b548c43f9993 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fc_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class FcFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc new file mode 100644 index 00000000000000..35efedb57971d1 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc @@ -0,0 +1,112 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/lite/api/cxx_api.h" +#include "paddle/fluid/lite/core/mir/passes.h" +#include "paddle/fluid/lite/core/op_registry.h" + +DEFINE_string(model_dir, "", ""); +DEFINE_string(optimized_model, "", ""); + +namespace paddle { +namespace lite { +namespace mir { + +TEST(fc_fuse_pass, fuse_test) { + lite::ExecutorLite predictor; +#ifndef LITE_WITH_CUDA + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); +#else + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kNCHW)}, + Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)}, + Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}, + }); +#endif + + predictor.Build(FLAGS_model_dir, + Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda + valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({100, 100}))); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100 * 100; i++) { + data[i] = i; + } + + predictor.Run(); + + auto* out = predictor.GetOutput(0); + LOG(INFO) << out << " memory size " << out->data_size(); + LOG(INFO) << "out " << out->data()[0]; + LOG(INFO) << "out " << out->data()[1]; + LOG(INFO) << "dims " << out->dims(); + EXPECT_NEAR(out->data()[0], 38.120617f, 1e-5); + EXPECT_NEAR(out->data()[1], 10.109812f, 1e-5); + CHECK_EQ(out->dims()[0], 100); + CHECK_EQ(out->dims()[1], 500); +} + +#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +TEST(fc_fuse_pass, save_model_test) { + lite::ExecutorLite predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); + predictor.Build(FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, + valid_places); + + LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; + predictor.SaveModel(FLAGS_optimized_model); +} +#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(mul); +USE_LITE_OP(elementwise_add); +USE_LITE_OP(elementwise_sub); +USE_LITE_OP(fc); +USE_LITE_OP(feed); +USE_LITE_OP(fetch); +USE_LITE_OP(io_copy); +USE_LITE_OP(softmax); +USE_LITE_OP(scale); +USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); +USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); + +#ifdef LITE_WITH_X86 +USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def); +#endif + +#ifdef LITE_WITH_CUDA +USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); +#endif diff --git a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt new file mode 100644 index 00000000000000..e0816c1be56665 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt @@ -0,0 +1,11 @@ +cc_library(fuse_fc + SRCS fc_fuser.cc + DEPS pattern_matcher_high_api) +cc_library(conv_elementwise_add_relu + SRCS conv_elementwise_add_relu_fuser.cc + DEPS pattern_matcher_high_api) + +set(mir_fusers + fuse_fc + conv_elementwise_add_relu + CACHE INTERNAL "fusers") diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc new file mode 100644 index 00000000000000..c1322386348581 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void ConvElementwiseAddReLUFuser::BuildPattern() { + // create input nodes. + auto* input = VarNode("input"); + auto* filter = VarNode("filter"); + auto* bias = VarNode("bias"); + + // create op nodes + auto* conv2d = OpNode("conv2d", "conv2d"); + auto* add = OpNode("add", "elementwise_add"); + auto* relu = OpNode("relu", "relu"); + + // create intermediate nodes + auto* conv2d_out = VarNode("conv2d_out"); + auto* add_out = VarNode("add_out"); + + // create output node + auto* out = VarNode("output"); + + // create topology. + std::vector conv2d_inputs{filter, input}; + std::vector add_inputs{conv2d_out, bias}; + conv2d_inputs >> *conv2d >> *conv2d_out; + add_inputs >> *add >> *add_out; + *add_out >> *relu >> *out; + + // Some op specialities. + conv2d_out->AsIntermediate(); + add_out->AsIntermediate(); + conv2d->AsIntermediate(); + add->AsIntermediate(); + relu->AsIntermediate(); +} + +void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto conv_op = LiteOpRegistry::Global().Create("conv2d"); + auto conv_old = matched.at("conv2d")->stmt()->op; + auto* scope = conv_old->scope(); + auto& valid_places = conv_old->valid_places(); + conv_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places); + + IR_NODE_LINK_TO(matched.at("input"), new_op_node); + IR_NODE_LINK_TO(matched.at("filter"), new_op_node); + IR_NODE_LINK_TO(matched.at("bias"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("output")); +} + +cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) { + auto* desc = matched.at("conv2d")->stmt()->op_info(); + + cpp::OpDesc op_desc; + op_desc.SetType("conv2d"); + op_desc.SetInput("Input", {matched.at("input")->arg()->name}); + op_desc.SetInput("Filter", {matched.at("filter")->arg()->name}); + op_desc.SetInput("Bias", {matched.at("bias")->arg()->name}); + op_desc.SetOutput("Output", {matched.at("output")->arg()->name}); + // Other inputs. See operators/conv_op.h + std::vector input_arg_names = desc->InputArgumentNames(); + for (auto name : input_arg_names) LOG(INFO) << name; + + if (std::find(input_arg_names.begin(), input_arg_names.end(), + "ResidualData") != input_arg_names.end()) { + op_desc.SetInput("ResidualData", desc->Input("ResidualData")); + } + + // Only consider strides, padding, groups, dilations, fuse_relu for now + op_desc.SetAttr("strides", desc->GetAttr>("strides")); + op_desc.SetAttr("paddings", desc->GetAttr>("paddings")); + op_desc.SetAttr("groups", desc->GetAttr("groups")); + op_desc.SetAttr("dilations", desc->GetAttr>("dilations")); + op_desc.SetAttr("fuse_relu", true); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h new file mode 100644 index 00000000000000..5ba0ee268415c9 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class ConvElementwiseAddReLUFuser : public FuseBase { + public: + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc b/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc new file mode 100644 index 00000000000000..a8b6336595c0fe --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void FcFuser::BuildPattern() { + // create nodes. + auto* x = VarNode("x")->assert_is_op_input("mul", "X"); + auto* W = VarNode("W")->assert_is_op_input("mul", "Y"); + auto* b = VarNode("b"); + auto* mul = OpNode("mul", "mul"); + auto* mul_out = VarNode("mul_out"); + auto* add = OpNode("add", "elementwise_add"); + auto* Out = VarNode("Out"); + + // create topology. + std::vector mul_inputs{W, x}; + std::vector add_inputs{mul_out, b}; + mul_inputs >> *mul >> *mul_out; + add_inputs >> *add >> *Out; + + // Some op specialities. + mul_out->AsIntermediate(); + mul->AsIntermediate(); + add->AsIntermediate(); +} + +void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto fc_op = LiteOpRegistry::Global().Create("fc"); + auto mul = matched.at("mul")->stmt()->op; + auto* scope = mul->scope(); + auto& valid_places = mul->valid_places(); + fc_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places); + + IR_NODE_LINK_TO(matched.at("W"), new_op_node); + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(matched.at("b"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("Out")); +} + +cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + op_desc.SetType("fc"); + op_desc.SetInput("Input", {matched.at("x")->arg()->name}); + op_desc.SetInput("W", {matched.at("W")->arg()->name}); + op_desc.SetInput("Bias", {matched.at("b")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("Out")->arg()->name}); + op_desc.SetAttr( + "in_num_col_dims", + matched.at("mul")->stmt()->op_info()->GetAttr("x_num_col_dims")); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/fc_fuser.h b/paddle/fluid/lite/core/mir/fusion/fc_fuser.h new file mode 100644 index 00000000000000..0e2bc3bc3c3385 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/fc_fuser.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class FcFuser : public FuseBase { + public: + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.h b/paddle/fluid/lite/core/mir/generate_program_pass.h index 498fe01bb6ec17..460c4e90623bb8 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.h +++ b/paddle/fluid/lite/core/mir/generate_program_pass.h @@ -41,7 +41,7 @@ class GenerateProgramPass : public ProgramPass { } private: - std::vector insts_; + std::vector insts_; }; } // namespace mir diff --git a/paddle/fluid/lite/core/mir/node.h b/paddle/fluid/lite/core/mir/node.h index e10472c5885982..67ee47a9e12fde 100644 --- a/paddle/fluid/lite/core/mir/node.h +++ b/paddle/fluid/lite/core/mir/node.h @@ -93,6 +93,16 @@ class Node { return x; } + Stmt* stmt() const { + CHECK(IsStmt()); + return stmt_.get(); + } + + Arg* arg() const { + CHECK(IsArg()); + return arg_.get(); + } + // Set roles. Arg& AsArg() { if (role_ != Role::kUnk) { diff --git a/paddle/fluid/lite/core/mir/pass_manager.cc b/paddle/fluid/lite/core/mir/pass_manager.cc index 508c2fd5522519..e12246ca83985f 100644 --- a/paddle/fluid/lite/core/mir/pass_manager.cc +++ b/paddle/fluid/lite/core/mir/pass_manager.cc @@ -16,10 +16,6 @@ namespace paddle { namespace lite { -namespace mir { - -PassManager::PassManager() {} - -} // namespace mir +namespace mir {} // namespace mir } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pass_manager.h b/paddle/fluid/lite/core/mir/pass_manager.h index 2fc4654d920583..e80c0c851632ad 100644 --- a/paddle/fluid/lite/core/mir/pass_manager.h +++ b/paddle/fluid/lite/core/mir/pass_manager.h @@ -30,7 +30,7 @@ class PassManager { return x; } - PassManager(); + PassManager() {} void Run(const std::unique_ptr& graph) { for (auto& pass : passes_) { diff --git a/paddle/fluid/lite/core/mir/pass_registry.h b/paddle/fluid/lite/core/mir/pass_registry.h index 5c213169b5242a..0586845f3ceb6d 100644 --- a/paddle/fluid/lite/core/mir/pass_registry.h +++ b/paddle/fluid/lite/core/mir/pass_registry.h @@ -15,7 +15,6 @@ #pragma once #include -#include "paddle/fluid/lite/core/mir/pass.h" #include "paddle/fluid/lite/core/mir/pass_manager.h" namespace paddle { @@ -32,6 +31,10 @@ class PassRegistry { bool Touch() const { return true; } }; +} // namespace mir +} // namespace lite +} // namespace paddle + #define REGISTER_MIR_PASS(name__, class__) \ paddle::lite::mir::PassRegistry mir_pass_registry##name__(#name__, \ new class__); \ @@ -43,7 +46,3 @@ class PassRegistry { extern bool mir_pass_registry##name__##_fake(); \ static bool mir_pass_usage##name__ __attribute__((unused)) = \ mir_pass_registry##name__##_fake(); - -} // namespace mir -} // namespace lite -} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/passes.h b/paddle/fluid/lite/core/mir/passes.h index 60e53257ba0100..e0110a8e3b27fd 100644 --- a/paddle/fluid/lite/core/mir/passes.h +++ b/paddle/fluid/lite/core/mir/passes.h @@ -22,6 +22,8 @@ namespace mir {} // namespace mir } // namespace paddle USE_MIR_PASS(demo); +USE_MIR_PASS(lite_fc_fuse_pass); +USE_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass); USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(variable_place_inference_pass); USE_MIR_PASS(type_target_transform_pass); diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.cc b/paddle/fluid/lite/core/mir/pattern_matcher.cc new file mode 100644 index 00000000000000..8a83bd242bd0b9 --- /dev/null +++ b/paddle/fluid/lite/core/mir/pattern_matcher.cc @@ -0,0 +1,492 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "paddle/fluid/inference/analysis/dot.h" +#include "paddle/fluid/lite/core/mir/pattern_matcher.h" +#include "paddle/fluid/lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace mir { + +size_t PMPattern::id_ = 0UL; + +PMNode &PMNode::operator>>(PMNode &right) { + pattern_->AddEdge(this, &right); + // automatically add out op link relation. + if (right.IsOp()) { + CHECK(!right.op_type_.empty()); + this->assert_is_op_input(right.op_type_); + } + + return right; +} + +PMNode &PMNode::operator>>(std::vector &nodes) { + for (auto *node : nodes) { + *this >> *node; + } + return *this; +} + +PMNode &operator>>(std::vector &others, PMNode &me) { + for (auto *o : others) { + *o >> me; + } + return me; +} + +PMNode *PMPattern::NewNode(const std::string &name) { + if (!name.empty()) { + CHECK_EQ(node_map_.count(name), 0UL) + << "PMNode's name should be unique, get duplicate " << name; + } + + nodes_.emplace_back(new PMNode(this, name)); + auto *cur = nodes_.back().get(); + node_map_[name] = cur; + return cur; +} + +PMNode *PMPattern::NewNode(PMNode::teller_t &&teller, const std::string &name) { + if (!name.empty()) { + CHECK_EQ(node_map_.count(name), 0UL) + << "PMNode's name should be unique, get duplicate " << name; + } + + nodes_.emplace_back(new PMNode(std::move(teller), this, name)); + auto *cur = nodes_.back().get(); + node_map_[name] = cur; + return cur; +} + +PMNode *PMPattern::RetrieveNode(const std::string &id) const { + auto it = node_map_.find(id); + if (it == node_map_.end()) { + return nullptr; + } + + return it->second; +} + +void PMPattern::AddEdge(PMNode *a, PMNode *b) { + CHECK(a); + CHECK(b); + CHECK_NE(a, b) << "Can't connect to the same nodes."; + edges_.emplace_back(a, b); +} + +void PatternMatcher::operator()(SSAGraph *graph, + PatternMatcher::handle_t handler) { + if (!MarkPMNodesInGraph(graph)) { + return; + } + + auto subgraphs = DetectPatterns(); + UniquePatterns(&subgraphs); + RemoveOverlappedMatch(&subgraphs); + ValidateByNodeRole(&subgraphs); + + if (subgraphs.empty()) return; + LOG(INFO) << "--- detected " << subgraphs.size() << " subgraphs."; + int id = 0; + for (auto &g : subgraphs) { + VLOG(3) << "optimizing #" << id++ << " subgraph"; + handler(g, graph); + } +} + +bool PatternMatcher::MarkPMNodesInGraph(SSAGraph *graph) { + VLOG(3) << "mark pmnodes in graph"; + if (graph->nodes().empty()) return false; + + for (auto &node : graph->mutable_nodes()) { + for (const auto &pmnode : pattern_.nodes()) { + if (pmnode->Tell(&node)) { + pmnodes2nodes_[pmnode.get()].insert(&node); + } + } + } + // Check to early stop if some PMNode can't find matched Node. + for (auto &pmnode : pattern_.nodes()) { + if (!pmnodes2nodes_.count(pmnode.get())) { + VLOG(4) << pmnode->name() << " can't find matched Node, early stop"; + // return false; + } + } + VLOG(3) << pmnodes2nodes_.size() << " nodes marked"; + + return !pmnodes2nodes_.empty(); +} + +// The intermediate Nodes can only link to the nodes inside the pattern, or this +// subgraph will be droped. +void PatternMatcher::ValidateByNodeRole( + std::vector *subgraphs) { + std::vector result; + + subgraphs->erase( + std::remove_if(subgraphs->begin(), subgraphs->end(), + [](const PatternMatcher::subgraph_t &subgraph) -> bool { + // Collect the inlinks and outlinks. + std::unordered_set ios; + for (auto &item : subgraph) { + ios.insert(item.second); + } + for (auto &item : subgraph) { + if (item.first->IsIntermediate()) { + for (auto *x : item.second->inlinks) { + if (!ios.count(x)) { + return true; + } + } + for (auto *x : item.second->outlinks) { + if (!ios.count(x)) { + return true; + } + } + } + } + return false; + }), + subgraphs->end()); +} + +struct HitGroup { + std::unordered_map roles; + + bool Match(Node *node, PMNode *pat) { + if (nodes_.count(node)) { + if (roles.count(pat) && roles[pat] == node) return true; + return false; + } else { + if (roles.count(pat) && roles[pat] != node) return false; + return true; + } + } + + void Register(Node *node, PMNode *pat) { + roles[pat] = node; + nodes_.insert(node); + } + + private: + std::unordered_set nodes_; +}; + +// Tell whether Node a links to b. +bool IsNodesLink(Node *a, Node *b) { + for (auto *node : a->outlinks) { + if (b == node) { + return true; + } + } + return false; +} + +std::vector PatternMatcher::DetectPatterns() { + // Init empty subgraphs. + std::vector result; + std::vector init_groups; + std::array, 2> bi_records; + auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get() + : pattern_.edges().front().first; + if (!pmnodes2nodes_.count(first_pnode)) return result; + for (auto *node : pmnodes2nodes_[first_pnode]) { + HitGroup group; + group.roles[first_pnode] = node; + init_groups.emplace_back(group); + } + + int step = 0; + bi_records[0] = std::move(init_groups); + + // Extend a PMNode to subgraphs by deducing the connection relations defined + // in edges of PMNodes. + for (const auto &edge : pattern_.edges()) { + VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name(); + // TODO(Superjomn) Fix bug here, the groups might be duplicate here. + // Each role has two PMNodes, which indicates two roles. + // Detect two Nodes that can match these two roles and they are connected. + auto &pre_groups = bi_records[step % 2]; + auto &cur_groups = bi_records[1 - (step++ % 2)]; + cur_groups.clear(); + if (pre_groups.empty()) break; + // source -> target + for (Node *source : pmnodes2nodes_[edge.first]) { + for (Node *target : pmnodes2nodes_[edge.second]) { + // TODO(Superjomn) add some prune strategies. + for (const auto &group : pre_groups) { + if (IsNodesLink(source, target)) { + HitGroup new_group = group; + bool flag = new_group.Match(source, edge.first) && + new_group.Match(target, edge.second); + if (flag) { + new_group.Register(source, edge.first); + new_group.Register(target, edge.second); + cur_groups.push_back(new_group); + // TODO(Superjomn) need to unique + } + } + } + } + } + VLOG(3) << "step " << step << " get records: " << cur_groups.size(); + } + + for (auto &group : bi_records[step % 2]) { + PatternMatcher::subgraph_t subgraph; + for (auto &role : group.roles) { + subgraph.emplace(role.first, role.second); + } + result.emplace_back(subgraph); + } + return result; +} + +struct GraphItemLessThan { + bool operator()(const std::pair &a, + const std::pair &b) { + if (a.first != b.first) { + return a.first < b.first; + } else { + return a.second < b.second; + } + } +}; + +// TODO(Superjomn) enhance the function as it marks unique unique as duplicates +// see https://github.com/PaddlePaddle/Paddle/issues/13550 +void PatternMatcher::UniquePatterns( + std::vector *subgraphs) { + if (subgraphs->empty()) return; + std::vector result; + + std::unordered_set set; + std::hash hasher; + for (auto &g : *subgraphs) { + // Sort the items in the sub-graph, and transform to a string key. + std::vector> sorted_keys(g.begin(), g.end()); + std::sort(sorted_keys.begin(), sorted_keys.end(), GraphItemLessThan()); + std::stringstream ss; + for (auto &item : sorted_keys) { + ss << item.first << ":" << item.second; + } + auto key = hasher(ss.str()); + if (!set.count(key)) { + result.emplace_back(g); + set.insert(key); + } + } + *subgraphs = result; +} + +void PatternMatcher::RemoveOverlappedMatch(std::vector *subgraphs) { + std::vector result; + std::unordered_set node_set; + + for (const auto &subgraph : *subgraphs) { + bool valid = true; + for (auto &item : subgraph) { + if (item.first->IsIntermediate() && node_set.count(item.second)) { + valid = false; + break; + } + } + if (valid) { + for (auto &item : subgraph) { + node_set.insert(item.second); + } + result.push_back(subgraph); + } + } + *subgraphs = result; +} + +std::string PMPattern::DotString() const { + using inference::analysis::Dot; + Dot dot; + int id = 0; + // Create Nodes + std::unordered_map node2dot; + for (const auto &node : nodes()) { + std::string node_id = "Node" + std::to_string(id++); + dot.AddNode(node_id, {}, node->name()); + node2dot[node.get()] = node_id; + } + // Create Edges + for (const auto &edge : edges()) { + if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) { + LOG(ERROR) << "no node " << edge.first << " " << edge.second; + continue; + } + auto &src = node2dot.at(edge.first); + auto &trg = node2dot.at(edge.second); + dot.AddEdge(src, trg, {}); + } + return dot.Build(); +} + +PMNode &PMNode::LinksTo(const std::vector &others) { + // extend outlinks. + for (PMNode *x : others) { + pattern_->AddEdge(this, x); + } + return *this; +} + +PMNode &PMNode::LinksFrom(const std::vector &others) { + // extend outlinks. + for (PMNode *x : others) { + pattern_->AddEdge(x, this); + } + return *this; +} + +PMNode *PMNode::assert_is_op() { + asserts_.emplace_back([](const Node *x) { return x && x->IsStmt(); }); + return this; +} + +PMNode *PMNode::assert_is_op(const std::string &op_type) { + asserts_.emplace_back([op_type](const Node *x) { + if (x && x->IsStmt()) { + auto *op_info = x->stmt()->op_info(); + return op_info->Type() == op_type; + } else { + return false; + } + }); + return this; +} + +PMNode *PMNode::assert_is_var() { + asserts_.emplace_back([](const Node *x) { return x && x->IsArg(); }); + return this; +} + +PMNode *PMNode::assert_var_not_persistable() { + assert_is_var(); + asserts_.emplace_back([](const Node *x) { return !x->arg()->is_weight; }); + return this; +} + +PMNode *PMNode::assert_is_persistable_var() { + assert_is_var(); + asserts_.emplace_back([=](const Node *x) { return x->arg()->is_weight; }); + return this; +} + +PMNode *PMNode::assert_is_op_output(const std::string &op_type) { + assert_is_var(); + asserts_.emplace_back([=](const Node *x) { + for (auto *op : x->inlinks) { + if (op && op->IsStmt()) { + auto *op_info = x->stmt()->op_info(); + if (op_info->Type() == op_type) return true; + } + } + return false; + }); + return this; +} + +PMNode *PMNode::assert_is_op_input(const std::string &op_type) { + assert_is_var(); + asserts_.emplace_back([=](const Node *x) { + for (auto *op : x->outlinks) { + if (op && op->IsStmt()) { + auto *op_info = op->stmt()->op_info(); + if (op_info->Type() == op_type) { + return true; + } + } + } + return false; + }); + return this; +} + +PMNode *PMNode::assert_is_op_input(const std::string &op_type, + const std::string &argument) { + assert_is_var(); + assert_is_op_nth_input(op_type, argument, 0); + return this; +} + +PMNode *PMNode::assert_is_op_nth_input(const std::string &op_type, + const std::string &argument, int nth) { + assert_is_var(); + assert_is_op_input(op_type); + asserts_.emplace_back([=](const Node *x) { + for (auto *op : x->outlinks) { + if (op->IsStmt() && op->stmt()->op_info()->Type() == op_type && + IsNthInput(*x, *op, argument, nth)) + return true; + } + return false; + }); + return this; +} + +bool IsNthInput(const Node &var, const Node &op, const std::string &argument, + int nth) { + CHECK(var.IsArg()); + CHECK(op.IsStmt()); + if (!HasInput(op, argument) || + static_cast(op.stmt()->op_info()->Input(argument).size()) <= nth) + return false; + return var.arg()->name == op.stmt()->op_info()->Input(argument)[nth]; +} + +bool HasInput(const Node &op, const std::string &argument) { + CHECK(op.IsStmt()); + auto const &names = op.stmt()->op_info()->input_argnames(); + if (std::find(names.begin(), names.end(), argument) == names.end()) + return false; + return true; +} + +void GraphSafeRemoveNodes(SSAGraph *graph, + const std::unordered_set &nodes) { + for (auto *node : nodes) { + graph->RemoveNode(node); + } + + for (auto &node : graph->mutable_nodes()) { + for (auto it = node.inlinks.begin(); it != node.inlinks.end();) { + if (nodes.count(*it)) { + it = node.inlinks.erase(it); + } else { + it++; + } + } + for (auto it = node.outlinks.begin(); it != node.outlinks.end();) { + if (nodes.count(*it)) { + it = node.outlinks.erase(it); + } else { + it++; + } + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.h b/paddle/fluid/lite/core/mir/pattern_matcher.h new file mode 100644 index 00000000000000..f2862a229e3eea --- /dev/null +++ b/paddle/fluid/lite/core/mir/pattern_matcher.h @@ -0,0 +1,410 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_TESTING +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/lite/core/mir/node.h" +#include "paddle/fluid/lite/core/mir/ssa_graph.h" +#include "paddle/fluid/lite/model_parser/pb/op_desc.h" + +namespace paddle { +namespace lite { +namespace mir { +class PMPattern; + +// Some basic terminologies: +// - PMPattern: a pattern defined as a data flow graph. +// - PMNode: the node in the pattern, each PMNode represents an `mir::Node` +// that meets some conditions defined in `PMNode.teller`. +// - A pattern is defined with PMNodes with edges. + +// Pattern matcher node. This node helps to build a pattern. +struct PMNode { + // tell whether an mir::Node* is a candidation for a PMNode. + using teller_t = std::function; + enum class Type { kOp, kVar }; + enum class Role { + kUnknown, // No role, + kInput, // an input and will be retained, + kOutput, // an output and will be retained, + kIntermediate // will be removed after handler. + }; + + // this link to others + PMNode& LinksTo(const std::vector& others); + PMNode& LinksFrom(const std::vector& others); + + // Link this to another node. + PMNode& operator>>(PMNode& right); + + // Link many nodes to this node. + friend PMNode& operator>>(std::vector& others, PMNode& me); + + // Link this to many other nodes. + PMNode& operator>>(std::vector& nodes); + + bool Tell(const Node* node) const { + if (teller_) return teller_(node); + + for (auto& asrt : asserts_) { + if (!asrt(node)) return false; + } + return true; + } + + bool IsOp() const { return type_ == Type::kOp; } + bool IsVar() const { return type_ == Type::kVar; } + + const std::string& name() const { return name_; } + + PMNode& operator=(const PMNode&) = delete; + PMNode(const PMNode&) = delete; + + // Mark this node is an Input of a subgraph and will be retained. + PMNode* AsInput() { + role_ = Role::kInput; + return this; + } + // Mark this node is an Output of a subgraph and will be retained. + PMNode* AsOutput() { + role_ = Role::kOutput; + return this; + } + // Mark this node will be removed, so all the links should be inside a matched + // sub-graph. + PMNode* AsIntermediate() { + role_ = Role::kIntermediate; + return this; + } + + PMNode* AsVar() { + type_ = Type::kVar; + assert_is_var(); + return this; + } + + PMNode* AsOp(const std::string& op_type) { + type_ = Type::kOp; + assert_is_op(op_type); + return this; + } + + void set_op_type(const std::string& op_type) { op_type_ = op_type; } + + bool IsIntermediate() const { return role_ == Role::kIntermediate; } + bool IsInput() const { return role_ == Role::kInput; } + bool IsOutput() const { return role_ == Role::kOutput; } + + // Assertions, helper functions to simplify the pattern definition. + PMNode* assert_is_op(); + PMNode* assert_is_op(const std::string& op_type); + PMNode* assert_is_var(); + PMNode* assert_var_not_persistable(); + PMNode* assert_is_persistable_var(); + PMNode* assert_is_op_output(const std::string& op_type); + PMNode* assert_is_op_input(const std::string& op_type); + PMNode* assert_is_op_input(const std::string& op_type, + const std::string& argument); + PMNode* assert_is_op_nth_input(const std::string& op_type, + const std::string& argument, int nth); + + template + PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { + asserts_.emplace_back([=](Node* x) { + if (x && x->IsStmt()) { + auto* op_info = x->stmt()->op_info(); + return op_info->HasAttr(attr_name) && + op_info->GetAttr(attr_name) == attr; + } else { + return false; + } + }); + return this; + } + + private: + PMNode(PMPattern* pattern, const std::string& name = "", + Type type = Type::kVar) + : pattern_(pattern), name_(name), type_(type) {} + PMNode(teller_t&& teller, PMPattern* pattern, const std::string& name = "", + Type type = Type::kVar) + : teller_(std::move(teller)), + pattern_(pattern), + name_(name), + type_(type) { + CHECK(teller_ != nullptr) << "invalid teller functer is set."; + } + + PMNode(PMNode&& other) = default; + + friend class PMPattern; + + // Will removed latter. + teller_t teller_; + std::vector asserts_; + PMPattern* pattern_; + std::string name_; + std::string op_type_; + Type type_; + Role role_{Role::kUnknown}; +}; + +/* + * A pattern in a graph, which defined with PMNode and edges. Most graph + * patterns can be divided into PMNodes and link relations between them. + * + * For example, the FC fusion need to filter the MUL and ELEMENTWISE_ADD + * operators from the computation graph, the MUL's output should have only one + * consumer which is the ELEMENTWISE_ADD. + * This pattern can be defined as with the following pseudo codes + * + * // Create two operator PMNodes. + * MUL = PMPattern.NewNode().assert_is_op("mul"); + * ELE = PMPattern.NewNode().assert_is_op("elementwise_add"); + * // Create the variable PMNodes. + * MUL_out = PMPattern.NewNode().assert_is_op_output("mul") \ + * .assert_is_op_input("elementwise_add") \ + * .AsIntermediate(); + * // Add relations. + * MUL->LinksTo({MUL_out}); + * MUL_out->LinksTo({ELE}); + * + * One can add more specific asserts for PMNodes or edges, both the Operator + * and Variable Nodes can be ruled in PMNode.assert_more(...). + * + * PMPattern can record the general patterns, such as the pattern represents + * - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place. + * - Ops whose inputs and outputs share the same variables + */ +class PMPattern { + public: + using edge_t = std::pair; + + void AddEdge(PMNode* a, PMNode* b); + + PMNode* NewNode(PMNode::teller_t&& teller, const std::string& name = NewID()); + PMNode* NewNode(const std::string& name = NewID()); + PMNode* NewNode(const std::string& prefix, const std::string& name) { + return NewNode(prefix + "/" + name); + } + PMNode* RetrieveNode(const std::string& id) const; + + const std::vector>& nodes() const { return nodes_; } + const std::vector& edges() const { return edges_; } + + std::string DotString() const; + + private: +#ifdef PADDLE_WITH_TESTING + FRIEND_TEST(PMPattern, AddEdge); + FRIEND_TEST(PMPattern, NewNode); +#endif + + static std::string NewID() { return "pmnode-" + std::to_string(id_++); } + + std::vector> nodes_; + std::vector edges_; + std::unordered_map node_map_; + static size_t id_; +}; + +/* + * PatternMatcher helps to detect the specific patterns in the graph. + * Input a pattern, output a list of the matched subgraphs/nodes. + * This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.). + * + * The algorithm has three phases: + * 1. Mark the nodes that match the defined PMNodes in a PMPattern, + * 2. Extend a PMNode to subgraphs by deducing the connection relation defined + * in PAPattern(the edges), + * 3. Get the filtered subgraphs and treat them with a pre-defined handler. + * + * Usage: + * // Create a matcher + * PatternMatcher matcher; + * // Define the matcher's pattern, by adding PMNode and define the edges. + * auto* node0 = matcher.mutable_pattern().AddNode(...) + * auto* node1 = matcher.mutable_pattern().AddNode(...) + * node0->teller = some lambda. + * node1->teller = some lambda. + * matcher.mutable_pattern().AddEdge(node0, node1); + * // Create an handler, to define the behavior of treating the filtered + * // subgraphs that comply with the patterns. + * PatternMatcher::handle_t handler = some labmda + * // Execute the matcher. + * matcher(&graph, handler); + */ +class PatternMatcher { + public: + using subgraph_t = std::unordered_map; + + // Operate on the detected pattern. + using handle_t = + std::function; + + void operator()(SSAGraph* graph, handle_t handler); + + const PMPattern& pattern() const { return pattern_; } + PMPattern* mutable_pattern() { return &pattern_; } + + private: + // Mark the nodes that fits the pattern. + bool MarkPMNodesInGraph(SSAGraph* graph); + + // Detect all the pattern and output the hit records. + std::vector DetectPatterns(); + + // Remove duplicate patterns. + void UniquePatterns(std::vector* subgraphs); + + // Remove overlapped match subgraphs, when overlapped, keep the previous one. + // The intermediate PMNodes will be removed, so can't shared by multiple + // patterns. + void RemoveOverlappedMatch(std::vector* subgraphs); + + // Validate whether the intermediate nodes are linked by external nodes. + void ValidateByNodeRole(std::vector* subgraphs); + +#ifdef PADDLE_WITH_TESTING + FRIEND_TEST(PatternMatcher, MarkPMNodesInGraph); + FRIEND_TEST(PatternMatcher, DetectPatterns); +#endif + + private: + using hit_rcd_t = + std::pair; + PMPattern pattern_; + std::unordered_map> pmnodes2nodes_; +}; + +// Check whether a var node is a op node's nth input. +bool IsNthInput(const Node& var, const Node& op, const std::string& argument, + int nth); + +// Check whether the op node has input of given name. +bool HasInput(const Node& op, const std::string& argument); + +// Graph safely remove some nodes, will automatically clean up the edges. +void GraphSafeRemoveNodes(SSAGraph* graph, + const std::unordered_set& nodes); + +// Some pre-defined patterns those can be reused in multiple passes. +// The related Fluid Layer or Op should be one pattern here for better re-usage +// across different fusion. +namespace patterns { + +struct KeyCounter { + static KeyCounter& Instance() { + static KeyCounter x; + return x; + } + + int IncCounter(const std::string& key) { return dic_[key]++; } + + private: + std::unordered_map dic_; +}; + +// Generate a unique PMNode's name with name_scope and id. +// The format is {name_scope}/{repr}/{id}/{name} +static std::string PMNodeName(const std::string& name_scope, + const std::string& repr, size_t id, + const std::string& name) { + std::stringstream ss; + ss << name_scope << "/" << repr << "/" << id << "/" << name; + return ss.str(); +} +// Generate a unique PMNode's name. +// The format is {name_scope}/{repr}/{id} +static std::string PMNodeName(const std::string& name_scope, + const std::string& repr) { + std::stringstream ss; + ss << name_scope << "/" << repr << "/" + << KeyCounter::Instance().IncCounter(repr); + return ss.str(); +} +// Generate a unique key. It can be used for a universally unique temporary +// name. +// The format is {repr}/{id} +static std::string UniqueKey(const std::string& repr) { + std::stringstream ss; + ss << repr << "/" << KeyCounter::Instance().IncCounter(repr); + return ss.str(); +} + +// Declare a PMNode in a pattern, will create two methods: +// std::string xxx_repr(); return this PMNode's string id. +// PMNode* xxx_n(); return the corresponding PMNode. +#define PATTERN_DECL_NODE(name__) \ + std::string name__##_repr() const { \ + return PMNodeName(name_scope_, repr_, id_, #name__); \ + } \ + PMNode* name__##_n() const { return pattern->RetrieveNode(name__##_repr()); } + +// Get an mir::Node* from the matched subgraph. +// var: variable. +// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition. +// pat: the pattern object. +#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \ + CHECK(subgraph.count(pat.arg##_n())) \ + << "Node not found for PMNode " pat.arg##_repr(); \ + Node* var = subgraph.at(pat.arg##_n()); \ + CHECK(var) << "node " << #arg << "not exists in the sub-graph" + +// The base class of all the patterns. +struct PatternBase { + PatternBase(PMPattern* pattern, const std::string& name_scope, + const std::string& repr) + : pattern(pattern), + name_scope_(name_scope), + repr_(repr), + id_(KeyCounter::Instance().IncCounter(repr)) {} + + PMPattern* pattern; + + protected: + std::string name_scope_; + std::string repr_; + size_t id_; +}; + +} // namespace patterns + +// Link two mir::Nodes from each other. +#define IR_NODE_LINK_TO(a, b) \ + a->outlinks.push_back(b); \ + b->inlinks.push_back(a); + +// Set the out_var as the output of the op +#define IR_OP_VAR_LINK(op, out_var) \ + op->outlinks.push_back(out_var); \ + out_var->inlinks.clear(); \ + out_var->inlinks.push_back(op); + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc new file mode 100644 index 00000000000000..5dc929cda5ee29 --- /dev/null +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" +#include + +namespace paddle { +namespace lite { +namespace mir { + +void FuseBase::PerformPatternMatcher(SSAGraph *graph) { + LOG(INFO) << "\n" << matcher_.pattern().DotString(); + // Get subgraphs and record the mir::Node pointers for each PMNode. + auto handler = [&](const PatternMatcher::subgraph_t &subgraph, SSAGraph *g) { + // get all the reigistered nodes. + key2nodes_.emplace_back(); + for (auto &item : nodes_) { + key2nodes_.back()[item.first] = subgraph.at(item.second); + } + }; + + matcher_(graph, handler); +} + +void FuseBase::DeleteInterNodes(SSAGraph *graph) { + std::set keys; + for (auto &node : nodes_) { + if (node.second->IsIntermediate()) { + keys.insert(node.first); + } + } + + LOG(INFO) << "keys.size " << keys.size(); + + std::unordered_set nodes2rm; + for (auto &matched : key2nodes_) { + LOG(INFO) << "get matched " << matched.size(); + for (const auto &key : keys) { + nodes2rm.insert(matched.at(key)); + } + } + + LOG(INFO) << "clean nodes " << nodes2rm.size(); + GraphSafeRemoveNodes(graph, nodes2rm); +} + +PMNode *FuseBase::GetOrCreateNode(const std::string &key) { + auto it = nodes_.find(key); + if (it != nodes_.end()) { + return it->second; + } + nodes_.emplace(key, + matcher_.mutable_pattern()->NewNode(patterns::UniqueKey(key))); + it = nodes_.find(key); + return it->second; +} + +PMNode *FuseBase::OpNode(const std::string &key, const std::string &op_type) { + GetOrCreateNode(key)->set_op_type(op_type); + GetOrCreateNode(key)->AsOp(op_type); + return GetOrCreateNode(key); +} + +PMNode *FuseBase::VarNode(const std::string &key) { + GetOrCreateNode(key)->AsVar(); + return GetOrCreateNode(key); +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h new file mode 100644 index 00000000000000..b3a23c654bdb36 --- /dev/null +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h @@ -0,0 +1,77 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/lite/core/mir/node.h" +#include "paddle/fluid/lite/core/mir/pattern_matcher.h" +#include "paddle/fluid/lite/core/mir/ssa_graph.h" + +namespace paddle { +namespace lite { +namespace mir { + +class FuseBase { + public: + using key2nodes_t = std::map; + + virtual ~FuseBase() = default; + + void operator()(SSAGraph* graph) { + BuildPattern(); + PerformPatternMatcher(graph); + + for (const auto& matched : key2nodes_) { + InsertNewNode(graph, matched); + } + + DeleteInterNodes(graph); + } + + // Build a PMPattern using PMNode. + virtual void BuildPattern() = 0; + + // Generate an operator desc with a matched subgraph. + virtual cpp::OpDesc GenOpDesc(const key2nodes_t& matched) = 0; + + PMNode* OpNode(const std::string& key, const std::string& op_type); + + PMNode* VarNode(const std::string& key); + + protected: + virtual void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) = 0; + + private: + void PerformPatternMatcher(SSAGraph* graph); + + // Delete nodes that are marked as Intermediate + void DeleteInterNodes(SSAGraph* graph); + + PMNode* GetOrCreateNode(const std::string& key); + + protected: + PatternMatcher matcher_; + std::map nodes_; + std::vector key2nodes_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc b/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc new file mode 100644 index 00000000000000..beee4d32acb733 --- /dev/null +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc @@ -0,0 +1,151 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" +#include +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" +#include "paddle/fluid/lite/core/program.h" + +namespace paddle { +namespace lite { +namespace mir { + +// An demo. +class FcFuser : public FuseBase { + public: + void BuildPattern() override { + // create nodes. + auto* x = VarNode("x")->assert_is_op_input("mul", "X"); + auto* W = VarNode("W")->assert_is_op_input("mul", "Y"); + auto* b = VarNode("b"); + auto* mul = OpNode("mul", "mul"); + auto* mul_out = VarNode("mul_out"); + auto* add = OpNode("add", "elementwise_add"); + auto* Out = VarNode("Out"); + + // create topology. + std::vector mul_inputs{W, x}; + std::vector add_inputs{mul_out, b}; + mul_inputs >> *mul >> *mul_out; + add_inputs >> *add >> *Out; + + // Some op specialities. + mul_out->AsIntermediate(); + mul->AsIntermediate(); + add->AsIntermediate(); + } + + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto op_desc = GenOpDesc(matched); + auto fc_op = LiteOpRegistry::Global().Create("fc"); + auto mul = matched.at("mul")->stmt()->op; + auto* scope = mul->scope(); + auto& valid_places = mul->valid_places(); + fc_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places); + + IR_NODE_LINK_TO(matched.at("W"), new_op_node); + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(matched.at("b"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("Out")); + } + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override { + cpp::OpDesc op_desc; + op_desc.SetType("fc"); + op_desc.SetInput("Input", {matched.at("x")->arg()->name}); + op_desc.SetInput("W", {matched.at("W")->arg()->name}); + op_desc.SetInput("Bias", {matched.at("b")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("Out")->arg()->name}); + op_desc.SetAttr("in_num_col_dims", 1); + return op_desc; + } +}; + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + auto* main_block = program_desc->MutableBlock(0); + auto* mul = main_block->AppendOp(); + auto* add = main_block->AppendOp(); + main_block->Var("x"); + main_block->Var("b"); + main_block->Var("mul_out"); + main_block->Var("w"); + main_block->Var("out"); + + scope->Var("w")->GetMutable(); + scope->Var("b")->GetMutable(); + scope->Var("mul_out")->GetMutable(); + scope->Var("w")->GetMutable(); + scope->Var("out")->GetMutable(); + + mul->SetInput("X", {"x"}); + mul->SetInput("Y", {"w"}); + mul->SetOutput("Out", {"mul_out"}); + mul->SetType("mul"); + mul->SetAttr("x_num_col_dims", 1); + mul->SetAttr("y_num_col_dims", 1); + + add->SetInput("X", {"mul_out"}); + add->SetInput("Y", {"b"}); + add->SetOutput("Out", {"out"}); + add->SetType("elementwise_add"); + add->SetAttr("axis", 1); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(pattern_matcher_high_api, graph_test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + + ASSERT_EQ(graph->nodes().size(), + 7UL /*real nodes*/ + 2UL /*feed op + fetch op*/); + Visualize(graph.get()); +} + +TEST(pattern_matcher_high_api, fuse_test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + const int num_nodes = graph->nodes().size(); + FcFuser fuser; + fuser(graph.get()); + ASSERT_EQ(graph->nodes().size(), + num_nodes - 3UL /*nodes removed */ + 1UL /* fused fc node*/); + Visualize(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(fc); +USE_LITE_OP(mul); +USE_LITE_OP(elementwise_add); diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_test.cc b/paddle/fluid/lite/core/mir/pattern_matcher_test.cc new file mode 100644 index 00000000000000..3b082060fe2173 --- /dev/null +++ b/paddle/fluid/lite/core/mir/pattern_matcher_test.cc @@ -0,0 +1,233 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pattern_matcher.h" + +#include + +namespace paddle { +namespace lite { +namespace mir { + +void BuildGraph(SSAGraph* g) { + g->mutable_nodes().emplace_back(); + Node& o1 = g->mutable_nodes().back(); + o1.AsStmt().op_type = "op1"; + g->mutable_nodes().emplace_back(); + Node& o2 = g->mutable_nodes().back(); + o2.AsStmt().op_type = "op2"; + g->mutable_nodes().emplace_back(); + Node& o3 = g->mutable_nodes().back(); + o3.AsStmt().op_type = "op3"; + g->mutable_nodes().emplace_back(); + Node& o4 = g->mutable_nodes().back(); + o4.AsStmt().op_type = "op4"; + g->mutable_nodes().emplace_back(); + Node& o5 = g->mutable_nodes().back(); + o5.AsStmt().op_type = "op5"; + g->mutable_nodes().emplace_back(); + Node& v1 = g->mutable_nodes().back(); + v1.AsArg("var1"); + g->mutable_nodes().emplace_back(); + Node& v2 = g->mutable_nodes().back(); + v2.AsArg("var2"); + g->mutable_nodes().emplace_back(); + Node& v3 = g->mutable_nodes().back(); + v3.AsArg("var3"); + g->mutable_nodes().emplace_back(); + Node& v4 = g->mutable_nodes().back(); + v4.AsArg("var4"); + + // o1->v1->o2 + o1.outlinks.push_back(&v1); + o2.inlinks.push_back(&v1); + v1.inlinks.push_back(&o1); + v1.outlinks.push_back(&o2); + // o2->v2->o3 + // o2->v2->o4 + o2.outlinks.push_back(&v2); + o3.inlinks.push_back(&v2); + o4.inlinks.push_back(&v2); + v2.inlinks.push_back(&o2); + v2.outlinks.push_back(&o3); + v2.outlinks.push_back(&o4); + // o2->v3->o5 + o2.outlinks.push_back(&v3); + o5.inlinks.push_back(&v3); + v3.inlinks.push_back(&o2); + v3.outlinks.push_back(&o5); + // o3-v4->o5 + o3.outlinks.push_back(&v4); + o5.inlinks.push_back(&v4); + v4.inlinks.push_back(&o3); + v4.outlinks.push_back(&o5); +} + +TEST(PMPattern, NewNode) { + PMPattern x; + auto* n = x.NewNode([](const Node* x) { return true; }); + ASSERT_TRUE(n); + ASSERT_EQ(x.nodes_.size(), 1UL); +} + +TEST(PMPattern, AddEdge) { + PMPattern x; + auto* a = x.NewNode([](const Node* x) { return true; }); + auto* b = x.NewNode([](const Node* x) { return true; }); + ASSERT_TRUE(a); + ASSERT_TRUE(b); + x.AddEdge(a, b); + ASSERT_EQ(x.nodes_.size(), 2UL); + ASSERT_EQ(x.edges_.size(), 1UL); + ASSERT_EQ(x.edges_.front().first, a); + ASSERT_EQ(x.edges_.front().second, b); + + ASSERT_EQ(x.nodes().size(), 2UL); + ASSERT_EQ(x.edges().size(), 1UL); + ASSERT_EQ(x.edges().front().first, a); + ASSERT_EQ(x.edges().front().second, b); +} + +TEST(PatternMatcher, MarkPMNodesInGraph) { + PatternMatcher x; + // mark o2, o3, v2 + + // The pattern is a graph: + // o2(a node named o2) -> v2(a node named v2) + // v2 -> o3(a node named o3) + auto* o2 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsStmt() && node->stmt()->op_type == "op2"; + }); + auto* o3 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsStmt() && node->stmt()->op_type == "op3"; + }); + auto* v2 = x.pattern_.NewNode([](const Node* node) { + // The teller can be any condition, such as op type, or variable's shape. + return node && node->IsArg() && node->arg()->name == "var2"; + }); + + ASSERT_FALSE(o2->Tell(nullptr)); + ASSERT_FALSE(o3->Tell(nullptr)); + ASSERT_FALSE(v2->Tell(nullptr)); + + x.pattern_.AddEdge(o2, v2); + x.pattern_.AddEdge(v2, o3); + + ASSERT_EQ(x.pattern_.edges().size(), 2UL); + ASSERT_EQ(x.pattern_.edges()[0].first, o2); + ASSERT_EQ(x.pattern_.edges()[0].second, v2); + ASSERT_EQ(x.pattern_.edges()[1].first, v2); + ASSERT_EQ(x.pattern_.edges()[1].second, o3); + + SSAGraph graph; + BuildGraph(&graph); + + x.MarkPMNodesInGraph(&graph); + + ASSERT_EQ(x.pmnodes2nodes_.size(), 3UL); + + auto subgraphs = x.DetectPatterns(); + ASSERT_EQ(subgraphs.size(), 1UL); +} + +TEST(PatternMatcher, MultiSubgraph) { + SSAGraph graph; + BuildGraph(&graph); + + PatternMatcher x; + + // The pattern is a graph: + // op -> var + auto* any_op = x.mutable_pattern()->NewNode( + [](const Node* node) { + return node->IsStmt() && (node->stmt()->op_type == "op2" || + node->stmt()->op_type == "op3"); + }, + "OP0"); + auto* any_var = + x.mutable_pattern() + ->NewNode([](const Node* node) { return node->IsArg(); }, "VAR") + ->AsIntermediate(); + auto* any_op1 = x.mutable_pattern()->NewNode( + [](const Node* node) { return node->IsStmt(); }, "OP1"); + + x.mutable_pattern()->AddEdge(any_op, any_var); + x.mutable_pattern()->AddEdge(any_var, any_op1); + + int count = 0; + PatternMatcher::handle_t handle = [&](const PatternMatcher::subgraph_t& s, + SSAGraph* g) { + LOG(INFO) << "Detect " << s.at(any_op)->stmt()->op_type << " -> " + << s.at(any_var)->arg()->name << " -> " + << s.at(any_op1)->stmt()->op_type; + count++; + }; + + x(&graph, handle); + + // 1. Detect op3 -> var4 -> op5 + // 2. Detect op2 -> var2 -> op3 + // 3. Detect op2 -> var2 -> op4 + // 4. Detect op2 -> var3 -> op5 + // But 2 and 3 and 4 overlapped, so keep 2, so the final choices are 1 and 2 + ASSERT_GE(count, 1); + ASSERT_LE(count, 2); +} + +TEST(PatternMatcher, IntermediateCheck) { + SSAGraph graph; + BuildGraph(&graph); + + // o2->v2->o3 + // o2->v2->o4 + // check o2+o3 fuse, should fail because v2 also link to o4. + PatternMatcher matcher; + auto* op2 = matcher.mutable_pattern()->NewNode( + [](const Node* x) { + return x && x->IsStmt() && x->stmt()->op_type == "op2"; + }, + "op2"); + auto* op3 = matcher.mutable_pattern()->NewNode( + [](const Node* x) { + return x && x->IsStmt() && x->stmt()->op_type == "op3"; + }, + "op3"); + auto* v2 = matcher.mutable_pattern() + ->NewNode( + [](const Node* x) { + return x && x->IsArg() && x->arg()->name == "var2"; + }, + "var2") + ->AsIntermediate(); + v2->LinksFrom({op2}).LinksTo({op3}); + + int count = 0; + matcher(&graph, [&](const PatternMatcher::subgraph_t& g, SSAGraph* graph) { + ++count; + }); + EXPECT_EQ(count, 0); + + count = 0; + v2->AsInput(); + matcher(&graph, [&](const PatternMatcher::subgraph_t& g, SSAGraph* graph) { + ++count; + }); + ASSERT_EQ(count, 1); +} + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc b/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc index 3d2012306b9fd4..257766945af20c 100644 --- a/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc +++ b/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc @@ -21,63 +21,16 @@ namespace mir { class RuntimeContextAssignPass : public StmtPass { public: - RuntimeContextAssignPass() { -#ifdef LITE_WITH_CUDA - InitCudaBlas(); -#endif - } + RuntimeContextAssignPass() {} void Apply(const std::unique_ptr& graph) override { for (auto& node : graph->mutable_nodes()) { if (!node.IsStmt()) continue; - auto& inst = node.AsStmt(); - - switch (inst.picked_kernel().target()) { - case TARGET(kHost): - case TARGET(kX86): - inst.picked_kernel().SetContext(NewHostContext()); - break; -#ifdef LITE_WITH_CUDA - case TARGET(kCUDA): - inst.picked_kernel().SetContext(NewCudaContext()); - break; -#endif - default: - LOG(FATAL) << "unsupported target " - << TargetToStr(inst.picked_kernel().target()); - } + inst.picked_kernel().SetContext( + ContextScheduler::Global().NewContext(inst.picked_kernel().target())); } } - - std::unique_ptr NewHostContext() { - std::unique_ptr ctx(new KernelContext); - ctx->As(); - // Some initialization here. - return ctx; - } - -#ifdef LITE_WITH_CUDA - std::unique_ptr NewCudaContext() { - std::unique_ptr ctx(new KernelContext); - auto& cuda = ctx->As(); - // Some initialization here. - CHECK(cublas_fp32_) << "cublas_fp32 should be set first"; - cuda.blas_fp32 = cublas_fp32_; - return ctx; - } -#endif - -#ifdef LITE_WITH_CUDA - void InitCudaBlas() { - cublas_fp32_ = std::make_shared>(); - } -#endif - - private: -#ifdef LITE_WITH_CUDA - std::shared_ptr> cublas_fp32_; -#endif }; } // namespace mir diff --git a/paddle/fluid/lite/core/mir/ssa_graph.cc b/paddle/fluid/lite/core/mir/ssa_graph.cc index b565b6fd3ecbff..82507067c4726b 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph.cc @@ -94,7 +94,7 @@ std::vector SSAGraph::StmtTopologicalOrder() { } void SSAGraph::GraphCreateTmpVarNodes(const Program &program) { - for (const auto &name : program.tmp_vars) { + for (const auto &name : program.tmp_vars()) { CHECK(!arguments_.count(name)) << "duplicate creating temp variable: " << name; VLOG(5) << "create arg node " << name; @@ -107,7 +107,7 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) { void SSAGraph::GraphCreateWeightVarNodes(const Program &program) { // create weight nodes. - for (const auto &name : program.weights) { + for (const auto &name : program.weights()) { CHECK(!arguments_.count(name)) << "duplicate creating weight variable: " << name; VLOG(5) << "create arg node " << name; @@ -119,8 +119,7 @@ void SSAGraph::GraphCreateWeightVarNodes(const Program &program) { } Node *SSAGraph::GraphCreateInstructNode( - const Program &program, const std::shared_ptr &op, - const std::vector &valid_places) { + const std::shared_ptr &op, const std::vector &valid_places) { node_storage_.emplace_back(); // TODO(Superjomn) remove one valid_places here. op->SetValidPlaces(valid_places); @@ -140,8 +139,8 @@ void SSAGraph::Build(const Program &program, GraphCreateWeightVarNodes(program); CHECK(CheckNodesRoleSet()); - for (auto &op : program.ops) { - auto *op_node = GraphCreateInstructNode(program, op, valid_places); + for (auto &op : program.ops()) { + auto *op_node = GraphCreateInstructNode(op, valid_places); for (const std::string &name : op->op_info()->input_names()) { auto *arg = Argument(name); CHECK(arg->IsRoleSet()); @@ -162,6 +161,13 @@ void SSAGraph::Build(const Program &program, CheckValid(); } +void SSAGraph::RemoveNode(const mir::Node *node) { + auto pos = std::find_if(node_storage_.begin(), node_storage_.end(), + [&node](mir::Node &n) { return &n == node; }); + CHECK(pos != node_storage_.end()); + node_storage_.erase(pos); +} + mir::Node *SSAGraph::Argument(const std::string &name) { auto it = arguments_.find(name); CHECK(it != arguments_.end()) << "no argument called " << name; diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 73a1fd36e9fc6f..5cad1478c225a6 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -38,6 +38,7 @@ class SSAGraph : GraphBase { // @param program: the op program // @param valid_places: the valid places user set for the system. void Build(const Program &program, const std::vector &valid_places); + void RemoveNode(const mir::Node *node); mir::Node *Argument(const std::string &name); @@ -63,12 +64,12 @@ class SSAGraph : GraphBase { CHECK(CheckLinksRoleSet()); } + Node *GraphCreateInstructNode(const std::shared_ptr &op, + const std::vector &valid_places); + private: void GraphCreateTmpVarNodes(const Program &program); void GraphCreateWeightVarNodes(const Program &program); - Node *GraphCreateInstructNode(const Program &program, - const std::shared_ptr &op, - const std::vector &valid_places); // Check the bidirectional connection. bool CheckBidirectionalConnection(); @@ -77,7 +78,7 @@ class SSAGraph : GraphBase { bool CheckLinksRoleSet(); void MarkArgumentWeights(const Program &program) { - for (const auto &name : program.weights) { + for (const auto &name : program.weights()) { arguments_[name]->AsArg().is_weight = true; } } diff --git a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc index a3599664438ad6..9d48c123a0c8e3 100644 --- a/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc +++ b/paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc @@ -37,6 +37,8 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { if (!node.IsStmt()) continue; auto& instruct = node.AsStmt(); std::vector>> scored; + CHECK(!instruct.valid_kernels.empty()) << "No kernels found for " + << instruct.op_type; for (auto&& kernel : instruct.valid_kernels) { size_t score = KernelGrade(*kernel); scored.emplace_back(score, std::move(kernel)); diff --git a/paddle/fluid/lite/core/mir/type_target_transform_pass.cc b/paddle/fluid/lite/core/mir/type_target_transform_pass.cc index ddd07970166282..25789d34dca2fa 100644 --- a/paddle/fluid/lite/core/mir/type_target_transform_pass.cc +++ b/paddle/fluid/lite/core/mir/type_target_transform_pass.cc @@ -90,7 +90,7 @@ void TypeTargetTransformPass::AddIoCopyInst( inst_node->AsStmt().op->scope()->Var(io_copy_output_name); // Create IoCopy Instruction. - lite::OpDesc op_desc; + cpp::OpDesc op_desc; op_desc.SetType("io_copy"); op_desc.SetInput("Input", {var}); op_desc.SetOutput("Out", {io_copy_output_name}); @@ -104,8 +104,6 @@ void TypeTargetTransformPass::AddIoCopyInst( // Update the original instruction OpDesc. // Update its input to the io_copy_output_name - auto& inst = inst_node->AsStmt(); - auto inst_program_desc = inst.op_info()->desc(); // Add new link, var -> new_inst, new_inst->newarg, newarg->inst DirectedLink(graph->Argument(var), io_copy_inst); @@ -113,11 +111,11 @@ void TypeTargetTransformPass::AddIoCopyInst( DirectedLink(io_copy_output_arg, inst_node); // reset opdesc and update kernel information - auto desc_dummy = inst_node->AsStmt().op->op_info()->desc(); - UpdateInputTo(&desc_dummy, var, io_copy_output_name); + UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), var, + io_copy_output_name); - lite::OpDesc desc_fake(desc_dummy); - inst_node->AsStmt().op->Attach(desc_fake, inst_node->AsStmt().op->scope()); + inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(), + inst_node->AsStmt().op->scope()); std::string tmp; if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) { diff --git a/paddle/fluid/lite/core/mir/type_target_transform_pass.h b/paddle/fluid/lite/core/mir/type_target_transform_pass.h index f8557f44e3c975..838c0bcdabc927 100644 --- a/paddle/fluid/lite/core/mir/type_target_transform_pass.h +++ b/paddle/fluid/lite/core/mir/type_target_transform_pass.h @@ -24,10 +24,10 @@ namespace paddle { namespace lite { namespace mir { -static void UpdateInputTo(framework::proto::OpDesc* desc, - const std::string& from, const std::string& to) { +static void UpdateInputTo(cpp::OpDesc* desc, const std::string& from, + const std::string& to) { for (auto& item : *desc->mutable_inputs()) { - for (auto& input : *item.mutable_arguments()) { + for (auto& input : item.second) { if (input == from) { input = to; } diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h index 79c977b331f85f..4d555d638a91e1 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h @@ -65,7 +65,7 @@ class VariablePlaceInferencePass : public DebugPass { // check if inputs's place is set, if not set, update them with the // kernel's declaration. auto type = inst.picked_kernel().GetInputDeclType(arg_name); - auto arg_names = inst.op_info()->input_argument().at(arg_name); + auto arg_names = inst.op_info()->inputs().at(arg_name); for (auto& arg_name : arg_names) { VLOG(3) << "--- var " << arg_name; @@ -82,7 +82,7 @@ class VariablePlaceInferencePass : public DebugPass { for (auto& arg_name : inst.op_info()->output_argnames()) { VLOG(3) << "-- output arg_name " << arg_name; auto type = inst.picked_kernel().GetOutputDeclType(arg_name); - auto arg_names = inst.op_info()->output_argument().at(arg_name); + auto arg_names = inst.op_info()->outputs().at(arg_name); // check if outputs's place is set, if not set, update them with the // kernel's declaration. for (auto& arg_name : arg_names) { diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc b/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc index 9c33ff698acda2..d6b8561c378cb2 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc @@ -42,6 +42,12 @@ TEST(variable_place_inference_pass, test) { Place{ TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW), }, + Place{ + TARGET(kX86), PRECISION(kFloat), DATALAYOUT(kNCHW), + }, + Place{ + TARGET(kX86), PRECISION(kAny), DATALAYOUT(kAny), + }, }); Program program(*desc->Proto(), scope, places); @@ -58,7 +64,15 @@ TEST(variable_place_inference_pass, test) { }); Place prefered_place{ +#ifdef PADDLE_WITH_CUDA TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW), +#else +#ifdef PADDLE_WITH_ARM + TARGET(kARM), PRECISION(kFloat), DATALAYOUT(kNCHW), +#else // X86 + TARGET(kX86), PRECISION(kFloat), DATALAYOUT(kNCHW), +#endif // ARM +#endif }; optimizer.KernelPickPreferPlace(prefered_place); optimizer.Run(std::move(program), places, factor, passes); @@ -72,3 +86,16 @@ USE_LITE_OP(mul); USE_LITE_OP(feed); USE_LITE_OP(fetch); USE_LITE_OP(io_copy); + +#ifdef LITE_WITH_X86 +USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def); +#endif + +#ifdef LITE_WITH_ARM +USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); +#endif + +#ifdef LITE_WITH_CUDA +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); +#endif diff --git a/paddle/fluid/lite/core/naive_test_model.py b/paddle/fluid/lite/core/naive_test_model.py index f4bbdefceca143..832661e5ee86f2 100644 --- a/paddle/fluid/lite/core/naive_test_model.py +++ b/paddle/fluid/lite/core/naive_test_model.py @@ -1,3 +1,17 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy import sys, os import numpy as np @@ -26,8 +40,6 @@ #fluid.default_main_program().desc. - - #prog = fluid.compiler.CompiledProgram(fluid.default_main_program()) prog = fluid.default_main_program() @@ -36,11 +48,9 @@ with open('main_program.pb', 'wb') as f: f.write(prog.desc.serialize_to_string()) - #outs = exe.run(program=prog, feed={'a':data_1, }, fetch_list=[cost]) sys.exit(0) fluid.io.save_inference_model("./model2", [a.name], [a1], exe) print(numpy.array(outs)) - diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index e54053026d9070..484d22abf52dda 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -28,15 +28,23 @@ std::vector> OpLite::CreateKernels( CHECK(!op_type_.empty()) << "op_type_ should be set first"; auto pick_kernel = [&](const Place &place) { - auto ks = KernelRegistry::Global().Create( - (kernel_type.empty() ? op_type_ : kernel_type), place.target, - place.precision, place.layout); + auto ks = KernelRegistry::Global().Create(op_type_, place.target, + place.precision, place.layout); for (auto &&it : ks) { AttachKernel(it.get()); kernels.emplace_back(std::move(it)); } }; + if (!kernel_type.empty()) { + Place place; + std::string op_type, alias; + KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); + pick_kernel(place); + CHECK(!kernels.empty()) << "no kernel for kernel type " << kernel_type; + return kernels; + } + std::set place_set; for (auto place : places) { place_set.insert(place); @@ -53,7 +61,7 @@ std::vector> OpLite::CreateKernels( targets.insert(place.target); } - CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_; + // CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_; VLOG(2) << "op " << op_type_ << " get " << kernels.size() << " kernels"; return kernels; } @@ -62,19 +70,19 @@ bool OpLite::Run() { CHECK(kernel_); SyncInputEvents(); - kernel_->Run(); + kernel_->Launch(); RecordOutputEvents(); return true; } -bool OpLite::Attach(const OpDesc &opdesc, lite::Scope *scope) { +bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) { // valid_places_.clear(); CHECK(scope != nullptr); - //CHECK(!op_info_.get()); + // CHECK(!op_info_.get()); scope_ = scope; - op_info_.reset(new OpInfo); // Force clean the out-of-date infomation. - op_info_->Build(opdesc.ReadonlyProto()); + op_info_.reset( + new OpInfo(opdesc)); // Force clean the out-of-date infomation. return AttachImpl(opdesc, scope); } @@ -92,94 +100,5 @@ Tensor *OpLite::GetMutableTensor(lite::Scope *scope, return var->GetMutable(); } -bool OpInfo::GetInputArgname(const std::string &value_name, - std::string *out) const { - for (auto &item : input_argument_) { - auto it = std::find(item.second.begin(), item.second.end(), value_name); - if (it != item.second.end()) { - *out = item.first; - return true; - } - } - return false; -} -bool OpInfo::GetOutputArgname(const std::string &value_name, - std::string *out) const { - for (auto &item : output_argument_) { - auto it = std::find(item.second.begin(), item.second.end(), value_name); - if (it != item.second.end()) { - *out = item.first; - return true; - } - } - return false; -} - -void OpInfo::ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc) { - for (const auto &item : opdesc.inputs()) { - for (const auto &x : item.arguments()) { - input_names_.push_back(x); - } - } - for (const auto &item : opdesc.outputs()) { - for (const auto &x : item.arguments()) { - output_names_.push_back(x); - } - } -} - -void OpInfo::CollectInputAndOutputArgnames( - const framework::proto::OpDesc &opdesc) { - for (const auto &item : opdesc.inputs()) { - input_argnames_.push_back(item.parameter()); - } - for (const auto &item : opdesc.outputs()) { - output_argnames_.push_back(item.parameter()); - } -} - -void OpInfo::CollectArguments(const framework::proto::OpDesc &opdesc) { - for (const auto &item : opdesc.inputs()) { - for (auto &x : item.arguments()) { - input_argument_[item.parameter()].push_back(x); - } - } - for (const auto &item : opdesc.outputs()) { - for (auto &x : item.arguments()) { - output_argument_[item.parameter()].push_back(x); - } - } -} - -void OpInfo::Build(const framework::proto::OpDesc &desc) { - ExtractInputsAndOutputs(desc); - CollectInputAndOutputArgnames(desc); - CollectArguments(desc); - desc_.reset(new framework::proto::OpDesc(desc)); -} - -const std::map> &OpInfo::input_argument() - const { - return input_argument_; -} - -const std::map> &OpInfo::output_argument() - const { - return output_argument_; -} - -const std::list &OpInfo::input_argnames() const { - return input_argnames_; -} - -const std::list &OpInfo::output_argnames() const { - return output_argnames_; -} - -const framework::proto::OpDesc &OpInfo::desc() const { - CHECK(desc_) << "desc has't set"; - return *desc_; -} - } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index 2f878905ca326a..922aa2304e43a9 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -23,7 +23,7 @@ #include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/scope.h" -#include "paddle/fluid/lite/model_parser/compatible_pb.h" +#include "paddle/fluid/lite/model_parser/cpp/op_desc.h" namespace paddle { namespace lite { @@ -71,7 +71,7 @@ class OpLite : public Registry { virtual bool Run(); // Link the external execution environ to internal context. - bool Attach(const OpDesc &opdesc, lite::Scope *scope); + bool Attach(const cpp::OpDesc &opdesc, lite::Scope *scope); const OpInfo *op_info() const { return op_info_.get(); } OpInfo *mutable_op_info() { return op_info_.get(); } @@ -94,7 +94,7 @@ class OpLite : public Registry { protected: // Attach it with the runtime environment. - virtual bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) = 0; + virtual bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) = 0; // Specify the kernel to run by default. This will specify the value of // `kernel_place_`. @@ -131,7 +131,6 @@ class OpLite : public Registry { return var->GetMutable(); } - protected: lite::Scope *scope_{}; std::unique_ptr kernel_; @@ -145,40 +144,61 @@ class OpLite : public Registry { * Operator Information, such as some description. It will be shared by all the * kernels of the same operator. */ -class OpInfo { +class OpInfo : public cpp::OpDesc { public: - // To avoid the bugs from legancy framework::OpDesc, we use the ProtoBuf - // message instead. - void Build(const framework::proto::OpDesc &desc); - - const framework::proto::OpDesc &desc() const; - framework::proto::OpDesc *mutable_desc() { return desc_.get(); } - const std::list &input_names() const { return input_names_; } - const std::list &output_names() const { return output_names_; } - const std::map> &input_argument() const; - const std::map> &output_argument() const; - bool GetInputArgname(const std::string &value_name, std::string *out) const; - bool GetOutputArgname(const std::string &value_name, std::string *out) const; - - const std::list &input_argnames() const; - const std::list &output_argnames() const; - - private: - void ExtractInputsAndOutputs(const framework::proto::OpDesc &opdesc); - - void CollectInputAndOutputArgnames(const framework::proto::OpDesc &opdesc); - - void CollectArguments(const framework::proto::OpDesc &opdesc); - - private: - std::list input_names_; - std::list output_names_; - std::list input_argnames_; - std::list output_argnames_; - std::map> input_argument_; - std::map> output_argument_; - // NOTE too heavy. - std::unique_ptr desc_; + OpInfo(const OpInfo &) = default; + explicit OpInfo(const cpp::OpDesc &other) : cpp::OpDesc(other) {} + + // Collect all the input variable's name. + std::vector input_names() const { + std::vector res; + for (auto ¶m : InputArgumentNames()) { + for (auto &x : Input(param)) { + res.push_back(x); + } + } + return res; + } + + // Collect all the output variable's name. + std::vector output_names() const { + std::vector res; + for (auto ¶m : OutputArgumentNames()) { + for (auto &x : Output(param)) { + res.push_back(x); + } + } + return res; + } + + std::vector input_argnames() const { + return InputArgumentNames(); + } + + std::vector output_argnames() const { + return OutputArgumentNames(); + } + + bool GetInputArgname(const std::string &value_name, std::string *out) const { + for (auto &item : inputs_) { + auto it = std::find(item.second.begin(), item.second.end(), value_name); + if (it != item.second.end()) { + *out = item.first; + return true; + } + } + return false; + } + bool GetOutputArgname(const std::string &value_name, std::string *out) const { + for (auto &item : outputs_) { + auto it = std::find(item.second.begin(), item.second.end(), value_name); + if (it != item.second.end()) { + *out = item.first; + return true; + } + } + return false; + } }; } // namespace lite diff --git a/paddle/fluid/lite/core/op_registry.cc b/paddle/fluid/lite/core/op_registry.cc index 681cbdafcdeee3..8c3e44733df0f8 100644 --- a/paddle/fluid/lite/core/op_registry.cc +++ b/paddle/fluid/lite/core/op_registry.cc @@ -59,6 +59,9 @@ std::list> KernelRegistry::Create( case TARGET(kCUDA): { CREATE_KERNEL(kCUDA); } break; + case TARGET(kARM): { + CREATE_KERNEL(kARM); + } break; default: CHECK(false) << "not supported kernel target " << TargetToStr(target); } @@ -67,7 +70,10 @@ std::list> KernelRegistry::Create( return std::list>(); } -KernelRegistry::KernelRegistry() { +KernelRegistry::KernelRegistry() + : registries_(static_cast(TARGET(NUM)) * + static_cast(PRECISION(NUM)) * + static_cast(DATALAYOUT(NUM))) { #define INIT_FOR(target__, precision__, layout__) \ registries_[KernelRegistry::GetKernelOffset #include #include +#include #include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/target_wrapper.h" @@ -75,7 +76,11 @@ class KernelRegistry final { KernelRegistryForTarget *, // KernelRegistryForTarget * // + DATALAYOUT(kAny)> *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget * // >; KernelRegistry(); @@ -86,14 +91,15 @@ class KernelRegistry final { void Register(const std::string &name, typename KernelRegistryForTarget::creator_t &&creator) { - VLOG(3) << "register for " << TargetToStr(Target) << ":" - << PrecisionToStr(Precision) << "//" - << GetKernelOffset(); + // VLOG(3) << "register for " << TargetToStr(Target) << ":" + //<< PrecisionToStr(Precision) << "//" + //<< GetKernelOffset(); using kernel_registor_t = KernelRegistryForTarget; auto &varient = registries_[GetKernelOffset()]; - varient.template get()->Register(name, - std::move(creator)); + auto *reg = varient.template get(); + CHECK(reg) << "Can not be empty of " << name; + reg->Register(name, std::move(creator)); } template :" << std::endl; - ss << registries_[GetKernelOffset()] - .get *>() - ->DebugString(); - ss << std::endl; + constexpr TargetType tgt = TARGET(kHost); + constexpr PrecisionType dt = PRECISION(kFloat); + constexpr DataLayoutType lt = DATALAYOUT(kNCHW); + constexpr DataLayoutType kany = DATALAYOUT(kAny); + using kernel_registor_t = KernelRegistryForTarget; + auto *reg = registries_[GetKernelOffset()] + .template get(); + ss << reg->DebugString() << std::endl; return ss.str(); } private: - mutable std::array(TARGET(NUM)) * - static_cast(PRECISION(NUM)) * - static_cast(DATALAYOUT(NUM))> - registries_; + mutable std::vector registries_; }; template { public: KernelRegistor(const std::string &op_type, const std::string &alias) : Registor([=] { - VLOG(3) << "Register kernel " << op_type << " for " - << TargetToStr(target) << " " << PrecisionToStr(precision) - << " " << DataLayoutToStr(layout) << " alias " << alias; KernelRegistry::Global().Register( op_type, [=]() -> std::unique_ptr { std::unique_ptr x(new KernelType); diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index f585224a07f6ca..b78408a6740145 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -46,8 +46,10 @@ class Optimizer { SpecifyKernelPickTactic(kernel_pick_factor); InitTargetTypeTransformPass(); +#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK if (passes.empty()) { RunPasses(std::vector{{ + "lite_fc_fuse_pass", // "static_kernel_pick_pass", // "variable_place_inference_pass", // "argument_type_display_pass", // @@ -62,7 +64,8 @@ class Optimizer { } else { RunPasses(passes); } - exec_scope_ = program.exec_scope; +#endif + exec_scope_ = program.exec_scope(); } void KernelPickPreferPlace(const Place& place) { diff --git a/paddle/fluid/lite/core/profile/CMakeLists.txt b/paddle/fluid/lite/core/profile/CMakeLists.txt new file mode 100644 index 00000000000000..43731e8a414cff --- /dev/null +++ b/paddle/fluid/lite/core/profile/CMakeLists.txt @@ -0,0 +1,6 @@ +if (NOT LITE_WITH_PROFILE) + return() +endif() + +lite_cc_library(basic_profiler_lite SRCS basic_profiler.cc) +lite_cc_test(test_basic_profiler SRCS basic_profiler_test.cc DEPS basic_profiler_lite) diff --git a/paddle/fluid/lite/core/profile/basic_profiler.cc b/paddle/fluid/lite/core/profile/basic_profiler.cc new file mode 100644 index 00000000000000..86d5cd39ea99a3 --- /dev/null +++ b/paddle/fluid/lite/core/profile/basic_profiler.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/profile/basic_profiler.h" + +namespace paddle { +namespace lite { +namespace profile { + +const int BasicTimer::data_w = 10; +const int BasicTimer::name_w = 10; + +} // namespace profile +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/profile/basic_profiler.h b/paddle/fluid/lite/core/profile/basic_profiler.h new file mode 100644 index 00000000000000..16a9905f1ae6d4 --- /dev/null +++ b/paddle/fluid/lite/core/profile/basic_profiler.h @@ -0,0 +1,201 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * This file implements BasicProfile, a profiler that helps to profile the basic + * CPU execution. It can display the min, max, average lantency of the execution + * of each kernel. + */ +#pragma once +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace profile { + +/* Base class of all the profile records */ +template +class TimerBase { + public: + void Start() { self()->Start(); } + void Stop() { self()->Stop(); } + void Log(uint32_t x) { return self()->Log(x); } + std::string basic_repr() const { return const_self()->basic_repr(); } + + void SetId(int id) { self()->SetId(id); } + void SetKey(const std::string &key) { self()->SetKey(key); } + + int id() const { return const_self()->id(); } + + protected: + ChildT *self() { return reinterpret_cast(this); } + const ChildT *const_self() const { + return reinterpret_cast(this); + } +}; + +class BasicTimer : TimerBase { + uint64_t total_{}; + uint64_t count_{}; + uint32_t max_{std::numeric_limits::min()}; + uint32_t min_{std::numeric_limits::max()}; + int id_{-1}; + std::string key_; + std::chrono::time_point timer_{}; + + // TODO(Superjomn) make static + static const int name_w; + static const int data_w; + + public: + BasicTimer() = default; + BasicTimer(int id, const std::string &key) : id_(id), key_(key) {} + + void SetId(int id) { id_ = id; } + void SetKey(const std::string &key) { key_ = key; } + void Start() { timer_ = std::chrono::high_resolution_clock::now(); } + void Stop() { + auto duration = std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - timer_); + Log(duration.count()); + } + + int count() const { return count_; } + + void Log(uint32_t timespan) { + total_ += timespan; + max_ = std::max(max_, timespan); + min_ = std::min(min_, timespan); + count_++; + } + + static std::string basic_repr_header() { + std::stringstream ss; + ss << std::setw(name_w) << "kernel" // + << std::setw(data_w) << "average" // + << std::setw(data_w) << "min" // + << std::setw(data_w) << "max" // + << std::setw(data_w) << "count"; + return ss.str(); + } + + std::string basic_repr() const { + std::stringstream ss; + ss << std::setw(name_w) << key() // + << std::setw(data_w) << ave() // + << std::setw(data_w) << min() // + << std::setw(data_w) << max() // + << std::setw(data_w) << count_; + return ss.str(); + } + + const std::string &key() const { return key_; } + + int id() const { + CHECK_GE(id_, 0) << "id is not inited"; + return id_; + } + + double ave() const { return total_ * 1. / count_; } + double max() const { return max_; } + double min() const { return min_; } + + // BasicRecord(const BasicRecord &) = delete; + void operator=(const BasicTimer &) = delete; +}; + +/* + * A basic profiler, with each record logs the total latency. + */ +template +class BasicProfiler { + public: + explicit BasicProfiler(const std::string &name) : name_(name) {} + using record_t = TimerT; + + static BasicProfiler &Global() { + static std::unique_ptr x(new BasicProfiler("[global]")); + return *x; + } + + record_t &NewRcd(const std::string &key) { + records_.emplace_back(); + records_.back().SetId(records_.size() - 1); + records_.back().SetKey(key); + return records_.back(); + } + + const record_t &record(int id) { + CHECK_LT(id, records_.size()); + CHECK_GE(id, 0); + return records_[id]; + } + + record_t *mutable_record(int id) { + CHECK_GE(id, 0); + CHECK_LT(static_cast(id), records_.size()); + return &records_[id]; + } + + std::string basic_repr() const { + std::stringstream ss; + for (const auto &rcd : records_) { + ss << rcd.basic_repr() << "\n"; + } + return ss.str(); + } + + ~BasicProfiler() { + LOG(INFO) << "Profile dumps:"; + LOG(INFO) << "\n" + BasicTimer::basic_repr_header() + "\n" + basic_repr(); + } + + private: + std::string name_; + std::vector records_; +}; + +struct ProfileBlock { + explicit ProfileBlock(int id) : id_(id) { + BasicProfiler::Global().mutable_record(id_)->Start(); + } + + ~ProfileBlock() { + BasicProfiler::Global().mutable_record(id_)->Stop(); + } + + private: + int id_{}; +}; + +#define LITE_PROFILE_ONE(key__) \ + static int key__##__profiler_id = \ + ::paddle::lite::profile::BasicProfiler< \ + ::paddle::lite::profile::BasicTimer>::Global() \ + .NewRcd(#key__) \ + .id(); \ + ::paddle::lite::profile::ProfileBlock key__##profiler__(key__##__profiler_id); + +} // namespace profile +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/profile/basic_profiler_test.cc b/paddle/fluid/lite/core/profile/basic_profiler_test.cc new file mode 100644 index 00000000000000..0154e02ff65c58 --- /dev/null +++ b/paddle/fluid/lite/core/profile/basic_profiler_test.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/profile/basic_profiler.h" +#include +#include +#include // NOLINT +#include // NOLINT + +namespace paddle { +namespace lite { +namespace profile { + +TEST(basic_record, init) { + BasicTimer timer; + timer.SetKey("hello"); +} + +TEST(basic_profile, init) { + auto& rcd = BasicProfiler::Global().NewRcd("fc"); + for (int i = 11; i < 100; i++) { + rcd.Log(i); + } + + LOG(INFO) << BasicProfiler::Global().basic_repr(); +} + +TEST(basic_profile, real_latency) { + LITE_PROFILE_ONE(test0); + std::this_thread::sleep_for(std::chrono::milliseconds(1200)); +} + +} // namespace profile +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/program.cc b/paddle/fluid/lite/core/program.cc index 0ec9590d09c966..9f12f4b87d84d0 100644 --- a/paddle/fluid/lite/core/program.cc +++ b/paddle/fluid/lite/core/program.cc @@ -39,11 +39,11 @@ std::string RuntimeProgram::SerializeProgram( auto program_dummy = desc; program_dummy.mutable_blocks(0)->clear_ops(); for (auto &node : instructions_) { - auto desc_dummy = node.op()->op_info()->desc(); - OpDesc desc(desc_dummy); - desc.SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType()); + pb::OpDesc pb_desc; + TransformOpDescCppToPb(*node.op()->op_info(), &pb_desc); + pb_desc.SetAttr(kKernelTypeAttr, node.kernel()->SerializedKernelType()); // append new opdesc - *program_dummy.mutable_blocks(0)->add_ops() = *desc.Proto(); + *program_dummy.mutable_blocks(0)->add_ops() = *pb_desc.Proto(); } return program_dummy.SerializeAsString(); } @@ -62,5 +62,45 @@ void RuntimeProgram::SaveParams(const std::string &dir, } } +void Program::Build(const framework::proto::ProgramDesc &program) { + CHECK(ops_.empty()) << "Executor duplicate Build found"; + // Create operators. + for (const auto &proto_op_desc : program.blocks(0).ops()) { + lite::OpDesc op_desc_dummy(proto_op_desc); + cpp::OpDesc op_desc; + TransformOpDescPbToCpp(op_desc_dummy, &op_desc); + auto op_type = op_desc.Type(); + // if (op_type == "feed" || op_type == "fetch") continue; + VLOG(4) << "create Op [" << op_type << "]"; + LOG(INFO) << "create Op [" << op_type << "]"; + auto op = LiteOpRegistry::Global().Create(op_type); + CHECK(op) << "no Op found for " << op_type; + ops_.emplace_back(std::move(op)); + ops_.back()->Attach(op_desc, exec_scope_); + } +} + +void Program::PrepareWorkspace(const framework::proto::ProgramDesc &program) { + CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found"; + exec_scope_ = &scope_->NewScope(); + // Create Feed and Fetch var. + scope_->Var("feed")->GetMutable>(); + scope_->Var("fetch")->GetMutable>(); + + tmp_vars_.push_back("feed"); + tmp_vars_.push_back("fetch"); + CHECK(!program.blocks().empty()); + for (auto proto_var_desc : program.blocks(0).vars()) { + lite::VarDesc var_desc(proto_var_desc); + if (!var_desc.Persistable()) { + tmp_vars_.push_back(var_desc.Name()); + exec_scope_->Var(var_desc.Name()); + } else { + if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue; + weights_.push_back(var_desc.Name()); + } + } +} + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/program.h b/paddle/fluid/lite/core/program.h index b25b6ae7d16347..4f2f65d3fa714d 100644 --- a/paddle/fluid/lite/core/program.h +++ b/paddle/fluid/lite/core/program.h @@ -22,6 +22,10 @@ #include "paddle/fluid/lite/core/mir/node.h" #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/model_parser/compatible_pb.h" +#ifdef LITE_WITH_PROFILE +#include "paddle/fluid/lite/core/profile/basic_profiler.h" +#endif // LITE_WITH_PROFILE namespace paddle { namespace lite { @@ -33,79 +37,66 @@ static const char kKernelTypeAttr[] = "__@kernel_type_attr@__"; // - main block, which is a list of OpLite // - scope: which contains all the weights struct Program { - std::list tmp_vars; - std::list weights; - std::list> ops; - // the scope to run the kernels, NOTE this is the execution scope. - std::shared_ptr scope; - std::vector valid_places; - // Runtime scope. - lite::Scope* exec_scope{}; - const framework::proto::ProgramDesc desc; - - explicit Program(const std::shared_ptr& root) { scope = root; } + public: + explicit Program(const std::shared_ptr& root) { scope_ = root; } Program(const framework::proto::ProgramDesc& desc, const std::shared_ptr& root, const std::vector& valid_places) - : scope(root), valid_places(valid_places), desc(desc) { - CHECK(scope) << "scope should be init first"; + : scope_(root), valid_places_(valid_places), desc_(desc) { + CHECK(scope_) << "scope should be init first"; PrepareWorkspace(desc); Build(desc); } std::unique_ptr Clone() const { - std::unique_ptr res(new Program(desc, scope, valid_places)); + std::unique_ptr res(new Program(desc_, scope_, valid_places_)); return res; } + const std::list& weights() const { return weights_; } + const std::list& tmp_vars() const { return tmp_vars_; } + std::list* mutable_weights() { return &weights_; } + std::list* mutable_tmp_vars() { return &tmp_vars_; } + + const std::list>& ops() const { return ops_; } + std::list>* mutable_ops() { return &ops_; } + + lite::Scope* exec_scope() { return exec_scope_; } + lite::Scope* scope() { return scope_.get(); } + private: // Build from a program and scope. - void Build(const framework::proto::ProgramDesc& program) { - CHECK(ops.empty()) << "Executor duplicate Build found"; - // Create operators. - for (const auto& proto_op_desc : program.blocks(0).ops()) { - lite::OpDesc op_desc(proto_op_desc); - auto op_type = op_desc.Type(); - // if (op_type == "feed" || op_type == "fetch") continue; - VLOG(4) << "create Op [" << op_type << "]"; - LOG(INFO) << "create Op [" << op_type << "]"; - auto op = LiteOpRegistry::Global().Create(op_type); - CHECK(op) << "no Op found for " << op_type; - ops.emplace_back(std::move(op)); - ops.back()->Attach(op_desc, exec_scope); - } - } - + void Build(const framework::proto::ProgramDesc& program); // Create temporary variables. - void PrepareWorkspace(const framework::proto::ProgramDesc& program) { - CHECK(!exec_scope) << "Duplicate PrepareWorkspace found"; - exec_scope = &scope->NewScope(); - // Create Feed and Fetch var. - scope->Var("feed")->GetMutable>(); - scope->Var("fetch")->GetMutable>(); - - tmp_vars.push_back("feed"); - tmp_vars.push_back("fetch"); - CHECK(!program.blocks().empty()); - for (auto proto_var_desc : program.blocks(0).vars()) { - lite::VarDesc var_desc(proto_var_desc); - if (!var_desc.Persistable()) { - tmp_vars.push_back(var_desc.Name()); - exec_scope->Var(var_desc.Name()); - } else { - if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue; - weights.push_back(var_desc.Name()); - } - } - } + void PrepareWorkspace(const framework::proto::ProgramDesc& program); + + private: + std::list tmp_vars_; + std::list weights_; + std::list> ops_; + // the scope to run the kernels, NOTE this is the execution scope. + std::shared_ptr scope_; + std::vector valid_places_; + // Runtime scope. + lite::Scope* exec_scope_{}; + const framework::proto::ProgramDesc desc_; }; -struct Instruct { - Instruct(const std::shared_ptr& op, - std::unique_ptr&& kernel) - : op_(op), kernel_(std::move(kernel)) {} +struct Instruction { + Instruction(const std::shared_ptr& op, + std::unique_ptr&& kernel) + : op_(op), kernel_(std::move(kernel)) { +#ifdef LITE_WITH_PROFILE + profile_id_ = profile::BasicProfiler::Global() + .NewRcd(kernel_->SerializedKernelType()) + .id(); +#endif // LITE_WITH_PROFILE + } void Run() { +#ifdef LITE_WITH_PROFILE + profile::ProfileBlock x(profile_id_); +#endif // LITE_WITH_PROFILE CHECK(op_); CHECK(kernel_); if (first_epoch_) { @@ -113,10 +104,10 @@ struct Instruct { CHECK(op_->CheckShape()); } op_->InferShape(); - kernel_->Run(); + kernel_->Launch(); } - friend std::ostream& operator<<(std::ostream& os, const Instruct& other) { + friend std::ostream& operator<<(std::ostream& os, const Instruction& other) { os << other.kernel_->summary() << "\t(" << other.kernel_->doc() << ")"; return os; } @@ -128,6 +119,11 @@ struct Instruct { std::shared_ptr op_; std::unique_ptr kernel_; bool first_epoch_{true}; + +#ifdef LITE_WITH_PROFILE + // for profiler + int profile_id_{-1}; +#endif // LITE_WITH_PROFILE }; /* @@ -135,7 +131,7 @@ struct Instruct { */ class RuntimeProgram { public: - explicit RuntimeProgram(std::vector&& insts) + explicit RuntimeProgram(std::vector&& insts) : instructions_(std::move(insts)) { if (instructions_.empty()) { LOG(FATAL) << "no instructions"; @@ -165,7 +161,7 @@ class RuntimeProgram { private: RuntimeProgram(const RuntimeProgram&) = delete; - std::vector instructions_; + std::vector instructions_; lite::Scope* exec_scope_{}; }; diff --git a/paddle/fluid/lite/core/program_fake_utils.h b/paddle/fluid/lite/core/program_fake_utils.h index 6d24897ca6f5cd..b36e47bf1f25ef 100644 --- a/paddle/fluid/lite/core/program_fake_utils.h +++ b/paddle/fluid/lite/core/program_fake_utils.h @@ -33,11 +33,11 @@ Program FakeProgram() { std::string w1 = "w" + std::to_string(id); std::string b1 = "b" + std::to_string(id); std::string out1 = "out" + std::to_string(id); - auto w1v = program.scope->Var(w1)->GetMutable(); - auto b1v = program.scope->Var(b1)->GetMutable(); - auto out1v = program.scope->Var(out1)->GetMutable(); + auto w1v = program.scope()->Var(w1)->GetMutable(); + auto b1v = program.scope()->Var(b1)->GetMutable(); + auto out1v = program.scope()->Var(out1)->GetMutable(); - lite::OpDesc desc; + cpp::OpDesc desc; desc.SetInput("Input", {x}); desc.SetInput("W", {w1}); desc.SetInput("Bias", {b1}); @@ -46,12 +46,12 @@ Program FakeProgram() { desc.SetAttr("in_num_col_dims", 1); // add to input - program.tmp_vars.push_back(w1); - program.tmp_vars.push_back(b1); + program.mutable_tmp_vars()->push_back(w1); + program.mutable_tmp_vars()->push_back(b1); auto fc_op = LiteOpRegistry::Global().Create("fc"); - fc_op->Attach(desc, program.scope.get()); - program.ops.emplace_back(std::move(fc_op)); + fc_op->Attach(desc, program.scope()); + program.mutable_ops()->emplace_back(std::move(fc_op)); w1v->Resize(DDimHvy(std::vector({100, 100}))); b1v->Resize(DDimHvy(std::vector({100, 1}))); @@ -64,8 +64,8 @@ Program FakeProgram() { // out1, w2, b2 -fc-> out2 std::string x = "x"; - program.tmp_vars.push_back(x); - auto* xv = program.scope->Var(x)->GetMutable(); + program.mutable_tmp_vars()->push_back(x); + auto* xv = program.scope()->Var(x)->GetMutable(); xv->Resize(DDimHvy(std::vector({100, 100}))); for (int i = 0; i < 3; i++) { diff --git a/paddle/fluid/lite/core/scope.cc b/paddle/fluid/lite/core/scope.cc index 053803b00a082a..fbb837aedd369d 100644 --- a/paddle/fluid/lite/core/scope.cc +++ b/paddle/fluid/lite/core/scope.cc @@ -17,7 +17,13 @@ namespace paddle { namespace lite { -Scope::~Scope() {} +Scope::~Scope() { + for (auto *x : kids_) { + if (x) { + delete x; + } + } +} Scope &Scope::NewScope() const { kids_.push_back(new Scope); diff --git a/paddle/fluid/lite/core/target_wrapper.h b/paddle/fluid/lite/core/target_wrapper.h index df21c406e83b58..1029bf5300e678 100644 --- a/paddle/fluid/lite/core/target_wrapper.h +++ b/paddle/fluid/lite/core/target_wrapper.h @@ -30,6 +30,7 @@ enum class TargetType : int { kHost, kX86, kCUDA, + kARM, kAny, // any target NUM, // number of fields. }; @@ -62,7 +63,8 @@ static const std::string& TargetToStr(TargetType target) { } static const std::string& PrecisionToStr(PrecisionType precision) { - static const std::string precision2string[] = {"unk", "float", "int8", "any"}; + static const std::string precision2string[] = {"unk", "float", "int8_t", + "any"}; auto x = static_cast(precision); CHECK_LT(x, static_cast(PRECISION(NUM))); return precision2string[x]; @@ -75,6 +77,29 @@ static const std::string& DataLayoutToStr(DataLayoutType layout) { return datalayout2string[x]; } +static const std::string& TargetRepr(TargetType target) { + static const std::string target2string[] = {"kUnk", "kHost", "kX86", "kCUDA", + "kAny"}; + auto x = static_cast(target); + CHECK_LT(x, static_cast(TARGET(NUM))); + return target2string[x]; +} + +static const std::string& PrecisionRepr(PrecisionType precision) { + static const std::string precision2string[] = {"kUnk", "kFloat", "kInt8", + "kAny"}; + auto x = static_cast(precision); + CHECK_LT(x, static_cast(PRECISION(NUM))); + return precision2string[x]; +} + +static const std::string& DataLayoutRepr(DataLayoutType layout) { + static const std::string datalayout2string[] = {"kUnk", "kNCHW", "kAny"}; + auto x = static_cast(layout); + CHECK_LT(x, static_cast(DATALAYOUT(NUM))); + return datalayout2string[x]; +} + /* * Place specifies the execution context of a Kernel or input/output for a * kernel. It is used to make the analysis of the MIR more clear and accurate. @@ -227,5 +252,20 @@ class TargetWrapper { }; #endif // LITE_WITH_CUDA +template +void CopySync(void* dst, void* src, size_t size, IoDirection dir) { + switch (Target) { + case TARGET(kX86): + case TARGET(kHost): + case TARGET(kARM): + TargetWrapperX86::MemcpySync(dst, src, size, IoDirection::HtoH); + break; +#ifdef LITE_WITH_CUDA + case TARGET(kCUDA): + TargetWrapperCuda::MemcpySync(dst, src, size, dir); +#endif + } +} + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/tensor.h b/paddle/fluid/lite/core/tensor.h index 807fbfc6a62350..d6980ff8898374 100644 --- a/paddle/fluid/lite/core/tensor.h +++ b/paddle/fluid/lite/core/tensor.h @@ -47,8 +47,9 @@ class DDimBase { DDimBase() = default; explicit DDimBase(const std::vector &x) { self()->ConstructFrom(x); } - value_type operator[](int offset) const { return (*self())[offset]; } - std::vector Vectorize() { return self()->Vectorize(); } + value_type operator[](int offset) const { return (*const_self())[offset]; } + value_type &operator[](int offset) { return (*self())[offset]; } + std::vector Vectorize() const { return self()->Vectorize(); } size_t size() const { return const_self()->size(); } bool empty() const { return const_self()->empty(); } @@ -73,18 +74,19 @@ class DDimBase { {Slice(0, col).production(), Slice(col, size()).production()})); } - friend std::ostream &operator<<(std::ostream &os, const DDimT &dims) { - if (dims.empty()) { - os << "[]"; - return os; + std::string repr() const { + std::stringstream ss; + ss << "{"; + for (size_t i = 0; i < this->size() - 1; i++) { + ss << (*this)[i] << ","; } + if (!this->empty()) ss << (*this)[size() - 1]; + ss << "}"; + return ss.str(); + } - os << "["; - for (size_t i = 0; i < dims.size() - 1; i++) { - os << dims[i] << " "; - } - if (!dims.empty()) os << dims[dims.size() - 1]; - os << "]"; + friend std::ostream &operator<<(std::ostream &os, const DDimT &dims) { + os << dims.repr(); return os; } @@ -102,6 +104,12 @@ template class TensorBase { public: TensorBase() = default; + + template + void Assign(T *data, const DimT &dim) { + self()->Assign(data, dim); + } + TargetType target() const { return self()->target(); } template diff --git a/paddle/fluid/lite/core/variable.h b/paddle/fluid/lite/core/variable.h index c83871446d254a..d52a813a09c70d 100644 --- a/paddle/fluid/lite/core/variable.h +++ b/paddle/fluid/lite/core/variable.h @@ -24,7 +24,7 @@ namespace lite { class Variable { public: template - const T& Get() { + const T& Get() const { return blob_.get(); } diff --git a/paddle/fluid/lite/gen_code/CMakeLists.txt b/paddle/fluid/lite/gen_code/CMakeLists.txt new file mode 100644 index 00000000000000..bacfc3e988e603 --- /dev/null +++ b/paddle/fluid/lite/gen_code/CMakeLists.txt @@ -0,0 +1,27 @@ +lite_cc_library(gen_code_lite SRCS gen_code.cc + DEPS program_lite op_lite scope_lite + cpp_op_desc_lite + HVY_DEPS operator) +lite_cc_library(paddle_infer_gencode SRCS paddle_infer.cc DEPS program_lite utils_lite) + +if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + lite_cc_test(test_gen_code_lite SRCS gen_code_test.cc DEPS gen_code_lite ${tensor_lite} + mul_op_lite + compatible_pb_lite + model_parser_lite + X86_DEPS mul_compute_x86 + ARM_DEPS mul_compute_arm + ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) + + lite_cc_library(__generated_code__ + SRCS ${CMAKE_BINARY_DIR}/paddle/fluid/lite/gen_code/__generated_code__.cc + DEPS scope_lite op_lite kernel_lite paddle_infer_gencode + ) + + lite_cc_test(test_generated_code SRCS generated_code_test.cc DEPS __generated_code__ + ${ops_lite} ${host_kernels} + X86_DEPS ${x86_kernels} + ) + + add_dependencies(__generated_code__ test_gen_code_lite) +endif() diff --git a/paddle/fluid/lite/gen_code/gen_code.cc b/paddle/fluid/lite/gen_code/gen_code.cc new file mode 100644 index 00000000000000..a50241bb715ceb --- /dev/null +++ b/paddle/fluid/lite/gen_code/gen_code.cc @@ -0,0 +1,193 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/gen_code/gen_code.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace gencode { + +void Module::AddWeight(const std::string &name, const TensorRepr &tensor) { + auto w_name = WeightUniqueName(); + Line(string_format("// Create weight: %s", name.c_str())); + // auto* w0 = scope.Var("w0")->GetMutable(); + Line(string_format("auto* %s = scope->Var(%s)->GetMutable();", + w_name.c_str(), Repr(name).c_str())); + // lite::DDim w_ddim({1, 2}) + Line(string_format("lite::DDim %s_ddim(std::vector(%s));", + w_name.c_str(), tensor.ddim.repr().c_str())); + // std::vector w_data({}); + auto w_data_repr = DataRepr( + std::string(static_cast(tensor.raw_data), tensor.num_bytes), + tensor.dtype); + Line(string_format("std::vector<%s> %s_data({%s});", + PrecisionToStr(tensor.dtype).c_str(), w_name.c_str(), + w_data_repr.c_str())); + // w0->Assign(w0_data.data(), w0_ddim); + Line(string_format( + "%s->Assign<%s, lite::DDim, TARGET(kX86)>(%s_data.data(), %s_ddim);", + w_name.c_str(), PrecisionToStr(tensor.dtype).c_str(), w_name.c_str(), + w_name.c_str())); + Line(""); +} + +void Module::AddHeaderIncludeGenCode() { + Line(""); + Line("#include "); + Line("#include "); + Line("#include \"paddle/fluid/lite/core/compatible_tensor.h\""); + Line("#include \"paddle/fluid/lite/core/context.h\""); + Line("#include \"paddle/fluid/lite/gen_code/paddle_infer.h\""); + Line("#include \"paddle/fluid/lite/core/op_registry.h\""); + Line("#include \"paddle/fluid/lite/core/scope.h\""); + Line("#include \"paddle/fluid/lite/model_parser/cpp/op_desc.h\""); + Line(""); + Line(""); +} + +std::string Module::DataRepr(const std::string &raw_data, PrecisionType dtype) { + std::stringstream ss; + switch (dtype) { + case PRECISION(kFloat): { + const float *raw = reinterpret_cast(raw_data.c_str()); + int num_elems = raw_data.size() / sizeof(float); + if (num_elems) { + for (int i = 0; i < num_elems - 1; i++) { + ss << raw[i] << ","; + } + ss << raw[num_elems - 1]; + } + } break; + + default: + LOG(FATAL) << "Unsupported type " << PrecisionToStr(dtype); + } + return ss.str(); +} + +void Module::AddOpDescHelper(const std::string &op_id, + const cpp::OpDesc &desc) { + std::string desc_var = op_id + "_desc"; + Line(string_format("lite::cpp::OpDesc %s;", desc_var.c_str())); + auto vec_str_repr = [](const std::vector &vec) { + return Repr(vec); + }; + for (auto &item : desc.inputs()) { + Line(string_format("%s.SetInput(%s, %s);", desc_var.c_str(), + Repr(item.first).c_str(), + vec_str_repr(item.second).c_str())); + } + + for (auto &item : desc.outputs()) { + Line(string_format("%s.SetOutput(%s, %s);", desc_var.c_str(), + Repr(item.first).c_str(), + vec_str_repr(item.second).c_str())); + } + + auto attr_repr = [&](const std::string &name) -> std::string { + using AttrType = OpDescAPI::AttrType; + auto type = desc.GetAttrType(name); + + switch (type) { + case AttrType::INT: + return std::to_string(desc.GetAttr(name)); + case AttrType::FLOAT: + return std::to_string(desc.GetAttr(name)); + case AttrType::BOOLEAN: + return std::to_string(desc.GetAttr(name)); + case AttrType::STRING: + return "\"" + desc.GetAttr(name) + "\""; + case AttrType::STRINGS: { + std::vector tmp; + auto vals = desc.GetAttr>(name); + std::transform(vals.begin(), vals.end(), std::back_inserter(tmp), + [](const std::string &x) { return Repr(x); }); + return "{" + Join(tmp, ",") + "}"; + } + default: + LOG(FATAL) << "Unsupported attribute type: " << static_cast(type); + } + return ""; + }; + + auto attr_type_repr = [&](const std::string &name) -> std::string { + using AttrType = OpDescAPI::AttrType; + auto type = desc.GetAttrType(name); + + switch (type) { + case AttrType::INT: + return "int"; + case AttrType::FLOAT: + return "float"; + case AttrType::BOOLEAN: + return "bool"; + case AttrType::STRING: + return "std::string"; + case AttrType::STRINGS: + return "std::vector"; + default: + LOG(FATAL) << "Unsupported attribute type: " << static_cast(type); + } + + return "unk_t"; + }; + for (auto &item : desc.AttrNames()) { + // Drop the python information. + if (item == "op_callstack") continue; + auto attr_type = attr_type_repr(item); + auto attr_val = attr_repr(item); + Line(string_format("%s.SetAttr<%s>(%s, %s);", // + desc_var.c_str(), attr_type.c_str(), Repr(item).c_str(), + attr_val.c_str())); + } +} + +void Module::AddOp(const cpp::OpDesc &op) { + auto op_name = OpUniqueName(); + AddOpDescHelper(op_name, op); + + Line(string_format("// Create Op: %s", op.Type().c_str())); + + Line(string_format("auto %s = lite::LiteOpRegistry::Global().Create(\"%s\");", + op_name.c_str(), op.Type().c_str())); + + CHECK(op.HasAttr(kKernelTypeAttr)) + << "the kernel type should be specified before generate code."; + auto kernel_type = op.GetAttr(kKernelTypeAttr); + Line(string_format("%s->Attach(%s, exec_scope);", op_name.c_str(), + (op_name + "_desc").c_str())); + + // Create kernel + auto kernel_name = KernelUniqueName(); + Line(string_format( + "auto %s = std::move(%s->CreateKernels(valid_places, \"%s\").front());", + kernel_name.c_str(), op_name.c_str(), kernel_type.c_str())); + + // Set Context for kernel + // clang-format off + Line(string_format("%s->SetContext(lite::ContextScheduler::Global().NewContext(%s->target()));", kernel_name.c_str(), kernel_name.c_str())); // NOLINT + // clang-format on + + Line(string_format("ops.push_back(%s);", op_name.c_str())); + Line(string_format("kernels.push_back(std::move(%s));", kernel_name.c_str())); + + op_kinds_.insert(op.Type()); + kernel_kinds_.insert(kernel_type); +} +} // namespace gencode +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/gen_code/gen_code.h b/paddle/fluid/lite/gen_code/gen_code.h new file mode 100644 index 00000000000000..1a55483f03a357 --- /dev/null +++ b/paddle/fluid/lite/gen_code/gen_code.h @@ -0,0 +1,254 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/framework.pb.h" +#include "paddle/fluid/lite/core/program.h" +#include "paddle/fluid/lite/core/target_wrapper.h" +#include "paddle/fluid/lite/model_parser/cpp/op_desc.h" +#include "paddle/fluid/lite/model_parser/desc_apis.h" +#include "paddle/fluid/lite/utils/string.h" + +namespace paddle { +namespace lite { +namespace gencode { + +struct TensorRepr { + TensorRepr() = default; + TensorRepr(PrecisionType dtype, const std::vector &ddim, + void *raw_data, size_t num_bytes) + : dtype(dtype), ddim(ddim), raw_data(raw_data), num_bytes(num_bytes) {} + + PrecisionType dtype; + lite::DDim ddim; + const void *raw_data; + size_t num_bytes{}; +}; + +class Module { + std::vector ops; + std::vector weights; + std::vector tmp_vars_; + std::stringstream stream_; + std::set kernel_kinds_; + std::set op_kinds_; + + int line_indent_{}; + const int indent_unit_{2}; + + public: + void NewOp(const cpp::OpDesc &desc) { ops.push_back(desc); } + void NewWeight(const TensorRepr &x) { weights.push_back(x); } + void NewTmpVar(const std::string &x) { tmp_vars_.push_back(x); } + + std::stringstream &stream() { return stream_; } + + void AddHeaderIncludeGenCode(); + + void AddNamespaceBegin() { + Line("namespace paddle {"); + Line("namespace gencode{"); + Line(""); + } + + void AddNamespaceEnd() { + Line(""); + Line("} // namespace gencode"); + Line("} // namespace paddle"); + } + + void AddInitFuncBegin() { + Line("void PaddlePredictor::Init() {"); + Line(""); + IncIndent(); + } + + void AddInitFuncEnd() { + DecIndent(); + Line(""); + Line("}"); + } + + void AddScopeDecl() { + Line("lite::Scope* scope = static_cast(raw_scope_);"); + + // clang-format off + Line("lite::Scope* exec_scope = static_cast(raw_exe_scope_);"); // NOLINT + // clang-format on + + // Create feed and fetch in exec_scope. + Line(string_format("exec_scope->Var(%s);", Repr("feed").c_str())); + Line(string_format("exec_scope->Var(%s);", Repr("fetch").c_str())); + } + + void AddValidPlaceDecl() { + // clang-format off + Line("std::vector valid_places({lite::Place({TARGET(kX86), PRECISION(kFloat), DATALAYOUT(kNCHW)}), lite::Place({TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)})});"); // NOLINT + // clang-format on + } + + void AddMemberCast() { + Line("// Cast the raw members"); + // clang-format off + Line(string_format("auto& ops = *static_cast>*>(raw_ops_);")); // NOLINT + Line(string_format("auto& kernels = *static_cast>*>(raw_kernels_);")); // NOLINT + // clang-format on + Line(""); + } + + void AddWeight(const std::string &name, const TensorRepr &tensor); + + void AddTmpVar(const std::string &x) { + Line(string_format("// Create temporary variable: %s", x.c_str())); + Line(string_format("exec_scope->Var(%s);", Repr(x).c_str())); + Line(""); + } + + void AddOp(const cpp::OpDesc &op); + + void AddOpDescHelper(const std::string &op_id, const cpp::OpDesc &desc); + + void AddOpCompileDeps() { + Line(""); + Line("// Add Operator compile deps"); + for (auto &op_type : op_kinds_) { + Line(string_format("USE_LITE_OP(%s)", op_type.c_str())); + } + Line(""); + } + void AddKernelCompileDeps() { + Line("// Add Kernel compile deps"); + + std::string op_type, alias; + Place place; + for (auto &kernel_type : kernel_kinds_) { + KernelBase::ParseKernelType(kernel_type, &op_type, &alias, &place); + Line(string_format("USE_LITE_KERNEL(%s, %s, %s, %s, %s)", // + op_type.c_str(), // + TargetRepr(place.target).c_str(), + PrecisionRepr(place.precision).c_str(), + DataLayoutRepr(place.layout).c_str(), alias.c_str())); + } + } + + private: + std::string WeightUniqueName() const { + return "w_" + std::to_string(weight_counter_++); + } + std::string TmpVarUniqueName() const { + return "tmp_" + std::to_string(tmp_var_counter_++); + } + std::string OpUniqueName() const { + return "op_" + std::to_string(op_counter_++); + } + std::string KernelUniqueName() const { + return "kernel_" + std::to_string(kernel_counter_++); + } + + std::string DataRepr(const std::string &raw_data, PrecisionType dtype); + + void IncIndent() { line_indent_++; } + void DecIndent() { line_indent_--; } + + void Line(const std::string &x) { + std::string indent_str(line_indent_ * indent_unit_, ' '); + stream() << indent_str << x << "\n"; + } + + private: + mutable int weight_counter_{}; + mutable int tmp_var_counter_{}; + mutable int op_counter_{}; + mutable int kernel_counter_{}; +}; + +class ProgramCodeGenerator { + public: + ProgramCodeGenerator(const framework::proto::ProgramDesc &program, + const lite::Scope &exec_scope) + : program_(program), exec_scope_(exec_scope) { + LOG(INFO) << program.DebugString(); + } + + std::string GenCode() { + Module m; + m.AddHeaderIncludeGenCode(); + m.AddNamespaceBegin(); + m.AddInitFuncBegin(); + m.AddMemberCast(); + m.AddScopeDecl(); + m.AddValidPlaceDecl(); + + AddWeights(&m); + AddTmpVars(&m); + AddOps(&m); + + m.AddInitFuncEnd(); + m.AddNamespaceEnd(); + + m.AddOpCompileDeps(); + m.AddKernelCompileDeps(); + + return m.stream().str(); + } + + void AddWeights(Module *m) { + for (auto &var : program_.blocks(0).vars()) { + if (var.persistable()) { + auto name = var.name(); + if (name == "feed" || name == "fetch") continue; + const auto &tensor = exec_scope_.FindVar(name)->Get(); + TensorRepr repr; + TensorToRepr(tensor, &repr); + m->AddWeight(name, repr); + } + } + } + void AddTmpVars(Module *m) { + for (auto &var : program_.blocks(0).vars()) { + if (!var.persistable()) { + m->AddTmpVar(var.name()); + } + } + } + void AddOps(Module *m) { + for (auto &op : program_.blocks(0).ops()) { + pb::OpDesc pb_desc(op); + cpp::OpDesc cpp_desc; + TransformOpDescPbToCpp(pb_desc, &cpp_desc); + m->AddOp(cpp_desc); + } + } + + private: + void TensorToRepr(const lite::Tensor &tensor, TensorRepr *repr) { + repr->ddim = tensor.dims(); + // TODO(Superjomn) support other types. + repr->dtype = PRECISION(kFloat); + repr->raw_data = tensor.data(); + repr->num_bytes = repr->ddim.production() * sizeof(float); + } + + private: + const framework::proto::ProgramDesc &program_; + const lite::Scope &exec_scope_; +}; + +} // namespace gencode +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/gen_code/gen_code_test.cc b/paddle/fluid/lite/gen_code/gen_code_test.cc new file mode 100644 index 00000000000000..c27b775c061bcd --- /dev/null +++ b/paddle/fluid/lite/gen_code/gen_code_test.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/gen_code/gen_code.h" +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/context.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/model_parser/cpp/op_desc.h" +#include "paddle/fluid/lite/model_parser/model_parser.h" + +DEFINE_string(optimized_model, "", ""); +DEFINE_string(generated_code_file, "__generated_code__.cc", ""); + +namespace paddle { +namespace lite { +namespace gencode { + +// Manually construct a program. +TEST(gen_code, manual) { + // For holding the weights. + lite::Scope scope; + // For holding the temporary variables. + auto &tmp_scope = scope.NewScope(); + + // Create weight variables. + auto *w0 = scope.Var("w0")->GetMutable(); + // Create temporary variables. + auto *a = tmp_scope.Var("x")->GetMutable(); + tmp_scope.Var("out")->GetMutable(); + + // Set weights. + std::vector w0_data({0, 1, 2, 3}); + w0->Assign( + w0_data.data(), lite::DDim{std::vector({2, 2})}); + + std::vector a_data({0, 1, 2, 3}); + a->Assign( + a_data.data(), lite::DDim{std::vector({2, 2})}); + + std::vector valid_places({ + Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kHost), PRECISION(kAny)}, + }); + auto mul_op = LiteOpRegistry::Global().Create("mul"); + cpp::OpDesc mul_op_desc; + mul_op_desc.SetType("mul"); + mul_op_desc.SetInput("X", {"x"}); + mul_op_desc.SetInput("Y", {"w0"}); + mul_op_desc.SetAttr("x_num_col_dims", 1); + mul_op_desc.SetAttr("y_num_col_dims", 1); + mul_op_desc.SetOutput("Out", {"out"}); + + mul_op->Attach(mul_op_desc, &tmp_scope); + auto mul_kernel = std::move(mul_op->CreateKernels(valid_places).front()); + auto fc_ctx = ContextScheduler::Global().NewContext(TARGET(kX86)); + mul_op->CheckShape(); + mul_op->InferShape(); + mul_kernel->SetContext(std::move(fc_ctx)); + mul_kernel->Launch(); +} + +TEST(gen_code, auto_gen) { + std::vector w0_data({0, 1, 2, 3}); + TensorRepr w0(PRECISION(kFloat), std::vector({2, 2}), w0_data.data(), + w0_data.size() * sizeof(float)); + + std::vector w1_data({0.01, 1.2, 2.3, 3.4, 1.1, 2.2}); + TensorRepr w1(PRECISION(kFloat), std::vector({3, 2}), w1_data.data(), + w1_data.size() * sizeof(float)); + + cpp::OpDesc op0; + op0.SetType("mul"); + op0.SetInput("X", {"a", "b"}); + op0.SetOutput("Out", {"out0"}); + op0.SetAttr("desc", "this is a desc"); + op0.SetAttr("x_col", 1); + op0.SetAttr("y_col", 2); + op0.SetAttr(kKernelTypeAttr, "x86"); + + gencode::Module module; + module.AddHeaderIncludeGenCode(); + + module.AddNamespaceBegin(); + module.AddInitFuncBegin(); + + module.AddMemberCast(); + + module.AddWeight("w0", w0); + module.AddWeight("w1", w1); + module.AddTmpVar("a"); + module.AddTmpVar("b"); + + module.AddOp(op0); + + module.AddInitFuncEnd(); + module.AddNamespaceEnd(); + + LOG(INFO) << module.stream().str(); +} + +TEST(gen_code, optimized_program) { + lite::Scope scope; + framework::proto::ProgramDesc desc; + LoadModel(FLAGS_optimized_model, &scope, &desc); + + ProgramCodeGenerator codegen(desc, scope); + + std::ofstream file(FLAGS_generated_code_file); + + file << codegen.GenCode(); + + file.close(); +} + +} // namespace gencode +} // namespace lite +} // namespace paddle + +USE_LITE_OP(mul); +#ifdef LITE_WITH_X86 +USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def); +#endif + +#ifdef LITE_WITH_ARM +USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); +#endif diff --git a/paddle/fluid/lite/gen_code/generated_code_test.cc b/paddle/fluid/lite/gen_code/generated_code_test.cc new file mode 100644 index 00000000000000..e5874a2e149fce --- /dev/null +++ b/paddle/fluid/lite/gen_code/generated_code_test.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/fluid/lite/gen_code/paddle_infer.h" + +namespace paddle { +namespace lite { + +TEST(PaddlePredictor, Init) { + gencode::PaddlePredictor predictor; + predictor.Init(); +} + +TEST(PaddlePredictor, Run) { + gencode::PaddlePredictor predictor; + predictor.Init(); + + LOG(INFO) << "run the generated code"; + auto input_tensor = predictor.GetInput(0); + input_tensor->Resize(std::vector({100, 100})); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100 * 100; i++) { + data[i] = i; + } + + predictor.Run(); + + auto output_tensor = predictor.GetOutput(0); + LOG(INFO) << "output: " << output_tensor->data()[0]; +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/gen_code/paddle_infer.cc b/paddle/fluid/lite/gen_code/paddle_infer.cc new file mode 100644 index 00000000000000..ac4e99cb714dc1 --- /dev/null +++ b/paddle/fluid/lite/gen_code/paddle_infer.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/gen_code/paddle_infer.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/op_lite.h" + +namespace paddle { +namespace gencode { + +void Tensor::Resize(const Tensor::ddim_t &shape) { + CHECK(raw_mutable_tensor_); + auto *tensor = static_cast(raw_mutable_tensor_); + tensor->Resize(shape); +} + +#define FOR_EACH_TYPE(HANDLE) \ + HANDLE(int); \ + HANDLE(float); \ + HANDLE(int8_t); \ + HANDLE(int64_t); + +#define IMPL_DATA(T) \ + template <> \ + const T *Tensor::data() const { \ + CHECK(raw_tensor_); \ + const auto *tensor = static_cast(raw_tensor_); \ + return tensor->data(); \ + } +FOR_EACH_TYPE(IMPL_DATA); +#undef IMPL_DATA + +#define IMPL_MUTABLE_DATA(T) \ + template <> \ + T *Tensor::mutable_data() { \ + CHECK(raw_mutable_tensor_); \ + auto *tensor = static_cast(raw_mutable_tensor_); \ + return tensor->mutable_data(); \ + } +FOR_EACH_TYPE(IMPL_MUTABLE_DATA); +#undef IMPL_MUTABLE_DATA + +PaddlePredictor::PaddlePredictor() { + raw_ops_ = new std::vector>; + raw_kernels_ = new std::vector>; + raw_scope_ = new lite::Scope; + raw_exe_scope_ = &(static_cast(raw_scope_)->NewScope()); +} + +std::unique_ptr PaddlePredictor::GetTensor( + const std::string &id) const { + auto *exe_scope = static_cast(raw_exe_scope_); + const auto *var = exe_scope->FindVar(id); + const auto &tensor = var->Get(); + return std::unique_ptr(new Tensor(&tensor, nullptr)); +} + +std::unique_ptr PaddlePredictor::GetMutableTensor( + const std::string &id) { + auto *exe_scope = static_cast(raw_exe_scope_); + auto *var = exe_scope->FindVar(id); + auto *tensor = var->GetMutable(); + return std::unique_ptr(new Tensor(nullptr, tensor)); +} + +#define CAST_OPS \ + auto *ops = \ + static_cast> *>(raw_ops_); +#define CAST_KERNELS \ + auto *kernels = \ + static_cast> *>( \ + raw_kernels_); +#define CAST_SCOPE auto *scope = static_cast(raw_scope_); + +PaddlePredictor::~PaddlePredictor() { + CAST_OPS + CAST_KERNELS + CAST_SCOPE + + if (ops) { + delete ops; + } + if (kernels) { + delete kernels; + } + if (scope) { + delete scope; + } +} + +void PaddlePredictor::Run() { + CAST_OPS + CAST_KERNELS + + CHECK(ops); + CHECK(kernels); + CHECK_EQ(ops->size(), kernels->size()); + + for (size_t i = 0; i < ops->size(); i++) { + LOG(INFO) << "Running the " << i << "-th operator"; + ops->at(i)->InferShape(); + kernels->at(i)->Launch(); + } +} + +std::unique_ptr PaddlePredictor::GetInput(size_t offset) { + auto *exec_scope = static_cast(raw_exe_scope_); + auto *_feed_list = exec_scope->FindVar("feed"); + CHECK(_feed_list) << "no feed variable in exec_scope"; + auto *feed_list = _feed_list->GetMutable>(); + if (offset >= feed_list->size()) { + feed_list->resize(offset + 1); + } + + return std::unique_ptr(new Tensor(nullptr, &feed_list->at(offset))); +} + +std::unique_ptr PaddlePredictor::GetOutput(size_t offset) { + auto *exec_scope = static_cast(raw_exe_scope_); + auto *_fetch_list = exec_scope->FindVar("fetch"); + CHECK(_fetch_list) << "no fatch variable in exec_scope"; + auto &fetch_list = *_fetch_list->GetMutable>(); + CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; + return std::unique_ptr(new Tensor(&fetch_list.at(offset), nullptr)); +} + +} // namespace gencode +} // namespace paddle diff --git a/paddle/fluid/lite/gen_code/paddle_infer.h b/paddle/fluid/lite/gen_code/paddle_infer.h new file mode 100644 index 00000000000000..99158b0503c8b7 --- /dev/null +++ b/paddle/fluid/lite/gen_code/paddle_infer.h @@ -0,0 +1,70 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +namespace paddle { +namespace gencode { + +/// Zero Copy Tensor. +class Tensor { + public: + using ddim_t = std::vector; + + Tensor(const void *raw_tensor, void *raw_mutable_tensor) + : raw_tensor_(raw_tensor), raw_mutable_tensor_(raw_mutable_tensor) {} + + void Resize(const ddim_t &shape); + template + const T *data() const; + template + T *mutable_data(); + + private: + const void *raw_tensor_; + void *raw_mutable_tensor_{}; +}; + +/* + * Predictor for the generated code. + */ +class PaddlePredictor { + public: + void Init(); + + std::unique_ptr GetTensor(const std::string &id) const; + std::unique_ptr GetMutableTensor(const std::string &id); + + // Get offset-th col of feed. + std::unique_ptr GetInput(size_t offset); + + std::unique_ptr GetOutput(size_t offset); + + void Run(); + + PaddlePredictor(); + ~PaddlePredictor(); + + private: + void *raw_ops_; + void *raw_kernels_; + void *raw_scope_{}; + void *raw_exe_scope_{}; // raw_exe_scope is not owned. +}; + +} // namespace gencode +} // namespace paddle diff --git a/paddle/fluid/lite/host/CMakeLists.txt b/paddle/fluid/lite/host/CMakeLists.txt index efc29d0e830abd..90812f3f3cd712 100644 --- a/paddle/fluid/lite/host/CMakeLists.txt +++ b/paddle/fluid/lite/host/CMakeLists.txt @@ -1 +1 @@ -cc_library(target_wrapper_host SRCS target_wrapper.cc DEPS target_wrapper_lite) +cc_library(target_wrapper_host SRCS target_wrapper.cc) diff --git a/paddle/fluid/lite/kernels/CMakeLists.txt b/paddle/fluid/lite/kernels/CMakeLists.txt index 877ac7e05e333d..ce22ba1216664c 100644 --- a/paddle/fluid/lite/kernels/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/CMakeLists.txt @@ -1,6 +1,7 @@ message(STATUS "add lite kernels") -set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite ${tensor_lite}) +set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite context_lite ${tensor_lite}) add_subdirectory(host) add_subdirectory(arm) add_subdirectory(cuda) add_subdirectory(x86) + diff --git a/paddle/fluid/lite/kernels/arm/CMakeLists.txt b/paddle/fluid/lite/kernels/arm/CMakeLists.txt index a7060dbd62367d..ff3cab02ee8b7e 100644 --- a/paddle/fluid/lite/kernels/arm/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/arm/CMakeLists.txt @@ -1 +1,27 @@ +if(NOT (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)) + return() +endif() + message(STATUS "compile with lite ARM kernels") + +cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm) +cc_library(relu_compute_arm SRCS relu_compute.cc DEPS ${lite_kernel_deps}) +cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) +cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm) +cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) +cc_library(elementwise_add_compute_arm SRCS elementwise_add_compute.cc DEPS ${lite_kernel_deps} math_arm) + +lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm) +lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm) +lite_cc_test(test_softmax_compute_arm SRCS softmax_compute_test.cc DEPS softmax_compute_arm) +lite_cc_test(test_elementwise_add_compute_arm SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_arm) + +set(arm_kernels + fc_compute_arm + relu_compute_arm + mul_compute_arm + scale_compute_arm + softmax_compute_arm + elementwise_add_compute_arm) + +set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels") diff --git a/paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc b/paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc new file mode 100644 index 00000000000000..310cde17bbd2f2 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/elementwise_add_compute.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/elementwise_add_compute.h" +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void ElementwiseAddCompute::Run() { + auto& param = Param(); + const float* x_data = param.X->data(); + const float* y_data = param.Y->data(); + float* out_data = param.Out->mutable_data(); + int n = param.X->dims().production(); + lite::arm::math::elementwise_add(x_data, y_data, out_data, n); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::ElementwiseAddCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/elementwise_add_compute.h b/paddle/fluid/lite/kernels/arm/elementwise_add_compute.h new file mode 100644 index 00000000000000..9939509d0be25e --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/elementwise_add_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class ElementwiseAddCompute + : public KernelLite { + public: + void Run() override; + + virtual ~ElementwiseAddCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc b/paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc new file mode 100644 index 00000000000000..7156d08ce77df9 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/elementwise_add_compute_test.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/elementwise_add_compute.h" +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +TEST(elementwise_add_arm, retrive_op) { + auto elementwise_add = + KernelRegistry::Global().Create( + "elementwise_add"); + ASSERT_FALSE(elementwise_add.empty()); + ASSERT_TRUE(elementwise_add.front()); +} + +TEST(elementwise_add_arm, init) { + ElementwiseAddCompute elementwise_add; + ASSERT_EQ(elementwise_add.precision(), PRECISION(kFloat)); + ASSERT_EQ(elementwise_add.target(), TARGET(kARM)); +} + +template +void elementwise_add_compute_ref(const operators::ElementwiseParam& param) { + const dtype* x_data = param.X->data(); + const dtype* y_data = param.Y->data(); + dtype* out_data = param.Out->mutable_data(); + DDim dim = param.X->dims(); + ASSERT_EQ(dim.data(), param.Out->dims().data()); + for (int i = 0; i < dim.production(); i++) { + out_data[i] = x_data[i] + y_data[i]; + } +} + +TEST(elementwise_add, compute) { + ElementwiseAddCompute elementwise_add; + operators::ElementwiseParam param; + + lite::Tensor x, y, out, out_ref; + x.Resize(DDim(std::vector({2, 3, 4, 5}))); + y.Resize(DDim(std::vector({2, 3, 4, 5}))); + out.Resize(DDim(std::vector({2, 3, 4, 5}))); + out_ref.Resize(DDim(std::vector({2, 3, 4, 5}))); + auto* x_data = x.mutable_data(); + auto* y_data = y.mutable_data(); + auto* out_data = out.mutable_data(); + auto* out_ref_data = out_ref.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_data[i] = y_data[i] = i; + } + + param.X = &x; + param.Y = &y; + param.Out = &out; + elementwise_add.SetParam(param); + elementwise_add.Run(); + + param.Out = &out_ref; + elementwise_add_compute_ref(param); + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/arm/fc_compute.cc b/paddle/fluid/lite/kernels/arm/fc_compute.cc new file mode 100644 index 00000000000000..b26551e0533a5a --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/fc_compute.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/fc_compute.h" +#include "paddle/fluid/lite/arm/math/funcs.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void FcCompute::Run() { + auto& param = this->Param(); + auto x_dims = param.input->dims(); + auto w_dims = param.w->dims(); + + CHECK_GE(x_dims.size(), 2UL); + CHECK_EQ(w_dims.size(), 2UL); + CHECK_EQ(param.output->dims().size(), 2UL); + + const auto* i_data = param.input->data(); + const auto* w_data = param.w->data(); + const auto* b_data = param.bias ? param.bias->data() : nullptr; + auto* o_data = param.output->mutable_data(); + + int x_h = x_dims.Slice(0, param.in_num_col_dims).production(); + int x_w = x_dims.Slice(param.in_num_col_dims, x_dims.size()).production(); + int n = w_dims[1]; + CHECK_EQ(x_w, static_cast(w_dims[0])); + auto& ctx = this->ctx_->template As(); + if (x_h > 1) { + float* packed_in = static_cast(ctx.workspace_data()) + + ctx.l2_cache_size() / sizeof(float); + lite::arm::math::prepackA(packed_in, i_data, x_w, 0, x_h, 0, x_w, false, + &ctx); + lite::arm::math::sgemm_prepack(packed_in, w_data, b_data, o_data, x_h, n, + x_w, false, false, false, &ctx); + + if (param.bias) { + CHECK_EQ(param.bias->numel(), n); + lite::arm::math::fill_bias_fc(o_data, b_data, x_h, n); + } + } else { + // use sgemmv + // sgemv((const float*)weights, (const float*)din, (float*)dout, + // false, n, x_w, _param->_flag_bias, (float*)bias, false); + } +} + +TargetType FcCompute::target() const { return TARGET(kARM); } + +PrecisionType FcCompute::precision() const { return PRECISION(kFloat); } + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(fc, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::FcCompute, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/fc_compute.h b/paddle/fluid/lite/kernels/arm/fc_compute.h new file mode 100644 index 00000000000000..414517843354f6 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/fc_compute.h @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/operators/fc_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class FcCompute : public KernelLite { + public: + using param_t = operators::FcParam; + + void Run() override; + + TargetType target() const override; + PrecisionType precision() const override; + + virtual ~FcCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/fc_compute_test.cc b/paddle/fluid/lite/kernels/arm/fc_compute_test.cc new file mode 100644 index 00000000000000..2e85fccf7d66be --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/fc_compute_test.cc @@ -0,0 +1,149 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/fc_compute.h" +#include +#include +#include "paddle/fluid/lite/arm/math/funcs.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +TEST(fc_arm, retrive_op) { + auto fc = + KernelRegistry::Global().Create("fc"); + ASSERT_FALSE(fc.empty()); + ASSERT_TRUE(fc.front()); +} + +TEST(fc_arm, init) { + FcCompute fc; + ASSERT_EQ(fc.precision(), PRECISION(kFloat)); + ASSERT_EQ(fc.target(), TARGET(kARM)); +} + +TEST(fc_arm, compare_test) { + lite::Tensor x, w, b, out, ref; + constexpr int batch_size = 2; + x.Resize({batch_size, 3}); + w.Resize({3, 4}); + b.Resize({1, 4}); + out.Resize({batch_size, 4}); + ref.Resize({batch_size, 4}); + + auto x_data = x.mutable_data(); + auto w_data = w.mutable_data(); + auto b_data = b.mutable_data(); + auto out_data = out.mutable_data(); + auto ref_data = ref.mutable_data(); + + for (int64_t i = 0; i < x.dims().product(); i++) { + x_data[i] = static_cast(i); + } + for (int64_t i = 0; i < w.dims().product(); i++) { + w_data[i] = static_cast(i); + } + for (int64_t i = 0; i < b.dims().product(); i++) { + b_data[i] = static_cast(i); + } + + lite::arm::math::fc_compute_eigen(x_data, batch_size, 3, // + w_data, 3, 4, // + b_data, ref_data); + + // fc compute kernel + FcCompute fc; + operators::FcParam param; + + param.in_num_col_dims = 1; + param.input = &x; + param.w = &w; + param.bias = &b; + param.output = &out; + param.in_mat_dims = x.dims(); + + DeviceInfo::Init(); + std::unique_ptr ctx(new KernelContext); + ctx->As(); + fc.SetParam(param); + fc.SetContext(std::move(ctx)); + fc.Run(); + + VLOG(3) << "output vs ref"; + for (int i = 0; i < out.dims().product(); i++) { + VLOG(3) << out_data[i] << " vs " << ref_data[i]; + } + + for (int i = 0; i < out.dims().product(); ++i) { + EXPECT_NEAR(out_data[i], ref_data[i], 1e-5); + } +} + +TEST(fc_arm, num_col_dims) { + FcCompute fc; + operators::FcParam param; + + lite::Tensor x; + lite::Tensor w; + lite::Tensor bias; + lite::Tensor output; + + x.Resize({1, 2, 3}); + w.Resize({3, 4}); + bias.Resize({1, 4}); + output.Resize({2, 4}); + + auto* x_data = x.mutable_data(); + auto* w_data = w.mutable_data(); + auto* bias_data = bias.mutable_data(); + auto* output_data = output.mutable_data(); + + for (int64_t i = 0; i < x.dims().product(); i++) { + x_data[i] = static_cast(i); + } + for (int64_t i = 0; i < w.dims().product(); i++) { + w_data[i] = static_cast(i); + } + for (int64_t i = 0; i < bias.dims().product(); i++) { + bias_data[i] = static_cast(i); + } + for (int64_t i = 0; i < output.dims().product(); i++) { + output_data[i] = static_cast(i); + } + + param.in_num_col_dims = 2; + param.input = &x; + param.w = &w; + param.bias = &bias; + param.output = &output; + param.in_mat_dims = x.dims(); + + std::unique_ptr ctx(new KernelContext); + ctx->As(); + DeviceInfo::Init(); + + fc.SetParam(param); + fc.SetContext(std::move(ctx)); + fc.Run(); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/host/mul_compute.cc b/paddle/fluid/lite/kernels/arm/mul_compute.cc similarity index 80% rename from paddle/fluid/lite/kernels/host/mul_compute.cc rename to paddle/fluid/lite/kernels/arm/mul_compute.cc index 2bb509c86ac836..ff12b236031896 100644 --- a/paddle/fluid/lite/kernels/host/mul_compute.cc +++ b/paddle/fluid/lite/kernels/arm/mul_compute.cc @@ -20,7 +20,7 @@ namespace paddle { namespace lite { namespace kernels { -namespace host { +namespace arm { template void mul_compute_eigen(const T* x, int x_h, int x_w, const T* y, int y_h, @@ -35,7 +35,7 @@ void mul_compute_eigen(const T* x, int x_h, int x_w, const T* y, int y_h, Out = X * Y; } -class MulCompute : public KernelLite { +class MulCompute : public KernelLite { public: using param_t = operators::MulParam; @@ -59,22 +59,19 @@ class MulCompute : public KernelLite { mul_compute_eigen(param.x->data(), x_shape.x, x_shape.y, // param.y->data(), y_shape.x, y_shape.y, // param.output->mutable_data()); - LOG(INFO) << "MUL x " << *param.x; - LOG(INFO) << "MUL W " << *param.y; - LOG(INFO) << "MUL out " << *param.output; } virtual ~MulCompute() = default; }; -} // namespace host +} // namespace arm } // namespace kernels } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(mul, kHost, kFloat, kNCHW, - paddle::lite::kernels::host::MulCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindInput("Y", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) +REGISTER_LITE_KERNEL(mul, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::MulCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/host/relu_compute.cc b/paddle/fluid/lite/kernels/arm/relu_compute.cc similarity index 91% rename from paddle/fluid/lite/kernels/host/relu_compute.cc rename to paddle/fluid/lite/kernels/arm/relu_compute.cc index 59b9ccd836410f..6e27e8ec669aa4 100644 --- a/paddle/fluid/lite/kernels/host/relu_compute.cc +++ b/paddle/fluid/lite/kernels/arm/relu_compute.cc @@ -12,4 +12,4 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/lite/kernels/host/relu_compute.h" +#include "paddle/fluid/lite/kernels/arm/relu_compute.h" diff --git a/paddle/fluid/lite/kernels/host/relu_compute.h b/paddle/fluid/lite/kernels/arm/relu_compute.h similarity index 81% rename from paddle/fluid/lite/kernels/host/relu_compute.h rename to paddle/fluid/lite/kernels/arm/relu_compute.h index aae9e161aabe30..29d17bf5918e11 100644 --- a/paddle/fluid/lite/kernels/host/relu_compute.h +++ b/paddle/fluid/lite/kernels/arm/relu_compute.h @@ -20,9 +20,9 @@ namespace paddle { namespace lite { namespace kernels { -namespace host { +namespace arm { -class ReluCompute : public KernelLite { +class ReluCompute : public KernelLite { public: void Run() override { auto& param = Param(); @@ -34,15 +34,15 @@ class ReluCompute : public KernelLite { } } - TargetType target() const override { return TARGET(kHost); } + TargetType target() const override { return TARGET(kARM); } PrecisionType precision() const override { return PRECISION(kFloat); } }; -} // namespace host +} // namespace arm } // namespace kernels } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(relu, kHost, kFloat, kNCHW, - paddle::lite::kernels::host::ReluCompute, def) +REGISTER_LITE_KERNEL(relu, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::ReluCompute, def) .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/scale_compute.cc b/paddle/fluid/lite/kernels/arm/scale_compute.cc new file mode 100644 index 00000000000000..a89e19fb05a412 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/scale_compute.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/scale_compute.h" +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void ScaleCompute::Run() { + auto& param = Param(); + const float* x_data = param.x->data(); + float* output_data = param.output->mutable_data(); + DDim x_dims = param.x->dims(); + bool bias_after_scale = param.bias_after_scale; + float scale = param.scale; + float bias = param.bias; + if (!bias_after_scale) { + bias *= scale; + } + lite::arm::math::scale(x_data, output_data, x_dims.production(), scale, bias); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(scale, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::ScaleCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/scale_compute.h b/paddle/fluid/lite/kernels/arm/scale_compute.h new file mode 100644 index 00000000000000..b0ee41c654d209 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/scale_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class ScaleCompute : public KernelLite { + public: + void Run() override; + + virtual ~ScaleCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/scale_compute_test.cc b/paddle/fluid/lite/kernels/arm/scale_compute_test.cc new file mode 100644 index 00000000000000..fee47d7eb7a6c0 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/scale_compute_test.cc @@ -0,0 +1,106 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/scale_compute.h" +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +void scale_compute_ref(const operators::ScaleParam& param) { + const dtype* x_data = param.x->mutable_data(); + dtype* output_data = param.output->mutable_data(); + DDim x_dims = param.x->dims(); + DDim output_dims = param.output->dims(); + ASSERT_EQ(x_dims.data(), output_dims.data()); + bool bias_after_scale = param.bias_after_scale; + float scale = param.scale; + float bias = param.bias; + if (!bias_after_scale) { + bias *= scale; + } + for (int i = 0; i < output_dims.production(); i++) { + output_data[i] = x_data[i] * scale + bias; + } +} + +TEST(scale_arm, init) { + ScaleCompute scale; + ASSERT_EQ(scale.precision(), PRECISION(kFloat)); + ASSERT_EQ(scale.target(), TARGET(kARM)); +} + +TEST(scale_arm, compute) { + ScaleCompute scale; + operators::ScaleParam param; + + lite::Tensor x; + lite::Tensor output; + lite::Tensor output_ref; + + for (auto n : {1, 3, 4, 11}) { + for (auto c : {1, 3, 11, 4}) { + for (auto h : {3, 1, 11, 4}) { + for (auto w : {1, 3, 4, 12}) { + for (auto bias_after_scale : {true, false}) { + for (auto s : {-100.25f, -1.0f, 0.13f, 3840.975f}) { + for (auto b : {-3075.495f, -15.f, 0.11234f, 128.15f}) { + x.Resize(DDim(std::vector({n, c, h, w}))); + output.Resize(DDim(std::vector({n, c, h, w}))); + output_ref.Resize(DDim(std::vector({n, c, h, w}))); + auto* x_data = x.mutable_data(); + auto* output_data = output.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_data[i] = i; + } + param.x = &x; + param.output = &output; + param.bias_after_scale = bias_after_scale; + param.scale = s; + param.bias = b; + scale.SetParam(param); + scale.Run(); + param.output = &output_ref; + scale_compute_ref(param); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } + } + } + } +} + +TEST(scale, retrive_op) { + auto scale = + KernelRegistry::Global().Create("scale"); + ASSERT_FALSE(scale.empty()); + ASSERT_TRUE(scale.front()); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/arm/softmax_compute.cc b/paddle/fluid/lite/kernels/arm/softmax_compute.cc new file mode 100644 index 00000000000000..099385395e2e79 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/softmax_compute.cc @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/softmax_compute.h" +#include "paddle/fluid/lite/arm/math/funcs.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void SoftmaxCompute::Run() { + auto& param = Param(); + const float* din = param.x->data(); + float* dout = param.output->mutable_data(); + auto x_dims = param.x->dims(); + auto x_rank = x_dims.size(); + int axis = param.axis; + if (axis < 0) { + axis += x_rank; + } + int outer_num = x_dims.Slice(0, axis).production(); + int inner_num = x_dims.Slice(axis + 1, x_rank).production(); + int axis_size = x_dims[axis]; + if (inner_num == 1) { + if (axis_size >= 4) { + lite::arm::math::softmax_inner1_large_axis(din, dout, outer_num, + axis_size); + } else { + lite::arm::math::softmax_inner1_small_axis(din, dout, outer_num, + axis_size); + } + } else { + int compute_size = outer_num * inner_num; + if (axis_size == 4 && inner_num % 8 == 0) { + lite::arm::math::softmax_inner8_axis4(din, dout, axis_size, inner_num, + outer_num); + } else if (axis_size == 4 && inner_num % 4 == 0) { + lite::arm::math::softmax_inner4_axis4(din, dout, axis_size, inner_num, + outer_num); + } else { + if (inner_num % 8 == 0) { + lite::arm::math::softmax_inner8(din, dout, axis_size, inner_num, + outer_num); + } else if (inner_num % 4 == 0) { + lite::arm::math::softmax_inner4(din, dout, axis_size, inner_num, + outer_num); + } else { + lite::arm::math::softmax_basic(din, dout, axis_size, inner_num, + outer_num); + } + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, + paddle::lite::kernels::arm::SoftmaxCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/softmax_compute.h b/paddle/fluid/lite/kernels/arm/softmax_compute.h new file mode 100644 index 00000000000000..4d538473ebd89e --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/softmax_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class SoftmaxCompute : public KernelLite { + public: + void Run() override; + + virtual ~SoftmaxCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/arm/softmax_compute_test.cc b/paddle/fluid/lite/kernels/arm/softmax_compute_test.cc new file mode 100644 index 00000000000000..80a64f4eaf7428 --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/softmax_compute_test.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/arm/softmax_compute.h" +#include +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +void softmax_compute_ref(const operators::SoftmaxParam& param) { + const dtype* x_data = param.x->mutable_data(); + dtype* output_data = param.output->mutable_data(); + DDim x_dims = param.x->dims(); + ASSERT_EQ(x_dims.data(), param.output->dims().data()); + auto x_rank = x_dims.size(); + int axis = param.axis; + if (axis < 0) { + axis += x_rank; + } + int axis_size = x_dims[axis]; + int outer_num = x_dims.Slice(0, axis).production(); + int inner_num = x_dims.Slice(axis + 1, x_rank).production(); + int compute_size = outer_num * inner_num; + for (int i = 0; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int start = idx_outer * inner_num + idx_inner; + int offset; + + offset = start; + dtype max_data = std::numeric_limits::lowest(); + for (int j = 0; j < axis_size; j++) { + max_data = x_data[offset] > max_data ? x_data[offset] : max_data; + offset += inner_num; + } + + offset = start; + dtype sum_data = (dtype)0; + for (int j = 0; j < axis_size; j++) { + output_data[offset] = exp(x_data[offset] - max_data); + sum_data += output_data[offset]; + offset += inner_num; + } + + offset = start; + for (int j = 0; j < axis_size; j++) { + output_data[offset] /= sum_data; + offset += inner_num; + } + } +} + +TEST(softmax_arm, init) { + SoftmaxCompute softmax; + ASSERT_EQ(softmax.precision(), PRECISION(kFloat)); + ASSERT_EQ(softmax.target(), TARGET(kARM)); +} + +TEST(softmax_arm, compute) { + SoftmaxCompute softmax; + operators::SoftmaxParam param; + + lite::Tensor x; + lite::Tensor output; + lite::Tensor output_ref; + + for (auto n : {1, 3, 4, 11}) { + for (auto c : {1, 3, 11, 4}) { + for (auto h : {3, 1, 11, 4}) { + for (auto w : {1, 3, 4, 12}) { + for (auto axis : {-4, -3, -2, -1, 0, 1, 2, 3}) { + x.Resize(DDim(std::vector({n, c, h, w}))); + output.Resize(DDim(std::vector({n, c, h, w}))); + output_ref.Resize(DDim(std::vector({n, c, h, w}))); + auto* x_data = x.mutable_data(); + auto* output_data = output.mutable_data(); + auto* output_ref_data = output_ref.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_data[i] = i; + } + param.x = &x; + param.axis = axis; + param.output = &output; + softmax.SetParam(param); + softmax.Run(); + param.output = &output_ref; + softmax_compute_ref(param); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); + } + } + } + } + } + } +} + +TEST(softmax, retrive_op) { + auto softmax = + KernelRegistry::Global().Create( + "softmax"); + ASSERT_FALSE(softmax.empty()); + ASSERT_TRUE(softmax.front()); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/arm/use_kernels.h b/paddle/fluid/lite/kernels/arm/use_kernels.h new file mode 100644 index 00000000000000..d856950f3a177d --- /dev/null +++ b/paddle/fluid/lite/kernels/arm/use_kernels.h @@ -0,0 +1,23 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/lite/core/op_registry.h" + +USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def); +USE_LITE_KERNEL(feed, kARM, kAny, kAny, def); +USE_LITE_KERNEL(fetch, kARM, kAny, kAny, def); diff --git a/paddle/fluid/lite/kernels/cuda/CMakeLists.txt b/paddle/fluid/lite/kernels/cuda/CMakeLists.txt index 104fb79c703145..f35f634a217fab 100644 --- a/paddle/fluid/lite/kernels/cuda/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/cuda/CMakeLists.txt @@ -8,3 +8,4 @@ nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${tensor_lite}) cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${tensor_lite}) nv_library(kernels_cuda DEPS mul_compute_cuda io_copy_compute_cuda cuda_blas_lite) + diff --git a/paddle/fluid/lite/kernels/cuda/mul_compute.h b/paddle/fluid/lite/kernels/cuda/mul_compute.h index 597d84683268b4..43ad6ba5f96773 100644 --- a/paddle/fluid/lite/kernels/cuda/mul_compute.h +++ b/paddle/fluid/lite/kernels/cuda/mul_compute.h @@ -35,11 +35,11 @@ class MulCompute : public KernelLite { using param_t = operators::MulParam; void Run() override { - CHECK(context_) << "running context should be set first"; - auto& context = context_->As(); - CHECK(context.blas_fp32) << "blas should init first"; + CHECK(ctx_) << "running context should be set first"; + auto& context = ctx_->As(); + CHECK(context.cublas_fp32()) << "blas should init first"; /* - auto& blas = *context.blas_fp32; + auto& blas = *context.cublas_fp32(); CHECK(param.x->target() == TARGET(kCUDA)); auto* x = param.x->data(); int x_h = param.x->dims()[0]; diff --git a/paddle/fluid/lite/kernels/host/CMakeLists.txt b/paddle/fluid/lite/kernels/host/CMakeLists.txt index 9bd2120457a9c8..d1f33477aaa5c6 100644 --- a/paddle/fluid/lite/kernels/host/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/host/CMakeLists.txt @@ -1,19 +1,13 @@ message(STATUS "compile with lite host kernels") -cc_library(fc_compute_host SRCS fc_compute.cc DEPS ${lite_kernel_deps} eigen3) -cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps}) -cc_library(mul_compute_host SRCS mul_compute.cc DEPS ${lite_kernel_deps} eigen3) -cc_library(scale_compute_host SRCS scale_compute.cc DEPS ${lite_kernel_deps} eigen3) cc_library(feed_compute_host SRCS feed_compute.cc DEPS ${lite_kernel_deps}) cc_library(fetch_compute_host SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) +cc_library(reshape_compute_host SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op_lite) + +lite_cc_test(test_reshape_compute_host SRCS reshape_compute_test.cc DEPS reshape_compute_host) set(host_kernels feed_compute_host fetch_compute_host - fc_compute_host - relu_compute_host - mul_compute_host - scale_compute_host - ) - -set(host_kernels "${host_kernels}" CACHE INTERNAL "host kernels") + reshape_compute_host + CACHE INTERNAL "host kernels") diff --git a/paddle/fluid/lite/kernels/host/fc_compute.cc b/paddle/fluid/lite/kernels/host/fc_compute.cc deleted file mode 100644 index ae5b23ce3ece54..00000000000000 --- a/paddle/fluid/lite/kernels/host/fc_compute.cc +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/lite/kernels/host/fc_compute.h" -#include -#include "paddle/fluid/lite/core/op_registry.h" -#include "paddle/fluid/lite/core/type_system.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace host { - -// NOTE should use pure std C++ implementation. -void FcCompute::Run() { - auto& param = this->Param(); - - CHECK_GE(param.input->dims().size(), 2UL); - CHECK_EQ(param.output->dims().size(), 2UL); - - fc_compute_eigen( - param.input->data(), // x - param.input->dims().Slice(0, param.in_num_col_dims).production(), - param.input->dims() - .Slice(param.in_num_col_dims, param.input->dims().size()) - .production(), - param.w->data(), // w - param.w->dims()[1], // w_w - param.w->dims()[0], // w_h - param.bias->data(), // b - param.output->mutable_data()); -} - -// TargetType FcCompute::target() const { return TARGET(kHost); } - -// PrecisionType FcCompute::precision() const { return PRECISION(kFloat); } - -} // namespace host -} // namespace kernels -} // namespace lite -} // namespace paddle - -REGISTER_LITE_KERNEL(fc, kHost, kFloat, kNCHW, - paddle::lite::kernels::host::FcCompute, def) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindInput("W", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) - .Finalize(); diff --git a/paddle/fluid/lite/kernels/host/fc_compute.h b/paddle/fluid/lite/kernels/host/fc_compute.h deleted file mode 100644 index 1a6c4eb4c0fbe1..00000000000000 --- a/paddle/fluid/lite/kernels/host/fc_compute.h +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include "paddle/fluid/lite/core/kernel.h" -#include "paddle/fluid/lite/operators/fc_op.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace host { - -class FcCompute : public KernelLite { - public: - using param_t = operators::FcParam; - - void Run() override; - - // TargetType target() const override; - // PrecisionType precision() const override; - - virtual ~FcCompute() = default; -}; - -template -void fc_compute_eigen(const T* x, int x_w, int x_h, // - const T* w, int w_w, int w_h, // - const T* b, // - T* out) { - using matrix_t = - Eigen::Matrix; - - Eigen::Map X(x, x_h, x_w); - Eigen::Map W(w, w_h, w_w); - Eigen::Map Out(out, x_h, w_h); - - Out = X * W.transpose(); - - if (b) { - Eigen::Map> B(b, w_h); - Out = Out.array().rowwise() + B.transpose().array(); - } -} - -template -__attribute__((optimize("unroll-loops"))) // -T dot(const T* x, const T* y, int dim) { - T out{}; - for (int i = 0; i < dim; i++) { - out += x[i] * y[i]; - } - return out; -} - -template -void fc_compute_naive(const T* x, int x_w, int x_h, // - const T* w, int w_w, int w_h, // - const T* b, // - T* out) { - CHECK_EQ(x_w, w_w); - // out shape: (x_h, w_w) - memset(out, 0, x_h * w_h * sizeof(T)); - - for (int r = 0; r < x_h; r++) { - for (int c = 0; c < w_h; c++) { - out[r * w_h + c] = dot(&x[r * x_w], &w[c * w_w], w_w) + b[c]; - } - } -} - -} // namespace host -} // namespace kernels -} // namespace lite -} // namespace paddle diff --git a/paddle/fluid/lite/kernels/host/fc_compute_test.cc b/paddle/fluid/lite/kernels/host/fc_compute_test.cc deleted file mode 100644 index 69b0450900e34e..00000000000000 --- a/paddle/fluid/lite/kernels/host/fc_compute_test.cc +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/lite/kernels/host/fc_compute.h" -#include -#include -#include "paddle/fluid/lite/core/op_registry.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace host { - -TEST(fc_compute_naive, test) { - lite::Tensor x, w, b, out, out1; - const int batch_size = 2; - x.Resize({batch_size, 3}); - w.Resize({4, 3}); - b.Resize({1, 4}); - out.Resize({batch_size, 4}); - out1.Resize({batch_size, 4}); - - auto x_data = x.mutable_data(); - auto w_data = w.mutable_data(); - auto b_data = b.mutable_data(); - auto out_data = out.mutable_data(); - auto out_data1 = out1.mutable_data(); - - for (int i = 0; i < product(x.dims()); i++) x_data[i] = i; - for (int i = 0; i < product(w.dims()); i++) w_data[i] = i; - for (int i = 0; i < product(b.dims()); i++) b_data[i] = i; - - fc_compute_naive(x_data, 3, batch_size, // - w_data, 3, 4, // - b_data, out_data); - fc_compute_eigen(x_data, 3, batch_size, // - w_data, 3, 4, // - b_data, out_data1); - - for (int i = 0; i < product(out.dims()); i++) { - EXPECT_NEAR(out_data[0], out_data1[0], 1e-6); - } -} - -TEST(fc_host, init) { - FcCompute fc; - ASSERT_EQ(fc.precision(), PRECISION(kFloat)); - ASSERT_EQ(fc.target(), TARGET(kHost)); -} - -TEST(fc_host, algorithm) { - using matrix_t = Eigen::Matrix; - using matrix_map_t = Eigen::Map; - - // dim 10, 20 - std::vector input(10 * 20); - std::vector w(20 * 20); - std::vector output(10 * 20); - - Eigen::Map input_mat(input.data(), 10, 20); - Eigen::Map weight_mat(w.data(), 20, 20); - matrix_map_t output_mat(output.data(), 10, 20); - - output_mat = weight_mat.transpose() * input_mat; -} - -TEST(fc_host, compute) { - FcCompute fc; - operators::FcParam param; - - lite::Tensor x; - lite::Tensor w; - lite::Tensor bias; - lite::Tensor output; - - x.Resize(DDim(std::vector({1, 10, 20}))); - w.Resize(DDim(std::vector({20, 20}))); - bias.Resize(DDim(std::vector({1, 10}))); - output.Resize(DDim(std::vector({10, 20}))); - - auto* x_data = x.mutable_data(); - auto* w_data = w.mutable_data(); - auto* bias_data = bias.mutable_data(); - auto* output_data = output.mutable_data(); - - for (int i = 0; i < 10 * 20; i++) x_data[i] = i; - for (int i = 0; i < 20 * 20; i++) w_data[i] = i; - for (int i = 0; i < 10; i++) bias_data[i] = i; - for (int i = 0; i < 10 * 20; i++) output_data[i] = 0; - - param.in_num_col_dims = 2; - param.input = &x; - param.w = &w; - param.bias = &bias; - param.output = &output; - param.in_mat_dims = x.dims(); - - fc.SetParam(param); - fc.Run(); - - LOG(INFO) << "x"; - for (int i = 0; i < 10 * 20; i++) LOG(INFO) << x_data[i]; - - LOG(INFO) << "output:"; - for (int i = 0; i < 10 * 20; i++) LOG(INFO) << output.data()[i]; -} - -TEST(fc, retrive_op) { - auto fc = - KernelRegistry::Global().Create("fc"); - ASSERT_TRUE(fc); -} - -} // namespace host -} // namespace kernels -} // namespace lite -} // namespace paddle - -USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/kernels/host/feed_compute.cc b/paddle/fluid/lite/kernels/host/feed_compute.cc index ba503c577f4099..7bbd648c20d3f7 100644 --- a/paddle/fluid/lite/kernels/host/feed_compute.cc +++ b/paddle/fluid/lite/kernels/host/feed_compute.cc @@ -27,12 +27,12 @@ class FeedCompute void Run() override { auto ¶m = Param(); - LOG(INFO) << "feed_list.size: " << param.feed_list->size(); - LOG(INFO) << "col " << param.col; + VLOG(4) << "feed_list.size: " << param.feed_list->size(); + VLOG(4) << "col " << param.col; const lite::Tensor &feed_item = (*param.feed_list)[0]; param.out->ShareDataWith(feed_item); - LOG(INFO) << "FEED input " << feed_item << " col " << param.col; - LOG(INFO) << "FEED output " << *param.out; + VLOG(4) << "FEED input " << feed_item << " col " << param.col; + VLOG(4) << "FEED output " << *param.out; } }; diff --git a/paddle/fluid/lite/kernels/host/reshape_compute.cc b/paddle/fluid/lite/kernels/host/reshape_compute.cc new file mode 100644 index 00000000000000..c797ddf45b4fae --- /dev/null +++ b/paddle/fluid/lite/kernels/host/reshape_compute.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/host/reshape_compute.h" +#include +#include "paddle/fluid/lite/operators/reshape_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +void ReshapeCompute::Run() { + auto& param = Param(); + auto x = param.x; + auto actual_shape = param.actual_shape; + auto output = param.output; + bool inplace = param.inplace; + auto x_dims = x->dims(); + auto output_dims = output->dims(); + if (actual_shape) { + auto actual_shape_dims = actual_shape->dims(); + auto* actual_shape_data = actual_shape->data(); +#ifdef LITE_WITH_CUDA + lite::Tensor cpu_actual_shape; + if (actual_shape->target() == TARGET(kCUDA)) { + cpu_actual_shape.CopyDataFrom(*actual_shape); + actual_shape_data = cpu_actual_shape.data(); + } +#endif + auto shape = std::vector( + actual_shape_data, actual_shape_data + actual_shape_dims.production()); + output_dims = lite::operators::ValidateShape(shape, x_dims); + output->Resize(output_dims); + } + if (inplace) { + output->ShareDataWith(*x); + } else { + output->CopyDataFrom(*x); + } + output->Resize(output_dims); +} + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(reshape, kHost, kAny, kAny, + paddle::lite::kernels::host::ReshapeCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .BindInput("Shape", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .Finalize(); + +REGISTER_LITE_KERNEL(reshape2, kHost, kAny, kAny, + paddle::lite::kernels::host::ReshapeCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .BindInput("Shape", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/host/reshape_compute.h b/paddle/fluid/lite/kernels/host/reshape_compute.h new file mode 100644 index 00000000000000..423b589d37d015 --- /dev/null +++ b/paddle/fluid/lite/kernels/host/reshape_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class ReshapeCompute + : public KernelLite { + public: + void Run() override; + + virtual ~ReshapeCompute() = default; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/host/reshape_compute_test.cc b/paddle/fluid/lite/kernels/host/reshape_compute_test.cc new file mode 100644 index 00000000000000..07a8101fec631a --- /dev/null +++ b/paddle/fluid/lite/kernels/host/reshape_compute_test.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/kernels/host/reshape_compute.h" +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +TEST(reshape_host, init) { + ReshapeCompute reshape; + ASSERT_EQ(reshape.precision(), PRECISION(kAny)); + ASSERT_EQ(reshape.target(), TARGET(kHost)); +} + +TEST(reshape_host, compute) { + ReshapeCompute reshape; + operators::ReshapeParam param; + + Tensor x; + Tensor actual_shape; + Tensor output; + + x.Resize(DDim(std::vector({1, 2, 4, 6}))); + actual_shape.Resize(DDim(std::vector({2}))); + + auto* x_data = x.mutable_data(); + auto* actual_shape_data = actual_shape.mutable_data(); + for (int i = 0; i < x.dims().production(); i++) { + x_data[i] = i; + } + actual_shape_data[0] = 6; + actual_shape_data[1] = 8; + + param.x = &x; + param.shape = {-1, 0, 3, 2, 1}; + param.output = &output; + param.actual_shape = &actual_shape; + param.inplace = false; + reshape.SetParam(param); + reshape.Run(); + + // check output dims + CHECK_EQ(actual_shape.dims().production(), output.dims().size()); + for (int i = 0; i < output.dims().size(); i++) { + CHECK_EQ(output.dims()[i], actual_shape_data[i]); + } + + // check output data + auto* output_data = output.mutable_data(); + CHECK_NE(output_data, x_data); + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], x_data[i], 1e-6); + } + + // check output data if inplace = true; + param.inplace = true; + reshape.SetParam(param); + reshape.Run(); + output_data = output.mutable_data(); + CHECK_EQ(output_data, x_data); +} + +TEST(reshape, retrive_op) { + auto reshape = + KernelRegistry::Global() + .Create("reshape"); + ASSERT_FALSE(reshape.empty()); + ASSERT_TRUE(reshape.front()); +} + +TEST(reshape2, retrive_op) { + auto reshape2 = + KernelRegistry::Global() + .Create("reshape2"); + ASSERT_FALSE(reshape2.empty()); + ASSERT_TRUE(reshape2.front()); +} + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def); +USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def); diff --git a/paddle/fluid/lite/kernels/host/use_kernels.h b/paddle/fluid/lite/kernels/host/use_kernels.h index e9e9c88c624d80..b3b534283b3509 100644 --- a/paddle/fluid/lite/kernels/host/use_kernels.h +++ b/paddle/fluid/lite/kernels/host/use_kernels.h @@ -15,8 +15,7 @@ #pragma once #include "paddle/fluid/lite/core/op_registry.h" -USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def); -USE_LITE_KERNEL(mul, kHost, kFloat, kNCHW, def); -USE_LITE_KERNEL(scale, kHost, kFloat, kNCHW, def); USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); +USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def); +USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def); diff --git a/paddle/fluid/lite/kernels/x86/CMakeLists.txt b/paddle/fluid/lite/kernels/x86/CMakeLists.txt index 90e3d20a27e161..6309267dd0635d 100644 --- a/paddle/fluid/lite/kernels/x86/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/x86/CMakeLists.txt @@ -2,5 +2,34 @@ if(NOT LITE_WITH_X86) return() endif() -cc_library(activation_compute SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_op) -cc_library(elementwise_compute SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_op) +cc_library(activation_compute_x86 SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_op) +cc_library(mean_compute_x86 SRCS mean_compute.cc DEPS ${lite_kernel_deps}) +cc_library(fill_constant_compute_x86 SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps}) +cc_library(sgd_compute_x86 SRCS sgd_compute.cc DEPS ${lite_kernel_deps}) + +cc_library(fc_compute_x86 SRCS fc_compute.cc DEPS ${lite_kernel_deps}) +cc_library(mul_compute_x86 SRCS mul_compute.cc DEPS ${lite_kernel_deps}) +cc_library(relu_compute_x86 SRCS relu_compute.cc DEPS ${lite_kernel_deps}) +cc_library(scale_compute_x86 SRCS scale_compute.cc DEPS ${lite_kernel_deps}) +cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_sub_op elementwise_add_op) +cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) +cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} ) +cc_library(concat_compute_x86 SRCS concat_compute.cc DEPS ${lite_kernel_deps} ) +cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col) +cc_library(pool_compute_x86 SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling) + +set(x86_kernels + activation_compute_x86 + elementwise_compute_x86 + mean_compute_x86 + fill_constant_compute_x86 + mul_compute_x86 + relu_compute_x86 + fc_compute_x86 + scale_compute_x86 + softmax_compute_x86 + dropout_compute_x86 + concat_compute_x86 + conv_compute_x86 + pool_compute_x86 + CACHE INTERNAL "x86 kernels") diff --git a/paddle/fluid/lite/kernels/x86/activation_compute.cc b/paddle/fluid/lite/kernels/x86/activation_compute.cc index 4873a30ba4cc74..a07a69af2d1194 100644 --- a/paddle/fluid/lite/kernels/x86/activation_compute.cc +++ b/paddle/fluid/lite/kernels/x86/activation_compute.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/lite/core/kernel.h" @@ -41,47 +55,41 @@ void ActivateGrad(const platform::CPUDeviceContext& context, } template -class SquareCompute : public KernelLite { +class SquareCompute : public KernelLite { public: using param_t = operators::ActivationParam; void Run() override { - auto& context = context_->As(); + auto& context = ctx_->As(); auto& param = *param_.get_mutable(); - CHECK(context.x86_device_context); + CHECK(context.x86_device_context()); param.Out->template mutable_data(); - Activate>(*context.x86_device_context, + Activate>(*context.x86_device_context(), ¶m.X->raw_tensor(), ¶m.Out->raw_tensor()); } - // TargetType target() const override; - // PrecisionType precision() const override; - virtual ~SquareCompute() = default; }; template -class SquareGradCompute : public KernelLite { +class SquareGradCompute : public KernelLite { public: using param_t = operators::ActivationGradParam; void Run() override { - auto& context = context_->As(); + auto& context = ctx_->As(); auto& param = *param_.get_mutable(); - CHECK(context.x86_device_context); + CHECK(context.x86_device_context()); param.X_grad->template mutable_data(); ActivateGrad>( - *context.x86_device_context, ¶m.X->raw_tensor(), + *context.x86_device_context(), ¶m.X->raw_tensor(), ¶m.Out->raw_tensor(), ¶m.Out_grad->raw_tensor(), ¶m.X_grad->raw_tensor()); } - // TargetType target() const override; - // PrecisionType precision() const override; - virtual ~SquareGradCompute() = default; }; @@ -93,16 +101,16 @@ class SquareGradCompute : public KernelLite { // float REGISTER_LITE_KERNEL(square, kX86, kFloat, kNCHW, paddle::lite::kernels::x86::SquareCompute, def) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindInput("W", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); REGISTER_LITE_KERNEL(square_grad, kX86, kFloat, kNCHW, paddle::lite::kernels::x86::SquareGradCompute, def) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindInput("W", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput(paddle::framework::GradVarName("Out"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput(paddle::framework::GradVarName("X"), + {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/concat_compute.cc b/paddle/fluid/lite/kernels/x86/concat_compute.cc new file mode 100644 index 00000000000000..23ae8ca505559c --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/concat_compute.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/types.h" +#include "paddle/fluid/operators/strided_memcpy.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class ConcatCompute : public KernelLite { + public: + using param_t = operators::ConcatParam; + + void Run() override { + auto& param = *param_.get_mutable(); + int64_t axis = static_cast(param.axis); + auto out = param.output; + + if (axis == 0 && param.x.size() < 10) { + size_t output_offset = 0; + for (auto* in : param.x) { + if (!in || in->dims().production() == 0UL) { + continue; + } + auto in_stride = framework::stride_numel(in->dims().data()); + auto out_stride = framework::stride_numel(out->dims().data()); + paddle::operators::StridedNumelCopyWithAxis( + platform::CPUDeviceContext(), axis, + out->mutable_data() + output_offset, out_stride, in->data(), + in_stride, in_stride[axis]); + + output_offset += in_stride[axis]; + } + } else { + std::vector inputs; + for (size_t j = 0; j < param.x.size(); ++j) { + if (param.x[j] && param.x[j]->dims().production() > 0) { + inputs.push_back(*param.x[j]); + } else { + continue; + } + } + + int num = inputs.size(); + int rows = 1; + auto dim_0 = inputs[0].dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + + std::vector input_cols(inputs.size()); + for (int i = 0; i < num; ++i) { + int t_cols = inputs[i].dims().production() / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + // computation + auto output_data = param.output->template mutable_data(); + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + auto input_data = inputs[j].data(); + for (int k = 0; k < out_rows; ++k) { + std::memcpy(output_data + k * out_cols + col_idx, + input_data + k * col_len, sizeof(T) * col_len); + } + col_idx += col_len; + } + } + } + + virtual ~ConcatCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(concat, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::ConcatCompute, def) + .BindInput("X", {LiteType::GetTensorListTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/conv_compute.cc b/paddle/fluid/lite/kernels/x86/conv_compute.cc new file mode 100644 index 00000000000000..9d2de5be452c7e --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/conv_compute.cc @@ -0,0 +1,169 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/types.h" +#include "paddle/fluid/lite/operators/conv_op.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/depthwise_conv.h" +#include "paddle/fluid/operators/math/im2col.h" +#include "paddle/fluid/operators/math/vol2col.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +inline bool IsExpand(const std::vector& filter_dim, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations) { + bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; + for (size_t j = 0; j < strides.size(); ++j) { + filter_1 = filter_1 && (static_cast(filter_dim[j + 2]) == 1); + strides_1 = strides_1 && (strides[j] == 1); + padding_0 = padding_0 && (paddings[j] == 0); + dilation_1 = dilation_1 && (dilations[j] == 1); + } + return !(filter_1 && strides_1 && padding_0 && dilation_1); +} + +template +class Conv2dCompute : public KernelLite { + public: + using param_t = operators::ConvParam; + void Run() override { + auto& param = *param_.get_mutable(); + lite::Tensor filter = *param.filter; + param.output->template mutable_data(); + + const int batch_size = static_cast(param.x->dims()[0]); + + std::vector filter_shape_vec(filter.dims().Vectorize()); + std::vector output_shape_vec(param.output->dims().Vectorize()); + + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = param.x->dims()[1] / param.groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } + lite::DDim col_shape(col_shape_vec); + lite::DDim col_matrix_shape = col_shape.Flattern2D(data_dim + 1); + bool is_expand = IsExpand(filter_shape_vec, param.strides, param.paddings, + param.dilations); + + lite::Tensor col; + lite::Tensor col_matrix; + if (is_expand) { + col.Resize(col_shape); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + lite::DDim input_shape = param.x->dims().Slice(1, param.x->dims().size()); + + lite::DDim filter_matrix_shape(std::vector{ + filter.dims()[0], filter.dims().production() / filter.dims()[0]}); + filter.Resize(filter_matrix_shape); + + lite::DDim output_matrix_shape(std::vector{ + param.output->dims()[1], + param.output->dims().production() / + (param.output->dims()[0] * param.output->dims()[1])}); + + int in_step = static_cast(param.x->dims()[1]) / param.groups; + int out_step = static_cast(param.output->dims()[1]) / param.groups; + + paddle::operators::math::Vol2ColFunctor + vol2col; + paddle::operators::math::Im2ColFunctor< + paddle::operators::math::ColFormat::kCFO, platform::CPUDeviceContext, T> + im2col; + auto blas = paddle::operators::math::GetBlas( + platform::CPUDeviceContext()); + for (int i = 0; i < batch_size; i++) { + lite::Tensor in_batch; + in_batch.ShareDataWith( + param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data())); + lite::Tensor out_batch; + out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize( + input_shape.data())); + + for (int g = 0; g < param.groups; g++) { + lite::Tensor in_slice; + in_slice.ShareDataWith( + in_batch.raw_tensor().Slice(g * in_step, (g + 1) * in_step)); + + if (!is_expand) { + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(platform::CPUDeviceContext(), in_slice.raw_tensor(), + param.dilations, param.strides, + std::vector{param.paddings[0], param.paddings[1], + param.paddings[0], param.paddings[1]}, + &(col.raw_tensor())); + } else if (data_dim == 3U) { + // vol2col + vol2col(platform::CPUDeviceContext(), in_slice.raw_tensor(), + param.dilations, param.strides, param.paddings, + &(col.raw_tensor())); + } + + // gemm + lite::Tensor out_slice; + out_slice.ShareDataWith( + out_batch.raw_tensor().Slice(g * out_step, (g + 1) * out_step)); + lite::Tensor filter_slice; + filter_slice.ShareDataWith( + filter.raw_tensor().Slice(g * out_step, (g + 1) * out_step)); + blas.MatMul(filter_slice.raw_tensor(), false, col_matrix.raw_tensor(), + false, T(1.0), &(out_slice.raw_tensor()), T(0.0)); + } + } + } + + virtual ~Conv2dCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::Conv2dCompute, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +REGISTER_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::Conv2dCompute, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/dropout_compute.cc b/paddle/fluid/lite/kernels/x86/dropout_compute.cc new file mode 100644 index 00000000000000..d762ec2a06f8b4 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/dropout_compute.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +using EigenMatrix = framework::EigenMatrix; + +template +class DropoutCompute : public KernelLite { + public: + using param_t = operators::DropoutParam; + void Run() override { + auto& param = *param_.get_mutable(); + const auto* x_data = param.x->data(); + auto* out_data = param.output->template mutable_data(); + if (!param.is_test) { + auto* mask_data = param.mask->template mutable_data(); + std::random_device rnd; + std::minstd_rand engine; + int seed = param.fix_seed ? param.seed : rnd(); + engine.seed(seed); + std::uniform_real_distribution dist(0, 1); + + size_t size = framework::product(param.mask->dims().data()); + for (size_t i = 0; i < size; ++i) { + if (dist(engine) < param.dropout_prob) { + mask_data[i] = 0; + out_data[i] = 0; + } else { + if (param.dropout_implementation == "upscale_in_train") { + mask_data[i] = 1.0f / static_cast(1.0f - param.dropout_prob); + out_data[i] = x_data[i] / static_cast(1.0f - param.dropout_prob); + } else { + mask_data[i] = 1; + out_data[i] = x_data[i]; + } + } + } + } else { + auto X = EigenMatrix::Reshape(param.x->raw_tensor(), 1); + auto Y = EigenMatrix::Reshape(param.output->raw_tensor(), 1); + auto& place = *platform::CPUDeviceContext().eigen_device(); + if (param.dropout_implementation == "upscale_in_train") { + Y.device(place) = X; + } else { + Y.device(place) = X * static_cast(1.0f - param.dropout_prob); + } + } + } + + virtual ~DropoutCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::DropoutCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Mask", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/elementwise_compute.cc b/paddle/fluid/lite/kernels/x86/elementwise_compute.cc index e5fabd87323322..8e2ea92d6de24e 100644 --- a/paddle/fluid/lite/kernels/x86/elementwise_compute.cc +++ b/paddle/fluid/lite/kernels/x86/elementwise_compute.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/lite/core/kernel.h" @@ -16,41 +30,118 @@ struct SubFunctor { inline HOSTDEVICE T operator()(T a, T b) const { return a - b; } }; +template +struct AddFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return a + b; } +}; + template class ElementwiseSubCompute - : public KernelLite { + : public KernelLite { public: using param_t = operators::ElementwiseParam; void Run() override { auto& param = *param_.get_mutable(); - auto& context = context_->As(); - CHECK(context.x86_device_context); + auto& context = ctx_->As(); + CHECK(context.x86_device_context()); param.Out->template mutable_data(); paddle::operators::ElementwiseComputeEx, platform::CPUDeviceContext, T>( - *context.x86_execution_context, ¶m.X->raw_tensor(), + *context.x86_execution_context(), ¶m.X->raw_tensor(), ¶m.Y->raw_tensor(), param.axis, SubFunctor(), ¶m.Out->raw_tensor()); } - // TargetType target() const override; - // PrecisionType precision() const override; - virtual ~ElementwiseSubCompute() = default; }; +template +struct SubGradDX { + T operator()(T x, T y, T out, T dout) const { return dout; } +}; + +template +struct SubGradDY { + T operator()(T x, T y, T out, T dout) const { return -dout; } +}; + +template +class ElementwiseSubGradCompute + : public KernelLite { + public: + using param_t = operators::ElementwiseGradParam; + void Run() override { + auto& param = *param_.get_mutable(); + auto& context = ctx_->As(); + CHECK(context.x86_device_context()); + + param.X_grad->template mutable_data(); + param.Y_grad->template mutable_data(); + // skip out, x, y + auto dout = param.Out_grad->raw_tensor(); + auto dx = param.X_grad->raw_tensor(); + auto dy = param.Y_grad->raw_tensor(); + auto& skip = dout; + paddle::operators::ElemwiseExplicitGradCompute< + platform::CPUDeviceContext, T, SubGradDX, SubGradDY>( + *context.x86_execution_context(), skip, skip, skip, dout, param.axis, + &dx, &dy, SubGradDX(), SubGradDY()); + } + + virtual ~ElementwiseSubGradCompute() = default; +}; + +template +class ElementwiseAddCompute + : public KernelLite { + public: + using param_t = operators::ElementwiseParam; + void Run() override { + auto& param = *param_.get_mutable(); + auto& context = ctx_->As(); + CHECK(context.x86_device_context()); + param.Out->template mutable_data(); + paddle::operators::ElementwiseComputeEx, + platform::CPUDeviceContext, T>( + *context.x86_execution_context(), ¶m.X->raw_tensor(), + ¶m.Y->raw_tensor(), param.axis, AddFunctor(), + ¶m.Out->raw_tensor()); + } + + virtual ~ElementwiseAddCompute() = default; +}; + } // namespace x86 } // namespace kernels } // namespace lite } // namespace paddle // float -REGISTER_LITE_KERNEL(square, kHost, kFloat, kNCHW, +REGISTER_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, paddle::lite::kernels::x86::ElementwiseSubCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); + +REGISTER_LITE_KERNEL(elementwise_sub_grad, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::ElementwiseSubCompute, + def) + .BindInput(paddle::framework::GradVarName("Out"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput(paddle::framework::GradVarName("X"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput(paddle::framework::GradVarName("Y"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +REGISTER_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::ElementwiseAddCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/fc_compute.cc b/paddle/fluid/lite/kernels/x86/fc_compute.cc new file mode 100644 index 00000000000000..dad37febc80433 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/fc_compute.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/type_system.h" +#include "paddle/fluid/lite/operators/fc_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +void fc_compute_eigen(const T* x, int x_h, int x_w, // + const T* w, int w_h, int w_w, // + const T* b, // + T* out) { + using matrix_t = + Eigen::Matrix; + + Eigen::Map X(x, x_h, x_w); + Eigen::Map W(w, w_h, w_w); + Eigen::Map Out(out, x_h, w_w); + + Out = X * W; + + if (b) { + Eigen::Map> B(b, w_w); + Out = Out.array().rowwise() + B.transpose().array(); + } +} + +template +void fc_compute_naive(const T* x, int x_h, int x_w, // + const T* w, int w_h, int w_w, // + const T* b, // + T* out) { + CHECK_EQ(x_w, w_h); + // out shape: (x_h, w_w) + memset(out, 0, x_h * w_w * sizeof(T)); + for (int i = 0; i < x_h; i++) { + for (int j = 0; j < w_w; j++) { + T tmp = static_cast(0); + for (int k = 0; k < x_w; k++) { + tmp += x[i * x_w + k] * w[k * w_w + j]; + } + out[i * w_w + j] = tmp + b[j]; + } + } +} + +template +class FcCompute : public KernelLite { + public: + using param_t = operators::FcParam; + + void Run() override { + auto& param = *param_.get_mutable(); + CHECK_GE(param.input->dims().size(), 2UL); + CHECK_EQ(param.output->dims().size(), 2UL); + + fc_compute_eigen( + param.input->data(), // x + param.input->dims().Slice(0, param.in_num_col_dims).production(), + param.input->dims() + .Slice(param.in_num_col_dims, param.input->dims().size()) + .production(), + param.w->data(), // w + param.w->dims()[0], // w_h + param.w->dims()[1], // w_w + param.bias->data(), // b + param.output->mutable_data()); + } + + virtual ~FcCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(fc, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::FcCompute, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/fill_constant_compute.cc b/paddle/fluid/lite/kernels/x86/fill_constant_compute.cc new file mode 100644 index 00000000000000..5a5a719af3b503 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/fill_constant_compute.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class FillConstantCompute : public KernelLite { + public: + using param_t = operators::FillConstantParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto& context = ctx_->As(); + CHECK(context.x86_device_context()); + + param.Out->template mutable_data(); + + paddle::operators::math::set_constant( + *context.x86_device_context(), ¶m.Out->raw_tensor(), param.value); + } + + virtual ~FillConstantCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +// float +REGISTER_LITE_KERNEL(fill_constant, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::FillConstantCompute, + def) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/mean_compute.cc b/paddle/fluid/lite/kernels/x86/mean_compute.cc new file mode 100644 index 00000000000000..ac1a37707adfc2 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/mean_compute.cc @@ -0,0 +1,98 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/operators/activation_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +using EigenScalar = framework::EigenScalar; +template +using EigenVector = framework::EigenVector; + +template +class MeanCompute : public KernelLite { + public: + using param_t = operators::MeanParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto& context = ctx_->As(); + CHECK(context.x86_device_context()); + + param.Out->template mutable_data(); + + auto X = EigenVector::Flatten(param.X->raw_tensor()); + auto y = EigenScalar::From(param.Out->raw_tensor()); + const auto& place = *(context.x86_device_context()->eigen_device()); + + y.device(place) = X.mean(); + } + + virtual ~MeanCompute() = default; +}; + +template +class MeanGradCompute : public KernelLite { + public: + using param_t = operators::MeanGradParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto& context = ctx_->As(); + CHECK_EQ(param.Out_grad->raw_tensor().numel(), 1); + CHECK(context.x86_device_context()); + + param.X_grad->template mutable_data(); + T x_grad_size = static_cast(param.X_grad->raw_tensor().numel()); + Eigen::DSizes bcast(static_cast(x_grad_size)); + EigenVector::Flatten(param.X_grad->raw_tensor()) + .device(*(context.x86_device_context()->eigen_device())) = + (EigenVector::From(param.Out_grad->raw_tensor()) / x_grad_size) + .broadcast(bcast); + } + + virtual ~MeanGradCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +// float +REGISTER_LITE_KERNEL(mean, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::MeanCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +REGISTER_LITE_KERNEL(mean_grad, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::MeanGradCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput(paddle::framework::GradVarName("Out"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput(paddle::framework::GradVarName("X"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/mul_compute.cc b/paddle/fluid/lite/kernels/x86/mul_compute.cc new file mode 100644 index 00000000000000..ad009893c8a7c7 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/mul_compute.cc @@ -0,0 +1,149 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/types.h" +#include "paddle/fluid/operators/math/blas.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +using Tensor = framework::Tensor; + +template +class MulCompute : public KernelLite { + public: + using param_t = operators::MulParam; + + void Run() override { + auto& context = ctx_->As(); + auto& param = *param_.get_mutable(); + CHECK(context.x86_device_context()); + + param.output->template mutable_data(); + + auto* x = ¶m.x->raw_tensor(); + auto* y = ¶m.y->raw_tensor(); + + const Tensor x_matrix = x->dims().size() > 2 ? framework::ReshapeToMatrix( + *x, param.x_num_col_dims) + : *x; + const Tensor y_matrix = y->dims().size() > 2 ? framework::ReshapeToMatrix( + *y, param.y_num_col_dims) + : *y; + + auto* z = ¶m.output->raw_tensor(); + auto z_dim = z->dims(); + if (z_dim.size() != 2) { + z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); + } + + auto blas = paddle::operators::math::GetBlas( + *context.x86_device_context()); + + blas.MatMul(x_matrix, y_matrix, z); + if (z_dim.size() != 2) { + z->Resize(z_dim); + } + } + + virtual ~MulCompute() = default; +}; + +template +class MulGradCompute : public KernelLite { + public: + void Run() override { + auto& context = ctx_->As(); + auto& param = *param_.get_mutable(); + CHECK(context.x86_device_context()); + + auto* x = ¶m.x->raw_tensor(); + auto* y = ¶m.y->raw_tensor(); + auto x_matrix = x->dims().size() > 2 + ? framework::ReshapeToMatrix(*x, param.x_num_col_dims) + : static_cast(*x); + auto y_matrix = y->dims().size() > 2 + ? framework::ReshapeToMatrix(*y, param.y_num_col_dims) + : static_cast(*y); + auto* dout = ¶m.output_grad->raw_tensor(); + + Tensor dout_mat; + dout_mat.ShareDataWith(*dout); + dout_mat.Resize( + {framework::flatten_to_2d(x->dims(), param.x_num_col_dims)[0], + framework::flatten_to_2d(y->dims(), param.y_num_col_dims)[1]}); + + auto* dx = ¶m.x_grad->raw_tensor(); + auto* dy = ¶m.y_grad->raw_tensor(); + + if (dx != nullptr) { + dx->set_lod(x->lod()); + } + if (dy != nullptr) { + dy->set_lod(y->lod()); + } + + auto blas = paddle::operators::math::GetBlas( + *context.x86_device_context()); + if (dx) { + // dx->mutable_data(context.x86_device_context->GetPlace()); + param.x_grad->template mutable_data(); + Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix( + *dx, param.x_num_col_dims) + : *dx; + + // dx = dout * y'. dx: M x K, dout : M x N, y : K x N + blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix); + } + if (dy) { + // dy->yutable_data(context.x86_device_context->GetPlace()); + param.y_grad->template mutable_data(); + Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix( + *dy, param.y_num_col_dims) + : *dy; + // dy = x' * dout. dy K x N, dout : M x N, x : M x K + blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix); + } + } + + virtual ~MulGradCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(mul, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::MulCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +REGISTER_LITE_KERNEL(mul_grad, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::MulGradCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput(paddle::framework::GradVarName("Out"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput(paddle::framework::GradVarName("X"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput(paddle::framework::GradVarName("Y"), + {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/pool_compute.cc b/paddle/fluid/lite/kernels/x86/pool_compute.cc new file mode 100644 index 00000000000000..745c2a78789907 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/pool_compute.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/types.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/pooling.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class PoolCompute : public KernelLite { + public: + using param_t = operators::PoolParam; + void Run() override { + auto& param = *param_.get_mutable(); + if (param.global_pooling) { + for (size_t i = 0; i < param.ksize.size(); ++i) { + param.paddings[i] = 0; + param.ksize[i] = static_cast(param.x->dims()[i + 2]); + } + } + switch (param.ksize.size()) { + case 2: { + if (param.pooling_type == "max") { + paddle::operators::math::Pool2dFunctor< + platform::CPUDeviceContext, paddle::operators::math::MaxPool, + T> + pool2d_forward; + paddle::operators::math::MaxPool pool_process; + pool2d_forward(platform::CPUDeviceContext(), param.x->raw_tensor(), + param.ksize, param.strides, param.paddings, + pool_process, true, false, + &(param.output->raw_tensor())); + } else if (param.pooling_type == "avg") { + paddle::operators::math::Pool2dFunctor< + platform::CPUDeviceContext, paddle::operators::math::AvgPool, + T> + pool2d_forward; + paddle::operators::math::AvgPool pool_process; + pool2d_forward(platform::CPUDeviceContext(), param.x->raw_tensor(), + param.ksize, param.strides, param.paddings, + pool_process, param.exclusive, param.adaptive, + &(param.output->raw_tensor())); + } + } break; + case 3: { + } break; + } + } + virtual ~PoolCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(pool2d, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::PoolCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/relu_compute.cc b/paddle/fluid/lite/kernels/x86/relu_compute.cc new file mode 100644 index 00000000000000..44b1f525ab05ed --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/relu_compute.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/type_system.h" +#include "paddle/fluid/lite/operators/relu_op.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class ReluCompute : public KernelLite { + public: + using param_t = operators::ReluParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto n = param.input->dims().production(); + const float* input = param.input->data(); + float* output = param.output->mutable_data(); + for (int i = 0; i < n; i++) { + output[i] = std::max(0.f, input[i]); + } + } + + virtual ~ReluCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(relu, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::ReluCompute, def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/host/scale_compute.cc b/paddle/fluid/lite/kernels/x86/scale_compute.cc similarity index 63% rename from paddle/fluid/lite/kernels/host/scale_compute.cc rename to paddle/fluid/lite/kernels/x86/scale_compute.cc index 3fc542646ba7ae..0135a6f614ef4b 100644 --- a/paddle/fluid/lite/kernels/host/scale_compute.cc +++ b/paddle/fluid/lite/kernels/x86/scale_compute.cc @@ -13,14 +13,18 @@ // limitations under the License. #include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_registry.h" -#include "paddle/fluid/lite/core/types.h" +#include "paddle/fluid/lite/core/type_system.h" +#include "paddle/fluid/lite/operators/relu_op.h" namespace paddle { namespace lite { namespace kernels { -namespace host { +namespace x86 { template void scale_compute(const T* x, T* out, int size, float scale, float bias, @@ -31,13 +35,14 @@ void scale_compute(const T* x, T* out, int size, float scale, float bias, } } -class ScaleCompute : public KernelLite { +template +class ScaleCompute : public KernelLite { public: - using param_t = operators::MulParam; + using param_t = operators::ScaleParam; void Run() override { - auto& param = Param(); - scale_compute(param.x->data(), param.output->mutable_data(), + auto& param = *param_.get_mutable(); + scale_compute(param.x->data(), param.output->mutable_data(), param.x->dims().production(), param.scale, param.bias, param.bias_after_scale); } @@ -45,13 +50,13 @@ class ScaleCompute : public KernelLite { virtual ~ScaleCompute() = default; }; -} // namespace host +} // namespace x86 } // namespace kernels } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(scale, kHost, kFloat, kNCHW, - paddle::lite::kernels::host::ScaleCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) +REGISTER_LITE_KERNEL(scale, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::ScaleCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/sgd_compute.cc b/paddle/fluid/lite/kernels/x86/sgd_compute.cc new file mode 100644 index 00000000000000..2b50c9172a0bcb --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/sgd_compute.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/operators/jit/kernels.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SGDCompute : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override { + auto &context = ctx_->As(); + auto &sgd_param = *param_.get_mutable(); + CHECK(context.x86_device_context()); + + // param.Out->template mutable_data(); + + const auto *param = &sgd_param.Param->raw_tensor(); + const auto *grad = &sgd_param.Grad->raw_tensor(); + const auto *learning_rate = &sgd_param.LearningRate->raw_tensor(); + auto *param_out = &sgd_param.ParamOut->raw_tensor(); + + auto sz = param_out->numel(); + PADDLE_ENFORCE_EQ(param->numel(), sz); + PADDLE_ENFORCE_EQ(grad->numel(), sz); + + paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1); + const T *lr = learning_rate->template data(); + const T *param_data = param->template data(); + const T *grad_data = grad->template data(); + int64_t rows_idx = 0; + T *out_data = param_out->template mutable_data( + context.x86_device_context()->GetPlace()); + + auto sgd = + paddle::operators::jit::KernelFuncs, + platform::CPUPlace>::Cache() + .At(attr); + sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr); + } + + virtual ~SGDCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +// float +REGISTER_LITE_KERNEL(sgd, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::SGDCompute, def) + .BindInput("Param", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("LearningRate", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Grad", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("ParamOut", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/softmax_compute.cc b/paddle/fluid/lite/kernels/x86/softmax_compute.cc new file mode 100644 index 00000000000000..fe408aa3c84239 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/softmax_compute.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/math/softmax.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +static inline int CanonicalAxis(const int axis, const int rank) { + if (axis < 0) { + return axis + rank; + } + return axis; +} + +static inline int SizeToAxis(const int axis, lite::DDim dims) { + int size = 1; + for (int i = 0; i < axis; i++) { + size *= dims[i]; + } + return size; +} + +static inline int SizeFromAxis(const int axis, lite::DDim dims) { + int size = 1; + for (int i = axis; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + +template +class SoftmaxCompute : public KernelLite { + public: + using param_t = operators::SoftmaxParam; + + void Run() override { + auto& param = *param_.get_mutable(); + // auto& context = context_->As(); + CHECK(param.output); + CHECK(param.x); + const int rank = param.x->dims().size(); + const int axis = CanonicalAxis(param.axis, rank); + int axis_dim = param.x->dims()[axis]; + const int n = SizeToAxis(axis, param.x->dims()); + const int d = SizeFromAxis(axis, param.x->dims()); + std::vector shape{n, d}; + + lite::Tensor input_2d, out_2d; + input_2d.ShareDataWith(*param.x); + input_2d.Resize(lite::DDim(shape)); + out_2d.ShareDataWith(*param.output); + out_2d.Resize(lite::DDim(shape)); + + paddle::operators::math::SoftmaxFunctor()( + platform::CPUDeviceContext(), axis_dim, &input_2d.raw_tensor(), + &out_2d.raw_tensor()); + } + + virtual ~SoftmaxCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::SoftmaxCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/model_parser/CMakeLists.txt b/paddle/fluid/lite/model_parser/CMakeLists.txt index 95d67c32c51e76..63fe21abdafb91 100644 --- a/paddle/fluid/lite/model_parser/CMakeLists.txt +++ b/paddle/fluid/lite/model_parser/CMakeLists.txt @@ -1,26 +1,29 @@ #cc_library(runtime_lite SRCS runtime.cc) -lite_cc_test(test_model_parser_lite SRCS model_parser_test.cc - DEPS model_parser_lite framework_proto_lite - ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model) -if(WITH_TESTING) -add_dependencies(test_model_parser_lite extern_lite_download_lite_naive_model_tar_gz) -endif(WITH_TESTING) +#TODO(Superjomn) enable it again. +if(NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + lite_cc_test(test_model_parser_lite SRCS model_parser_test.cc + DEPS model_parser_lite framework_proto_lite + ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model) + if(WITH_TESTING) + add_dependencies(test_model_parser_lite extern_lite_download_lite_naive_model_tar_gz) + endif(WITH_TESTING) +endif() -if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) - cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite var_desc_lite) -else() - cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto_lite proto_desc) -endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) +cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite) set(model_parser_deps variable_lite scope_lite ${tensor_lite} scope_lite target_wrapper_host compatible_pb_lite + memory_lite ) if (LITE_WITH_CUDA) set(model_parser_deps ${model_parser_deps} target_wrapper_cuda) endif() cc_library(model_parser_lite SRCS model_parser.cc DEPS ${model_parser_deps}) +lite_cc_test(test_op_desc_lite SRCS op_desc_test.cc DEPS cpp_op_desc_lite op_desc_lite compatible_pb_lite) + add_subdirectory(pb) +add_subdirectory(cpp) diff --git a/paddle/fluid/lite/model_parser/compatible_pb.cc b/paddle/fluid/lite/model_parser/compatible_pb.cc index ee0f7c41acc89c..23a09f8afbf0a7 100644 --- a/paddle/fluid/lite/model_parser/compatible_pb.cc +++ b/paddle/fluid/lite/model_parser/compatible_pb.cc @@ -13,3 +13,115 @@ // limitations under the License. #include "paddle/fluid/lite/model_parser/compatible_pb.h" +#include +#include + +namespace paddle { +namespace lite { + +void InputsPbToCpp(const pb::OpDesc &pb_desc, cpp::OpDesc *cpp_desc) { + for (const std::string ¶m : pb_desc.InputArgumentNames()) { + cpp_desc->SetInput(param, pb_desc.Input(param)); + } +} + +void InputsCppToPb(const cpp::OpDesc &cpp_desc, pb::OpDesc *pb_desc) { + for (const std::string ¶m : cpp_desc.InputArgumentNames()) { + pb_desc->SetInput(param, cpp_desc.Input(param)); + } +} + +void OutputsPbToCpp(const pb::OpDesc &pb_desc, cpp::OpDesc *cpp_desc) { + for (const std::string ¶m : pb_desc.OutputArgumentNames()) { + cpp_desc->SetOutput(param, pb_desc.Output(param)); + } +} + +void OutputsCppToPb(const cpp::OpDesc &cpp_desc, pb::OpDesc *pb_desc) { + for (const std::string ¶m : cpp_desc.OutputArgumentNames()) { + pb_desc->SetOutput(param, cpp_desc.Output(param)); + } +} + +void AttrsPbToCpp(const pb::OpDesc &pb_desc, cpp::OpDesc *cpp_desc) { + using AttrType = OpDescAPI::AttrType; + auto set_attr = [&](const std::string &name, AttrType type) { + switch (type) { + case AttrType::INT: + cpp_desc->SetAttr(name, pb_desc.GetAttr(name)); + break; + case AttrType::FLOAT: + cpp_desc->SetAttr(name, pb_desc.GetAttr(name)); + break; + case AttrType::STRING: + cpp_desc->SetAttr(name, + pb_desc.GetAttr(name)); + break; + case AttrType::INTS: + cpp_desc->SetAttr>( + name, pb_desc.GetAttr>(name)); + break; + case AttrType::FLOATS: + cpp_desc->SetAttr>( + name, pb_desc.GetAttr>(name)); + break; + case AttrType::BOOLEAN: + cpp_desc->SetAttr(name, pb_desc.GetAttr(name)); + break; + case AttrType::STRINGS: + cpp_desc->SetAttr>( + name, pb_desc.GetAttr>(name)); + break; + default: + LOG(FATAL) << "Unsupported attr type found " << static_cast(type); + } + }; + + for (const auto &attr_name : pb_desc.AttrNames()) { + auto type = pb_desc.GetAttrType(attr_name); + set_attr(attr_name, type); + } +} + +void AttrsCppToPb(const cpp::OpDesc &cpp_desc, pb::OpDesc *pb_desc) { + using AttrType = OpDescAPI::AttrType; + auto set_attr = [&](const std::string &name, AttrType type) { + switch (type) { +#define IMPL_ONE(type__, T) \ + case AttrType::type__: \ + pb_desc->SetAttr(name, cpp_desc.GetAttr(name)); \ + break; + IMPL_ONE(INT, int32_t); + IMPL_ONE(FLOAT, float); + IMPL_ONE(STRING, std::string); + IMPL_ONE(STRINGS, std::vector); + IMPL_ONE(FLOATS, std::vector); + IMPL_ONE(INTS, std::vector); + IMPL_ONE(BOOLEAN, bool); + default: + LOG(FATAL) << "Unsupported attr type found: " << static_cast(type); + } + }; +#undef IMPL_ONE + for (const auto &attr_name : cpp_desc.AttrNames()) { + auto type = cpp_desc.GetAttrType(attr_name); + set_attr(attr_name, type); + } +} + +void TransformOpDescPbToCpp(const pb::OpDesc &pb_desc, cpp::OpDesc *cpp_desc) { + cpp_desc->SetType(pb_desc.Type()); + InputsPbToCpp(pb_desc, cpp_desc); + OutputsPbToCpp(pb_desc, cpp_desc); + AttrsPbToCpp(pb_desc, cpp_desc); +} + +void TransformOpDescCppToPb(const cpp::OpDesc &cpp_desc, pb::OpDesc *pb_desc) { + pb_desc->SetType(cpp_desc.Type()); + InputsCppToPb(cpp_desc, pb_desc); + OutputsCppToPb(cpp_desc, pb_desc); + AttrsCppToPb(cpp_desc, pb_desc); +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/compatible_pb.h b/paddle/fluid/lite/model_parser/compatible_pb.h index cef1406f985554..23041ea1fe5161 100644 --- a/paddle/fluid/lite/model_parser/compatible_pb.h +++ b/paddle/fluid/lite/model_parser/compatible_pb.h @@ -20,39 +20,28 @@ * lite::pb::XXDesc. */ -#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #include "paddle/fluid/lite/core/framework.pb.h" +#include "paddle/fluid/lite/model_parser/cpp/op_desc.h" #include "paddle/fluid/lite/model_parser/pb/op_desc.h" #include "paddle/fluid/lite/model_parser/pb/var_desc.h" -#else -#include "paddle/fluid/framework/op_desc.h" -#include "paddle/fluid/framework/var_desc.h" -#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK namespace paddle { namespace lite { -#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK using Attribute = lite::pb::Attribute; using OpDesc = lite::pb::OpDesc; using VarDesc = lite::pb::VarDesc; -#else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK -using Attribute = framework::Attribute; -using OpDesc = framework::OpDesc; -using VarDesc = framework::VarDesc; -#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK -#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK template T GetAttr(const Attribute& x) { return x.get(); } -#else -template -T GetAttr(const Attribute& x) { - return boost::get(x); -} -#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + +/// Transform an OpDesc from pb to cpp format. +void TransformOpDescPbToCpp(const pb::OpDesc& pb_desc, cpp::OpDesc* cpp_desc); + +/// Transform an OpDesc from cpp to pb format. +void TransformOpDescCppToPb(const cpp::OpDesc& cpp_desc, pb::OpDesc* pb_desc); } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/model_parser/cpp/CMakeLists.txt b/paddle/fluid/lite/model_parser/cpp/CMakeLists.txt new file mode 100644 index 00000000000000..71073179991294 --- /dev/null +++ b/paddle/fluid/lite/model_parser/cpp/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite) diff --git a/paddle/fluid/lite/model_parser/cpp/op_desc.cc b/paddle/fluid/lite/model_parser/cpp/op_desc.cc new file mode 100644 index 00000000000000..b6b854d72afe92 --- /dev/null +++ b/paddle/fluid/lite/model_parser/cpp/op_desc.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/model_parser/cpp/op_desc.h" +#include +#include + +namespace paddle { +namespace lite { +namespace cpp { + +#define SET_ATTR_IMPL(T, repr__) \ + template <> \ + void OpDesc::SetAttr(const std::string& name, const T& v) { \ + attr_types_[name] = AttrType::repr__; \ + attrs_[name].set(v); \ + } + +SET_ATTR_IMPL(int32_t, INT); +SET_ATTR_IMPL(float, FLOAT); +SET_ATTR_IMPL(std::string, STRING); +SET_ATTR_IMPL(bool, BOOLEAN); +SET_ATTR_IMPL(std::vector, INTS); +SET_ATTR_IMPL(std::vector, FLOATS); +SET_ATTR_IMPL(std::vector, STRINGS); + +std::pair +FindAttr(const cpp::OpDesc& desc, const std::string& name) { + auto it = desc.attrs().find(name); + CHECK(it != desc.attrs().end()) << "No attributes called " << name + << " found"; + auto attr_it = desc.attr_types().find(name); + CHECK(attr_it != desc.attr_types().end()); + return std::make_pair(it, attr_it); +} + +#define GET_IMPL_ONE(T, repr__) \ + template <> \ + T OpDesc::GetAttr(const std::string& name) const { \ + auto pair = FindAttr(*this, name); \ + CHECK(pair.second->second == AttrType::repr__) \ + << "required type is " << #repr__ << " not match the true type"; \ + return pair.first->second.get(); \ + } + +GET_IMPL_ONE(int32_t, INT); +GET_IMPL_ONE(float, FLOAT); +GET_IMPL_ONE(std::string, STRING); +GET_IMPL_ONE(bool, BOOLEAN); +GET_IMPL_ONE(std::vector, LONGS); +GET_IMPL_ONE(std::vector, FLOATS); +GET_IMPL_ONE(std::vector, INTS); +GET_IMPL_ONE(std::vector, STRINGS); + +} // namespace cpp +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/cpp/op_desc.h b/paddle/fluid/lite/model_parser/cpp/op_desc.h new file mode 100644 index 00000000000000..b70c1692659a89 --- /dev/null +++ b/paddle/fluid/lite/model_parser/cpp/op_desc.h @@ -0,0 +1,126 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "paddle/fluid/lite/model_parser/desc_apis.h" +#include "paddle/fluid/lite/utils/any.h" +#include "paddle/fluid/lite/utils/varient.h" + +namespace paddle { +namespace lite { +namespace cpp { + +/* + * The cpp::OpDesc is the internal representation for Op. All the internal + * imprementation should use it, not the pb::OpDesc. + */ +class OpDesc : public OpDescAPI { + public: + using attrs_t = std::map; + using attr_types_t = std::map; + + protected: + std::string type_; + std::map> inputs_; + std::map> outputs_; + std::map attrs_; + std::map attr_types_; + + public: + OpDesc() = default; + + std::string Type() const override { return type_; } + void SetType(const std::string& x) override { type_ = x; } + + const std::map>& inputs() const { + return inputs_; + } + const std::map>& outputs() const { + return outputs_; + } + std::map>* mutable_inputs() { + return &inputs_; + } + std::map>* mutable_outputs() { + return &outputs_; + } + std::vector Input(const std::string& param) const override { + auto it = inputs_.find(param); + CHECK(it != inputs_.end()); + return it->second; + } + + std::vector InputArgumentNames() const override { + std::vector res; + for (const auto& x : inputs_) res.push_back(x.first); + return res; + } + std::vector OutputArgumentNames() const override { + std::vector res; + for (const auto& x : outputs_) res.push_back(x.first); + return res; + } + + std::vector Output(const std::string& param) const override { + auto it = outputs_.find(param); + CHECK(it != outputs_.end()); + return it->second; + } + + void SetInput(const std::string& param, + const std::vector& args) override { + inputs_[param] = args; + } + + void SetOutput(const std::string& param, + const std::vector& args) override { + outputs_[param] = args; + } + + bool HasAttr(const std::string& name) const override { + return attrs_.count(name); + } + + AttrType GetAttrType(const std::string& name) const override { + auto it = attr_types_.find(name); + CHECK(it != attr_types_.end()); + return it->second; + } + + std::vector AttrNames() const override { + std::vector res; + for (const auto& x : attrs_) { + res.push_back(x.first); + } + return res; + } + + template + void SetAttr(const std::string& name, const T& v); + + template + T GetAttr(const std::string& name) const; + + const std::map& attrs() const { return attrs_; } + const std::map& attr_types() const { + return attr_types_; + } +}; + +} // namespace cpp +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/desc_apis.h b/paddle/fluid/lite/model_parser/desc_apis.h new file mode 100644 index 00000000000000..d28f82a0e73085 --- /dev/null +++ b/paddle/fluid/lite/model_parser/desc_apis.h @@ -0,0 +1,85 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +namespace paddle { +namespace lite { + +/* + * Compatible interfaces for all the different kinds of opdesc. All the OpDesc + * classes should implement this. + * NOTE Some interfaces are weried, we remain them unchanged to keep compatible + * with framework::OpDesc in Fluid framework. + */ +class OpDescAPI { + public: + // The AttrType is used to make the proto::AttrType portable. + enum class AttrType { + INT = 0, + FLOAT = 1, + STRING = 2, + INTS = 3, + FLOATS = 4, + STRINGS = 5, + BOOLEAN = 6, + BOOLEANS = 7, + BLOCK = 8, + LONG = 9, + BLOCKS = 10, + LONGS = 11, + UNK, + }; + + virtual ~OpDescAPI() = default; + + /// Get operator's type. + virtual std::string Type() const = 0; + /// Set operator's type. + virtual void SetType(const std::string& type) = 0; + /// Get arguments given the parameter. + virtual std::vector Input(const std::string& param) const = 0; + /// Get parameters. + virtual std::vector InputArgumentNames() const = 0; + /// Get arguments given the parameter. + virtual std::vector Output(const std::string& param) const = 0; + /// Get parameters. + virtual std::vector OutputArgumentNames() const = 0; + /// Set a input given the parameter and arguments. + virtual void SetInput(const std::string& param, + const std::vector& args) = 0; + virtual void SetOutput(const std::string& param, + const std::vector& args) = 0; + /// Tell whether this desc has an attribute. + virtual bool HasAttr(const std::string& name) const = 0; + + /// Get the type of an attribute. + virtual AttrType GetAttrType(const std::string& name) const = 0; + + virtual std::vector AttrNames() const = 0; + + /// Set an attribute. + template + void SetAttr(const std::string& name, const T& v); + + /// Get an attribute. + template + T GetAttr(const std::string& name) const; +}; + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/op_desc_test.cc b/paddle/fluid/lite/model_parser/op_desc_test.cc new file mode 100644 index 00000000000000..df74c626040509 --- /dev/null +++ b/paddle/fluid/lite/model_parser/op_desc_test.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/model_parser/cpp/op_desc.h" +#include +#include "paddle/fluid/lite/model_parser/compatible_pb.h" +#include "paddle/fluid/lite/model_parser/pb/op_desc.h" + +namespace paddle { +namespace lite { + +template +void TestX() { + OpDesc desc; + + desc.SetInput("X", {"a", "b"}); + auto X = desc.Input("X"); + ASSERT_EQ(X.size(), 2UL); + ASSERT_EQ(X[0], "a"); + ASSERT_EQ(X[1], "b"); + + desc.SetOutput("Y", {"c", "d"}); + auto Y = desc.Output("Y"); + ASSERT_EQ(Y.size(), 2UL); + ASSERT_EQ(Y[0], "c"); + ASSERT_EQ(Y[1], "d"); + + desc.template SetAttr("aint", 100); + ASSERT_TRUE(desc.HasAttr("aint")); + ASSERT_FALSE(desc.HasAttr("afloat")); + ASSERT_EQ(desc.template GetAttr("aint"), 100); +} + +TEST(OpDesc, Basic) { + TestX(); + TestX(); +} + +TEST(OpDesc, CppToPb) { + cpp::OpDesc desc; + + desc.SetInput("X", {"a", "b"}); + desc.SetOutput("Y", {"c", "d"}); + desc.template SetAttr("aint", 100); + + pb::OpDesc pb_desc; + + TransformOpDescCppToPb(desc, &pb_desc); + { + auto& desc = pb_desc; + auto X = desc.Input("X"); + ASSERT_EQ(X.size(), 2UL); + ASSERT_EQ(X[0], "a"); + ASSERT_EQ(X[1], "b"); + + auto Y = desc.Output("Y"); + ASSERT_EQ(Y.size(), 2UL); + ASSERT_EQ(Y[0], "c"); + ASSERT_EQ(Y[1], "d"); + + ASSERT_TRUE(desc.HasAttr("aint")); + ASSERT_FALSE(desc.HasAttr("afloat")); + ASSERT_EQ(desc.template GetAttr("aint"), 100); + } +} + +TEST(OpDesc, PbToCpp) { + pb::OpDesc desc; + + desc.SetInput("X", {"a", "b"}); + desc.SetOutput("Y", {"c", "d"}); + desc.template SetAttr("aint", 100); + + cpp::OpDesc cpp_desc; + + TransformOpDescPbToCpp(desc, &cpp_desc); + { + auto& desc = cpp_desc; + auto X = desc.Input("X"); + ASSERT_EQ(X.size(), 2UL); + ASSERT_EQ(X[0], "a"); + ASSERT_EQ(X[1], "b"); + + auto Y = desc.Output("Y"); + ASSERT_EQ(Y.size(), 2UL); + ASSERT_EQ(Y[0], "c"); + ASSERT_EQ(Y[1], "d"); + + ASSERT_TRUE(desc.HasAttr("aint")); + ASSERT_FALSE(desc.HasAttr("afloat")); + ASSERT_EQ(desc.template GetAttr("aint"), 100); + } +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/model_parser/pb/op_desc.cc b/paddle/fluid/lite/model_parser/pb/op_desc.cc index fb269cd067180b..7f84510a3fad91 100644 --- a/paddle/fluid/lite/model_parser/pb/op_desc.cc +++ b/paddle/fluid/lite/model_parser/pb/op_desc.cc @@ -18,10 +18,9 @@ namespace paddle { namespace lite { namespace pb { -template <> -void OpDesc::SetAttr(const std::string &name, - const std::string &v) { - auto &xs = *desc_.mutable_attrs(); +google::protobuf::internal::RepeatedPtrIterator +FindAttr(framework::proto::OpDesc *desc, const std::string &name) { + auto &xs = *desc->mutable_attrs(); auto it = std::find_if( xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); @@ -33,11 +32,96 @@ void OpDesc::SetAttr(const std::string &name, return x.name() == name; }); } + return it; +} + +#define SET_IMPL_ONE(T, ty__, pb_f__) \ + template <> \ + void OpDesc::SetAttr(const std::string &name, const T &v) { \ + auto it = FindAttr(&desc_, name); \ + it->set_type(framework::proto::ty__); \ + it->set_##pb_f__(v); \ + } +SET_IMPL_ONE(int, INT, i); +SET_IMPL_ONE(float, FLOAT, f); +SET_IMPL_ONE(bool, BOOLEAN, b); +template <> +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v) { + auto it = FindAttr(&desc_, name); + it->set_type(framework::proto::INTS); + it->clear_ints(); + for (auto &i : v) { + it->add_ints(i); + } +} + +template <> +void OpDesc::SetAttr(const std::string &name, + const std::string &v) { + auto it = FindAttr(&desc_, name); it->set_type(framework::proto::STRING); it->set_s(v.c_str()); } +template <> +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v) { + auto it = FindAttr(&desc_, name); + it->set_type(framework::proto::FLOATS); + it->clear_floats(); + for (auto &i : v) { + it->add_floats(i); + } +} + +template <> +void OpDesc::SetAttr>( + const std::string &name, const std::vector &v) { + auto it = FindAttr(&desc_, name); + it->set_type(framework::proto::STRINGS); + it->clear_strings(); + for (auto &i : v) { + it->add_strings(i); + } +} + +google::protobuf::internal::RepeatedPtrIterator< + const framework::proto::OpDesc_Attr> +GetFindAttr(const framework::proto::OpDesc &desc, const std::string &name) { + auto &xs = desc.attrs(); + auto it = std::find_if( + xs.begin(), xs.end(), + [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); + return it; +} + +#define GET_ATTR_IMPL(T, pb_f__) \ + template <> \ + T OpDesc::GetAttr(const std::string &name) const { \ + auto it = GetFindAttr(desc_, name); \ + return it->pb_f__(); \ + } + +#define GET_ATTRS_IMPL(T, pb_f__) \ + template <> \ + T OpDesc::GetAttr(const std::string &name) const { \ + auto it = GetFindAttr(desc_, name); \ + T res; \ + for (const auto &v : it->pb_f__()) { \ + res.push_back(v); \ + } \ + return res; \ + } +GET_ATTR_IMPL(int32_t, i); +GET_ATTR_IMPL(float, f); +GET_ATTR_IMPL(bool, b); +GET_ATTRS_IMPL(std::vector, ints); +GET_ATTRS_IMPL(std::vector, floats); +GET_ATTRS_IMPL(std::vector, strings); +GET_ATTR_IMPL(std::string, s); + } // namespace pb } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/model_parser/pb/op_desc.h b/paddle/fluid/lite/model_parser/pb/op_desc.h index b1fbce54d865c9..e8772e162a5e72 100644 --- a/paddle/fluid/lite/model_parser/pb/op_desc.h +++ b/paddle/fluid/lite/model_parser/pb/op_desc.h @@ -27,13 +27,15 @@ #include #include #include "paddle/fluid/lite/core/framework.pb.h" +#include "paddle/fluid/lite/model_parser/desc_apis.h" #include "paddle/fluid/lite/utils/all.h" namespace paddle { namespace lite { namespace pb { -using Attribute = variant>; +using Attribute = + variant, std::vector>; using VariableNameMap = std::map>; /* @@ -42,7 +44,7 @@ using VariableNameMap = std::map>; * except the desc_, to avoid the inconsistent state, which is normal in the * original interface and results in bugs. */ -class OpDesc { +class OpDesc : public OpDescAPI { public: OpDesc() {} @@ -53,38 +55,38 @@ class OpDesc { framework::proto::OpDesc *Proto() { return &desc_; } const framework::proto::OpDesc &ReadonlyProto() const { return desc_; } - std::string Type() const { return desc_.type(); } + std::string Type() const override { return desc_.type(); } - void SetType(const std::string &type) { desc_.set_type(type); } + void SetType(const std::string &type) override { desc_.set_type(type); } // Get the arguments of parameter called `param` - std::vector Input(const std::string ¶m) const { + std::vector Input(const std::string ¶m) const override { return GetArguments(desc_.inputs(), param); } - std::vector InputArgumentNames() const { + std::vector InputArgumentNames() const override { return GetArgumentNames(desc_.inputs()); } void SetInput(const std::string ¶m, - const std::vector &args) { + const std::vector &args) override { SetArgument(desc_.mutable_inputs(), param, args); } - std::vector Output(const std::string ¶m) const { + std::vector Output(const std::string ¶m) const override { return GetArguments(desc_.outputs(), param); } - std::vector OutputArgumentNames() const { + std::vector OutputArgumentNames() const override { return GetArgumentNames(desc_.outputs()); } void SetOutput(const std::string ¶m, - const std::vector &args) { + const std::vector &args) override { SetArgument(desc_.mutable_outputs(), param, args); } - bool HasAttr(const std::string &name) const { + bool HasAttr(const std::string &name) const override { const auto &xs = desc_.attrs(); auto it = std::find_if(xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Attr &x) { @@ -93,17 +95,38 @@ class OpDesc { return it != xs.end(); } - framework::proto::AttrType GetAttrType(const std::string &name) const { + AttrType GetAttrType(const std::string &name) const override { const auto &xs = desc_.attrs(); auto it = std::find_if(xs.begin(), xs.end(), [&](const framework::proto::OpDesc_Attr &x) { return x.name() == name; }); CHECK(it != xs.end()); - return it->type(); +#define DEF_ONE(type__) \ + case framework::proto::AttrType::type__: \ + return AttrType::type__; + + switch (it->type()) { + DEF_ONE(INT); + DEF_ONE(FLOAT); + DEF_ONE(STRING); + DEF_ONE(INTS); + DEF_ONE(FLOATS); + DEF_ONE(STRINGS); + DEF_ONE(BOOLEAN); + DEF_ONE(BOOLEANS); + DEF_ONE(BLOCK); + DEF_ONE(LONG); + DEF_ONE(BLOCKS); + DEF_ONE(LONGS); + default: + LOG(ERROR) << "Unknown attribute type"; + return AttrType::UNK; + } +#undef DEF_ONE } - std::vector AttrNames() const { + std::vector AttrNames() const override { std::vector res; const auto &xs = desc_.attrs(); std::transform( @@ -113,66 +136,10 @@ class OpDesc { } template - void SetAttr(const std::string &name, const T &v) { - auto &xs = *desc_.mutable_attrs(); - auto it = std::find_if(xs.begin(), xs.end(), - [&](const framework::proto::OpDesc_Attr &x) { - return x.name() == name; - }); - if (it == xs.end()) { - auto *attr = xs.Add(); - attr->set_name(name); - it = std::find_if(xs.begin(), xs.end(), - [&](const framework::proto::OpDesc_Attr &x) { - return x.name() == name; - }); - } - - size_t hash = typeid(T).hash_code(); - if (hash == typeid(int).hash_code()) { // NOLINT - it->set_type(framework::proto::INT); - it->set_i(v); - } else if (hash == typeid(float).hash_code()) { // NOLINT - it->set_type(framework::proto::FLOAT); - it->set_f(v); - } else if (hash == typeid(bool).hash_code()) { // NOLINT - it->set_type(framework::proto::BOOLEAN); - it->set_b(v); - } else { - LOG(FATAL) << "unsupport attr type"; - } - } - - Attribute GetAttr(const std::string &name) const { - auto &xs = desc_.attrs(); - auto it = std::find_if(xs.begin(), xs.end(), - [&](const framework::proto::OpDesc_Attr &x) { - return x.name() == name; - }); - - Attribute res; - CHECK(it != xs.end()); - - switch (it->type()) { - case framework::proto::INT: - res.set(it->i()); - break; - case framework::proto::FLOAT: - res.set(it->f()); - break; - case framework::proto::STRING: - res.set(it->s()); - break; - case framework::proto::BOOLEAN: - res.set(it->b()); - break; + void SetAttr(const std::string &name, const T &v); - default: - LOG(FATAL) << "unsupported attr type"; - } - - return res; - } + template + T GetAttr(const std::string &name) const; private: std::vector GetArguments( @@ -231,6 +198,10 @@ template <> void OpDesc::SetAttr(const std::string &name, const std::string &v); +template <> +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v); + } // namespace pb } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 184acb8485d168..691ff743b173d2 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -4,21 +4,47 @@ cc_library(fc_op_lite SRCS fc_op.cc DEPS ${op_DEPS}) cc_library(relu_op_lite SRCS relu_op.cc DEPS ${op_DEPS}) cc_library(mul_op_lite SRCS mul_op.cc DEPS ${op_DEPS}) cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS}) +cc_library(softmax_op_lite SRCS softmax_op.cc DEPS ${op_DEPS}) +cc_library(reshape_op_lite SRCS reshape_op.cc DEPS ${op_DEPS} ) cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS}) cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS}) cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS}) cc_library(activation_ops_lite SRCS activation_ops.cc DEPS ${op_DEPS}) cc_library(elementwise_ops_lite SRCS elementwise_ops.cc DEPS ${op_DEPS}) +cc_library(mean_op_lite SRCS mean_op.cc DEPS ${op_DEPS}) +cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) +#cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS}) +cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite) +cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) +cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) +cc_library(conv_op_lite SRCS conv_op.cc DEPS ${op_DEPS}) +cc_library(pool_op_lite SRCS pool_op.cc DEPS ${op_DEPS}) -cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite}) set(ops_lite - fc_op_lite - relu_op_lite - mul_op_lite - scale_op_lite - feed_op_lite - fetch_op_lite - io_copy_op_lite - PARENT_SCOPE) + fc_op_lite + relu_op_lite + mul_op_lite + scale_op_lite + softmax_op_lite + reshape_op_lite + feed_op_lite + fetch_op_lite + io_copy_op_lite + elementwise_ops_lite + mean_op_lite + fill_constant_op_lite + activation_ops_lite + dropout_op_lite + concat_op_lite + conv_op_lite + pool_op_lite + CACHE INTERNAL "ops lite") -lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite fc_compute_host) +lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc + DEPS fc_op_lite memory_lite + X86_DEPS fc_compute_x86 + ARM_DEPS fc_compute_arm) +lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite) +lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite) +lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite) +lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite) diff --git a/paddle/fluid/lite/operators/activation_ops.cc b/paddle/fluid/lite/operators/activation_ops.cc index 1e824e8580ef5c..8cda67af14a786 100644 --- a/paddle/fluid/lite/operators/activation_ops.cc +++ b/paddle/fluid/lite/operators/activation_ops.cc @@ -1,3 +1,20 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef LITE_WITH_X86 +#include "paddle/fluid/framework/operator.h" +#endif #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_registry.h" @@ -16,22 +33,62 @@ class ActivationOp : public OpLite { return true; } - bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { auto X_name = opdesc.Input("X").front(); auto Out_name = opdesc.Output("Out").front(); param_.X = GetVar(scope, X_name); param_.Out = GetMutableVar(scope, Out_name); + return true; } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "activation_op"; } + private: mutable ActivationParam param_; }; +#ifdef LITE_WITH_X86 +class ActivationGradOp : public OpLite { + public: + explicit ActivationGradOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.X_grad); + CHECK_OR_FALSE(param_.Out_grad); + return true; + } + + bool InferShape() const override { + param_.X_grad->Resize(param_.Out_grad->dims()); + return true; + } + + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { + auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front(); + auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front(); + + param_.Out_grad = GetVar(scope, Out_grad_name); + param_.X_grad = GetMutableVar(scope, X_grad_name); + return true; + } + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "activation_grad_op"; } + + private: + mutable ActivationGradParam param_; +}; +#endif + } // namespace operators } // namespace lite } // namespace paddle REGISTER_LITE_OP(square, paddle::lite::operators::ActivationOp); +#ifdef LITE_WITH_X86 +REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp); +#endif diff --git a/paddle/fluid/lite/operators/concat_op.cc b/paddle/fluid/lite/operators/concat_op.cc new file mode 100644 index 00000000000000..e51d6e0d349823 --- /dev/null +++ b/paddle/fluid/lite/operators/concat_op.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/concat_op.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool ConcatOpLite::CheckShape() const { + CHECK_GT_OR_FALSE(param_.x.size(), 1UL); + CHECK_OR_FALSE(param_.output); + return true; +} + +bool ConcatOpLite::InferShape() const { + std::vector input_dims; + for (auto p : param_.x) { + input_dims.push_back(p->dims()); + } + size_t axis = static_cast(param_.axis); + const size_t n = input_dims.size(); + CHECK_GT_OR_FALSE(n, 0); + auto &out_dims = input_dims[0]; + size_t in_zero_dims_size = out_dims.size(); + for (size_t i = 1; i < n; i++) { + for (size_t j = 0; j < in_zero_dims_size; j++) { + if (j == axis) { + out_dims[axis] += input_dims[i][j]; + } else { + CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]); + } + } + } + if (out_dims[axis] < 0) { + out_dims[axis] = -1; + } + // Set output dims + param_.output->Resize(lite::DDim(out_dims)); + return true; +} + +// TODO(Superjomn) replace framework::OpDesc with a lite one. +bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + auto inputs = op_desc.Input("X"); + auto out = op_desc.Output("Out").front(); + + for (auto var : inputs) { + param_.x.push_back(scope->FindVar(var)->GetMutable()); + } + CHECK(scope->FindVar(out)); + param_.output = scope->FindVar(out)->GetMutable(); + param_.axis = op_desc.GetAttr("axis"); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(concat, paddle::lite::operators::ConcatOpLite); diff --git a/paddle/fluid/lite/operators/concat_op.h b/paddle/fluid/lite/operators/concat_op.h new file mode 100644 index 00000000000000..17408289a61175 --- /dev/null +++ b/paddle/fluid/lite/operators/concat_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ConcatOpLite : public OpLite { + public: + ConcatOpLite() {} + explicit ConcatOpLite(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "concat"; } + + private: + mutable ConcatParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/concat_op_test.cc b/paddle/fluid/lite/operators/concat_op_test.cc new file mode 100644 index 00000000000000..3af3fc8ef78e63 --- /dev/null +++ b/paddle/fluid/lite/operators/concat_op_test.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/concat_op.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +TEST(concat_op_lite, test) { + // prepare variables + lite::Scope scope; + auto* x0 = scope.Var("x0")->GetMutable(); + auto* x1 = scope.Var("x1")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); + x0->Resize(lite::DDim(std::vector({10, 20}))); + x1->Resize(lite::DDim(std::vector({10, 20}))); + output->Resize(lite::DDim(std::vector{20, 20})); + + // set data + for (int i = 0; i < 10 * 20; i++) { + x0->mutable_data()[i] = i; + } + for (int i = 0; i < 10 * 20; i++) { + x1->mutable_data()[i] = i; + } + for (int i = 0; i < 10 * 20; i++) { + output->mutable_data()[i] = 0.; + } + + // prepare op desc + cpp::OpDesc desc; + desc.SetType("concat"); + desc.SetInput("X", {"x0", "x1"}); + desc.SetOutput("Out", {"output"}); + desc.SetAttr("axis", static_cast(0)); + + ConcatOpLite concat("concat"); + + concat.SetValidPlaces({Place{TARGET(kX86), PRECISION(kFloat)}}); + concat.Attach(desc, &scope); +} + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/conv_op.cc b/paddle/fluid/lite/operators/conv_op.cc new file mode 100644 index 00000000000000..63838efd6fe571 --- /dev/null +++ b/paddle/fluid/lite/operators/conv_op.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/conv_op.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool ConvOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + CHECK_OR_FALSE(param_.filter); + return true; +} + +bool ConvOpLite::InferShape() const { + auto in_dims = param_.x->dims(); + auto filter_dims = param_.filter->dims(); + std::vector strides = param_.strides; + std::vector paddings = param_.paddings; + int groups = param_.groups; + std::vector dilations = param_.dilations; + + CHECK_OR_FALSE(in_dims.size() == 4 || in_dims.size() == 5); + CHECK_EQ_OR_FALSE(in_dims.size(), filter_dims.size()); + CHECK_OR_FALSE(in_dims.size() - strides.size() == 2U); + CHECK_EQ_OR_FALSE(paddings.size(), strides.size()); + CHECK_EQ_OR_FALSE(in_dims[1], filter_dims[1] * groups); + CHECK_EQ_OR_FALSE(filter_dims[0] % groups, 0); + + std::vector output_shape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < strides.size(); ++i) { + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); + } + param_.output->Resize(lite::DDim(output_shape)); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(conv2d, paddle::lite::operators::ConvOpLite); +REGISTER_LITE_OP(depthwise_conv2d, paddle::lite::operators::ConvOpLite); diff --git a/paddle/fluid/lite/operators/conv_op.h b/paddle/fluid/lite/operators/conv_op.h new file mode 100644 index 00000000000000..e5ad7fe67f9561 --- /dev/null +++ b/paddle/fluid/lite/operators/conv_op.h @@ -0,0 +1,108 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +inline int ConvOutputSize(int input_size, int filter_size, int dilation, + int padding, int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + int output_size = (input_size + 2 * padding - dkernel) / stride + 1; + CHECK_OR_FALSE(output_size > 0); + + return output_size; +} + +inline bool IsExpand(const std::vector& filter_dim, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations) { + bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; + for (size_t j = 0; j < strides.size(); ++j) { + filter_1 = filter_1 && (static_cast(filter_dim[j + 2]) == 1); + strides_1 = strides_1 && (strides[j] == 1); + padding_0 = padding_0 && (paddings[j] == 0); + dilation_1 = dilation_1 && (dilations[j] == 1); + } + return !(filter_1 && strides_1 && padding_0 && dilation_1); +} + +class ConvOpLite : public OpLite { + public: + ConvOpLite() {} + + explicit ConvOpLite(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + // TODO(Superjomn) replace framework::OpDesc with a lite one. + bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { + auto X = op_desc.Input("Input").front(); + auto Filter = op_desc.Input("Filter").front(); + auto Out = op_desc.Output("Output").front(); + + param_.x = scope->FindVar(X)->GetMutable(); + param_.filter = scope->FindVar(Filter)->GetMutable(); + param_.output = scope->FindVar(Out)->GetMutable(); + + std::vector input_arg_names = op_desc.InputArgumentNames(); + if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") != + input_arg_names.end()) { + auto bias_var = scope->FindVar(op_desc.Input("Bias").front()); + if (bias_var != nullptr) { + param_.bias = + const_cast(&(bias_var->Get())); + } + } + if (std::find(input_arg_names.begin(), input_arg_names.end(), + "ResidualData") != input_arg_names.end()) { + auto residual_data_var = + scope->FindVar(op_desc.Input("ResidualData").front()); + if (residual_data_var != nullptr) { + param_.residualData = const_cast( + &(residual_data_var->Get())); + } + } + + param_.strides = op_desc.GetAttr>("strides"); + param_.paddings = op_desc.GetAttr>("paddings"); + param_.groups = op_desc.GetAttr("groups"); + param_.dilations = op_desc.GetAttr>("dilations"); + + return true; + } + + std::string DebugString() const override { return "conv2d"; } + + private: + mutable ConvParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/dropout_op.cc b/paddle/fluid/lite/operators/dropout_op.cc new file mode 100644 index 00000000000000..b5b50dc3d16687 --- /dev/null +++ b/paddle/fluid/lite/operators/dropout_op.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +class DropoutOpLite : public OpLite { + public: + explicit DropoutOpLite(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.x); + return true; + } + + bool InferShape() const override { + const auto x_dims = param_.x->dims(); + param_.output->Resize(x_dims); + if (param_.is_test == false) { + param_.mask->Resize(x_dims); + } + // share LoD + // param_.output->set_lod(param_.input->lod()); + return true; + } + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + // TODO(Superjomn) replace framework::OpDesc with a lite one. + bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { + auto input = op_desc.Input("X").front(); + auto out = op_desc.Output("Out").front(); + auto Mask = op_desc.Output("Mask").front(); + + param_.x = GetVar(scope, input); + param_.output = GetMutableVar(scope, out); + param_.mask = GetMutableVar(scope, Mask); + + param_.dropout_prob = op_desc.GetAttr("dropout_prob"); + if (op_desc.HasAttr("axis")) { + param_.is_test = op_desc.GetAttr("is_test"); + } + param_.fix_seed = op_desc.GetAttr("fix_seed"); + param_.seed = op_desc.GetAttr("seed"); + param_.dropout_implementation = + op_desc.GetAttr("dropout_implementation"); + return true; + } + + std::string DebugString() const override { return "dropout"; } + + private: + mutable DropoutParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(dropout, paddle::lite::operators::DropoutOpLite); diff --git a/paddle/fluid/lite/operators/elementwise_ops.cc b/paddle/fluid/lite/operators/elementwise_ops.cc index f4a22c6fcd8b5e..b400b1ab26c137 100644 --- a/paddle/fluid/lite/operators/elementwise_ops.cc +++ b/paddle/fluid/lite/operators/elementwise_ops.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_registry.h" @@ -17,31 +31,82 @@ class ElementwiseOp : public OpLite { } bool InferShape() const override { - CHECK_OR_FALSE(param_.X->dims() == param_.Y->dims()); + CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size()); param_.Out->Resize(param_.X->dims()); return true; } - bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { - CHECK_EQ(opdesc.Inputs().size(), 2UL); + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { auto X_name = opdesc.Input("X").front(); auto Y_name = opdesc.Input("Y").front(); auto Out_name = opdesc.Output("Out").front(); param_.X = GetVar(scope, X_name); param_.Y = GetVar(scope, Y_name); - param_.Out = GetMutableVar(scope, Out_name); - param_.axis = boost::get(opdesc.GetAttr("axis")); + param_.Out = GetMutableVar(scope, Out_name); + param_.axis = opdesc.GetAttr("axis"); + return true; } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "elementwise_op"; } + private: mutable operators::ElementwiseParam param_; }; +#ifdef LITE_WITH_X86 +class ElementwiseGradExplicitOp : public OpLite { + public: + explicit ElementwiseGradExplicitOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.Y); + CHECK_OR_FALSE(param_.X_grad); + CHECK_OR_FALSE(param_.Y_grad); + CHECK_OR_FALSE(param_.Out_grad); + return true; + } + + bool InferShape() const override { + param_.X_grad->Resize(param_.Out_grad->dims()); + param_.Y_grad->Resize(param_.Y->dims()); + return true; + } + + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { + CHECK_EQ(opdesc.InputArgumentNames().size(), 1UL); + auto Out_name = opdesc.Input(framework::GradVarName("Out")).front(); + auto X_name = opdesc.Output(framework::GradVarName("X")).front(); + auto Y_name = opdesc.Output(framework::GradVarName("Y")).front(); + + param_.Out_grad = GetVar(scope, Out_name); + param_.X_grad = GetMutableVar(scope, X_name); + param_.Y_grad = GetMutableVar(scope, Y_name); + param_.axis = opdesc.GetAttr("axis"); + + return true; + } + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { + return "elementwise_grad_explicit_op"; + } + + private: + mutable operators::ElementwiseGradParam param_; +}; +#endif + } // namespace operators } // namespace lite } // namespace paddle REGISTER_LITE_OP(elementwise_sub, paddle::lite::operators::ElementwiseOp); +#ifdef LITE_WITH_X86 +REGISTER_LITE_OP(elementwise_sub_grad, + paddle::lite::operators::ElementwiseGradExplicitOp); +#endif +REGISTER_LITE_OP(elementwise_add, paddle::lite::operators::ElementwiseOp); diff --git a/paddle/fluid/lite/operators/fc_op.h b/paddle/fluid/lite/operators/fc_op.h index a6043fa7b1f90b..0e738018322f42 100644 --- a/paddle/fluid/lite/operators/fc_op.h +++ b/paddle/fluid/lite/operators/fc_op.h @@ -46,7 +46,7 @@ class FcOpLite : public OpLite { */ // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override { + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto input = op_desc.Input("Input").front(); auto W = op_desc.Input("W").front(); auto bias = op_desc.Input("Bias").front(); @@ -57,7 +57,7 @@ class FcOpLite : public OpLite { param_.bias = scope->FindVar(bias)->GetMutable(); CHECK(scope->FindVar(out)); param_.output = scope->FindVar(out)->GetMutable(); - param_.in_num_col_dims = GetAttr(op_desc.GetAttr("in_num_col_dims")); + param_.in_num_col_dims = op_desc.GetAttr("in_num_col_dims"); return true; } diff --git a/paddle/fluid/lite/operators/fc_op_test.cc b/paddle/fluid/lite/operators/fc_op_test.cc index dccb53f3be1568..880b8a820e537b 100644 --- a/paddle/fluid/lite/operators/fc_op_test.cc +++ b/paddle/fluid/lite/operators/fc_op_test.cc @@ -20,7 +20,7 @@ namespace paddle { namespace lite { namespace operators { -TEST(fc_op_lite, test) { +TEST(fc_op_lite, TestX86) { // prepare variables Scope scope; auto* x = scope.Var("x")->GetMutable(); @@ -47,7 +47,7 @@ TEST(fc_op_lite, test) { } // prepare op desc - framework::OpDesc desc; + cpp::OpDesc desc; desc.SetType("fc"); desc.SetInput("Input", {"x"}); desc.SetInput("W", {"w"}); @@ -57,9 +57,11 @@ TEST(fc_op_lite, test) { FcOpLite fc("fc"); - fc.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); + fc.SetValidPlaces({Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}}); fc.Attach(desc, &scope); - auto kernels = fc.CreateKernels({Place{TARGET(kHost), PRECISION(kFloat)}}); + auto kernels = fc.CreateKernels({Place{TARGET(kX86), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kFloat)}}); ASSERT_FALSE(kernels.empty()); } @@ -67,4 +69,10 @@ TEST(fc_op_lite, test) { } // namespace lite } // namespace paddle -USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def); +#ifdef LITE_WITH_X86 +USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def); +#endif + +#ifdef LITE_WITH_ARM +USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def); +#endif diff --git a/paddle/fluid/lite/operators/feed_op.cc b/paddle/fluid/lite/operators/feed_op.cc index 45a7c198cb6eeb..c977adfd4b32b6 100644 --- a/paddle/fluid/lite/operators/feed_op.cc +++ b/paddle/fluid/lite/operators/feed_op.cc @@ -34,12 +34,12 @@ class FeedOp : public OpLite { void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } protected: - bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { auto feed_var_name = opdesc.Input("X").front(); auto* feed_var = scope->FindVar(feed_var_name); CHECK(feed_var); - auto& feed_tensor_list = feed_var->Get>(); - param_.feed_list = &feed_tensor_list; + auto* feed_tensor_list = feed_var->GetMutable>(); + param_.feed_list = feed_tensor_list; auto out_name = opdesc.Output("Out").front(); auto* out_var = scope->FindVar(out_name); @@ -48,7 +48,7 @@ class FeedOp : public OpLite { // NOTE need boost here // TODO(Superjomn) drop the need of framework::op_desc - param_.col = GetAttr(opdesc.GetAttr("col")); + param_.col = opdesc.GetAttr("col"); return true; } diff --git a/paddle/fluid/lite/operators/fetch_op.cc b/paddle/fluid/lite/operators/fetch_op.cc index 337a6ecc9d571b..51efda776b21b8 100644 --- a/paddle/fluid/lite/operators/fetch_op.cc +++ b/paddle/fluid/lite/operators/fetch_op.cc @@ -33,7 +33,7 @@ class FetchOp : public OpLite { void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } protected: - bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override { + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { auto _x = opdesc.Input("X").front(); auto* x = scope->FindVar(_x); CHECK(x); @@ -43,7 +43,7 @@ class FetchOp : public OpLite { auto* out = scope->FindVar(_out); param_.fetch_list = out->GetMutable>(); - param_.col = GetAttr(opdesc.GetAttr("col")); + param_.col = opdesc.GetAttr("col"); return true; } diff --git a/paddle/fluid/lite/operators/fill_constant_op.cc b/paddle/fluid/lite/operators/fill_constant_op.cc new file mode 100644 index 00000000000000..b762f0d3c9215f --- /dev/null +++ b/paddle/fluid/lite/operators/fill_constant_op.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FillConstantOp : public OpLite { + public: + explicit FillConstantOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.Out); + return true; + } + + bool InferShape() const override { + param_.Out->Resize(param_.shape); + return true; + } + + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { + auto Out_name = opdesc.Output("Out").front(); + + param_.Out = GetMutableVar(scope, Out_name); + param_.dtype = opdesc.GetAttr("dtype"); + param_.shape = opdesc.GetAttr>("shape"); + param_.value = opdesc.GetAttr("value"); + param_.force_cpu = opdesc.GetAttr("force_cpu"); + return true; + } + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "fill_constant"; } + + private: + mutable operators::FillConstantParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(fill_constant, paddle::lite::operators::FillConstantOp); diff --git a/paddle/fluid/lite/operators/io_copy_op.cc b/paddle/fluid/lite/operators/io_copy_op.cc index 220853fc2639c4..44d49a30a0eb0d 100644 --- a/paddle/fluid/lite/operators/io_copy_op.cc +++ b/paddle/fluid/lite/operators/io_copy_op.cc @@ -29,7 +29,8 @@ bool IoCopyOp::InferShape() const { return true; } bool IoCopyOp::Run() { return OpLite::Run(); } -bool IoCopyOp::AttachImpl(const OpDesc &opdesc, paddle::lite::Scope *scope) { +bool IoCopyOp::AttachImpl(const cpp::OpDesc &opdesc, + paddle::lite::Scope *scope) { auto x = opdesc.Input("Input").front(); auto out = opdesc.Output("Out").front(); param_.x = GetTensor(scope, x); diff --git a/paddle/fluid/lite/operators/io_copy_op.h b/paddle/fluid/lite/operators/io_copy_op.h index efcd11bc3092af..dd95ef8d33a662 100644 --- a/paddle/fluid/lite/operators/io_copy_op.h +++ b/paddle/fluid/lite/operators/io_copy_op.h @@ -31,7 +31,7 @@ class IoCopyOp : public OpLite { void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } protected: - bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; private: operators::IoCopyParam param_; diff --git a/paddle/fluid/lite/operators/mean_op.cc b/paddle/fluid/lite/operators/mean_op.cc new file mode 100644 index 00000000000000..411dcbb735a001 --- /dev/null +++ b/paddle/fluid/lite/operators/mean_op.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef LITE_WITH_X86 +#include "paddle/fluid/framework/operator.h" +#endif +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +class MeanOp : public OpLite { + public: + explicit MeanOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out); + return true; + } + + bool InferShape() const override { + param_.Out->Resize(std::vector{1}); + return true; + } + + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { + auto X_name = opdesc.Input("X").front(); + auto Out_name = opdesc.Output("Out").front(); + + param_.X = GetVar(scope, X_name); + param_.Out = GetMutableVar(scope, Out_name); + return true; + } + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "mean"; } + + private: + mutable operators::ElementwiseParam param_; +}; + +#ifdef LITE_WITH_X86 +class MeanGradOp : public OpLite { + public: + explicit MeanGradOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out_grad); + CHECK_OR_FALSE(param_.X_grad); + return true; + } + + bool InferShape() const override { + param_.X_grad->Resize(param_.X->dims()); + // param_.X_grad->set_lod(param_.X->lod()); + return true; + } + + bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { + CHECK_EQ(opdesc.InputArgumentNames().size(), 3UL); + auto X_name = opdesc.Input("X").front(); + auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front(); + auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front(); + + param_.X = GetVar(scope, X_name); + param_.Out_grad = GetVar(scope, Out_grad_name); + param_.X_grad = GetMutableVar(scope, X_grad_name); + return true; + } + + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "mean_grad"; } + + private: + mutable operators::MeanGradParam param_; +}; +#endif + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(mean, paddle::lite::operators::MeanOp); +#ifdef LITE_WITH_X86 +REGISTER_LITE_OP(mean_grad, paddle::lite::operators::MeanGradOp); +#endif diff --git a/paddle/fluid/lite/operators/mul_op.cc b/paddle/fluid/lite/operators/mul_op.cc index b78ae4578a6130..70eb37dd09b214 100644 --- a/paddle/fluid/lite/operators/mul_op.cc +++ b/paddle/fluid/lite/operators/mul_op.cc @@ -28,9 +28,19 @@ bool MulOpLite::CheckShape() const { const auto x_dims = param_.x->dims(); const auto y_dims = param_.y->dims(); - CHECK_EQ_OR_FALSE(y_dims.size(), 2UL); CHECK_GT_OR_FALSE(x_dims.size(), static_cast(param_.x_num_col_dims)); + CHECK_GT_OR_FALSE(y_dims.size(), static_cast(param_.y_num_col_dims)); + // auto x_mat_dims = + // framework::flatten_to_2d(x_dims.data(), param_.x_num_col_dims); + // auto y_mat_dims = + // framework::flatten_to_2d(y_dims.data(), param_.y_num_col_dims); + + // PADDLE_ENFORCE_EQ(x_mat_dims[1], y_mat_dims[0], + // "First matrix's width must be equal with second matrix's + // " + // "height. %s, %s", + // x_mat_dims[1], y_mat_dims[0]); return true; } @@ -39,11 +49,16 @@ bool MulOpLite::InferShape() const { const auto y_dims = param_.y->dims(); // Set output dims - std::vector out_dims(param_.x_num_col_dims + 1, 0); + std::vector out_dims( + param_.x_num_col_dims + y_dims.size() - param_.y_num_col_dims, 0); for (int i = 0; i < param_.x_num_col_dims; ++i) { out_dims[i] = x_dims[i]; } - out_dims.back() = y_dims[1]; + + for (auto i = static_cast(param_.y_num_col_dims); i < y_dims.size(); + ++i) { + out_dims[i] = y_dims[i]; + } param_.output->Resize(lite::DDim(out_dims)); @@ -52,6 +67,41 @@ bool MulOpLite::InferShape() const { return true; } +#ifdef LITE_WITH_X86 + +bool MulGradOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.y); + CHECK_OR_FALSE(param_.output_grad); + CHECK_OR_FALSE(param_.x_grad); + CHECK_OR_FALSE(param_.y_grad); + + return true; +} + +bool MulGradOpLite::InferShape() const { + param_.x_grad->Resize(param_.x->dims()); + param_.y_grad->Resize(param_.y->dims()); + return true; +} + +bool MulGradOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + auto X_name = op_desc.Input("X").front(); + auto Y_name = op_desc.Input("Y").front(); + auto Out_grad_name = op_desc.Output(framework::GradVarName("Out")).front(); + auto X_grad_name = op_desc.Output(framework::GradVarName("X")).front(); + auto Y_grad_name = op_desc.Output(framework::GradVarName("Y")).front(); + + param_.x = GetVar(scope, X_name); + param_.y = GetVar(scope, Y_name); + param_.output_grad = GetVar(scope, Out_grad_name); + param_.x_grad = GetMutableVar(scope, X_grad_name); + param_.y_grad = GetMutableVar(scope, Y_grad_name); + + return true; +} +#endif + } // namespace operators } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/operators/mul_op.h b/paddle/fluid/lite/operators/mul_op.h index 806fdb01f9bd62..7aa1581bb2adb6 100644 --- a/paddle/fluid/lite/operators/mul_op.h +++ b/paddle/fluid/lite/operators/mul_op.h @@ -37,7 +37,7 @@ class MulOpLite : public OpLite { void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override { + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { auto input = op_desc.Input("X").front(); auto W = op_desc.Input("Y").front(); auto out = op_desc.Output("Out").front(); @@ -45,12 +45,13 @@ class MulOpLite : public OpLite { CHECK(var); param_.x = var->GetMutable(); var = scope->FindVar(W); - CHECK(var); + CHECK(var) << "no var called " << W; param_.y = var->GetMutable(); - CHECK(scope->FindVar(out)); - param_.output = scope->FindVar(out)->GetMutable(); - param_.x_num_col_dims = GetAttr(op_desc.GetAttr("x_num_col_dims")); - param_.y_num_col_dims = GetAttr(op_desc.GetAttr("y_num_col_dims")); + var = scope->FindVar(out); + CHECK(var) << "no var called " << out; + param_.output = var->GetMutable(); + param_.x_num_col_dims = op_desc.GetAttr("x_num_col_dims"); + param_.y_num_col_dims = op_desc.GetAttr("y_num_col_dims"); return true; } @@ -61,6 +62,26 @@ class MulOpLite : public OpLite { mutable MulParam param_; }; +class MulGradOpLite : public OpLite { + public: + MulGradOpLite() {} + + explicit MulGradOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; + + std::string DebugString() const override { return "mul_grad"; } + + private: + mutable MulGradParam param_; +}; + } // namespace operators } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index d21c0e3135d17c..23b21cb276442d 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -13,8 +13,10 @@ // limitations under the License. #pragma once +#include #include #include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/framework.pb.h" #include "paddle/fluid/lite/utils/all.h" /* @@ -72,6 +74,17 @@ struct MulParam { int y_num_col_dims{1}; }; +struct MulGradParam { + const lite::Tensor* x{}; + const lite::Tensor* y{}; + const lite::Tensor* output_grad{}; + lite::Tensor* x_grad{}; + lite::Tensor* y_grad{}; + + int x_num_col_dims{1}; + int y_num_col_dims{1}; +}; + // For Scale Op struct ScaleParam { lite::Tensor* x{}; @@ -82,6 +95,85 @@ struct ScaleParam { bool bias_after_scale{true}; }; +// For Softmax op +struct SoftmaxParam { + lite::Tensor* x{}; + lite::Tensor* output{}; + int axis{-1}; +}; + +// For Reshape and Reshape2 Op +struct ReshapeParam { + const lite::Tensor* x{}; + const lite::Tensor* actual_shape{nullptr}; + lite::Tensor* output{}; + lite::Tensor* xshape{}; + + std::vector shape{}; + bool inplace{false}; +}; + +// For Concat op +struct ConcatParam { + std::vector x{}; + lite::Tensor* output{}; + int axis{0}; +}; + +// For Convolution op +struct ConvParam { + lite::Tensor* x{}; + lite::Tensor* filter{}; + lite::Tensor* bias{}; + lite::Tensor* residualData{}; + lite::Tensor* output{}; + std::vector strides{1, 1}; + std::vector paddings{0, 0}; + int groups{1}; + std::vector dilations{1, 1}; + bool fuse_relu_before_depthwise_conv{false}; + bool use_mkldnn{false}; + bool fuse_relu{false}; // only used in mkldnn kernel + bool use_quantizer{ + false}; // set true for op that should be quantized, only used for cpu + bool fuse_residual_connection{false}; + float scale_in{1.0f}; // only used with mkl-dnn int8 + float scale_out{1.0f}; // only used with mkl-dnn int8 + float scale_in_eltwise{1.0f}; // only used with mkl-dnn int8 + float scale_weights{1.0f}; // only used with mkl-dnn int8 + bool force_fp32_output{false}; // only used in mkl-dnn int8 + std::string data_format{"Anylayout"}; +}; + +// For Pooling op +struct PoolParam { + lite::Tensor* x{}; + lite::Tensor* output{}; + std::string pooling_type{""}; + std::vector ksize{}; + bool global_pooling{ + false}; // if true, knernel size and paddings will be ignored + std::vector strides{1, 1}; + std::vector paddings{0, 0}; + bool exclusive{true}; + bool adaptive{false}; + bool ceil_mode{false}; + bool use_quantizer{false}; + std::string data_format{"AnyLayout"}; +}; + +// For Dropout op +struct DropoutParam { + const lite::Tensor* x{}; + lite::Tensor* output{}; + lite::Tensor* mask{}; + float dropout_prob{.5f}; + bool is_test{false}; + bool fix_seed{false}; + int seed{0}; + std::string dropout_implementation{"downgrade_in_infer"}; +}; + /// ----------------------- element wise operators ---------------------- struct ElementwiseParam { const lite::Tensor* X{}; @@ -91,9 +183,10 @@ struct ElementwiseParam { }; struct ElementwiseGradParam { - const lite::Tensor* X_grad{}; - const lite::Tensor* Y_grad{}; - lite::Tensor* Out_grad{}; + const lite::Tensor* Y{}; + const lite::Tensor* Out_grad{}; + lite::Tensor* X_grad{}; + lite::Tensor* Y_grad{}; int axis{-1}; // for broadcasting. }; @@ -111,6 +204,39 @@ struct ActivationGradParam { const lite::Tensor* Out_grad{}; }; +/// ----------------------- mean operators ---------------------- +struct MeanParam { + const lite::Tensor* X{}; + lite::Tensor* Out{}; +}; + +struct MeanGradParam { + const lite::Tensor* X{}; + const lite::Tensor* Out_grad{}; + // for backward + lite::Tensor* X_grad{}; +}; + +/// ----------------------- fill_constant operators ---------------------- +struct FillConstantParam { + int dtype{framework::proto::VarType::FP32}; + std::vector shape{}; + float value{0.0f}; + // useless for x86, keep it for compatibility + bool force_cpu{false}; + lite::Tensor* Out{}; +}; + +/// ----------------------- sgd operators ---------------------- +struct SGDParam { + int dtype{framework::proto::VarType::FP32}; + + const lite::Tensor* Param{}; + const lite::Tensor* LearningRate{}; + const lite::Tensor* Grad{}; + lite::Tensor* ParamOut{}; +}; + } // namespace operators } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/operators/pool_op.cc b/paddle/fluid/lite/operators/pool_op.cc new file mode 100644 index 00000000000000..055f00f90a4776 --- /dev/null +++ b/paddle/fluid/lite/operators/pool_op.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/pool_op.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +int PoolOutputSize(int input_size, int filter_size, int padding, int stride, + bool ceil_mode) { + int output_size; + if (!ceil_mode) { + output_size = (input_size - filter_size + 2 * padding) / stride + 1; + } else { + output_size = + (input_size - filter_size + 2 * padding + stride - 1) / stride + 1; + } + CHECK_OR_FALSE(output_size > 0); + return output_size; +} + +bool PoolOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + return true; +} + +bool PoolOpLite::InferShape() const { + const auto input_dims = param_.x->dims(); + CHECK_OR_FALSE(input_dims.size() == 4 || input_dims.size() == 5); + + if (param_.global_pooling) { + param_.ksize.resize(static_cast(input_dims.size()) - 2); + for (size_t i = 0; i < param_.ksize.size(); ++i) { + param_.paddings[i] = 0; + param_.ksize[i] = static_cast(input_dims[i + 2]); + } + } + + CHECK_OR_FALSE(input_dims.size() - param_.ksize.size() == 2U); + CHECK_EQ_OR_FALSE(param_.ksize.size(), param_.strides.size()); + CHECK_EQ_OR_FALSE(param_.ksize.size(), param_.paddings.size()); + + std::vector output_shape({input_dims[0], input_dims[1]}); + if (param_.adaptive) { + output_shape.insert(output_shape.end(), param_.ksize.begin(), + param_.ksize.end()); + } else { + for (size_t i = 0; i < param_.ksize.size(); ++i) { + output_shape.push_back( + PoolOutputSize(input_dims[i + 2], param_.ksize[i], param_.paddings[i], + param_.strides[i], param_.ceil_mode)); + } + } + // share LoD + // param_.output->set_lod(param_.input->lod()); + param_.output->Resize(lite::DDim(output_shape)); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(pool2d, paddle::lite::operators::PoolOpLite); diff --git a/paddle/fluid/lite/operators/pool_op.h b/paddle/fluid/lite/operators/pool_op.h new file mode 100644 index 00000000000000..64c15ccf1db813 --- /dev/null +++ b/paddle/fluid/lite/operators/pool_op.h @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class PoolOpLite : public OpLite { + public: + PoolOpLite() {} + + explicit PoolOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + // TODO(Superjomn) replace framework::OpDesc with a lite one. + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto input = op_desc.Input("X").front(); + auto out = op_desc.Output("Out").front(); + + param_.x = scope->FindVar(input)->GetMutable(); + param_.output = scope->FindVar(out)->GetMutable(); + param_.pooling_type = op_desc.GetAttr("pooling_type"); + param_.ksize = op_desc.GetAttr>("ksize"); + param_.strides = op_desc.GetAttr>("strides"); + param_.paddings = op_desc.GetAttr>("paddings"); + param_.ceil_mode = op_desc.GetAttr("ceil_mode"); + param_.adaptive = op_desc.GetAttr("adaptive"); + param_.global_pooling = op_desc.GetAttr("global_pooling"); + return true; + } + + std::string DebugString() const override { return "pool"; } + + private: + mutable PoolParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/relu_op.cc b/paddle/fluid/lite/operators/relu_op.cc index 8f6ffd139927ed..a588b1c8cbf101 100644 --- a/paddle/fluid/lite/operators/relu_op.cc +++ b/paddle/fluid/lite/operators/relu_op.cc @@ -25,25 +25,23 @@ bool ReluOp::InferShape() const { CHECK_OR_FALSE(param_.output); // TODO(Superjomn) Enable data sharing. param_.output->Resize(param_.input->dims()); - // param_.output->ShareDataWith(*param_.input); // share lod // param_.output->set_lod(param_.input->lod()); return true; } -bool ReluOp::AttachImpl(const OpDesc &opdesc, lite::Scope *scope) { +bool ReluOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.input = const_cast( &scope->FindVar(opdesc.Input("Input").front())->Get()); param_.output = scope->FindVar(opdesc.Output("Out").front())->GetMutable(); CHECK(param_.input); CHECK(param_.output); - kernel_->SetParam(param_); return true; } -REGISTER_LITE_OP(relu, ReluOp); - } // namespace operators } // namespace lite } // namespace paddle + +REGISTER_LITE_OP(relu, paddle::lite::operators::ReluOp); diff --git a/paddle/fluid/lite/operators/relu_op.h b/paddle/fluid/lite/operators/relu_op.h index a6204a107d8c4b..945a9680a75d71 100644 --- a/paddle/fluid/lite/operators/relu_op.h +++ b/paddle/fluid/lite/operators/relu_op.h @@ -32,10 +32,10 @@ class ReluOp : public OpLite { bool InferShape() const override; - bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } - std::string DebugString() const override { return "tanh"; } + std::string DebugString() const override { return "relu"; } private: mutable ReluParam param_; diff --git a/paddle/fluid/lite/operators/reshape_op.cc b/paddle/fluid/lite/operators/reshape_op.cc new file mode 100644 index 00000000000000..6fc9c1af1e6646 --- /dev/null +++ b/paddle/fluid/lite/operators/reshape_op.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/reshape_op.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool ReshapeOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + CHECK_OR_FALSE(!param_.shape.empty()); + return true; +} + +bool ReshapeOp::InferShape() const { + auto x_dims = param_.x->dims(); + auto output_dims = ValidateShape(param_.shape, x_dims); + param_.output->Resize(output_dims); + return true; +} + +bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + auto x_var = scope->FindVar(opdesc.Input("X").front()); + auto output_var = scope->FindVar(opdesc.Output("Out").front()); + CHECK(x_var); + CHECK(output_var); + param_.x = const_cast(&(x_var->Get())); + param_.output = output_var->GetMutable(); + std::vector input_arg_names = opdesc.InputArgumentNames(); + if (std::find(input_arg_names.begin(), input_arg_names.end(), "Shape") != + input_arg_names.end()) { + auto actual_shape_var = scope->FindVar(opdesc.Input("Shape").front()); + if (actual_shape_var != nullptr) { + param_.actual_shape = + const_cast(&(actual_shape_var->Get())); + } + } + param_.shape = (opdesc.GetAttr>("shape")); + if (opdesc.HasAttr("inplace")) { + param_.inplace = opdesc.GetAttr("inplace"); + } + CHECK(param_.x) << "Input(X) of ReshapeOp should not be null."; + CHECK(param_.output) << "Output(Out) of ReshapeOp should not be null."; + CHECK(!param_.shape.empty()) + << "The shape information must be set by Attr(shape)."; + return true; +} + +bool Reshape2Op::CheckShape() const { + ReshapeOp::CheckShape(); + CHECK_OR_FALSE(param_.xshape); + return true; +} + +bool Reshape2Op::InferShape() const { + ReshapeOp::InferShape(); + auto x_dims = param_.x->dims(); + std::vector xshape_dims(x_dims.size() + 1, 0); + for (size_t i = 0; i < x_dims.size(); i++) { + xshape_dims[i + 1] = x_dims[i]; + } + param_.xshape->Resize(DDim(xshape_dims)); + return true; +} + +bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + ReshapeOp::AttachImpl(opdesc, scope); + auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); + CHECK(xshape_var); + param_.xshape = xshape_var->GetMutable(); + CHECK(param_.xshape) << "Output(XShape) of ReshapeOp should not be null."; + return true; +} + +DDim ValidateShape(const std::vector &shape, const DDim &input_dims) { + const DDim::value_type input_size = input_dims.production(); + auto input_shape = input_dims.Vectorize(); + bool all_positive = std::all_of(input_shape.cbegin(), input_shape.cend(), + [](DDim::value_type i) { return i > 0; }); + // only one dimension can be set to -1, whose size will be automatically + // infered. + const int unk_dim_val = -1; + const int copy_dim_val = 0; + + std::vector output_shape(shape.size(), 0); + DDim::value_type capacity = 1; + int unk_dim_idx = -1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == unk_dim_val) { + CHECK_EQ(unk_dim_idx, -1) + << "Only one input dimension of Attr(shape) can be unknown."; + unk_dim_idx = i; + } else if (shape[i] == copy_dim_val) { + CHECK_LT(static_cast(i), input_shape.size()) + << "The index of dimension to copy from input shape must be less " + "than the size of input shape."; + } else { + CHECK_GT(shape[i], 0) << "Each input dimension of Attr(shape) must not " + "be negtive except one unknown dimension."; + } + + capacity *= + (shape[i] ? static_cast(shape[i]) : input_shape[i]); + output_shape[i] = + (shape[i] ? static_cast(shape[i]) : input_shape[i]); + } + + if (unk_dim_idx != -1) { + if (all_positive) { + // input_size < 0 and is un-determinate in compile time, skip the check, + // for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], + // capacity = -24, input_size = -8, output_shape[0] = 0 + // the following check will fail. + output_shape[unk_dim_idx] = -input_size / capacity; + CHECK_EQ(output_shape[unk_dim_idx] * capacity, -input_size) + << "Invalid shape is given."; + } else { + output_shape[unk_dim_idx] = -1; + } + } else { + CHECK_EQ(capacity, input_size) << "Invalid shape is given."; + } + return DDim(output_shape); +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(reshape, paddle::lite::operators::ReshapeOp); +REGISTER_LITE_OP(reshape2, paddle::lite::operators::Reshape2Op); diff --git a/paddle/fluid/lite/operators/reshape_op.h b/paddle/fluid/lite/operators/reshape_op.h new file mode 100644 index 00000000000000..4f7e0b9c1348cf --- /dev/null +++ b/paddle/fluid/lite/operators/reshape_op.h @@ -0,0 +1,63 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ReshapeOp : public OpLite { + public: + ReshapeOp() {} + explicit ReshapeOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "reshape"; } + + protected: + mutable ReshapeParam param_; +}; + +class Reshape2Op : public ReshapeOp { + public: + Reshape2Op() : ReshapeOp() {} + explicit Reshape2Op(const std::string &op_type) : ReshapeOp(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "reshape2"; } +}; + +DDim ValidateShape(const std::vector &shape, const DDim &input_dims); + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/reshape_op_test.cc b/paddle/fluid/lite/operators/reshape_op_test.cc new file mode 100644 index 00000000000000..4bf137f16fe798 --- /dev/null +++ b/paddle/fluid/lite/operators/reshape_op_test.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/reshape_op.h" +#include +#include +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +TEST(reshape_op_lite, test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* actual_shape = scope.Var("actual_shape")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); + std::map, std::vector> shapes = { + {{-1, 0, 3, 2, 1}, {2, 4, 3, 2, 1}}, + {{0, -1, 3, 2, 1}, {2, 4, 3, 2, 1}}, + {{-1, 48}, {1, 48}}, + {{48, -1}, {48, 1}}, + {{0, 24}, {2, 24}}, + {{12, 0}, {12, 4}}, + }; + x->Resize(DDim(std::vector({2, 4, 6}))); + actual_shape->Resize(DDim(std::vector({2}))); + + auto* actual_shape_data = actual_shape->mutable_data(); + actual_shape_data[0] = 6; + actual_shape_data[1] = 8; + + for (auto& shape : shapes) { + for (auto& has_actual_shape : {true, false}) { + for (auto& inplace : {true, false}) { + // prepare op desc + cpp::OpDesc desc; + desc.SetType("reshape"); + desc.SetInput("X", {"x"}); + if (has_actual_shape) { + desc.SetInput("Shape", {"actual_shape"}); + } + desc.SetOutput("Out", {"output"}); + desc.SetAttr("shape", shape.first); + desc.SetAttr("inplace", inplace); + + ReshapeOp reshape("reshape"); + + reshape.SetValidPlaces( + {Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}}); + reshape.Attach(desc, &scope); + reshape.CheckShape(); + reshape.InferShape(); + + // check output dims + auto output_dims = output->dims(); + CHECK_EQ(output_dims.size(), shape.second.size()); + for (size_t i = 0; i < output_dims.size(); i++) { + CHECK_EQ(output_dims[i], shape.second[i]); + } + } + } + } +} + +TEST(reshape2_op_lite, test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* actual_shape = scope.Var("actual_shape")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); + auto* xshape = scope.Var("xshape")->GetMutable(); + std::map, std::vector> shapes = { + {{-1, 0, 3, 2, 1}, {2, 4, 3, 2, 1}}, + {{0, -1, 3, 2, 1}, {2, 4, 3, 2, 1}}, + {{-1, 48}, {1, 48}}, + {{48, -1}, {48, 1}}, + {{0, 24}, {2, 24}}, + {{12, 0}, {12, 4}}, + }; + x->Resize(DDim(std::vector({2, 4, 6}))); + actual_shape->Resize(DDim(std::vector({2}))); + + auto* actual_shape_data = actual_shape->mutable_data(); + actual_shape_data[0] = 6; + actual_shape_data[1] = 8; + + for (auto& shape : shapes) { + for (auto& has_actual_shape : {true, false}) { + for (auto& inplace : {true, false}) { + // prepare op desc + cpp::OpDesc desc; + desc.SetType("reshape"); + desc.SetInput("X", {"x"}); + if (has_actual_shape) { + desc.SetInput("Shape", {"actual_shape"}); + } + desc.SetOutput("Out", {"output"}); + desc.SetOutput("XShape", {"xshape"}); + desc.SetAttr("shape", shape.first); + desc.SetAttr("inplace", inplace); + + Reshape2Op reshape2("reshape2"); + + reshape2.SetValidPlaces( + {Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}}); + reshape2.Attach(desc, &scope); + reshape2.CheckShape(); + reshape2.InferShape(); + + // check output dims + auto output_dims = output->dims(); + CHECK_EQ(output_dims.size(), shape.second.size()); + for (int i = 0; i < output_dims.size(); i++) { + CHECK_EQ(output_dims[i], shape.second[i]); + } + // check xshape dims + auto x_dims = x->dims(); + auto xshape_dims = xshape->dims(); + CHECK_EQ(xshape_dims.size(), x_dims.size() + 1); + CHECK_EQ(xshape_dims[0], 0); + for (size_t i = 0; i < x_dims.size(); i++) { + CHECK_EQ(xshape_dims[i + 1], x_dims[i]); + } + } + } + } +} + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/scale_op.cc b/paddle/fluid/lite/operators/scale_op.cc index 87cbe2a2e035bd..fb55366488cae9 100644 --- a/paddle/fluid/lite/operators/scale_op.cc +++ b/paddle/fluid/lite/operators/scale_op.cc @@ -12,58 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include "paddle/fluid/lite/core/kernel.h" -#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/operators/scale_op.h" #include "paddle/fluid/lite/core/op_registry.h" -#include "paddle/fluid/lite/core/scope.h" -#include "paddle/fluid/lite/operators/op_params.h" -#include "paddle/fluid/lite/utils/all.h" - namespace paddle { namespace lite { namespace operators { -class ScaleOp : public OpLite { - public: - ScaleOp() {} - - explicit ScaleOp(const std::string &type) : OpLite(type) {} - - bool CheckShape() const override { - CHECK_OR_FALSE(param_.x); - CHECK_OR_FALSE(param_.output); - return true; - } - - bool InferShape() const override { - param_.output->Resize(param_.x->dims()); - return true; - } - - void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } - - // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override { - auto x = op_desc.Input("X").front(); - auto out = op_desc.Output("Out").front(); - - param_.x = scope->FindVar(x)->GetMutable(); - CHECK(scope->FindVar(out)); - param_.output = scope->FindVar(out)->GetMutable(); - param_.scale = GetAttr(op_desc.GetAttr("scale")); - param_.bias = GetAttr(op_desc.GetAttr("bias")); - param_.bias_after_scale = - GetAttr(op_desc.GetAttr("bias_after_scale")); - return true; - } - - std::string DebugString() const override { return op_type_; } - - private: - mutable ScaleParam param_; -}; +bool ScaleOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + return true; +} + +bool ScaleOp::InferShape() const { + param_.output->Resize(param_.x->dims()); + return true; +} + +bool ScaleOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + auto x = op_desc.Input("X").front(); + auto output = op_desc.Output("Out").front(); + param_.x = scope->FindVar(x)->GetMutable(); + param_.output = scope->FindVar(output)->GetMutable(); + param_.scale = op_desc.GetAttr("scale"); + param_.bias = op_desc.GetAttr("bias"); + param_.bias_after_scale = op_desc.GetAttr("bias_after_scale"); + CHECK(param_.x); + CHECK(param_.output); + return true; +} } // namespace operators } // namespace lite diff --git a/paddle/fluid/lite/operators/scale_op.h b/paddle/fluid/lite/operators/scale_op.h new file mode 100644 index 00000000000000..43493710bba591 --- /dev/null +++ b/paddle/fluid/lite/operators/scale_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ScaleOp : public OpLite { + public: + ScaleOp() {} + explicit ScaleOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "scale"; } + + private: + mutable ScaleParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/scale_op_test.cc b/paddle/fluid/lite/operators/scale_op_test.cc new file mode 100644 index 00000000000000..33ab91ff05cea1 --- /dev/null +++ b/paddle/fluid/lite/operators/scale_op_test.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/scale_op.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +TEST(scale_op_lite, test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); + x->Resize(DDim(std::vector({10, 20}))); + output->Resize(DDim(std::vector{1, 1})); + + // prepare op desc + cpp::OpDesc desc; + desc.SetType("scale"); + desc.SetInput("X", {"x"}); + desc.SetOutput("Out", {"output"}); + desc.SetAttr("bias_after_scale", false); + desc.SetAttr("scale", 0.5f); + desc.SetAttr("bias", 0.125f); + + ScaleOp scale("scale"); + + scale.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); + scale.Attach(desc, &scope); + scale.CheckShape(); + scale.InferShape(); + + // check output dims + auto x_dims = x->dims(); + auto output_dims = output->dims(); + CHECK_EQ(output_dims.size(), x_dims.size()); + for (size_t i = 0; i < output_dims.size(); i++) { + CHECK_EQ(output_dims[i], x_dims[i]); + } +} + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/sgd_op.cc b/paddle/fluid/lite/operators/sgd_op.cc new file mode 100644 index 00000000000000..2571ad0b102d8e --- /dev/null +++ b/paddle/fluid/lite/operators/sgd_op.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "/paddle/paddle/fluid/lite/operators/sgd_op.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SGDOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.Param); + CHECK_OR_FALSE(param_.LearningRate); + CHECK_OR_FALSE(param_.Grad); + CHECK_OR_FALSE(param_.ParamOut); + return true; +} + +bool SGDOpLite::InferShape() const { + auto lr_dims = param_.LearningRate->dims().data(); + CHECK_EQ_OR_FALSE(framework::product(lr_dims), 1); + param_.ParamOut->Resize(param_.Param->dims()); + return true; +} + +bool SGDOpLite::AttachImpl(const OpDesc& opdesc, lite::Scope* scope) { + CHECK_EQ(opdesc.Inputs().size(), 3UL); + auto Param_name = opdesc.Input("Param").front(); + auto LearningRate_name = opdesc.Input("LearningRate").front(); + auto Grad_name = opdesc.Input("Grad").front(); + auto ParamOut_name = opdesc.Output("ParamOut").front(); + + param_.Param = GetVar(scope, Param_name); + param_.LearningRate = GetVar(scope, LearningRate_name); + param_.Grad = GetVar(scope, Grad_name); + param_.ParamOut = GetMutableVar(scope, ParamOut_name); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(sgd, paddle::lite::operators::SGDOpLite); diff --git a/paddle/fluid/lite/operators/sgd_op.h b/paddle/fluid/lite/operators/sgd_op.h new file mode 100644 index 00000000000000..dea045c0b67cbf --- /dev/null +++ b/paddle/fluid/lite/operators/sgd_op.h @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SGDOpLite : public OpLite { + public: + SGDOpLite() {} + + explicit SGDOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override; + + std::string DebugString() const override { return "sgd"; } + + private: + mutable SGDParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/softmax_op.cc b/paddle/fluid/lite/operators/softmax_op.cc new file mode 100644 index 00000000000000..41d7b335e80bc0 --- /dev/null +++ b/paddle/fluid/lite/operators/softmax_op.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/softmax_op.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SoftmaxOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + auto x_dims = param_.x->dims(); + auto x_rank = x_dims.size(); + CHECK_OR_FALSE(param_.axis >= -static_cast(x_rank) && + param_.axis < static_cast(x_rank)); + return true; +} + +bool SoftmaxOp::InferShape() const { + param_.output->Resize(param_.x->dims()); + return true; +} + +bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + param_.x = const_cast( + &scope->FindVar(opdesc.Input("X").front())->Get()); + param_.output = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + param_.axis = opdesc.GetAttr("axis"); + CHECK(param_.x); + CHECK(param_.output); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(softmax, paddle::lite::operators::SoftmaxOp); diff --git a/paddle/fluid/lite/operators/softmax_op.h b/paddle/fluid/lite/operators/softmax_op.h new file mode 100644 index 00000000000000..515e4493c9949a --- /dev/null +++ b/paddle/fluid/lite/operators/softmax_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SoftmaxOp : public OpLite { + public: + SoftmaxOp() {} + explicit SoftmaxOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "softmax"; } + + private: + mutable SoftmaxParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/softmax_op_test.cc b/paddle/fluid/lite/operators/softmax_op_test.cc new file mode 100644 index 00000000000000..4659b35cd7bbe6 --- /dev/null +++ b/paddle/fluid/lite/operators/softmax_op_test.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/softmax_op.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +TEST(softmax_op_lite, test) { + // prepare variables + Scope scope; + auto* x = scope.Var("x")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); + x->Resize(DDim(std::vector({10, 20}))); + output->Resize(DDim(std::vector{10, 20})); + + // set data + for (int i = 0; i < 10 * 20; i++) { + x->mutable_data()[i] = i; + } + for (int i = 0; i < 10 * 20; i++) { + output->mutable_data()[i] = 0.; + } + + // prepare op desc + cpp::OpDesc desc; + desc.SetType("softmax"); + desc.SetInput("X", {"x"}); + desc.SetOutput("Out", {"output"}); + desc.SetAttr("axis", static_cast(-1)); + + SoftmaxOp softmax("softmax"); + + softmax.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}}); + softmax.Attach(desc, &scope); +} + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/tools/Dockerfile.mobile b/paddle/fluid/lite/tools/Dockerfile.mobile new file mode 100644 index 00000000000000..e48af1227513fe --- /dev/null +++ b/paddle/fluid/lite/tools/Dockerfile.mobile @@ -0,0 +1,90 @@ +# A image for paddle lite mobile cross compile and simulator on android + +FROM ubuntu:16.04 +MAINTAINER PaddlePaddle Authors + +RUN echo '\ +deb main restricted universe multiverse\n\ +deb -updates main restricted universe multiverse\n\ +deb -backports main restricted universe multiverse\n\ +deb -security main restricted universe multiverse\n'\ +> /etc/apt/sources.list +RUN sed -ie 's||http://mirrors.tuna.tsinghua.edu.cn/ubuntu/|' /etc/apt/sources.list +RUN sed -ie 's||xenial|' /etc/apt/sources.list + +RUN apt-get update && apt-get upgrade -y +RUN apt-get install -y --no-install-recommends \ + clang-format-3.8 \ + cmake-curses-gui \ + curl \ + fish \ + gawk \ + gcc \ + g++ \ + git \ + graphviz \ + less \ + make \ + patch \ + python \ + python-pip \ + python-setuptools \ + unzip \ + vim \ + wget + +# for android simulator +RUN apt-get install -y --no-install-recommends \ + libc6-i386 \ + lib32stdc++6 \ + redir \ + iptables \ + openjdk-8-jre + +# for cmake 3.10 +RUN curl -O https://mms-res.cdn.bcebos.com/cmake-3.10.3-Linux-x86_64.tar.gz && \ + tar xzf cmake-3.10.3-Linux-x86_64.tar.gz && \ + mv cmake-3.10.3-Linux-x86_64 /opt/cmake-3.10 && \ + mv /usr/bin/cmake /usr/bin/cmake.bak && ln -s /opt/cmake-3.10/bin/cmake /usr/bin/cmake && \ + mv /usr/bin/ccmake /usr/bin/ccmake.bak && ln -s /opt/cmake-3.10/bin/ccmake /usr/bin/ccmake + +# for arm linux compile +RUN apt-get install -y --no-install-recommends \ + g++-arm-linux-gnueabi \ + gcc-arm-linux-gnueabi \ + g++-arm-linux-gnueabihf \ + gcc-arm-linux-gnueabihf \ + gcc-aarch64-linux-gnu \ + g++-aarch64-linux-gnu + +# for android ndk17c +RUN cd /tmp && curl -O https://dl.google.com/android/repository/android-ndk-r17c-linux-x86_64.zip +ENV NDK_ROOT /opt/android-ndk-r17c +RUN cd /opt && unzip /tmp/android-ndk-r17c-linux-x86_64.zip + +# for android simulator +ENV ANDROID_HOME /opt/android_sdk +ENV PATH $PATH:${ANDROID_HOME}/tools:${ANDROID_HOME}/platform-tools:${ANDROID_HOME}/tools/bin +RUN wget "https://dl.google.com/android/repository/sdk-tools-linux-4333796.zip" && \ + unzip sdk-tools-linux-4333796.zip -d /opt/android_sdk && \ + mkdir /root/.android && touch /root/.android/repositories.cfg && \ + echo y | sdkmanager "platform-tools" "emulator" && \ + echo y | sdkmanager "platforms;android-24" && \ + echo y | sdkmanager "system-images;android-24;google_apis;arm64-v8a" "system-images;android-24;google_apis;armeabi-v7a" + +# this will install the ndk19c and only use clang to compile, then can replace the ndk17c +# echo y | sdkmanager "ndk;19.2.5345600" + +# Expose android port +EXPOSE 5555 +EXPOSE 5557 +# VNC port +EXPOSE 5900 + +# clean +RUN ln -s clang-format-3.8 /usr/bin/clang-format +RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --upgrade pip +RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple wheel +RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pre-commit +RUN apt-get autoremove -y && apt-get clean +RUN rm -rf /sdk-tools-linux-4333796.zip /tmp/android-ndk-r17c-linux-x86_64.zip /cmake-3.10.3-Linux-x86_64.tar.gz diff --git a/paddle/fluid/lite/tools/build.sh b/paddle/fluid/lite/tools/build.sh new file mode 100755 index 00000000000000..2a31f8d1ff9a52 --- /dev/null +++ b/paddle/fluid/lite/tools/build.sh @@ -0,0 +1,229 @@ +#!/bin/bash +set -ex + +TESTS_FILE="./lite_tests.txt" +LIBS_FILE="./lite_libs.txt" + +readonly common_flags="-DWITH_LITE=ON -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF -DWITH_PYTHON=OFF -DWITH_TESTING=ON -DLITE_WITH_ARM=OFF" + +# for code gen, a source file is generated after a test, but is dependended by some targets in cmake. +# here we fake an empty file to make cmake works. +function prepare_for_codegen { + # in build directory + mkdir -p ./paddle/fluid/lite/gen_code + touch ./paddle/fluid/lite/gen_code/__generated_code__.cc +} +function cmake_x86 { + prepare_for_codegen + cmake .. -DWITH_GPU=OFF -DWITH_MKLDNN=OFF -DLITE_WITH_X86=ON ${common_flags} +} + +function cmake_x86_for_CI { + prepare_for_codegen + cmake .. -DWITH_GPU=OFF -DWITH_MKLDNN=OFF -DLITE_WITH_X86=ON ${common_flags} -DLITE_WITH_PROFILE=ON +} + +function cmake_gpu { + prepare_for_codegen + cmake .. " -DWITH_GPU=ON {common_flags} -DLITE_WITH_GPU=ON" +} + +function cmake_arm { + # $1: ARM_TARGET_OS in "android" , "armlinux" + # $2: ARM_TARGET_ARCH_ABI in "arm64-v8a", "armeabi-v7a" ,"armeabi-v7a-hf" + cmake .. \ + -DWITH_GPU=OFF \ + -DWITH_MKL=OFF \ + -DWITH_LITE=ON \ + -DLITE_WITH_CUDA=OFF \ + -DLITE_WITH_X86=OFF \ + -DLITE_WITH_ARM=ON \ + -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \ + -DWITH_TESTING=ON \ + -DARM_TARGET_OS=$1 -DARM_TARGET_ARCH_ABI=$2 +} + +function build { + file=$1 + for _test in $(cat $file); do + make $_test -j$(expr $(nproc) - 2) + done +} + +# It will eagerly test all lite related unittests. +function test_lite { + local file=$1 + echo "file: ${file}" + + for _test in $(cat $file); do + # We move the build phase here to make the 'gen_code' test compiles after the + # corresponding test is executed and the C++ code generates. + make $_test -j$(expr $(nproc) - 2) + ctest -R $_test -V + done +} + +port_armv8=5554 +port_armv7=5556 + +# Run test on android +function test_lite_android { + local file=$1 + local adb_abi=$2 + local port= + if [[ ${adb_abi} == "armeabi-v7a" ]]; then + port=${port_armv7} + fi + + if [[ ${adb_abi} == "arm64-v8a" ]]; then + port=${port_armv8} + fi + if [[ "${port}x" == "x" ]]; then + echo "Port can not be empty" + exit 1 + fi + + echo "file: ${file}" + # push all to adb and test + adb_work_dir="/data/local/tmp" + skip_list="test_model_parser_lite" + for _test in $(cat $file); do + [[ $skip_list =~ (^|[[:space:]])$_test($|[[:space:]]) ]] && continue || echo 'skip $_test' + testpath=$(find ./paddle/fluid -name ${_test}) + adb -s emulator-${port} push ${testpath} ${adb_work_dir} + adb -s emulator-${port} shell chmod +x "${adb_work_dir}/${_test}" + adb -s emulator-${port} shell "./${adb_work_dir}/${_test}" + done +} + +# Build the code and run lite server tests. This is executed in the CI system. +function build_test_server { + mkdir -p ./build + cd ./build + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/paddle/build/third_party/install/mklml/lib" + cmake_x86_for_CI + # compile the tests and execute them. + test_lite $TESTS_FILE + # build the remaining libraries to check compiling error. + build $LIBS_FILE +} + +# Build the code and run lite server tests. This is executed in the CI system. +function build_test_arm { + adb kill-server + adb devices | grep emulator | cut -f1 | while read line; do adb -s $line emu kill; done + # start android arm64-v8a armeabi-v7a emulators first + echo n | avdmanager create avd -f -n paddle-armv8 -k "system-images;android-24;google_apis;arm64-v8a" + echo -ne '\n' | ${ANDROID_HOME}/emulator/emulator -avd paddle-armv8 -noaudio -no-window -gpu off -verbose -port ${port_armv8} & + sleep 1m + echo n | avdmanager create avd -f -n paddle-armv7 -k "system-images;android-24;google_apis;armeabi-v7a" + echo -ne '\n' | ${ANDROID_HOME}/emulator/emulator -avd paddle-armv7 -noaudio -no-window -gpu off -verbose -port ${port_armv7} & + sleep 1m + + for os in "android" "armlinux" ; do + for abi in "arm64-v8a" "armeabi-v7a" "armeabi-v7a-hf" ; do + if [[ ${abi} == "armeabi-v7a-hf" ]]; then + echo "armeabi-v7a-hf is not supported on both android and armlinux" + continue + fi + + if [[ ${os} == "armlinux" && ${abi} == "armeabi-v7a" ]]; then + echo "armeabi-v7a is not supported on armlinux yet" + continue + fi + + build_dir=build.lite.${os}.${abi} + mkdir -p $build_dir + cd $build_dir + cmake_arm ${os} ${abi} + build $TESTS_FILE + + if [[ ${os} == "android" ]]; then + adb_abi=${abi} + if [[ ${adb_abi} == "armeabi-v7a-hf" ]]; then + adb_abi="armeabi-v7a" + fi + if [[ ${adb_abi} == "armeabi-v7a" ]]; then + # skip v7 tests + continue + fi + test_lite_android $TESTS_FILE ${adb_abi} + # armlinux need in another docker + fi + cd - + done + done + adb devices | grep emulator | cut -f1 | while read line; do adb -s $line emu kill; done + echo "Done" +} + +############################# MAIN ################################# +function print_usage { + echo -e "\nUSAGE:" + echo + echo "----------------------------------------" + echo -e "cmake_x86: run cmake with X86 mode" + echo -e "cmake_cuda: run cmake with CUDA mode" + echo -e "cmake_arm: run cmake with ARM mode" + echo + echo -e "build: compile the tests" + echo + echo -e "test_server: run server tests" + echo -e "test_mobile: run mobile tests" + echo "----------------------------------------" + echo +} + +function main { + # Parse command line. + for i in "$@"; do + case $i in + --tests=*) + TESTS_FILE="${i#*=}" + shift + ;; + build) + build $TESTS_FILE + build $LIBS_FILE + shift + ;; + cmake_x86) + cmake_x86 + shift + ;; + cmake_cuda) + cmake_cuda + shift + ;; + cmake_arm) + cmake_arm $2 $3 + shift + ;; + test_server) + test_lite $TESTS_FILE + shift + ;; + test_mobile) + test_lite $TESTS_FILE + shift + ;; + build_test_server) + build_test_server + shift + ;; + build_test_arm) + build_test_arm + shift + ;; + *) + # unknown option + print_usage + exit 1 + ;; + esac + done +} + +print_usage + +main $@ diff --git a/paddle/fluid/lite/tools/mobile_readme.md b/paddle/fluid/lite/tools/mobile_readme.md new file mode 100644 index 00000000000000..2069de2af2664f --- /dev/null +++ b/paddle/fluid/lite/tools/mobile_readme.md @@ -0,0 +1,126 @@ + +# Paddle-lite-mobile开发指南 + +## 交叉编译 + +Paddle-lite-mobile 推荐在我们的Docker环境下交叉编译,减少环境配置上的不必要问题。 + +### 1. 拉取代码创建容器 + +```shell +$ git clone https://github.com/PaddlePaddle/Paddle.git +$ git checkout incubate/lite +``` + +编译docker环境: +`docker build --file paddle/fluid/lite/tools/Dockerfile.mobile --tag paddle-lite-mobile:latest . ` + +### 主要cmake选项 + +- `ARM_TARGET_OS` 代表目标操作系统, 目前支持 "android" "armlinux", 模型是Android +- `ARM_TARGET_ARCH_ABI` 代表ARCH, 目前支持 "arm64-v8a" "armeabi-v7a"。 模型是arm64-v8a + +### 编译 + +基于`paddle-lite-mobile`镜像创建容器,并在容器内外建立目录映射关系: + +```shell +$ docker run -it --name --net=host --privileged -v : paddle-lite-mobile bash +``` + +参考build.sh下的 cmake arm编译需要的平台。 + +参考示例: + +```shell +#!/bin/bash +cmake .. \ + -DWITH_GPU=OFF \ + -DWITH_LITE=ON \ + -DLITE_WITH_CUDA=OFF \ + -DLITE_WITH_X86=OFF \ + -DLITE_WITH_ARM=ON \ + -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \ + -DWITH_TESTING=ON \ + -DWITH_MKL=OFF \ + -DARM_TARGET_OS="android" -DARM_TARGET_ARCH_ABI="arm64-v8a" + +# fc层单测 +make test_fc_compute_arm -j + +``` +### 在Android上执行 + +#### 1. 创建模拟器(如果使用真机则跳过此步骤) + +```shell +# 创建Android avd (armv8) +$ echo n | avdmanager create avd -f -n paddle-armv8 -k "system-images;android-24;google_apis;arm64-v8a" +# 启动Android armv8 emulator +$ ${ANDROID_HOME}/emulator/emulator -avd paddle-armv8 -noaudio -no-window -gpu off -verbose & + +# 如果需要执行armv7版本,如下: +# $ echo n | avdmanager create avd -f -n paddle-armv7 -k "system-images;android-24;google_apis;armeabi-v7a" +# $ ${ANDROID_HOME}/emulator/emulator -avd paddle-armv7 -noaudio -no-window -gpu off -verbose & + +# 退出所有模拟器 +adb devices | grep emulator | cut -f1 | while read line; do adb -s $line emu kill; done +``` + +#### 2. 上传编译文件到手机上 + +键盘上`crtl+q+p`同时摁下,切换到容器外(容器还在后台运行),将刚刚编译出的程序`adb push`到手机上。USB线连接手机,确保`adb devices`可以找到手机设备。 +```shell +$ cd +$ adb push ./build/paddle/fluid/lite/kernels/arm/test_fc_compute_arm /data/local/tmp/ + +# 进入手机 +$ adb shell # 若多台手机设备先用命令adb devices查看目标手机的序列码 +$ cd /data/local/tmp + +# 执行编译的程序 +$ ./test_fc_compute_arm +``` + +### 在ARM LINUX下执行 + +拉取Linux arm64镜像 +```shell +$ docker pull multiarch/ubuntu-core:arm64-bionic +``` +运行容器并在内外建立目录映射关系 +```shell +$ docker run -it --name -v : multiarch/ubuntu-core:arm64-bionic +``` +进入bin目录,并运行并文件 +```shell +$ cd +$ ./test_fc_compute_arm +``` + +# Q&A + +#### 1. adb命令找不到:adb: command not found +解决:`sudo apt install -y adb` + +#### 2. 明明手机USB连接电脑却显示找不到设备:`error: device not found` +解决: +第一步`lsusb`命令查看插上拔下手机前后usb设备的变化情况,确定手机设备的ID。 假设`lsusb`命令执行显示`Bus 003 Device 011: ID 2717:9039 `,则ID是`0x2717`; +第二步:创建`adb_usb.ini`文件并追加写入ID:`echo 0x2717 >> ~/.android/adb_usb.ini`; +第三步:给手机添加权限`sudo vim /etc/udev/rules.d/70-android.rules`,根据第一步骤取得的`ATTRS{idVendor}`和`ATTRS{idProduct}`这两个属性值,在该文件加入该设备信息: + `SUBSYSTEM=="usb", ATTRS{idVendor}=="2717", ATTRS{idProduct}=="9039",MODE="0666"`; +第四步:重启USB服务: +```shell +$ sudo chmod a+rx /etc/udev/rules.d/70-android.rules +$ sudo service udev restart +``` +第五步:重启adb服务,adb devices有设备说明adb安装成功。 +```shell +$ adb kill-server +$ sudo adb start-server +$ adb devices + +# 若显示连接的手机设备,则表示成功 +List of devices attached +5cb00b6 device +``` diff --git a/paddle/fluid/lite/utils/CMakeLists.txt b/paddle/fluid/lite/utils/CMakeLists.txt index 1d299367d235b2..08eeaa54f8eacd 100644 --- a/paddle/fluid/lite/utils/CMakeLists.txt +++ b/paddle/fluid/lite/utils/CMakeLists.txt @@ -1,10 +1,11 @@ -if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) - set(utils_DEPS) - lite_cc_test(test_logging_lite SRCS logging_test.cc) -else() - set(utils_DEPS glog) -endif() +# if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) +# set(utils_DEPS) +# lite_cc_test(test_logging_lite SRCS logging_test.cc) +# else() +# endif() + +set(utils_DEPS glog) lite_cc_test(test_varient SRCS varient_test.cc DEPS utils_lite) cc_library(any_lite SRCS any.cc) -cc_library(utils_lite SRCS cp_logging.cc DEPS ${utils_DEPS} any_lite) +cc_library(utils_lite SRCS cp_logging.cc string.cc DEPS ${utils_DEPS} any_lite) diff --git a/paddle/fluid/lite/utils/all.h b/paddle/fluid/lite/utils/all.h index 70e71ae3008ace..7cc98a45201eb3 100644 --- a/paddle/fluid/lite/utils/all.h +++ b/paddle/fluid/lite/utils/all.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/fluid/lite/utils/any.h" #include "paddle/fluid/lite/utils/check.h" #include "paddle/fluid/lite/utils/cp_logging.h" #include "paddle/fluid/lite/utils/factory.h" @@ -21,4 +22,3 @@ #include "paddle/fluid/lite/utils/io.h" #include "paddle/fluid/lite/utils/macros.h" #include "paddle/fluid/lite/utils/varient.h" -#include "paddle/fluid/lite/utils/any.h" diff --git a/paddle/fluid/lite/utils/cp_logging.h b/paddle/fluid/lite/utils/cp_logging.h index d356b337abd5f4..e3c0f392533dca 100644 --- a/paddle/fluid/lite/utils/cp_logging.h +++ b/paddle/fluid/lite/utils/cp_logging.h @@ -13,8 +13,8 @@ // limitations under the License. #pragma once -#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK -#include "paddle/fluid/lite/utils/logging.h" -#else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +// #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +// #include "paddle/fluid/lite/utils/logging.h" +// #else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #include -#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +// #endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK diff --git a/paddle/fluid/lite/utils/string.cc b/paddle/fluid/lite/utils/string.cc new file mode 100644 index 00000000000000..c608c31fb9ffb2 --- /dev/null +++ b/paddle/fluid/lite/utils/string.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/utils/string.h" + +namespace paddle { +namespace lite {} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/utils/string.h b/paddle/fluid/lite/utils/string.h new file mode 100644 index 00000000000000..31b131276bfa22 --- /dev/null +++ b/paddle/fluid/lite/utils/string.h @@ -0,0 +1,78 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include // For va_start, etc. +#include +#include +#include // For std::unique_ptr +#include +#include +#include + +namespace paddle { +namespace lite { + +static std::string string_format(const std::string fmt_str, ...) { + /* Reserve two times as much as the length of the fmt_str */ + int final_n, n = (static_cast(fmt_str.size())) * 2; + std::unique_ptr formatted; + va_list ap; + while (1) { + formatted.reset( + new char[n]); /* Wrap the plain char array into the unique_ptr */ + std::strcpy(&formatted[0], fmt_str.c_str()); // NOLINT + va_start(ap, fmt_str); + final_n = vsnprintf(&formatted[0], n, fmt_str.c_str(), ap); + va_end(ap); + if (final_n < 0 || final_n >= n) + n += abs(final_n - n + 1); + else + break; + } + return std::string(formatted.get()); +} + +template +static std::string to_string_with_precision(const T& v, const int n = 6) { + std::stringstream ss; + ss.precision(n); + ss << std::fixed << v; + return ss.str(); +} + +static std::string Join(const std::vector& vec, + const std::string& delim) { + if (vec.empty()) return ""; + + std::stringstream ss; + for (size_t i = 0; i < vec.size() - 1; i++) ss << vec[i] << delim; + if (!vec.empty()) { + ss << vec.back(); + } + + return ss.str(); +} + +static std::string Repr(const std::string& x) { return "\"" + x + "\""; } + +static std::string Repr(const std::vector& v) { + std::vector tmp; + std::transform(v.begin(), v.end(), std::back_inserter(tmp), + [](const std::string& x) { return Repr(x); }); + return "{" + Join(tmp, ",") + "}"; +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/utils/varient.h b/paddle/fluid/lite/utils/varient.h index 49714eea7f4391..2d2a3061108978 100644 --- a/paddle/fluid/lite/utils/varient.h +++ b/paddle/fluid/lite/utils/varient.h @@ -116,8 +116,9 @@ struct variant { if (type_id == typeid(T).hash_code()) return *reinterpret_cast(&data); else - LOG(FATAL) << "unmatched type get, should be " << type_id << " but get " - << typeid(T).name(); + throw std::invalid_argument("unmatched type"); + // LOG(FATAL) << "unmatched type get, should be " << type_id << " but get " + // << typeid(T).name(); return *reinterpret_cast(&data); } @@ -127,8 +128,9 @@ struct variant { if (type_id == typeid(T).hash_code()) return reinterpret_cast(&data); else - LOG(FATAL) << "unmatched type get, should be " << type_id << " but get " + LOG(ERROR) << "unmatched type get, should be " << type_id << " but get " << typeid(T).name(); + throw std::invalid_argument("unmatched type"); } ~variant() { helper_t::destroy(type_id, &data); } }; diff --git a/paddle/fluid/lite/x86/CMakeLists.txt b/paddle/fluid/lite/x86/CMakeLists.txt index be772b921b4edc..0347593e38af4a 100644 --- a/paddle/fluid/lite/x86/CMakeLists.txt +++ b/paddle/fluid/lite/x86/CMakeLists.txt @@ -3,3 +3,4 @@ if (NOT LITE_WITH_X86) endif() cc_library(target_wrapper_x86 SRCS target_wrapper.cc) + diff --git a/paddle/fluid/memory/CMakeLists.txt b/paddle/fluid/memory/CMakeLists.txt index 7eb663ea280e65..0d4c5c37e1dc88 100644 --- a/paddle/fluid/memory/CMakeLists.txt +++ b/paddle/fluid/memory/CMakeLists.txt @@ -6,7 +6,8 @@ cc_library(memcpy SRCS memcpy.cc DEPS place) cc_library(memory DEPS malloc - memcpy) + memcpy + ) #if (WITH_GPU) # nv_test(pinned_memory_test SRCS pinned_memory_test.cu DEPS place memory) #endif() diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 5de00db55add1e..c4386689d3e4fd 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -72,7 +72,7 @@ ENDIF() # avoiding cycle dependencies cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc ${STREAM_CALLBACK_DEPS} place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} - temp_allocator ${dgc_deps}) + temp_allocator ${dgc_deps} xxhash) if(WIN32) if(WITH_GPU AND NOT WITH_DSO)