diff --git a/.github/workflows/metax_work.yaml b/.github/workflows/metax_work.yaml new file mode 100644 index 00000000000..0d3d2637cdd --- /dev/null +++ b/.github/workflows/metax_work.yaml @@ -0,0 +1,52 @@ +name: padlle metax gpu test + +on: + workflow_dispatch: + pull_request: + types: [opened, synchronize] + branches: [develop, release/**] + paths: + - "**" + - "!backends/**" + - "backends/metax_gpu/**" + +permissions: read-all + +defaults: + run: + shell: bash + +jobs: + metax-gpu-test: + runs-on: paddle-metax-runner-set + steps: + - name: Checkout repository + run: | + git config --global user.name "GitHub Actions" + git config --global user.email "actions@github.com" + + if [ "${{ github.event_name }}" == "pull_request" ]; then + BRANCH_NAME=${{ github.head_ref }} + else + BRANCH_NAME=${{ github.ref_name }} + fi + + git clone \ + --reference-if-able /home/runner/PaddleCustomDevice \ + --depth=1 \ + --shallow-submodules \ + --jobs=8 \ + --branch $BRANCH_NAME \ + --recurse-submodules \ + https://${{ github.actor }}:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}.git . + + + - name: compile + run: | + cd backends/metax_gpu + bash build.sh + + - name: run test + run: | + cd backends/metax_gpu/tests + bash run_test.sh diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index 6048b59e6c1..cca23ab42f5 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -37,6 +37,8 @@ include(cblas) include(flashattn) include(cutlass) include(dgc) +include(warpctc) +include(warprnnt) set(PLUGIN_VERSION ${PADDLE_VERSION}) diff --git a/backends/metax_gpu/change_patch.sh b/backends/metax_gpu/change_patch.sh index 833ae00f6bd..60d74ec0f3d 100644 --- a/backends/metax_gpu/change_patch.sh +++ b/backends/metax_gpu/change_patch.sh @@ -25,3 +25,4 @@ cp patch/tmp/mixed_vector* ../../Paddle/paddle/phi/core cd ../../Paddle/ git apply --verbose ../backends/metax_gpu/patch/paddle.patch cd - +cp -r patch/intrinsics.cuh ../../Paddle/third_party/warpctc/include/contrib/moderngpu/include/device/ diff --git a/backends/metax_gpu/cmake/warpctc.cmake b/backends/metax_gpu/cmake/warpctc.cmake new file mode 100644 index 00000000000..71c892a6cfa --- /dev/null +++ b/backends/metax_gpu/cmake/warpctc.cmake @@ -0,0 +1,149 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + +include(ExternalProject) + +if(WITH_ROCM) + add_definitions(-DWARPCTC_WITH_HIP) +endif() + +set(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc) +set(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc) +# in case of low internet speed set(WARPCTC_REPOSITORY +# https://gitee.com/tianjianhe/warp-ctc.git) +set(WARPCTC_TAG bdc2b4550453e0ef2d3b5190f9c6103a84eff184) +set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/warpctc) +set(WARPCTC_PATCH_COMMAND "") +set(WARPCTC_CCBIN_OPTION "") +if(WIN32) + set(WARPCTC_PATCH_CUDA_COMMAND + git checkout -- . && git checkout ${WARPCTC_TAG} && git apply + ${PADDLE_SOURCE_DIR}/patches/warpctc/CMakeLists.txt.cuda.patch) +else() + set(WARPCTC_PATCH_CUDA_COMMAND + git checkout -- . && git checkout ${WARPCTC_TAG} && patch -Nd + ${SOURCE_DIR} < + ${PADDLE_SOURCE_DIR}/patches/warpctc/CMakeLists.txt.cuda.patch) +endif() + +if(NOT WIN32 AND WITH_GPU) + if(${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0 AND ${CMAKE_CXX_COMPILER_VERSION} + VERSION_GREATER 12.0) + file(TO_NATIVE_PATH + ${PADDLE_SOURCE_DIR}/patches/warpctc/CMakeLists.txt.patch native_src) + set(WARPCTC_PATCH_COMMAND git checkout -- . && git checkout ${WARPCTC_TAG} + && patch -Nd ${SOURCE_DIR} < ${native_src} &&) + set(WARPCTC_CCBIN_OPTION -DCCBIN_COMPILER=${CCBIN_COMPILER}) + endif() +endif() + +if(WITH_ROCM) + set(WARPCTC_PATHCH_ROCM_COMMAND + patch -p1 < + ${PADDLE_SOURCE_DIR}/patches/warpctc/CMakeLists.txt.rocm.patch && patch + -p1 < ${PADDLE_SOURCE_DIR}/patches/warpctc/devicetypes.cuh.patch && patch + -p1 < ${PADDLE_SOURCE_DIR}/patches/warpctc/hip.cmake.patch) +endif() + +set(WARPCTC_INCLUDE_DIR + "${WARPCTC_INSTALL_DIR}/include" + CACHE PATH "Warp-ctc Directory" FORCE) +# Used in unit test test_WarpCTCLayer +set(WARPCTC_LIB_DIR + "${WARPCTC_INSTALL_DIR}/lib" + CACHE PATH "Warp-ctc Library Directory" FORCE) + +if(WIN32) + set(WARPCTC_LIBRARIES + "${WARPCTC_INSTALL_DIR}/bin/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}" + CACHE FILEPATH "Warp-ctc Library" FORCE) +else() + set(WARPCTC_LIBRARIES + "${WARPCTC_INSTALL_DIR}/lib/libwarpctc${CMAKE_SHARED_LIBRARY_SUFFIX}" + CACHE FILEPATH "Warp-ctc Library" FORCE) +endif() + +if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" + OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" + OR WIN32) + set(USE_OMP OFF) +else() + set(USE_OMP ON) +endif() + +if(WIN32) + set(WARPCTC_C_FLAGS $) + set(WARPCTC_C_FLAGS_DEBUG $) + set(WARPCTC_C_FLAGS_RELEASE + $) + set(WARPCTC_CXX_FLAGS $) + set(WARPCTC_CXX_FLAGS_RELEASE + $) + set(WARPCTC_CXX_FLAGS_DEBUG + $) +else() + set(WARPCTC_C_FLAGS ${CMAKE_C_FLAGS}) + set(WARPCTC_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG}) + set(WARPCTC_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE}) + set(WARPCTC_CXX_FLAGS ${CMAKE_CXX_FLAGS}) + set(WARPCTC_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) + set(WARPCTC_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) +endif() + +ExternalProject_Add( + extern_warpctc + ${EXTERNAL_PROJECT_LOG_ARGS} + SOURCE_DIR ${SOURCE_DIR} + PREFIX ${WARPCTC_PREFIX_DIR} + UPDATE_COMMAND "" + PATCH_COMMAND + COMMAND ${WARPCTC_PATCH_COMMAND} + COMMAND ${WARPCTC_PATCH_CUDA_COMMAND} + COMMAND ${WARPCTC_PATHCH_ROCM_COMMAND} + # BUILD_ALWAYS 1 + CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DCMAKE_C_FLAGS=${WARPCTC_C_FLAGS} + -DCMAKE_C_FLAGS_DEBUG=${WARPCTC_C_FLAGS_DEBUG} + -DCMAKE_C_FLAGS_RELEASE=${WARPCTC_C_FLAGS_RELEASE} + -DCMAKE_CXX_FLAGS=${WARPCTC_CXX_FLAGS} + -DCMAKE_CXX_FLAGS_RELEASE=${WARPCTC_CXX_FLAGS_RELEASE} + -DCMAKE_CXX_FLAGS_DEBUG=${WARPCTC_CXX_FLAGS_DEBUG} + -DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR} + -DWITH_GPU=${WITH_GPU} + -DWITH_ROCM=${WITH_ROCM} + -DWITH_OMP=${USE_OMP} + -DNVCC_FLAGS_EXTRA=${NVCC_FLAGS_EXTRA} + -DWITH_TORCH=OFF + -DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON + -DBUILD_SHARED=ON + -DBUILD_TESTS=OFF + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} + -DCUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR} + ${EXTERNAL_OPTIONAL_ARGS} + ${WARPCTC_CCBIN_OPTION} + CMAKE_CACHE_ARGS + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + -DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR} + BUILD_BYPRODUCTS ${WARPCTC_LIBRARIES}) + +message(STATUS "warp-ctc library: ${WARPCTC_LIBRARIES}") +get_filename_component(WARPCTC_LIBRARY_PATH ${WARPCTC_LIBRARIES} DIRECTORY) +include_directories(${WARPCTC_INCLUDE_DIR}) # For warpctc code to include its + # headers. + +add_library(warpctc INTERFACE) +add_dependencies(warpctc extern_warpctc) diff --git a/backends/metax_gpu/cmake/warprnnt.cmake b/backends/metax_gpu/cmake/warprnnt.cmake new file mode 100644 index 00000000000..54a7ad6be86 --- /dev/null +++ b/backends/metax_gpu/cmake/warprnnt.cmake @@ -0,0 +1,142 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + +include(ExternalProject) + +if(WITH_ROCM) + add_definitions(-DWARPRNNT_WITH_HIP) +endif() + +set(WARPRNNT_PREFIX_DIR ${THIRD_PARTY_PATH}/warprnnt) +set(WARPRNNT_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warprnnt) +set(WARPRNNT_TAG 7ea6bfe748779c245a0fcaa5dd9383826273eff2) +set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/warprnnt) +set(WARPRNNT_PATCH_COMMAND "") +set(WARPRNNT_CCBIN_OPTION "") +if(WIN32) + set(WARPCTC_PATCH_CUDA_COMMAND + ${CMAKE_COMMAND} -E copy_if_different + ${PADDLE_SOURCE_DIR}/patches/warprnnt/CMakeLists.txt.cuda.patch + "/") +else() + set(WARPCTC_PATCH_CUDA_COMMAND + git checkout -- . && git checkout ${WARPRNNT_TAG} && patch -Nd + ${SOURCE_DIR} < + ${PADDLE_SOURCE_DIR}/patches/warprnnt/CMakeLists.txt.cuda.patch) +endif() +if(WITH_ROCM) + set(WARPRNNT_PATCH_ROCM_COMMAND + patch -p1 < + ${PADDLE_SOURCE_DIR}/patches/warprnnt/CMakeLists.txt.rocm.patch) +endif() +if(NOT WIN32 AND WITH_GPU) + if(${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0 AND ${CMAKE_CXX_COMPILER_VERSION} + VERSION_GREATER 12.0) + file(TO_NATIVE_PATH + ${PADDLE_SOURCE_DIR}/patches/warprnnt/CMakeLists.txt.patch native_src) + set(WARPRNNT_PATCH_COMMAND + git checkout -- . && git checkout ${WARPRNNT_TAG} && patch -Nd + ${SOURCE_DIR} < ${native_src}) + set(WARPRNNT_CCBIN_OPTION -DCCBIN_COMPILER=${CCBIN_COMPILER}) + endif() +endif() + +set(WARPRNNT_INCLUDE_DIR + "${WARPRNNT_INSTALL_DIR}/include" + CACHE PATH "Warp-rnnt Directory" FORCE) +# Used in unit test test_WarpCTCLayer +set(WARPRNNT_LIB_DIR + "${WARPRNNT_INSTALL_DIR}/lib" + CACHE PATH "Warp-rnnt Library Directory" FORCE) + +if(WIN32) + set(WARPRNNT_LIBRARIES + "${WARPRNNT_INSTALL_DIR}/bin/warprnnt${CMAKE_SHARED_LIBRARY_SUFFIX}" + CACHE FILEPATH "Warp-rnnt Library" FORCE) +else() + set(WARPRNNT_LIBRARIES + "${WARPRNNT_INSTALL_DIR}/lib/libwarprnnt${CMAKE_SHARED_LIBRARY_SUFFIX}" + CACHE FILEPATH "Warp-rnnt Library" FORCE) +endif() + +if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" + OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" + OR WIN32) + set(USE_OMP OFF) +else() + set(USE_OMP ON) +endif() + +if(WIN32) + set(WARPRNNT_C_FLAGS $) + set(WARPRNNT_C_FLAGS_DEBUG + $) + set(WARPRNNT_C_FLAGS_RELEASE + $) + set(WARPRNNT_CXX_FLAGS $) + set(WARPRNNT_CXX_FLAGS_RELEASE + $) + set(WARPRNNT_CXX_FLAGS_DEBUG + $) +else() + set(WARPRNNT_C_FLAGS ${CMAKE_C_FLAGS}) + set(WARPRNNT_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG}) + set(WARPRNNT_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE}) + set(WARPRNNT_CXX_FLAGS ${CMAKE_CXX_FLAGS}) + set(WARPRNNT_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) + set(WARPRNNT_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) +endif() +ExternalProject_Add( + extern_warprnnt + ${EXTERNAL_PROJECT_LOG_ARGS} + SOURCE_DIR ${SOURCE_DIR} + PREFIX ${WARPRNNT_PREFIX_DIR} + UPDATE_COMMAND "" + PATCH_COMMAND + COMMAND ${WARPCTC_PATCH_CUDA_COMMAND} + COMMAND ${WARPRNNT_PATCH_ROCM_COMMAND} + # BUILD_ALWAYS 1 + CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DCMAKE_C_FLAGS=${WARPRNNT_C_FLAGS} + -DCMAKE_C_FLAGS_DEBUG=${WARPRNNT_C_FLAGS_DEBUG} + -DCMAKE_C_FLAGS_RELEASE=${WARPRNNT_C_FLAGS_RELEASE} + -DCMAKE_CXX_FLAGS=${WARPRNNT_CXX_FLAGS} + -DCMAKE_CXX_FLAGS_RELEASE=${WARPRNNT_CXX_FLAGS_RELEASE} + -DCMAKE_CXX_FLAGS_DEBUG=${WARPRNNT_CXX_FLAGS_DEBUG} + -DCMAKE_INSTALL_PREFIX=${WARPRNNT_INSTALL_DIR} + -DWITH_GPU=${WITH_GPU} + -DWITH_ROCM=${WITH_ROCM} + -DWITH_OMP=${USE_OMP} + -DNVCC_FLAGS_EXTRA=${NVCC_FLAGS_EXTRA} + -DBUILD_SHARED=ON + -DBUILD_TESTS=OFF + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} + ${EXTERNAL_OPTIONAL_ARGS} + ${WARPCTC_CCBIN_OPTION} + CMAKE_CACHE_ARGS + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + -DCMAKE_INSTALL_PREFIX:PATH=${WARPRNNT_INSTALL_DIR} + BUILD_BYPRODUCTS ${WARPRNNT_LIBRARIES}) + +message(STATUS "warp-rnnt library: ${WARPRNNT_LIBRARIES}") +get_filename_component(WARPRNNT_LIBRARY_PATH ${WARPRNNT_LIBRARIES} DIRECTORY) +include_directories(${WARPRNNT_INCLUDE_DIR}) # For warprnnt code to include its + # headers. + +add_library(warprnnt INTERFACE) +# set_property(TARGET warprnnt PROPERTY IMPORTED_LOCATION ${WARPRNNT_LIBRARIES}) +add_dependencies(warprnnt extern_warprnnt) diff --git a/backends/metax_gpu/kernels/cuda_kernels/warpctc_grad_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/warpctc_grad_kernel_register.cu index e77a29d12fe..d02f805a671 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/warpctc_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/warpctc_grad_kernel_register.cu @@ -17,7 +17,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/warpctc_grad_kernel.h" -PD_REGISTER_PLUGIN_KERNEL(warpctc_grad, +PD_CUSTOM_KERNEL_REGISTER(warpctc_grad, metax_gpu, ALL_LAYOUT, phi::WarpctcGradKernel, diff --git a/backends/metax_gpu/kernels/cuda_kernels/warpctc_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/warpctc_kernel_register.cu index 5b343506cad..c488e23fba9 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/warpctc_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/warpctc_kernel_register.cu @@ -17,5 +17,5 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/warpctc_kernel.h" -PD_REGISTER_PLUGIN_KERNEL( +PD_CUSTOM_KERNEL_REGISTER( warpctc, metax_gpu, ALL_LAYOUT, phi::WarpctcKernel, float, double) {} diff --git a/backends/metax_gpu/kernels/impl/warpctc_kernel_impl.h b/backends/metax_gpu/kernels/impl/warpctc_kernel_impl.h index eb64f21c90f..9794ba1b3c0 100644 --- a/backends/metax_gpu/kernels/impl/warpctc_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/warpctc_kernel_impl.h @@ -204,7 +204,8 @@ class WarpCTCFunctor { void init(const Context& dev_ctx, const size_t blank) { warpctc_version_ = phi::dynload::get_warpctc_version(); - if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) { + if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || + dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) options_.loc = CTC_GPU; options_.stream = diff --git a/backends/metax_gpu/kernels/impl/warprnnt_kernel_impl.h b/backends/metax_gpu/kernels/impl/warprnnt_kernel_impl.h index 96e756b16b1..bb4311f5912 100644 --- a/backends/metax_gpu/kernels/impl/warprnnt_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/warprnnt_kernel_impl.h @@ -138,7 +138,8 @@ class WarpRNNTFunctor { // There is no memory allocated operations within warp-rnnt. rnntStatus_t status = RNNT_STATUS_UNKNOWN_ERROR; bool gpu = false; - if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) { + if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || + dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) gpu = true; #else @@ -207,7 +208,8 @@ class WarpRNNTFunctor { options_.fastemit_lambda = fastemit_lambda; options_.batch_first = true; - if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) { + if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || + dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) options_.loc = RNNT_GPU; options_.stream = diff --git a/backends/metax_gpu/patch/intrinsics.cuh b/backends/metax_gpu/patch/intrinsics.cuh new file mode 100644 index 00000000000..71365b6577c --- /dev/null +++ b/backends/metax_gpu/patch/intrinsics.cuh @@ -0,0 +1,459 @@ +/****************************************************************************** + * Copyright (c) 2013, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/****************************************************************************** + * + * Code and text by Sean Baxter, NVIDIA Research + * See http://nvlabs.github.io/moderngpu for repository and documentation. + * + ******************************************************************************/ + +#include "devicetypes.cuh" + +#pragma once + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" + +namespace mgpu { + +MGPU_HOST_DEVICE uint2 ulonglong_as_uint2(uint64 x) { + return *reinterpret_cast(&x); +} +MGPU_HOST_DEVICE uint64 uint2_as_ulonglong(uint2 x) { + return *reinterpret_cast(&x); +} + +MGPU_HOST_DEVICE int2 longlong_as_int2(int64 x) { + return *reinterpret_cast(&x); +} +MGPU_HOST_DEVICE int64 int2_as_longlong(int2 x) { + return *reinterpret_cast(&x); +} + +MGPU_HOST_DEVICE int2 double_as_int2(double x) { + return *reinterpret_cast(&x); +} +MGPU_HOST_DEVICE double int2_as_double(int2 x) { + return *reinterpret_cast(&x); +} + +MGPU_HOST_DEVICE void SetDoubleX(double& d, int x) { + reinterpret_cast(&d)[0] = x; +} +MGPU_HOST_DEVICE int GetDoubleX(double d) { + return double_as_int2(d).x; +} +MGPU_HOST_DEVICE void SetDoubleY(double& d, int y) { + reinterpret_cast(&d)[1] = y; +} +MGPU_HOST_DEVICE int GetDoubleY(double d) { + return double_as_int2(d).y; +} + + +//////////////////////////////////////////////////////////////////////////////// +// PTX for bfe and bfi + +#if __CUDA_ARCH__ >= 200 + +MGPU_DEVICE uint bfe_ptx(uint x, uint bit, uint numBits) { + uint result; + asm("bfe.u32 %0, %1, %2, %3;" : + "=r"(result) : "r"(x), "r"(bit), "r"(numBits)); + return result; +} + + +MGPU_DEVICE uint bfi_ptx(uint x, uint y, uint bit, uint numBits) { + uint result; + asm("bfi.b32 %0, %1, %2, %3, %4;" : + "=r"(result) : "r"(x), "r"(y), "r"(bit), "r"(numBits)); + return result; +} + +MGPU_DEVICE uint prmt_ptx(uint a, uint b, uint index) { + uint ret; + asm("prmt.b32 %0, %1, %2, %3;" : "=r"(ret) : "r"(a), "r"(b), "r"(index)); + return ret; +} + +#endif // __CUDA_ARCH__ >= 200 + + +//////////////////////////////////////////////////////////////////////////////// +// shfl_up + +__device__ __forceinline__ float shfl_up(float var, + unsigned int delta, int width = 32) { + +#if __CUDA_ARCH__ >= 300 +#if defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 9) + var = __shfl_up_sync(0xFFFFFFFF, var, delta, width); +#else + var = __shfl_up(var, delta, width); +#endif +#endif + return var; +} + +__device__ __forceinline__ double shfl_up(double var, + unsigned int delta, int width = 32) { + +#if __CUDA_ARCH__ >= 300 + int2 p = mgpu::double_as_int2(var); +#if defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 9) + p.x = __shfl_up_sync(0xFFFFFFFF, p.x, delta, width); + p.y = __shfl_up_sync(0xFFFFFFFF, p.y, delta, width); +#else + p.x = __shfl_up(p.x, delta, width); + p.y = __shfl_up(p.y, delta, width); +#endif + var = mgpu::int2_as_double(p); +#endif + + return var; +} + +//////////////////////////////////////////////////////////////////////////////// +// shfl_add + +// MGPU_DEVICE int shfl_add(int x, int offset, int width = WARP_SIZE) { +// int result = 0; +// #if __CUDA_ARCH__ >= 300 +// int mask = (WARP_SIZE - width)<< 8; +// #if defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 9) +// asm( +// "{.reg .s32 r0;" +// ".reg .pred p;" +// "shfl.up.sync.b32 r0|p, %1, %2, %3, 0xFFFFFFFF;" +// "@p add.s32 r0, r0, %4;" +// "mov.s32 %0, r0; }" +// : "=r"(result) : "r"(x), "r"(offset), "r"(mask), "r"(x)); +// #else +// asm( +// "{.reg .s32 r0;" +// ".reg .pred p;" +// "shfl.up.b32 r0|p, %1, %2, %3;" +// "@p add.s32 r0, r0, %4;" +// "mov.s32 %0, r0; }" +// : "=r"(result) : "r"(x), "r"(offset), "r"(mask), "r"(x)); +// #endif +// #endif +// return result; +// } + +MGPU_DEVICE int shfl_add(int x, int offset, int width = 32) +{ +#if __CUDA_ARCH__ >= 300 + unsigned fullMask = 0xffffffffU; + unsigned mask = (width == 32) ? fullMask : ((1U << width) - 1U); + int src = 0; +#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 9 + src = __shfl_up_sync(mask, x, offset, width); // CUDA 9+ +#else + src = __shfl_up(x, offset, width); // CUDA 8- +#endif + int lane = threadIdx.x & 31; + return (lane >= offset) ? (src + x) : x; +#else + return x; +#endif +} + +MGPU_DEVICE int shfl_max(int x, int offset, int width = WARP_SIZE) { + int result = 0; +#if __CUDA_ARCH__ >= 300 + int mask = (WARP_SIZE - width)<< 8; +#if defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 9) + asm( + "{.reg .s32 r0;" + ".reg .pred p;" + "shfl.up.sync.b32 r0|p, %1, %2, %3, 0xFFFFFFFF;" + "@p max.s32 r0, r0, %4;" + "mov.s32 %0, r0; }" + : "=r"(result) : "r"(x), "r"(offset), "r"(mask), "r"(x)); +#else + asm( + "{.reg .s32 r0;" + ".reg .pred p;" + "shfl.up.b32 r0|p, %1, %2, %3;" + "@p max.s32 r0, r0, %4;" + "mov.s32 %0, r0; }" + : "=r"(result) : "r"(x), "r"(offset), "r"(mask), "r"(x)); +#endif +#endif + return result; +} + +//////////////////////////////////////////////////////////////////////////////// +// brev, popc, clz, bfe, bfi, prmt + +// Reverse the bits in an integer. +MGPU_HOST_DEVICE uint brev(uint x) { +#if __CUDA_ARCH__ >= 200 + uint y = __brev(x); +#else + uint y = 0; + for(int i = 0; i < 32; ++i) + y |= (1 & (x>> i))<< (31 - i); +#endif + return y; +} + +// Count number of bits in a register. +MGPU_HOST_DEVICE int popc(uint x) { +#if __CUDA_ARCH__ >= 200 + return __popc(x); +#else + int c; + for(c = 0; x; ++c) + x &= x - 1; + return c; +#endif +} + +// Count leading zeros - start from most significant bit. +MGPU_HOST_DEVICE int clz(int x) { +#if __CUDA_ARCH__ >= 200 + return __clz(x); +#else + for(int i = 31; i >= 0; --i) + if((1<< i) & x) return 31 - i; + return 32; +#endif +} + +// Find first set - start from least significant bit. LSB is 1. ffs(0) is 0. +MGPU_HOST_DEVICE int ffs(int x) { +#if __CUDA_ARCH__ >= 200 + return __ffs(x); +#else + for(int i = 0; i < 32; ++i) + if((1<< i) & x) return i + 1; + return 0; +#endif +} + +MGPU_HOST_DEVICE uint bfe(uint x, uint bit, uint numBits) { +#if __CUDA_ARCH__ >= 200 + return bfe_ptx(x, bit, numBits); +#else + return ((1<< numBits) - 1) & (x>> bit); +#endif +} + +MGPU_HOST_DEVICE uint bfi(uint x, uint y, uint bit, uint numBits) { + uint result; +#if __CUDA_ARCH__ >= 200 + result = bfi_ptx(x, y, bit, numBits); +#else + if(bit + numBits > 32) numBits = 32 - bit; + uint mask = ((1<< numBits) - 1)<< bit; + result = y & ~mask; + result |= mask & (x<< bit); +#endif + return result; +} + +MGPU_HOST_DEVICE uint prmt(uint a, uint b, uint index) { + uint result; +#if __CUDA_ARCH__ >= 200 + result = prmt_ptx(a, b, index); +#else + result = 0; + for(int i = 0; i < 4; ++i) { + uint sel = 0xf & (index>> (4 * i)); + uint x = ((7 & sel) > 3) ? b : a; + x = 0xff & (x>> (8 * (3 & sel))); + if(8 & sel) x = (128 & x) ? 0xff : 0; + result |= x<< (8 * i); + } +#endif + return result; +} + +// Find log2(x) and optionally round up to the next integer logarithm. +MGPU_HOST_DEVICE int FindLog2(int x, bool roundUp = false) { + int a = 31 - clz(x); + if(roundUp) a += !MGPU_IS_POW_2(x); + return a; +} + +//////////////////////////////////////////////////////////////////////////////// +// vset4 + +#if __CUDA_ARCH__ >= 300 + +// Performs four byte-wise comparisons and returns 1 for each byte that +// satisfies the conditional, and zero otherwise. +MGPU_DEVICE uint vset4_lt_add_ptx(uint a, uint b, uint c) { + uint result; + asm("vset4.u32.u32.lt.add %0, %1, %2, %3;" : + "=r"(result) : "r"(a), "r"(b), "r"(c)); + return result; +} +MGPU_DEVICE uint vset4_eq_ptx(uint a, uint b) { + uint result; + asm("vset4.u32.u32.eq %0, %1, %2, %3;" : + "=r"(result) : "r"(a), "r"(b), "r"(0)); + return result; +} +#endif // __CUDA_ARCH__ >= 300 + +MGPU_HOST_DEVICE uint vset4_lt_add(uint a, uint b, uint c) { + uint result; +#if __CUDA_ARCH__ >= 300 + result = vset4_lt_add_ptx(a, b, c); +#else + result = c; + if((0x000000ff & a) < (0x000000ff & b)) result += 0x00000001; + if((0x0000ff00 & a) < (0x0000ff00 & b)) result += 0x00000100; + if((0x00ff0000 & a) < (0x00ff0000 & b)) result += 0x00010000; + if((0xff000000 & a) < (0xff000000 & b)) result += 0x01000000; +#endif + return result; +} + +MGPU_HOST_DEVICE uint vset4_eq(uint a, uint b) { + uint result; +#if __CUDA_ARCH__ >= 300 + result = vset4_eq_ptx(a, b); +#else + result = 0; + if((0x000000ff & a) == (0x000000ff & b)) result = 0x00000001; + if((0x0000ff00 & a) == (0x0000ff00 & b)) result += 0x00000100; + if((0x00ff0000 & a) == (0x00ff0000 & b)) result += 0x00010000; + if((0xff000000 & a) == (0xff000000 & b)) result += 0x01000000; +#endif + return result; +} + +//////////////////////////////////////////////////////////////////////////////// +// + +MGPU_HOST_DEVICE uint umulhi(uint x, uint y) { +#if __CUDA_ARCH__ >= 100 + return __umulhi(x, y); +#else + uint64 product = (uint64)x * y; + return (uint)(product>> 32); +#endif +} + +//////////////////////////////////////////////////////////////////////////////// +// ldg() function defined for all devices and all types. Only compiles to __ldg +// intrinsic for __CUDA_ARCH__ >= 320 && __CUDA_ARCH__ < 400 for types supported +// by __ldg in sm_32_intrinsics.h + +template +struct IsLdgType { + enum { value = false }; +}; +#define DEFINE_LDG_TYPE(T) \ + template<> struct IsLdgType { enum { value = true }; }; + +template::value> +struct LdgShim { + MGPU_DEVICE static T Ldg(const T* p) { + return *p; + } +}; + +#if __CUDA_ARCH__ >= 320 && __CUDA_ARCH__ < 400 + + // List of __ldg-compatible types from sm_32_intrinsics.h. + DEFINE_LDG_TYPE(char) + DEFINE_LDG_TYPE(short) + DEFINE_LDG_TYPE(int) + DEFINE_LDG_TYPE(long long) + DEFINE_LDG_TYPE(char2) + DEFINE_LDG_TYPE(char4) + DEFINE_LDG_TYPE(short2) + DEFINE_LDG_TYPE(short4) + DEFINE_LDG_TYPE(int2) + DEFINE_LDG_TYPE(int4) + DEFINE_LDG_TYPE(longlong2) + + DEFINE_LDG_TYPE(unsigned char) + DEFINE_LDG_TYPE(unsigned short) + DEFINE_LDG_TYPE(unsigned int) + DEFINE_LDG_TYPE(unsigned long long) + DEFINE_LDG_TYPE(uchar2) + DEFINE_LDG_TYPE(uchar4) + DEFINE_LDG_TYPE(ushort2) + DEFINE_LDG_TYPE(ushort4) + DEFINE_LDG_TYPE(uint2) + DEFINE_LDG_TYPE(uint4) + DEFINE_LDG_TYPE(ulonglong2) + + DEFINE_LDG_TYPE(float) + DEFINE_LDG_TYPE(double) + DEFINE_LDG_TYPE(float2) + DEFINE_LDG_TYPE(float4) + DEFINE_LDG_TYPE(double2) + + template struct LdgShim { + MGPU_DEVICE static T Ldg(const T* p) { + return __ldg(p); + } + }; +#endif + +template +MGPU_DEVICE T ldg(const T* p) { + return LdgShim::Ldg(p); +} + +//////////////////////////////////////////////////////////////////////////////// + +// Fast division for 31-bit integers. +// Uses the method in Hacker's Delight (2nd edition) page 228. +// Evaluates for denom > 1 and x < 2^31. +struct FastDivide { + uint denom; + uint coef; + uint shift; + + MGPU_HOST_DEVICE uint Divide(uint x) { + return umulhi(x, coef)>> shift; + } + MGPU_HOST_DEVICE uint Modulus(uint x) { + return x - Divide(x) * denom; + } + + explicit FastDivide(uint denom_) { + denom = denom_; + uint p = 31 + FindLog2(denom, true); + coef = (uint)(((1ull<< p) + denom - 1) / denom); + shift = p - 32; + } +}; + +#pragma GCC diagnostic pop + +} // namespace mgpu diff --git a/backends/metax_gpu/patch/paddle.patch b/backends/metax_gpu/patch/paddle.patch index 8127caee61e..0283a443adb 100755 --- a/backends/metax_gpu/patch/paddle.patch +++ b/backends/metax_gpu/patch/paddle.patch @@ -1087,6 +1087,32 @@ index 6f03f76eeb..5fe2c3e7dc 100644 #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/matrix_inverse.h" +diff --git a/paddle/phi/kernels/impl/merged_momentum_impl.h b/paddle/phi/kernels/impl/merged_momentum_impl.h +index 7b85903776..3f4b298807 100644 +--- a/paddle/phi/kernels/impl/merged_momentum_impl.h ++++ b/paddle/phi/kernels/impl/merged_momentum_impl.h +@@ -297,7 +297,7 @@ void MergedMomentumInnerCompute( + params_out[idx], + velocities_out[idx]); + VLOG(10) << "Launch MergedMomentum cpu kernel."; +- } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) { ++ } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) { + phi::funcs::ForRange for_range( + static_cast(dev_ctx), params[idx]->numel()); + const auto grad_type = grads[idx]->dtype(); +diff --git a/paddle/phi/kernels/impl/momentum_kernel_impl.h b/paddle/phi/kernels/impl/momentum_kernel_impl.h +index de5bcfc30b..eb2a9714f5 100644 +--- a/paddle/phi/kernels/impl/momentum_kernel_impl.h ++++ b/paddle/phi/kernels/impl/momentum_kernel_impl.h +@@ -457,7 +457,7 @@ void MomentumDenseImpl(const Context& dev_ctx, + regularization_coeff, + param_out, + velocity_out); +- } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) { ++ } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) { + funcs::ForRange for_range(dev_ctx, param.numel()); + const auto grad_type = grad.dtype(); + #define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \ diff --git a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h index 4099d8b506..baef2cd643 100644 --- a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h