Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ function(op_library TARGET)
list(REMOVE_ITEM miopen_cu_cc_srcs "affine_grid_cudnn_op.cu.cc")
list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc")
list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
list(REMOVE_ITEM hip_srcs "correlation_op.cu")
list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/imperative/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists(
if (gcs_.count(place) == 0) {
std::unique_ptr<framework::GarbageCollector> gc;
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gc.reset(new framework::DefaultStreamGarbageCollector(
BOOST_GET_CONST(platform::CUDAPlace, place), 0));

Expand All @@ -94,7 +94,7 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists(
"Please recompile or reinstall Paddle with GPU support."));
#endif
} else if (platform::is_cuda_pinned_place(place)) {
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gc.reset(new framework::CUDAPinnedGarbageCollector(
BOOST_GET_CONST(platform::CUDAPinnedPlace, place), 0));

Expand Down
49 changes: 38 additions & 11 deletions paddle/fluid/operators/conv_cudnn_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -699,24 +699,51 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {

// ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f;
#ifdef PADDLE_WITH_HIP
// MIOPEN ONLY support beta to be 0.0f
ScalingParamType<T> beta = 0.0f;
#else
ScalingParamType<T> beta = ctx.Attr<bool>("use_addto") ? 1.0f : 0.0f;
#endif
VLOG(4) << "Conv_grad: use_addto = " << ctx.Attr<bool>("use_addto");

if (input_grad) {
// When beta is 0, it is unnecessary to reset input_grad.
// When beta is 1, the output cannot be reset since addt strategy used.
#ifdef PADDLE_WITH_HIP
workspace_handle.RunFunc(
[&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionBackwardData(
handle, &alpha, args1.odesc.desc(), output_grad_data,
args1.wdesc.desc(), filter_data, args1.cdesc.desc(),
data_algo, &beta, args1.idesc.desc(),
transformed_input_grad_data, cudnn_workspace_ptr,
workspace_size));
},
workspace_size);
if (ctx.Attr<bool>("use_addto")) {
Tensor temp_tensor(transformed_input_grad.type());
temp_tensor.Resize(transformed_input_grad.dims());
T* temp_tensor_data = temp_tensor.mutable_data<T>(ctx.GetPlace());
workspace_handle.RunFunc(
[&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionBackwardData(
handle, &alpha, args1.odesc.desc(), output_grad_data,
args1.wdesc.desc(), filter_data, args1.cdesc.desc(),
data_algo, &beta, args1.idesc.desc(), temp_tensor_data,
cudnn_workspace_ptr, workspace_size));
},
workspace_size);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenOpTensor(
handle, miopenTensorOpAdd, &alpha, args1.idesc.desc(),
transformed_input_grad_data, &alpha, args1.idesc.desc(),
temp_tensor_data, &beta, args1.idesc.desc(),
transformed_input_grad_data));
} else {
workspace_handle.RunFunc(
[&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionBackwardData(
handle, &alpha, args1.odesc.desc(), output_grad_data,
args1.wdesc.desc(), filter_data, args1.cdesc.desc(),
data_algo, &beta, args1.idesc.desc(),
transformed_input_grad_data, cudnn_workspace_ptr,
workspace_size));
},
workspace_size);
}

#else
for (int i = 0; i < groups; i++) {
workspace_handle.RunFunc(
Expand Down
70 changes: 6 additions & 64 deletions paddle/fluid/operators/conv_miopen_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,28 +146,8 @@ struct SearchAlgorithm<miopenConvFwdAlgorithm_t> {
cudnn_workspace_ptr, workspace_size, false));
};

if (!exhaustive_search && !deterministic) {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.fwd_algo;
} else {
auto& temp = ctx.cuda_device_context();
AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetForward());

auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());

VLOG(10) << "miopenConvolutionFwdAlgoPerf_t:"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d;

algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
return find_result.fwd_algo;
});
}
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.fwd_algo;
VLOG(3) << "choose algo " << algo;
return algo;
}
Expand Down Expand Up @@ -208,27 +188,8 @@ struct SearchAlgorithm<miopenConvBwdDataAlgorithm_t> {
cudnn_workspace_ptr, workspace_size, false));
};

if (!exhaustive_search && !deterministic) {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.bwd_data_algo;
} else {
AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetBackwardData());

auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());

VLOG(10) << "miopenConvolutionFwdAlgoPerf_t"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d;

algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
return find_result.bwd_data_algo;
});
}
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.bwd_data_algo;
VLOG(3) << "choose algo " << algo;
return algo;
}
Expand Down Expand Up @@ -269,27 +230,8 @@ struct SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t> {
cudnn_workspace_ptr, workspace_size, false));
};

if (!exhaustive_search && !deterministic) {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.bwd_weights_algo;
} else {
AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetBackwardFilter());

auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());

VLOG(10) << "miopenConvolutionFwdAlgoPerf_t:"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d;

algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
return find_result.bwd_weights_algo;
});
}
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.bwd_weights_algo;
VLOG(3) << "choose algo " << algo;
return algo;
}
Expand Down
21 changes: 16 additions & 5 deletions paddle/fluid/operators/correlation_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,45 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef PADDLE_WITH_HIP
// HIP not supported yet

#include <algorithm>
#include <string>
#include "paddle/fluid/framework/op_registry.h"

#ifdef __HIPCC__
#define __syncwarp() __all(1)
#endif

namespace paddle {
namespace operators {

#ifdef __HIPCC__
#define THREADS_PER_BLOCK 64
#else
#define THREADS_PER_BLOCK 32
#endif
#define FULL_MASK 0xffffffff

using framework::Tensor;

template <typename T>
__forceinline__ __device__ T warpReduceSum(T val) {
for (int offset = 16; offset > 0; offset /= 2) {
#ifdef __HIPCC__
val += __shfl_down(val, offset);
#else
val += __shfl_down_sync(FULL_MASK, val, offset);
#endif
}
return val;
}

template <typename T>
__forceinline__ __device__ T blockReduceSum(T val) {
#ifdef __HIPCC__
static __shared__ T shared[64];
#else
static __shared__ T shared[32];
#endif
int lane = threadIdx.x % warpSize;
int wid = threadIdx.x / warpSize;

Expand Down Expand Up @@ -483,5 +496,3 @@ REGISTER_OP_CUDA_KERNEL(correlation, ops::CorrelationCUDAKernel<float>,
ops::CorrelationCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(correlation_grad, ops::CorrelationCUDAGradKernel<float>,
ops::CorrelationCUDAGradKernel<double>);

#endif // not PADDLE_WITH_HIP
3 changes: 1 addition & 2 deletions paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ if (WITH_GPU OR WITH_ROCM)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_batch_norm_act);\n")
endif()
# conv_fusion_op needs cudnn 7 above
# HIP not support cudnnConvolutionBiasActivationForward
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7100))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要保留NOT ${CUDNN_VERSION} VERSION_LESS 7100

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更新

if (NOT ${CUDNN_VERSION} VERSION_LESS 7100)
op_library(conv_fusion_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n")
endif()
Expand Down
83 changes: 81 additions & 2 deletions paddle/fluid/operators/fused/conv_fusion_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
#include "paddle/fluid/platform/cudnn_helper.h"
#endif

DECLARE_int64(cudnn_exhaustive_search_times);

namespace paddle {
namespace operators {

#if CUDNN_VERSION >= 7100
#if PADDLE_WITH_HIP || CUDNN_VERSION >= 7100
using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
Expand Down Expand Up @@ -162,7 +166,78 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
if (input->dims().size() == 5) {
layout = DataLayout::kNCDHW;
}
#ifdef PADDLE_WITH_HIP
miopenConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(padding_common, strides, dilations);
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenSetConvolutionGroupCount(cudnn_conv_desc,
groups));
// Now only support NCHW
std::vector<int> bias_dim = {
1, static_cast<int>(transformed_output.dims()[1]), 1, 1};
miopenTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize<int>(transformed_input.dims()));
miopenTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize<int>(transformed_output.dims()));
miopenTensorDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize<int>(filter->dims()));
miopenTensorDescriptor_t cudnn_bias_desc =
bias_desc.descriptor<T>(layout, bias_dim);
miopenActivationDescriptor_t cudnn_act_desc =
act_desc.descriptor<T>(activation);

