From dc979d0cc5c5e065a719f45b9957c843f0c2cb30 Mon Sep 17 00:00:00 2001 From: fengshuai03 Date: Wed, 15 Sep 2021 16:11:55 +0000 Subject: [PATCH 1/2] broadcast qkv_op --- .../tensorrt/plugin/qkv_to_context_plugin.cu | 48 ++++++++++++++++++- .../operators/fused/multihead_matmul_op.cu | 29 ++++++++++- 2 files changed, 73 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index 0d978939c4bf35..c097d8afa147c3 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -233,6 +233,21 @@ __global__ void apply_scale(T *data, T scale, int n) { #endif } +inline int round_up(int seq_len, int multiple = 32) { + assert(multiple); + return ((seq_len + multiple - 1) / multiple) * multiple; +} + +template +__global__ void broadcast(const T *src, T *dst, const int seq_len, + const int head_num) { + int batch_id = blockIdx.x / (head_num * seq_len); + int dst_offset = blockIdx.x * seq_len; + if (threadIdx.x < seq_len) { + dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len]; + } +} + int QkvToContextPluginDynamic::enqueue( const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs, @@ -258,7 +273,21 @@ int QkvToContextPluginDynamic::enqueue( auto *tptr = multihead_temp_data + scratch_size; const float *input0_data = static_cast(inputs[0]); - const float *input1_data = static_cast(inputs[1]); + // fit to [batch, head_num, length, length] + [batch, 1, 1, length] + framework::Tensor temp_qk_bias_tensor; + float *qk_bias = const_cast(static_cast(inputs[1])); + if (ProductDim(input_desc[1].dims) == (batch * seq_len)) { + temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len}); + auto *temp_qk_bias = temp_qk_bias_tensor.mutable_data( + platform::CUDAPlace(device_id)); + int grid = batch * head_number_ * seq_len; + int block = round_up(seq_len); + broadcast<<>>( + static_cast(inputs[1]), temp_qk_bias, seq_len, + head_number_); + qk_bias = temp_qk_bias; + } + const float *input1_data = static_cast(qk_bias); // BxSx3xNxH => tptr: 3xBxNxSxH. TransposeQKV(batch, seq_len, head_size_, head_number_, input0_data, tptr, stream); @@ -290,7 +319,22 @@ int QkvToContextPluginDynamic::enqueue( half *tptr = qkptr + scratch_size; const half *input0_data = static_cast(inputs[0]); - const half *input1_data = static_cast(inputs[1]); + // fit to [batch, head_num, length, length] + [batch, 1, 1, length] + framework::Tensor temp_qk_bias_tensor; + half *qk_bias = const_cast(static_cast(inputs[1])); + if (ProductDim(input_desc[1].dims) == (batch * seq_len)) { + temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len}); + auto *temp_qk_bias = + reinterpret_cast(temp_qk_bias_tensor.mutable_data( + platform::CUDAPlace(device_id))); + int grid = batch * head_number_ * seq_len; + int block = round_up(seq_len); + broadcast<<>>( + static_cast(inputs[1]), temp_qk_bias, seq_len, + head_number_); + qk_bias = temp_qk_bias; + } + const half *input1_data = static_cast(qk_bias); // BxSx3xNxH => tptr: 3xBxNxSxH. TransposeQKV(batch, seq_len, head_size_, head_number_, input0_data, tptr, stream); diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index c19e621b18fa7c..2c9ef35638f51f 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -132,6 +132,21 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size, } } +inline int round_up(int seq_len, int multiple = 32) { + assert(multiple); + return ((seq_len + multiple - 1) / multiple) * multiple; +} + +template +__global__ void broadcast(const T *src, T *dst, const int seq_len, + const int head_num) { + int batch_id = blockIdx.x / (head_num * seq_len); + int dst_offset = blockIdx.x * seq_len; + if (threadIdx.x < seq_len) { + dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len]; + } +} + template class MultiHeadMatMulV2Kernel : public framework::OpKernel { public: @@ -152,6 +167,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { int head_number = context.Attr("head_number"); // compute q*k with eltadd auto &device_ctx = context.template device_context(); + auto stream = device_ctx.stream(); // should be (B * S * hidden) auto input_dims = input->dims(); // shouble be (hidden * 3 * all_head_size) @@ -159,7 +175,17 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { int batch = input_dims[0]; int seq_len = input_dims[1]; int hidden = input_dims[2]; - + Tensor temp_bias_tensor; + // if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted + if (bias_qk.numel() == (batch * seq_len)) { + temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len}); + auto *temp_qk_bias = temp_bias_tensor.mutable_data(context.GetPlace()); + int grid = batch * head_number * seq_len; + int block = round_up(seq_len); + broadcast<<>>(bias_qk_d, temp_qk_bias, seq_len, + head_number); + bias_qk_d = static_cast(temp_qk_bias); + } int all_head_size = w_dims[2]; int head_size = all_head_size / head_number; @@ -196,7 +222,6 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { auto *qkptr = multihead_temp_data; auto *tptr = multihead_temp_data + scratch_size; - auto stream = device_ctx.stream(); // Do the transpose with bias. // BxSx3xNxH => tptr: 3xBxNxSxH. TransQKVWithBias(batch, seq_len, head_size, head_number, temp_out_data, From 8c39b995cfdf557d69fbec4f6ace50e470572c76 Mon Sep 17 00:00:00 2001 From: fengshuai03 Date: Thu, 16 Sep 2021 06:44:50 +0000 Subject: [PATCH 2/2] use PADDLE_ENFORCE_GT to replace assert --- .../fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu | 5 ++++- paddle/fluid/operators/fused/multihead_matmul_op.cu | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index c097d8afa147c3..6bae3606afe0ef 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -234,7 +234,10 @@ __global__ void apply_scale(T *data, T scale, int n) { } inline int round_up(int seq_len, int multiple = 32) { - assert(multiple); + PADDLE_ENFORCE_GT( + multiple, 0, + platform::errors::InvalidArgument( + "multiple should be a positive number,but it's (%d)", multiple)); return ((seq_len + multiple - 1) / multiple) * multiple; } diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index 2c9ef35638f51f..69056189ac2218 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -133,7 +133,10 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size, } inline int round_up(int seq_len, int multiple = 32) { - assert(multiple); + PADDLE_ENFORCE_GT( + multiple, 0, + platform::errors::InvalidArgument( + "multiple should be a positive number,but it's (%d)", multiple)); return ((seq_len + multiple - 1) / multiple) * multiple; }