Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmake/cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ function(select_nvcc_arch_flags out_variable)
elseif(${CUDA_ARCH_NAME} STREQUAL "Maxwell")
set(cuda_arch_bin "50")
elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal")
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} LESS 10.0)
add_definitions("-DSUPPORTS_CUDA_FP16")
endif()
set(cuda_arch_bin "60 61")
elseif(${CUDA_ARCH_NAME} STREQUAL "Volta")
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} LESS 10.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;

if (engine_->with_dynamic_shape()) {
plugin::DynamicPluginTensorRT* plugin = nullptr;
plugin = new plugin::EmbEltwiseLayernormPluginDynamic<float>(
auto use_fp16 = engine_->WithFp16();
auto plugin = new plugin::EmbEltwiseLayernormPluginDynamic(
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
eps);
eps, use_fp16);
layer = engine_->AddPluginV2(input_ids.data(), input_num, plugin);
} else {
PADDLE_THROW(platform::errors::Fatal(
Expand Down
214 changes: 133 additions & 81 deletions paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,34 @@ namespace plugin {
#if IS_TRT_VERSION_GE(6000)

template <typename T>
int EmbEltwiseLayernormPluginDynamic<T>::initialize() {
EmbEltwiseLayernormPluginDynamicImpl<
T>::~EmbEltwiseLayernormPluginDynamicImpl() {
this->terminate();
}

inline half fp32tofp16(float x) { return static_cast<half>(x); }

template <typename T>
int EmbEltwiseLayernormPluginDynamicImpl<T>::initialize() {
embs_gpu_.resize(embs_.size());
for (int i = 0; i < embs_.size(); i++) {
if (embs_[i]) {
cudaMalloc(&embs_gpu_[i], sizeof(float) * emb_sizes_[i]);
cudaMemcpy(embs_gpu_[i], embs_[i], emb_sizes_[i] * sizeof(float),
T *host_ptr;
auto size = emb_sizes_[i];

if (std::is_same<T, half>::value) {
host_ptr = new T[size];
std::transform(embs_[i], (embs_[i] + size), host_ptr, fp32tofp16);
} else {
host_ptr = reinterpret_cast<T *>(embs_[i]);
}

cudaMalloc(&embs_gpu_[i], sizeof(T) * size);
cudaMemcpy(embs_gpu_[i], host_ptr, size * sizeof(T),
cudaMemcpyHostToDevice);
if (std::is_same<T, half>::value) {
delete[] host_ptr;
}
}
}

Expand All @@ -53,11 +74,105 @@ int EmbEltwiseLayernormPluginDynamic<T>::initialize() {
cudaMemcpyHostToDevice);
}

int input_num = embs_.size();
in_ptr_tensor_.Resize({input_num});
emb_ptr_tensor_.Resize({input_num});

cudaGetDevice(&device_id_);
auto emb_ptr_gpu_d =
emb_ptr_tensor_.mutable_data<int64_t>(platform::CUDAPlace(device_id_));
cudaMemcpy(emb_ptr_gpu_d, embs_gpu_.data(), sizeof(uintptr_t) * input_num,
cudaMemcpyHostToDevice);

return 0;
}

template <typename T>
nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions(
void EmbEltwiseLayernormPluginDynamicImpl<T>::terminate() {
for (int i = 0; i < embs_gpu_.size(); ++i) {
if (embs_gpu_[i]) {
cudaFree(embs_gpu_[i]);
embs_gpu_[i] = nullptr;
}
}

if (bias_gpu_) {
cudaFree(bias_gpu_);
bias_gpu_ = nullptr;
}

if (scale_gpu_) {
cudaFree(scale_gpu_);
scale_gpu_ = nullptr;
}
}

template <typename T>
int EmbEltwiseLayernormPluginDynamicImpl<T>::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) {
auto id_dims = input_desc[0].dims;
int batch = id_dims.d[0];
int seq_len = id_dims.d[1];
int input_num = embs_.size();

auto in_ptr_gpu_d =
in_ptr_tensor_.mutable_data<int64_t>(platform::CUDAPlace(device_id_));
auto emb_ptr_gpu_d =
emb_ptr_tensor_.mutable_data<int64_t>(platform::CUDAPlace(device_id_));

auto new_input_ptr = reinterpret_cast<uintptr_t>(inputs[0]);

if (old_input_ptr_ != new_input_ptr) {
old_input_ptr_ = new_input_ptr;

cudaMemcpyAsync(in_ptr_gpu_d, reinterpret_cast<const void *>(inputs),
sizeof(uintptr_t) * input_num, cudaMemcpyHostToDevice,
stream);
}

auto out_type = output_desc[0].type;

if (std::is_same<T, float>::value) {
PADDLE_ENFORCE_EQ(
out_type == nvinfer1::DataType::kFLOAT, true,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only support fp32 input."));
} else if (std::is_same<T, half>::value) {
PADDLE_ENFORCE_EQ(
out_type == nvinfer1::DataType::kHALF, true,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only support fp16 input."));
} else {
PADDLE_THROW(platform::errors::Fatal(
"Unsupport data type, the out type of EmbEltwiseLayernorm should be "
"float or half."));
}

auto *output_d = reinterpret_cast<T *>(outputs[0]);

operators::math::EmbEltwiseLayerNormFunctor<T> emb_eltwise_layernorm_func;
emb_eltwise_layernorm_func(batch, seq_len, hidden_size_, in_ptr_gpu_d,
scale_gpu_, bias_gpu_, emb_ptr_gpu_d, output_d,
eps_, input_num, stream);
return cudaGetLastError() != cudaSuccess;
}

template class EmbEltwiseLayernormPluginDynamicImpl<float>;
#ifdef SUPPORTS_CUDA_FP16
template class EmbEltwiseLayernormPluginDynamicImpl<half>;
#endif // SUPPORTS_CUDA_FP16

int EmbEltwiseLayernormPluginDynamic::initialize() {
impl_->initialize();

return 0;
}

void EmbEltwiseLayernormPluginDynamic::terminate() { impl_->terminate(); }

nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) { // NOLINT
PADDLE_ENFORCE_EQ(output_index, 0,
Expand All @@ -76,18 +191,7 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions(
return ret;
}

template <typename T>
void EmbEltwiseLayernormPluginDynamic<T>::terminate() {
for (auto ptr : embs_gpu_) {
if (ptr) cudaFree(ptr);
}

if (bias_gpu_) cudaFree(bias_gpu_);
if (scale_gpu_) cudaFree(scale_gpu_);
}

template <typename T>
bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
int nb_outputs) {
PADDLE_ENFORCE_NOT_NULL(
Expand All @@ -98,6 +202,11 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
"The EmbEltwiseLayerNorm's output should be one"
"but it's (%d) outputs.",
nb_outputs));
PADDLE_ENFORCE_EQ(nb_outputs, 1,
platform::errors::InvalidArgument(
"The EmbEltwiseLayerNorm's output should be one"
"but it's (%d) outputs.",
nb_outputs));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
Expand All @@ -122,7 +231,7 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
}

if (pos == all_nums - 1) {
if (sizeof(T) == sizeof(float)) {
if (with_fp16_ == false) {
return desc.type == nvinfer1::DataType::kFLOAT;
} else {
return desc.type == nvinfer1::DataType::kHALF;
Expand All @@ -131,84 +240,27 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
return false;
}

template <typename T>
nvinfer1::DataType EmbEltwiseLayernormPluginDynamic<T>::getOutputDataType(
nvinfer1::DataType EmbEltwiseLayernormPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(
index, 0, platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return nvinfer1::DataType::kFLOAT;
if (with_fp16_)
return nvinfer1::DataType::kHALF;
else
return nvinfer1::DataType::kFLOAT;
}

template <typename T>
int EmbEltwiseLayernormPluginDynamic<T>::enqueue(
int EmbEltwiseLayernormPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) {
auto id_dims = input_desc[0].dims;
int batch = id_dims.d[0];
int seq_len = id_dims.d[1];
int input_num = embs_.size();

framework::Tensor in_ptr_tensor, emb_ptr_tensor;
int device_id;
cudaGetDevice(&device_id);

in_ptr_tensor.Resize({input_num});
emb_ptr_tensor.Resize({input_num});
int64_t *in_ptr_gpu_d =
in_ptr_tensor.mutable_data<int64_t>(platform::CUDAPlace(device_id));
int64_t *emb_ptr_gpu_d =
emb_ptr_tensor.mutable_data<int64_t>(platform::CUDAPlace(device_id));

std::vector<uintptr_t> in_ptr, emb_ptr;
for (int i = 0; i < input_num; i++) {
in_ptr.push_back(reinterpret_cast<uintptr_t>(inputs[i]));
emb_ptr.push_back(reinterpret_cast<uintptr_t>(embs_gpu_[i]));
}

cudaMemcpyAsync(in_ptr_gpu_d, in_ptr.data(), sizeof(int64_t) * input_num,
cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(emb_ptr_gpu_d, emb_ptr.data(), sizeof(int64_t) * input_num,
cudaMemcpyHostToDevice, stream);

auto out_type = output_desc[0].type;

const unsigned tpb = 256;
const dim3 grid(seq_len, batch, 1);
const dim3 block(tpb, 1, 1);
if (sizeof(T) == sizeof(float)) {
PADDLE_ENFORCE_EQ(
out_type == nvinfer1::DataType::kFLOAT, true,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only support fp32 input."));
} else if (sizeof(T) == sizeof(int16_t)) {
PADDLE_ENFORCE_EQ(
out_type == nvinfer1::DataType::kHALF, true,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only support fp16 input."));
} else {
PADDLE_THROW(platform::errors::Fatal(
"Unsupport data type, the out type of EmbEltwiseLayernorm should be "
"float or half."));
}

T *output_d = static_cast<T *>(outputs[0]);

operators::math::EmbEltwiseLayerNormFunctor<T> emb_eltwise_layernorm_func;
emb_eltwise_layernorm_func(batch, seq_len, hidden_size_, in_ptr_gpu_d,
scale_gpu_, bias_gpu_, emb_ptr_gpu_d, output_d,
eps_, input_num, stream);
impl_->enqueue(input_desc, output_desc, inputs, outputs, workspace, stream);
return cudaGetLastError() != cudaSuccess;
}

template class EmbEltwiseLayernormPluginDynamic<float>;
#ifdef SUPPORTS_CUDA_FP16
template class EmbEltwiseLayernormPluginDynamic<half>;
#endif // SUPPORTS_CUDA_FP16

#endif

} // namespace plugin
Expand Down
Loading