miopenConvFwdAlgorithm_t algo;
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();

auto x_dims = framework::vectorize(transformed_input.dims());
auto f_dims = framework::vectorize(filter->dims());

size_t workspace_size = 0;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionForwardGetWorkSpaceSize(
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_output_desc, &workspace_size));
int find_count;
miopenConvAlgoPerf_t find_result;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenFindConvolutionForwardAlgorithm(
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc, output_data,
kNUM_CUDNN_FWD_ALGS, &find_count, &find_result,
cudnn_workspace_ptr, workspace_size, false));
};
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.fwd_algo;
VLOG(3) << "cuDNN forward algo " << algo;

{
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto cudnn_func = [&](void* cudnn_workspace) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, algo, &beta, cudnn_output_desc,
output_data, cudnn_workspace, workspace_size));
};
workspace_handle.RunFunc(cudnn_func, workspace_size);
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionForwardBias(
handle, &alpha, cudnn_bias_desc, bias_data, &beta,
cudnn_output_desc, output_data));
if (activation != "identity") {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenActivationForward(
handle, cudnn_act_desc, &alpha, cudnn_output_desc, output_data,
&beta, cudnn_output_desc, output_data));
}
if (residual) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenOpTensor(
handle, miopenTensorOpAdd, &alpha, cudnn_output_desc, output_data,
&alpha, cudnn_output_desc, residual_data, &beta, cudnn_output_desc,
output_data));
}
}
#else // PADDLE_WITH_HIP
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(padding_common, strides, dilations);
PADDLE_ENFORCE_CUDA_SUCCESS(
Expand Down Expand Up @@ -327,6 +402,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
#endif
std::vector<int> channels = ctx.Attr<std::vector<int>>("split_channels");
if (channels.size()) {
auto outs = ctx.MultiOutput<framework::Tensor>("Outputs");
Expand Down Expand Up @@ -358,8 +434,11 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle

#if CUDNN_VERSION >= 7100
namespace ops = paddle::operators;
#if CUDNN_VERSION >= 7100
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>,
ops::CUDNNConvFusionOpKernel<double>);
#endif
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>);
#endif
8 changes: 8 additions & 0 deletions paddle/fluid/operators/math/unpooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ class Unpool2dMaxFunctor<platform::CUDADeviceContext, T> {
const T* input_data = input.data<T>();
const int* indices_data = indices.data<int>();
T* output_data = output->mutable_data<T>(context.GetPlace());
#ifdef __HIPCC__
int threads = 256;
#else
int threads = 1024;
#endif
int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMax<T><<<grid, threads, 0, context.stream()>>>(
input.numel(), input_data, indices_data, input_height, input_width,
Expand Down Expand Up @@ -117,7 +121,11 @@ class Unpool2dMaxGradFunctor<platform::CUDADeviceContext, T> {
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
#ifdef __HIPCC__
int threads = 256;
#else
int threads = 1024;
#endif
int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMaxGrad<T><<<grid, threads, 0, context.stream()>>>(
input.numel(), input_data, indices_data, input_height, input_width,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/memcpy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double,
ops::MemcpyKernel, plat::float16,
ops::MemcpyKernel);

#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_ROCM)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double,
ops::MemcpyKernel, int, ops::MemcpyKernel,
int64_t, ops::MemcpyKernel, bool,
Expand Down
Loading