Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,4 @@ endif()
mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD)
mark_as_advanced(CUDA_SDK_ROOT_DIR CUDA_SEPARABLE_COMPILATION)

include(thrust)
2 changes: 2 additions & 0 deletions cmake/hip.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,5 @@ message(STATUS "HIP library name: ${hip_library_name}")
# set HIP link libs
find_library(ROCM_HIPRTC_LIB ${hip_library_name} HINTS ${HIP_PATH}/lib)
message(STATUS "ROCM_HIPRTC_LIB: ${ROCM_HIPRTC_LIB}")

include(thrust)
24 changes: 24 additions & 0 deletions cmake/thrust.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
function(add_thrust_patches_if_necessary)
set(thrust_detect_file ${PROJECT_BINARY_DIR}/detect_thrust.cu)
file(WRITE ${thrust_detect_file} ""
"#include \"thrust/version.h\"\n"
"#include \"thrust/shuffle.h\"\n"
"#include \"stdio.h\"\n"
"int main() {\n"
" int version = THRUST_VERSION;\n"
" printf(\"%d\", version);\n"
" return 0;\n"
"}\n")

execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}"
"--run" "${thrust_detect_file}"
WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/"
RESULT_VARIABLE nvcc_res ERROR_QUIET)
if(NOT nvcc_res EQUAL 0)
set(thrust_patches "${PADDLE_SOURCE_DIR}/patches/thrust")
message(STATUS "Add thrust patches: ${thrust_patches}")
include_directories(${thrust_patches})
endif()
endfunction()

add_thrust_patches_if_necessary()
10 changes: 10 additions & 0 deletions paddle/fluid/operators/shuffle_batch_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ class ShuffleBatchOp : public framework::OperatorWithKernel {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}

framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "Seed") {
return expected_kernel_type;
}
return framework::OperatorWithKernel::GetKernelTypeForVar(
var_name, tensor, expected_kernel_type);
}
};

class ShuffleBatchOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down
159 changes: 159 additions & 0 deletions paddle/fluid/operators/shuffle_batch_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)

#ifndef _MSC_VER
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/shuffle.h>
#endif

#include "paddle/fluid/operators/shuffle_batch_op.h"
#include "paddle/fluid/platform/for_range.h"

namespace paddle {
namespace operators {

template <typename T, bool kIsForward>
struct ReorderFunctor {
ReorderFunctor(const T *x, const int64_t *shuffle_idx, T *y, int64_t stride)
: x_(x), shuffle_idx_(shuffle_idx), y_(y), stride_(stride) {}

HOSTDEVICE void operator()(int64_t idx) {
auto reorder_idx = shuffle_idx_[idx / stride_] * stride_ + idx % stride_;
if (kIsForward) {
y_[idx] = x_[reorder_idx];
} else {
y_[reorder_idx] = x_[idx];
}
}

private:
const T *x_;
const int64_t *shuffle_idx_;
T *y_;
int64_t stride_;
};

template <typename T>
class ShuffleBatchCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
#ifdef _MSC_VER
PADDLE_THROW(platform::errors::Unimplemented(
"GPU shuffle_batch is not supported on Windows yet"));
#else
auto *x = ctx.Input<framework::Tensor>("X");
auto *seed = ctx.Input<framework::Tensor>("Seed");
auto *out = ctx.Output<framework::Tensor>("Out");
auto *shuffleidx = ctx.Output<framework::Tensor>("ShuffleIdx");
auto *seed_out = ctx.Output<framework::Tensor>("SeedOut");

int64_t x_embed_size = x->dims()[x->dims().size() - 1];
int64_t elem_size = 1;
for (int i = 0; i < x->dims().size() - 1; i++) {
elem_size *= x->dims()[i];
}
shuffleidx->Resize(framework::make_ddim({elem_size}));

int64_t seed_int = 0;
if (seed->IsInitialized()) {
const auto &seed_place = seed->place();
if (platform::is_gpu_place(seed_place)) {
// NOTE: We have overwritten GetKernelTypeForVar, so seed_place would
// not be CUDAPlace in practice. This case would only happen in Python
// op_test framework.
framework::Tensor tmp_tensor;
framework::TensorCopySync(*seed, platform::CPUPlace(), &tmp_tensor);
seed_int = *(tmp_tensor.data<int64_t>());
} else {
seed_int = *(seed->data<int64_t>());
}
} else {
seed_int = ctx.Attr<int>("startup_seed");
}

auto *shuffleidx_data = shuffleidx->mutable_data<int64_t>(ctx.GetPlace());

auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
#ifdef PADDLE_WITH_CUDA
const auto &exec_policy = thrust::cuda::par.on(dev_ctx.stream());
#else
const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream());
#endif
thrust::random::default_random_engine engine(seed_int);
thrust::counting_iterator<int64_t> cnt_iter(0);
thrust::shuffle_copy(exec_policy, cnt_iter, cnt_iter + elem_size,
thrust::device_pointer_cast(shuffleidx_data), engine);
// TODO(zengjinle): for small data, direct cudaMemcpy may be better
auto *x_data = x->data<T>();
auto *out_data = out->mutable_data<T>(ctx.GetPlace());
ReorderFunctor<T, true> functor(x_data, shuffleidx_data, out_data,
x_embed_size);
platform::ForRange<platform::CUDADeviceContext> for_range(
dev_ctx, elem_size * x_embed_size);
for_range(functor);

auto *seed_out_data = seed_out->mutable_data<int64_t>(
framework::make_ddim({1}), platform::CPUPlace());
*seed_out_data = engine();
#endif
}
};

