From 84d8aebd73256f2469870a15ae25028e3e65800b Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Fri, 2 Jul 2021 09:09:47 +0000 Subject: [PATCH 01/10] add gpu implementation of shuffle batch test=develop --- paddle/fluid/operators/shuffle_batch_op.cc | 10 ++ paddle/fluid/operators/shuffle_batch_op.cu | 146 ++++++++++++++++++ .../tests/unittests/test_shuffle_batch_op.py | 57 ++++--- 3 files changed, 194 insertions(+), 19 deletions(-) create mode 100644 paddle/fluid/operators/shuffle_batch_op.cu diff --git a/paddle/fluid/operators/shuffle_batch_op.cc b/paddle/fluid/operators/shuffle_batch_op.cc index e540c728b69fe1..20459f92f3a590 100644 --- a/paddle/fluid/operators/shuffle_batch_op.cc +++ b/paddle/fluid/operators/shuffle_batch_op.cc @@ -53,6 +53,16 @@ class ShuffleBatchOp : public framework::OperatorWithKernel { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType(data_type, ctx.device_context()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + if (var_name == "Seed") { + return expected_kernel_type; + } + return framework::OperatorWithKernel::GetKernelTypeForVar( + var_name, tensor, expected_kernel_type); + } }; class ShuffleBatchOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/shuffle_batch_op.cu b/paddle/fluid/operators/shuffle_batch_op.cu new file mode 100644 index 00000000000000..b4ef49dbd35a1a --- /dev/null +++ b/paddle/fluid/operators/shuffle_batch_op.cu @@ -0,0 +1,146 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + +#include +#include +#include +#include "paddle/fluid/operators/shuffle_batch_op.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +struct ReorderFunctor { + ReorderFunctor(const T *x, const int64_t *shuffle_idx, T *y, int64_t stride) + : x_(x), shuffle_idx_(shuffle_idx), y_(y), stride_(stride) {} + + HOSTDEVICE void operator()(int64_t idx) { + auto reorder_idx = shuffle_idx_[idx / stride_] * stride_ + idx % stride_; + if (kIsForward) { + y_[idx] = x_[reorder_idx]; + } else { + y_[reorder_idx] = x_[idx]; + } + } + + private: + const T *x_; + const int64_t *shuffle_idx_; + T *y_; + int64_t stride_; +}; + +template +class ShuffleBatchCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *seed = ctx.Input("Seed"); + auto *out = ctx.Output("Out"); + auto *shuffleidx = ctx.Output("ShuffleIdx"); + auto *seed_out = ctx.Output("SeedOut"); + + int64_t x_embed_size = x->dims()[x->dims().size() - 1]; + int64_t elem_size = 1; + for (int i = 0; i < x->dims().size() - 1; i++) { + elem_size *= x->dims()[i]; + } + shuffleidx->Resize(framework::make_ddim({elem_size})); + + int64_t seed_int = 0; + if (seed->IsInitialized()) { + const auto &seed_place = seed->place(); + if (platform::is_gpu_place(seed_place)) { + // NOTE: We have overwritten GetKernelTypeForVar, so seed_place would + // not be CUDAPlace in practice. This case would only happen in Python + // op_test framework. + framework::Tensor tmp_tensor; + framework::TensorCopySync(*seed, platform::CPUPlace(), &tmp_tensor); + seed_int = *(tmp_tensor.data()); + } else { + seed_int = *(seed->data()); + } + } else { + seed_int = ctx.Attr("startup_seed"); + } + + auto *shuffleidx_data = shuffleidx->mutable_data(ctx.GetPlace()); + + auto &dev_ctx = ctx.template device_context(); +#ifdef PADDLE_WITH_CUDA + const auto &exec_policy = thrust::cuda::par.on(dev_ctx.stream()); +#else + const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream()); +#endif + + thrust::default_random_engine engine(seed_int); + thrust::counting_iterator cnt_iter(0); + thrust::shuffle_copy(exec_policy, cnt_iter, cnt_iter + elem_size, + thrust::device_pointer_cast(shuffleidx_data), engine); + // TODO(zengjinle): for small data, direct cudaMemcpy may be better + auto *x_data = x->data(); + auto *out_data = out->mutable_data(ctx.GetPlace()); + ReorderFunctor functor(x_data, shuffleidx_data, out_data, + x_embed_size); + platform::ForRange for_range( + dev_ctx, elem_size * x_embed_size); + for_range(functor); + + auto *seed_out_data = seed_out->mutable_data( + framework::make_ddim({1}), platform::CPUPlace()); + *seed_out_data = engine(); + } +}; + +template +class ShuffleBatchGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *out_grad = + ctx.Input(framework::GradVarName("Out")); + const auto *shuffleidx = ctx.Input("ShuffleIdx"); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + + const auto *out_grad_data = out_grad->data(); + const auto *shuffleidx_data = shuffleidx->data(); + auto *x_grad_data = x_grad->mutable_data(ctx.GetPlace()); + auto x_embed_size = x_grad->dims()[x_grad->dims().size() - 1]; + ReorderFunctor functor(out_grad_data, shuffleidx_data, + x_grad_data, x_embed_size); + auto &dev_ctx = ctx.template device_context(); + // TODO(zengjinle): for small data, direct cudaMemcpy may be better + platform::ForRange for_range(dev_ctx, + x_grad->numel()); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(shuffle_batch, ops::ShuffleBatchCUDAKernel, + ops::ShuffleBatchCUDAKernel, + ops::ShuffleBatchCUDAKernel, + ops::ShuffleBatchCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(shuffle_batch_grad, + ops::ShuffleBatchGradCUDAKernel, + ops::ShuffleBatchGradCUDAKernel, + ops::ShuffleBatchGradCUDAKernel, + ops::ShuffleBatchGradCUDAKernel); +#endif diff --git a/python/paddle/fluid/tests/unittests/test_shuffle_batch_op.py b/python/paddle/fluid/tests/unittests/test_shuffle_batch_op.py index 409c0c0cf70855..79ef1e9c79dc23 100644 --- a/python/paddle/fluid/tests/unittests/test_shuffle_batch_op.py +++ b/python/paddle/fluid/tests/unittests/test_shuffle_batch_op.py @@ -23,24 +23,26 @@ import random -class TestShuffleBatchOp(OpTest): +class TestShuffleBatchOpBase(OpTest): + def gen_random_array(self, shape, low=0, high=1): + rnd = (high - low) * np.random.random(shape) + low + return rnd.astype(self.dtype) + + def get_shape(self): + return (10, 10, 5) + def setUp(self): self.op_type = 'shuffle_batch' self.dtype = np.float64 - x = np.array( - [np.arange(100), np.arange(100)]).astype(self.dtype).reshape( - [2, 100]) - out = np.array( - [np.arange(100), np.arange(100)]).astype(self.dtype).reshape( - [2, 100]) - self.possible_res = [ - np.array([np.arange(100), np.arange(100)]).astype(self.dtype), - ] - self.inputs = {'X': x, 'Seed': np.array([1]).astype('int64')} + self.shape = self.get_shape() + x = self.gen_random_array(self.shape) + seed = np.random.random_integers( + low=10, high=100, size=(1, )).astype('int64') + self.inputs = {'X': x, 'Seed': seed} self.outputs = { - 'Out': out, - 'ShuffleIdx': np.array([1, 0]).astype('int64'), - 'SeedOut': np.array([1]).astype('int64') + 'Out': np.array([]).astype(x.dtype), + 'ShuffleIdx': np.array([]).astype('int64'), + 'SeedOut': np.array([]).astype(seed.dtype), } self.attrs = {'startup_seed': 1} @@ -48,16 +50,33 @@ def test_check_output(self): self.check_output_customized(self.verify_output) def verify_output(self, outs): - for elem in outs: - if elem.shape == self.outputs['Out'].shape: - out = elem + x = np.copy(self.inputs['X']) + y = None + for out in outs: + if out.shape == x.shape: + y = np.copy(out) break - is_equal = [np.all(out == res) for res in self.possible_res] - self.assertIn(True, is_equal) + + assert y is not None + sort_x = self.sort_array(x) + sort_y = self.sort_array(y) + self.assertTrue(np.array_equal(sort_x, sort_y)) + + def sort_array(self, array): + shape = array.shape + new_shape = [-1, shape[-1]] + arr_list = np.reshape(array, new_shape).tolist() + arr_list.sort(key=lambda x: x[0]) + return np.reshape(np.array(arr_list), shape) def test_check_grad(self): self.check_grad(['X'], 'Out') +class TestShuffleBatchOp2(TestShuffleBatchOpBase): + def get_shape(self): + return (4, 30) + + if __name__ == '__main__': unittest.main() From be068c070d88e198e49b9bd314857164511b164e Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 5 Jul 2021 07:17:38 +0000 Subject: [PATCH 02/10] add thrust cuda patches test=develop --- cmake/cuda.cmake | 23 ++ paddle/fluid/operators/shuffle_batch_op.cu | 22 +- .../cuda_includes/thrust/detail/shuffle.inl | 85 +++++++ patches/thrust/cuda_includes/thrust/shuffle.h | 216 +++++++++++++++++ .../thrust/system/detail/generic/shuffle.h | 74 ++++++ .../thrust/system/detail/generic/shuffle.inl | 220 ++++++++++++++++++ 6 files changed, 634 insertions(+), 6 deletions(-) create mode 100644 patches/thrust/cuda_includes/thrust/detail/shuffle.inl create mode 100644 patches/thrust/cuda_includes/thrust/shuffle.h create mode 100644 patches/thrust/cuda_includes/thrust/system/detail/generic/shuffle.h create mode 100644 patches/thrust/cuda_includes/thrust/system/detail/generic/shuffle.inl diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 9bdfc36201d539..828bd75ac85759 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -61,6 +61,28 @@ function(detect_installed_gpus out_variable) endif() endfunction() +function(add_thrust_patches_if_necessary) + set(thrust_detect_file ${PROJECT_BINARY_DIR}/detect_thrust.cu) + file(WRITE ${thrust_detect_file} "" + "#include \"thrust/version.h\"\n" + "#include \"thrust/shuffle.h\"\n" + "#include \"stdio.h\"\n" + "int main() {\n" + " int version = THRUST_VERSION;\n" + " printf(\"%d\", version);\n" + " return 0;\n" + "}\n") + + execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" + "--run" "${thrust_detect_file}" + WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/" + RESULT_VARIABLE nvcc_res ERROR_QUIET) + if(NOT nvcc_res EQUAL 0) + set(thrust_patches "${PADDLE_SOURCE_DIR}/patches/thrust/cuda_includes") + message(STATUS "Add thrust patches: ${thrust_patches}") + include_directories(${thrust_patches}) + endif() +endfunction() ######################################################################## # Function for selecting GPU arch flags for nvcc based on CUDA_ARCH_NAME @@ -233,3 +255,4 @@ endif() mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD) mark_as_advanced(CUDA_SDK_ROOT_DIR CUDA_SEPARABLE_COMPILATION) +add_thrust_patches_if_necessary() diff --git a/paddle/fluid/operators/shuffle_batch_op.cu b/paddle/fluid/operators/shuffle_batch_op.cu index b4ef49dbd35a1a..36d5a8638e46a3 100644 --- a/paddle/fluid/operators/shuffle_batch_op.cu +++ b/paddle/fluid/operators/shuffle_batch_op.cu @@ -16,8 +16,10 @@ #include #include -#include #include "paddle/fluid/operators/shuffle_batch_op.h" +#ifdef PADDLE_WITH_CUDA +#include +#endif #include "paddle/fluid/platform/for_range.h" namespace paddle { @@ -81,12 +83,7 @@ class ShuffleBatchCUDAKernel : public framework::OpKernel { auto *shuffleidx_data = shuffleidx->mutable_data(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context(); -#ifdef PADDLE_WITH_CUDA const auto &exec_policy = thrust::cuda::par.on(dev_ctx.stream()); -#else - const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream()); -#endif - thrust::default_random_engine engine(seed_int); thrust::counting_iterator cnt_iter(0); thrust::shuffle_copy(exec_policy, cnt_iter, cnt_iter + elem_size, @@ -129,6 +126,19 @@ class ShuffleBatchGradCUDAKernel : public framework::OpKernel { } }; +#else + +template +class ShuffleBatchCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "shuffle_batch op is only supported in CUDA devices")); + } +}; + +#endif + } // namespace operators } // namespace paddle diff --git a/patches/thrust/cuda_includes/thrust/detail/shuffle.inl b/patches/thrust/cuda_includes/thrust/detail/shuffle.inl new file mode 100644 index 00000000000000..edccc878731ef4 --- /dev/null +++ b/patches/thrust/cuda_includes/thrust/detail/shuffle.inl @@ -0,0 +1,85 @@ +/* + * Copyright 2008-2020 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file shuffle.inl + * \brief Inline file for shuffle.h. + */ + +#include +#include + +#if THRUST_CPP_DIALECT >= 2011 + +#include +#include +#include +#include + +namespace thrust { + +__thrust_exec_check_disable__ +template +__host__ __device__ void shuffle( + const thrust::detail::execution_policy_base& exec, + RandomIterator first, RandomIterator last, URBG&& g) { + using thrust::system::detail::generic::shuffle; + return shuffle( + thrust::detail::derived_cast(thrust::detail::strip_const(exec)), + first, last, g); +} + +template +__host__ __device__ void shuffle(RandomIterator first, RandomIterator last, + URBG&& g) { + using thrust::system::detail::generic::select_system; + + typedef typename thrust::iterator_system::type System; + System system; + + return thrust::shuffle(select_system(system), first, last, g); +} + +__thrust_exec_check_disable__ +template +__host__ __device__ void shuffle_copy( + const thrust::detail::execution_policy_base& exec, + RandomIterator first, RandomIterator last, OutputIterator result, + URBG&& g) { + using thrust::system::detail::generic::shuffle_copy; + return shuffle_copy( + thrust::detail::derived_cast(thrust::detail::strip_const(exec)), + first, last, result, g); +} + +template +__host__ __device__ void shuffle_copy(RandomIterator first, RandomIterator last, + OutputIterator result, URBG&& g) { + using thrust::system::detail::generic::select_system; + + typedef typename thrust::iterator_system::type System1; + typedef typename thrust::iterator_system::type System2; + + System1 system1; + System2 system2; + + return thrust::shuffle_copy(select_system(system1, system2), first, last, + result, g); +} + +} // namespace thrust + +#endif diff --git a/patches/thrust/cuda_includes/thrust/shuffle.h b/patches/thrust/cuda_includes/thrust/shuffle.h new file mode 100644 index 00000000000000..427414df7c11b9 --- /dev/null +++ b/patches/thrust/cuda_includes/thrust/shuffle.h @@ -0,0 +1,216 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright 2008-2020 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file shuffle.h + * \brief Reorders range by a uniform random permutation + */ + +#pragma once + +#include +#include + +#if THRUST_CPP_DIALECT >= 2011 + +#include +#include + +namespace thrust { + +/*! \addtogroup reordering +* \ingroup algorithms +* +* \addtogroup shuffling +* \ingroup reordering +* \{ +*/ + +/*! \p shuffle reorders the elements [first, last) by a uniform + * pseudorandom permutation, defined by + * random engine \p g. + * + * The algorithm's execution is parallelized as determined by \p exec. + * + * \param exec The execution policy to use for parallelization. + * \param first The beginning of the sequence to shuffle. + * \param last The end of the sequence to shuffle. + * \param g A UniformRandomBitGenerator + * + * \tparam DerivedPolicy The name of the derived execution policy. + * \tparam RandomIterator is a random access iterator + * \tparam URBG is a uniform random bit generator + * + * The following code snippet demonstrates how to use \p shuffle to create a + * random permutation + * using the \p thrust::host execution policy for parallelization: + * + * \code + * #include + * #include + * #include + * int A[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + * const int N = sizeof(A)/sizeof(int); + * thrust::default_random_engine g; + * thrust::shuffle(thrust::host, A, A + N, g); + * // A is now {6, 5, 8, 7, 2, 1, 4, 3, 10, 9} + * \endcode + * + * \see \p shuffle_copy + */ +template +__host__ __device__ void shuffle( + const thrust::detail::execution_policy_base& exec, + RandomIterator first, + RandomIterator last, + URBG&& g); + +/*! \p shuffle reorders the elements [first, last) by a uniform + * pseudorandom permutation, defined by + * random engine \p g. + * + * \param first The beginning of the sequence to shuffle. + * \param last The end of the sequence to shuffle. + * \param g A UniformRandomBitGenerator + * + * \tparam RandomIterator is a random access iterator + * \tparam URBG is a uniform random bit generator + * + * The following code snippet demonstrates how to use \p shuffle to create a + * random permutation. + * + * \code + * #include + * #include + * int A[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + * const int N = sizeof(A)/sizeof(int); + * thrust::default_random_engine g; + * thrust::shuffle(A, A + N, g); + * // A is now {6, 5, 8, 7, 2, 1, 4, 3, 10, 9} + * \endcode + * + * \see \p shuffle_copy + */ +template +__host__ __device__ void shuffle(RandomIterator first, + RandomIterator last, + URBG&& g); + +/*! shuffle_copy differs from shuffle only in that the reordered sequence is + written to different output sequences, rather than in place. + * \p shuffle_copy reorders the elements [first, last) by a uniform + pseudorandom permutation, defined by + * random engine \p g. + * + * The algorithm's execution is parallelized as determined by \p exec. + + * \param exec The execution policy to use for parallelization. + * \param first The beginning of the sequence to shuffle. + * \param last The end of the sequence to shuffle. + * \param result Destination of shuffled sequence + * \param g A UniformRandomBitGenerator + * + * \tparam DerivedPolicy The name of the derived execution policy. + * \tparam RandomIterator is a random access iterator + * \tparam OutputIterator is a model of Output + Iterator. + * \tparam URBG is a uniform random bit generator + * + * The following code snippet demonstrates how to use \p shuffle_copy to create + a random permutation. + * + * \code + * #include + * #include + * #include + * int A[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + * int result[10]; + * const int N = sizeof(A)/sizeof(int); + * thrust::default_random_engine g; + * thrust::shuffle_copy(thrust::host, A, A + N, result, g); + * // result is now {6, 5, 8, 7, 2, 1, 4, 3, 10, 9} + * \endcode + * + * \see \p shuffle + */ +template +__host__ __device__ void shuffle_copy( + const thrust::detail::execution_policy_base& exec, + RandomIterator first, + RandomIterator last, + OutputIterator result, + URBG&& g); + +/*! shuffle_copy differs from shuffle only in that the reordered sequence is + *written to different output sequences, rather than in place. + *\p shuffle_copy reorders the elements [first, last) by a uniform + *pseudorandom permutation, defined by + * random engine \p g. + * + * \param first The beginning of the sequence to shuffle. + * \param last The end of the sequence to shuffle. + * \param result Destination of shuffled sequence + * \param g A UniformRandomBitGenerator + * + * \tparam RandomIterator is a random access iterator + * \tparam OutputIterator is a model of Output + *Iterator. + * \tparam URBG is a uniform random bit generator + * + * The following code snippet demonstrates how to use \p shuffle_copy to create + *a random permutation. + * + * \code + * #include + * #include + * int A[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + * int result[10]; + * const int N = sizeof(A)/sizeof(int); + * thrust::default_random_engine g; + * thrust::shuffle_copy(A, A + N, result, g); + * // result is now {6, 5, 8, 7, 2, 1, 4, 3, 10, 9} + * \endcode + * + * \see \p shuffle + */ +template +__host__ __device__ void shuffle_copy(RandomIterator first, + RandomIterator last, + OutputIterator result, + URBG&& g); + +} // namespace thrust + +#include +#endif diff --git a/patches/thrust/cuda_includes/thrust/system/detail/generic/shuffle.h b/patches/thrust/cuda_includes/thrust/system/detail/generic/shuffle.h new file mode 100644 index 00000000000000..87008aaa10c4af --- /dev/null +++ b/patches/thrust/cuda_includes/thrust/system/detail/generic/shuffle.h @@ -0,0 +1,74 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright 2008-2020 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file shuffle.h + * \brief Generic implementations of shuffle functions. + */ + +#pragma once + +#include +#include + +#if THRUST_CPP_DIALECT >= 2011 + +#include + +namespace thrust { +namespace system { +namespace detail { +namespace generic { + +template +__host__ __device__ void shuffle( + thrust::execution_policy& exec, + RandomIterator first, + RandomIterator last, + URBG&& g); + +template +__host__ __device__ void shuffle_copy( + thrust::execution_policy& exec, + RandomIterator first, + RandomIterator last, + OutputIterator result, + URBG&& g); + +} // end namespace generic +} // end namespace detail +} // end namespace system +} // end namespace thrust + +#include + +#endif diff --git a/patches/thrust/cuda_includes/thrust/system/detail/generic/shuffle.inl b/patches/thrust/cuda_includes/thrust/system/detail/generic/shuffle.inl new file mode 100644 index 00000000000000..a0a27833c62f76 --- /dev/null +++ b/patches/thrust/cuda_includes/thrust/system/detail/generic/shuffle.inl @@ -0,0 +1,220 @@ +/* + * Copyright 2008-20120 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace thrust { +template +using iterator_value_t = typename iterator_value::type; + +namespace system { +namespace detail { +namespace generic { + +// An implementation of a Feistel cipher for operating on 64 bit keys +class feistel_bijection { + struct round_state { + std::uint32_t left; + std::uint32_t right; + }; + + public: + template + __host__ __device__ feistel_bijection(std::uint64_t m, URBG&& g) { + std::uint64_t total_bits = get_cipher_bits(m); + // Half bits rounded down + left_side_bits = total_bits / 2; + left_side_mask = (1ull << left_side_bits) - 1; + // Half the bits rounded up + right_side_bits = total_bits - left_side_bits; + right_side_mask = (1ull << right_side_bits) - 1; + + for (std::uint64_t i = 0; i < num_rounds; i++) { + key[i] = g(); + } + } + + __host__ __device__ std::uint64_t nearest_power_of_two() const { + return 1ull << (left_side_bits + right_side_bits); + } + __host__ __device__ std::uint64_t operator()(const std::uint64_t val) const { + // Extract the right and left sides of the input + auto left = static_cast(val >> right_side_bits); + auto right = static_cast(val & right_side_mask); + round_state state = {left, right}; + + for (std::uint64_t i = 0; i < num_rounds; i++) { + state = do_round(state, i); + } + + // Check we have the correct number of bits on each side + assert((state.left >> left_side_bits) == 0); + assert((state.right >> right_side_bits) == 0); + + // Combine the left and right sides together to get result + return state.left << right_side_bits | state.right; + } + + private: + // Find the nearest power of two + __host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) { + if (m == 0) return 0; + std::uint64_t i = 0; + m--; + while (m != 0) { + i++; + m >>= 1; + } + return i; + } + + // Equivalent to boost::hash_combine + __host__ __device__ + std::size_t hash_combine(std::uint64_t lhs, std::uint64_t rhs) const { + lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + return lhs; + } + + // Round function, a 'pseudorandom function' who's output is indistinguishable + // from random for each key value input. This is not cryptographically secure + // but sufficient for generating permutations. + __host__ __device__ std::uint32_t round_function(std::uint64_t value, + const std::uint64_t key_) const { + std::uint64_t hash0 = thrust::random::taus88(static_cast(value))(); + std::uint64_t hash1 = thrust::random::ranlux48(value)(); + return static_cast( + hash_combine(hash_combine(hash0, key_), hash1) & left_side_mask); + } + + __host__ __device__ round_state do_round(const round_state state, + const std::uint64_t round) const { + const std::uint32_t new_left = state.right & left_side_mask; + const std::uint32_t round_function_res = + state.left ^ round_function(state.right, key[round]); + if (right_side_bits != left_side_bits) { + // Upper bit of the old right becomes lower bit of new right if we have + // odd length feistel + const std::uint32_t new_right = + (round_function_res << 1ull) | state.right >> left_side_bits; + return {new_left, new_right}; + } + return {new_left, round_function_res}; + } + + static constexpr std::uint64_t num_rounds = 16; + std::uint64_t right_side_bits; + std::uint64_t left_side_bits; + std::uint64_t right_side_mask; + std::uint64_t left_side_mask; + std::uint64_t key[num_rounds]; +}; + +struct key_flag_tuple { + std::uint64_t key; + std::uint64_t flag; +}; + +// scan only flags +struct key_flag_scan_op { + __host__ __device__ key_flag_tuple operator()(const key_flag_tuple& a, + const key_flag_tuple& b) { + return {b.key, a.flag + b.flag}; + } +}; + +struct construct_key_flag_op { + std::uint64_t m; + feistel_bijection bijection; + __host__ __device__ construct_key_flag_op(std::uint64_t m, + feistel_bijection bijection) + : m(m), bijection(bijection) {} + __host__ __device__ key_flag_tuple operator()(std::uint64_t idx) { + auto gather_key = bijection(idx); + return key_flag_tuple{gather_key, (gather_key < m) ? 1ull : 0ull}; + } +}; + +template +struct write_output_op { + std::uint64_t m; + InputIterT in; + OutputIterT out; + // flag contains inclusive scan of valid keys + // perform gather using valid keys + __thrust_exec_check_disable__ + __host__ __device__ std::size_t operator()(key_flag_tuple x) { + if (x.key < m) { + // -1 because inclusive scan + out[x.flag - 1] = in[x.key]; + } + return 0; // Discarded + } +}; + +template +__host__ __device__ void shuffle( + thrust::execution_policy& exec, RandomIterator first, + RandomIterator last, URBG&& g) { + using InputType = typename thrust::iterator_value_t; + + // copy input to temp buffer + thrust::detail::temporary_array temp(exec, first, + last); + thrust::shuffle_copy(exec, temp.begin(), temp.end(), first, g); +} + +template +__host__ __device__ void shuffle_copy( + thrust::execution_policy& exec, RandomIterator first, + RandomIterator last, OutputIterator result, URBG&& g) { + // m is the length of the input + // we have an available bijection of length n via a feistel cipher + std::size_t m = last - first; + feistel_bijection bijection(m, g); + std::uint64_t n = bijection.nearest_power_of_two(); + + // perform stream compaction over length n bijection to get length m + // pseudorandom bijection over the original input + thrust::counting_iterator indices(0); + thrust::transform_iterator + key_flag_it(indices, construct_key_flag_op(m, bijection)); + write_output_op write_functor{m, first, + result}; + auto gather_output_it = thrust::make_transform_output_iterator( + thrust::discard_iterator(), write_functor); + // the feistel_bijection outputs a stream of permuted indices in range [0,n) + // flag each value < m and compact it, so we have a set of permuted indices in + // range [0,m) each thread gathers an input element according to its + // pseudorandom permuted index + thrust::inclusive_scan(exec, key_flag_it, key_flag_it + n, gather_output_it, + key_flag_scan_op()); +} + +} // end namespace generic +} // end namespace detail +} // end namespace system +} // end namespace thrust From ed180b1edd9f953c55768d7695c8044b30cd761c Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 5 Jul 2021 07:24:17 +0000 Subject: [PATCH 03/10] fix macro guard --- paddle/fluid/operators/shuffle_batch_op.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/shuffle_batch_op.cu b/paddle/fluid/operators/shuffle_batch_op.cu index 36d5a8638e46a3..2aa5b4790fcfc5 100644 --- a/paddle/fluid/operators/shuffle_batch_op.cu +++ b/paddle/fluid/operators/shuffle_batch_op.cu @@ -14,12 +14,11 @@ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -#include -#include #include "paddle/fluid/operators/shuffle_batch_op.h" #ifdef PADDLE_WITH_CUDA +#include +#include #include -#endif #include "paddle/fluid/platform/for_range.h" namespace paddle { From a7de588d1005c543a19f5fd52bc1564fa2192c3f Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 5 Jul 2021 09:54:29 +0000 Subject: [PATCH 04/10] fix shuffle batch compile on windows/hip --- cmake/cuda.cmake | 1 + paddle/fluid/operators/shuffle_batch_op.cu | 26 +++++++++------------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 828bd75ac85759..75b37d28da15c9 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -239,6 +239,7 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") if(WIN32) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler \"/wd4244 /wd4267 /wd4819 \"") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /bigobj") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus") if(MSVC_STATIC_CRT) set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -Xcompiler /MTd") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler /MT") diff --git a/paddle/fluid/operators/shuffle_batch_op.cu b/paddle/fluid/operators/shuffle_batch_op.cu index 2aa5b4790fcfc5..84346833b0412d 100644 --- a/paddle/fluid/operators/shuffle_batch_op.cu +++ b/paddle/fluid/operators/shuffle_batch_op.cu @@ -14,11 +14,10 @@ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -#include "paddle/fluid/operators/shuffle_batch_op.h" -#ifdef PADDLE_WITH_CUDA #include #include #include +#include "paddle/fluid/operators/shuffle_batch_op.h" #include "paddle/fluid/platform/for_range.h" namespace paddle { @@ -49,6 +48,10 @@ template class ShuffleBatchCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { +#ifdef PADDLE_WITH_HIP + PADDLE_THROW(platform::errors::Unimplemented( + "shuffle_batch does not support to run on HIP devices yet")); +#else auto *x = ctx.Input("X"); auto *seed = ctx.Input("Seed"); auto *out = ctx.Output("Out"); @@ -99,6 +102,7 @@ class ShuffleBatchCUDAKernel : public framework::OpKernel { auto *seed_out_data = seed_out->mutable_data( framework::make_ddim({1}), platform::CPUPlace()); *seed_out_data = engine(); +#endif } }; @@ -106,6 +110,10 @@ template class ShuffleBatchGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { +#ifdef PADDLE_WITH_HIP + PADDLE_THROW(platform::errors::Unimplemented( + "shuffle_batch_grad does not support to run on HIP devices yet")); +#else const auto *out_grad = ctx.Input(framework::GradVarName("Out")); const auto *shuffleidx = ctx.Input("ShuffleIdx"); @@ -122,22 +130,10 @@ class ShuffleBatchGradCUDAKernel : public framework::OpKernel { platform::ForRange for_range(dev_ctx, x_grad->numel()); for_range(functor); +#endif } }; -#else - -template -class ShuffleBatchCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "shuffle_batch op is only supported in CUDA devices")); - } -}; - -#endif - } // namespace operators } // namespace paddle From cfaad9076a5f0a11c6efa263a264bf09505712fa Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 5 Jul 2021 14:28:09 +0000 Subject: [PATCH 05/10] fix hip compilation error --- CMakeLists.txt | 2 ++ cmake/cuda.cmake | 23 ------------------ cmake/thrust.cmake | 24 +++++++++++++++++++ paddle/fluid/operators/shuffle_batch_op.cu | 14 ++++------- .../thrust/detail/shuffle.inl | 0 .../{cuda_includes => }/thrust/shuffle.h | 0 .../thrust/system/detail/generic/shuffle.h | 0 .../thrust/system/detail/generic/shuffle.inl | 0 8 files changed, 30 insertions(+), 33 deletions(-) create mode 100644 cmake/thrust.cmake rename patches/thrust/{cuda_includes => }/thrust/detail/shuffle.inl (100%) rename patches/thrust/{cuda_includes => }/thrust/shuffle.h (100%) rename patches/thrust/{cuda_includes => }/thrust/system/detail/generic/shuffle.h (100%) rename patches/thrust/{cuda_includes => }/thrust/system/detail/generic/shuffle.inl (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index f6b422f5bca403..f814167c9bf05b 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -274,6 +274,7 @@ endif() if(WITH_GPU) include(cuda) + include(thrust) # lite subgraph compilation depends on CUDNN_ROOT, # so include(cudnn) needs to be in front of include(third_party/lite) include(cudnn) # set cudnn libraries, must before configure @@ -286,6 +287,7 @@ endif() if(WITH_ROCM) include(hip) + include(thrust) include(miopen) # set miopen libraries, must before configure endif(WITH_ROCM) diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 75b37d28da15c9..af413bea43f16e 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -61,28 +61,6 @@ function(detect_installed_gpus out_variable) endif() endfunction() -function(add_thrust_patches_if_necessary) - set(thrust_detect_file ${PROJECT_BINARY_DIR}/detect_thrust.cu) - file(WRITE ${thrust_detect_file} "" - "#include \"thrust/version.h\"\n" - "#include \"thrust/shuffle.h\"\n" - "#include \"stdio.h\"\n" - "int main() {\n" - " int version = THRUST_VERSION;\n" - " printf(\"%d\", version);\n" - " return 0;\n" - "}\n") - - execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" - "--run" "${thrust_detect_file}" - WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/" - RESULT_VARIABLE nvcc_res ERROR_QUIET) - if(NOT nvcc_res EQUAL 0) - set(thrust_patches "${PADDLE_SOURCE_DIR}/patches/thrust/cuda_includes") - message(STATUS "Add thrust patches: ${thrust_patches}") - include_directories(${thrust_patches}) - endif() -endfunction() ######################################################################## # Function for selecting GPU arch flags for nvcc based on CUDA_ARCH_NAME @@ -256,4 +234,3 @@ endif() mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD) mark_as_advanced(CUDA_SDK_ROOT_DIR CUDA_SEPARABLE_COMPILATION) -add_thrust_patches_if_necessary() diff --git a/cmake/thrust.cmake b/cmake/thrust.cmake new file mode 100644 index 00000000000000..ff415b1e3c4bf6 --- /dev/null +++ b/cmake/thrust.cmake @@ -0,0 +1,24 @@ +function(add_thrust_patches_if_necessary) + set(thrust_detect_file ${PROJECT_BINARY_DIR}/detect_thrust.cu) + file(WRITE ${thrust_detect_file} "" + "#include \"thrust/version.h\"\n" + "#include \"thrust/shuffle.h\"\n" + "#include \"stdio.h\"\n" + "int main() {\n" + " int version = THRUST_VERSION;\n" + " printf(\"%d\", version);\n" + " return 0;\n" + "}\n") + + execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}" + "--run" "${thrust_detect_file}" + WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/" + RESULT_VARIABLE nvcc_res ERROR_QUIET) + if(NOT nvcc_res EQUAL 0) + set(thrust_patches "${PADDLE_SOURCE_DIR}/patches/thrust") + message(STATUS "Add thrust patches: ${thrust_patches}") + include_directories(${thrust_patches}) + endif() +endfunction() + +add_thrust_patches_if_necessary() diff --git a/paddle/fluid/operators/shuffle_batch_op.cu b/paddle/fluid/operators/shuffle_batch_op.cu index 84346833b0412d..6cd088a68968d9 100644 --- a/paddle/fluid/operators/shuffle_batch_op.cu +++ b/paddle/fluid/operators/shuffle_batch_op.cu @@ -48,10 +48,6 @@ template class ShuffleBatchCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { -#ifdef PADDLE_WITH_HIP - PADDLE_THROW(platform::errors::Unimplemented( - "shuffle_batch does not support to run on HIP devices yet")); -#else auto *x = ctx.Input("X"); auto *seed = ctx.Input("Seed"); auto *out = ctx.Output("Out"); @@ -85,7 +81,11 @@ class ShuffleBatchCUDAKernel : public framework::OpKernel { auto *shuffleidx_data = shuffleidx->mutable_data(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context(); +#ifdef PADDLE_WITH_CUDA const auto &exec_policy = thrust::cuda::par.on(dev_ctx.stream()); +#else + const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream()); +#endif thrust::default_random_engine engine(seed_int); thrust::counting_iterator cnt_iter(0); thrust::shuffle_copy(exec_policy, cnt_iter, cnt_iter + elem_size, @@ -102,7 +102,6 @@ class ShuffleBatchCUDAKernel : public framework::OpKernel { auto *seed_out_data = seed_out->mutable_data( framework::make_ddim({1}), platform::CPUPlace()); *seed_out_data = engine(); -#endif } }; @@ -110,10 +109,6 @@ template class ShuffleBatchGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { -#ifdef PADDLE_WITH_HIP - PADDLE_THROW(platform::errors::Unimplemented( - "shuffle_batch_grad does not support to run on HIP devices yet")); -#else const auto *out_grad = ctx.Input(framework::GradVarName("Out")); const auto *shuffleidx = ctx.Input("ShuffleIdx"); @@ -130,7 +125,6 @@ class ShuffleBatchGradCUDAKernel : public framework::OpKernel { platform::ForRange for_range(dev_ctx, x_grad->numel()); for_range(functor); -#endif } }; diff --git a/patches/thrust/cuda_includes/thrust/detail/shuffle.inl b/patches/thrust/thrust/detail/shuffle.inl similarity index 100% rename from patches/thrust/cuda_includes/thrust/detail/shuffle.inl rename to patches/thrust/thrust/detail/shuffle.inl diff --git a/patches/thrust/cuda_includes/thrust/shuffle.h b/patches/thrust/thrust/shuffle.h similarity index 100% rename from patches/thrust/cuda_includes/thrust/shuffle.h rename to patches/thrust/thrust/shuffle.h diff --git a/patches/thrust/cuda_includes/thrust/system/detail/generic/shuffle.h b/patches/thrust/thrust/system/detail/generic/shuffle.h similarity index 100% rename from patches/thrust/cuda_includes/thrust/system/detail/generic/shuffle.h rename to patches/thrust/thrust/system/detail/generic/shuffle.h diff --git a/patches/thrust/cuda_includes/thrust/system/detail/generic/shuffle.inl b/patches/thrust/thrust/system/detail/generic/shuffle.inl similarity index 100% rename from patches/thrust/cuda_includes/thrust/system/detail/generic/shuffle.inl rename to patches/thrust/thrust/system/detail/generic/shuffle.inl From 5a0fb953e94f5fe9833bc8dcbabef176cc66922b Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 5 Jul 2021 14:42:19 +0000 Subject: [PATCH 06/10] refine CMakeLists.txt --- CMakeLists.txt | 2 -- cmake/cuda.cmake | 4 +++- cmake/hip.cmake | 2 ++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f814167c9bf05b..f6b422f5bca403 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -274,7 +274,6 @@ endif() if(WITH_GPU) include(cuda) - include(thrust) # lite subgraph compilation depends on CUDNN_ROOT, # so include(cudnn) needs to be in front of include(third_party/lite) include(cudnn) # set cudnn libraries, must before configure @@ -287,7 +286,6 @@ endif() if(WITH_ROCM) include(hip) - include(thrust) include(miopen) # set miopen libraries, must before configure endif(WITH_ROCM) diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index af413bea43f16e..48e7bce1993325 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -217,7 +217,8 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") if(WIN32) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler \"/wd4244 /wd4267 /wd4819 \"") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /bigobj") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus") + add_definitions("-DTHRUST_CPP11_REQUIRED_NO_ERROR") + add_definitions("-DTHRUST_CPP14_REQUIRED_NO_ERROR") if(MSVC_STATIC_CRT) set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -Xcompiler /MTd") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler /MT") @@ -234,3 +235,4 @@ endif() mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD) mark_as_advanced(CUDA_SDK_ROOT_DIR CUDA_SEPARABLE_COMPILATION) +include(thrust) diff --git a/cmake/hip.cmake b/cmake/hip.cmake index 4c492d7cc48f06..514f5ea9deaa32 100644 --- a/cmake/hip.cmake +++ b/cmake/hip.cmake @@ -85,3 +85,5 @@ message(STATUS "HIP library name: ${hip_library_name}") # set HIP link libs find_library(ROCM_HIPRTC_LIB ${hip_library_name} HINTS ${HIP_PATH}/lib) message(STATUS "ROCM_HIPRTC_LIB: ${ROCM_HIPRTC_LIB}") + +include(thrust) From 568dd19804a84bfa313ef827cd878db01aff207d Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 5 Jul 2021 16:44:45 +0000 Subject: [PATCH 07/10] fix windows compile error --- paddle/fluid/operators/shuffle_batch_op.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/shuffle_batch_op.cu b/paddle/fluid/operators/shuffle_batch_op.cu index 6cd088a68968d9..f051ae0f5264dc 100644 --- a/paddle/fluid/operators/shuffle_batch_op.cu +++ b/paddle/fluid/operators/shuffle_batch_op.cu @@ -16,6 +16,7 @@ #include #include +#include #include #include "paddle/fluid/operators/shuffle_batch_op.h" #include "paddle/fluid/platform/for_range.h" @@ -86,7 +87,7 @@ class ShuffleBatchCUDAKernel : public framework::OpKernel { #else const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream()); #endif - thrust::default_random_engine engine(seed_int); + thrust::random::default_random_engine engine(seed_int); thrust::counting_iterator cnt_iter(0); thrust::shuffle_copy(exec_policy, cnt_iter, cnt_iter + elem_size, thrust::device_pointer_cast(shuffleidx_data), engine); From 76ace65bb2958f2f2c3414e89e11ace0d6bdd5ce Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Tue, 6 Jul 2021 04:10:16 +0000 Subject: [PATCH 08/10] try to fix windows CI compilation error --- cmake/cuda.cmake | 2 -- paddle/fluid/operators/shuffle_batch_op.cu | 8 ++++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 48e7bce1993325..a79d566f3a54f2 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -217,8 +217,6 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") if(WIN32) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler \"/wd4244 /wd4267 /wd4819 \"") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /bigobj") - add_definitions("-DTHRUST_CPP11_REQUIRED_NO_ERROR") - add_definitions("-DTHRUST_CPP14_REQUIRED_NO_ERROR") if(MSVC_STATIC_CRT) set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -Xcompiler /MTd") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler /MT") diff --git a/paddle/fluid/operators/shuffle_batch_op.cu b/paddle/fluid/operators/shuffle_batch_op.cu index f051ae0f5264dc..99d1f83355a5a5 100644 --- a/paddle/fluid/operators/shuffle_batch_op.cu +++ b/paddle/fluid/operators/shuffle_batch_op.cu @@ -14,6 +14,14 @@ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +// For MSVC, define __cplusplus to 201402L directly, +// otherwise, thrust would raise compilation error. +// See: +// https://docs.microsoft.com/en-us/cpp/build/reference/zc-cplusplus?view=msvc-160 +#if defined(_MSVC_LANG) && __cplusplus < 201103L +#define __cplusplus 201402L +#endif + #include #include #include From aafcbffe05e38dee27e70b963d72f1917bac1f99 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Tue, 6 Jul 2021 04:40:45 +0000 Subject: [PATCH 09/10] fix windows compilation again --- paddle/fluid/operators/shuffle_batch_op.cu | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/shuffle_batch_op.cu b/paddle/fluid/operators/shuffle_batch_op.cu index 99d1f83355a5a5..02210e64fb4398 100644 --- a/paddle/fluid/operators/shuffle_batch_op.cu +++ b/paddle/fluid/operators/shuffle_batch_op.cu @@ -14,18 +14,13 @@ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -// For MSVC, define __cplusplus to 201402L directly, -// otherwise, thrust would raise compilation error. -// See: -// https://docs.microsoft.com/en-us/cpp/build/reference/zc-cplusplus?view=msvc-160 -#if defined(_MSVC_LANG) && __cplusplus < 201103L -#define __cplusplus 201402L -#endif - +#ifndef _MSC_VER #include #include #include #include +#endif + #include "paddle/fluid/operators/shuffle_batch_op.h" #include "paddle/fluid/platform/for_range.h" @@ -57,6 +52,10 @@ template class ShuffleBatchCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { +#ifdef _MSC_VER + PADDLE_THROW(platform::errors::Unimplemented( + "GPU shuffle_batch is not supported on Windows yet")); +#else auto *x = ctx.Input("X"); auto *seed = ctx.Input("Seed"); auto *out = ctx.Output("Out"); @@ -111,6 +110,7 @@ class ShuffleBatchCUDAKernel : public framework::OpKernel { auto *seed_out_data = seed_out->mutable_data( framework::make_ddim({1}), platform::CPUPlace()); *seed_out_data = engine(); +#endif } }; @@ -118,6 +118,10 @@ template class ShuffleBatchGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { +#ifdef _MSC_VER + PADDLE_THROW(platform::errors::Unimplemented( + "GPU shuffle_batch_grad is not supported on Windows yet")); +#else const auto *out_grad = ctx.Input(framework::GradVarName("Out")); const auto *shuffleidx = ctx.Input("ShuffleIdx"); @@ -134,6 +138,7 @@ class ShuffleBatchGradCUDAKernel : public framework::OpKernel { platform::ForRange for_range(dev_ctx, x_grad->numel()); for_range(functor); +#endif } }; From 5957ec77232cd818328f716a617a8ced7812526b Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Tue, 6 Jul 2021 06:14:12 +0000 Subject: [PATCH 10/10] fix shuffle_batch op test on Windows --- .../paddle/fluid/tests/unittests/test_shuffle_batch_op.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_shuffle_batch_op.py b/python/paddle/fluid/tests/unittests/test_shuffle_batch_op.py index 79ef1e9c79dc23..62c26a73a8d434 100644 --- a/python/paddle/fluid/tests/unittests/test_shuffle_batch_op.py +++ b/python/paddle/fluid/tests/unittests/test_shuffle_batch_op.py @@ -20,6 +20,7 @@ import paddle.fluid.core as core import paddle.fluid.layers as layers from op_test import OpTest +import os import random @@ -31,6 +32,12 @@ def gen_random_array(self, shape, low=0, high=1): def get_shape(self): return (10, 10, 5) + def _get_places(self): + # NOTE: shuffle_batch is not supported on Windows + if os.name == 'nt': + return [fluid.CPUPlace()] + return super(TestShuffleBatchOpBase, self)._get_places() + def setUp(self): self.op_type = 'shuffle_batch' self.dtype = np.float64