template <typename T>
class ShuffleBatchGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
#ifdef _MSC_VER
PADDLE_THROW(platform::errors::Unimplemented(
"GPU shuffle_batch_grad is not supported on Windows yet"));
#else
const auto *out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
const auto *shuffleidx = ctx.Input<framework::Tensor>("ShuffleIdx");
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));

const auto *out_grad_data = out_grad->data<T>();
const auto *shuffleidx_data = shuffleidx->data<int64_t>();
auto *x_grad_data = x_grad->mutable_data<T>(ctx.GetPlace());
auto x_embed_size = x_grad->dims()[x_grad->dims().size() - 1];
ReorderFunctor<T, false> functor(out_grad_data, shuffleidx_data,
x_grad_data, x_embed_size);
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
// TODO(zengjinle): for small data, direct cudaMemcpy may be better
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
x_grad->numel());
for_range(functor);
#endif
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(shuffle_batch, ops::ShuffleBatchCUDAKernel<float>,
ops::ShuffleBatchCUDAKernel<double>,
ops::ShuffleBatchCUDAKernel<int32_t>,
ops::ShuffleBatchCUDAKernel<int64_t>);

REGISTER_OP_CUDA_KERNEL(shuffle_batch_grad,
ops::ShuffleBatchGradCUDAKernel<float>,
ops::ShuffleBatchGradCUDAKernel<double>,
ops::ShuffleBatchGradCUDAKernel<int32_t>,
ops::ShuffleBatchGradCUDAKernel<int64_t>);
#endif
85 changes: 85 additions & 0 deletions patches/thrust/thrust/detail/shuffle.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2008-2020 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*! \file shuffle.inl
* \brief Inline file for shuffle.h.
*/

#include <thrust/detail/config.h>
#include <thrust/detail/cpp11_required.h>

#if THRUST_CPP_DIALECT >= 2011

#include <thrust/iterator/iterator_traits.h>
#include <thrust/shuffle.h>
#include <thrust/system/detail/generic/select_system.h>
#include <thrust/system/detail/generic/shuffle.h>

namespace thrust {

__thrust_exec_check_disable__
template <typename DerivedPolicy, typename RandomIterator, typename URBG>
__host__ __device__ void shuffle(
const thrust::detail::execution_policy_base<DerivedPolicy>& exec,
RandomIterator first, RandomIterator last, URBG&& g) {
using thrust::system::detail::generic::shuffle;
return shuffle(
thrust::detail::derived_cast(thrust::detail::strip_const(exec)),
first, last, g);
}

template <typename RandomIterator, typename URBG>
__host__ __device__ void shuffle(RandomIterator first, RandomIterator last,
URBG&& g) {
using thrust::system::detail::generic::select_system;

typedef typename thrust::iterator_system<RandomIterator>::type System;
System system;

return thrust::shuffle(select_system(system), first, last, g);
}

__thrust_exec_check_disable__
template <typename DerivedPolicy, typename RandomIterator,
typename OutputIterator, typename URBG>
__host__ __device__ void shuffle_copy(
const thrust::detail::execution_policy_base<DerivedPolicy>& exec,
RandomIterator first, RandomIterator last, OutputIterator result,
URBG&& g) {
using thrust::system::detail::generic::shuffle_copy;
return shuffle_copy(
thrust::detail::derived_cast(thrust::detail::strip_const(exec)),
first, last, result, g);
}

template <typename RandomIterator, typename OutputIterator, typename URBG>
__host__ __device__ void shuffle_copy(RandomIterator first, RandomIterator last,
OutputIterator result, URBG&& g) {
using thrust::system::detail::generic::select_system;

typedef typename thrust::iterator_system<RandomIterator>::type System1;
typedef typename thrust::iterator_system<OutputIterator>::type System2;

System1 system1;
System2 system2;

return thrust::shuffle_copy(select_system(system1, system2), first, last,
result, g);
}

} // namespace thrust

#endif
Loading