Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,24 @@ __global__ void apply_scale(T *data, T scale, int n) {
#endif
}

inline int round_up(int seq_len, int multiple = 32) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两处代码是重复的吗?方便复用吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑过,不过目前就用这两次,放到公共的头文件中发现这个函数和其他函数类型相比有点不伦不类,二者一个是trt,一个是cuda所以目前不太好放,后续如果常用或者有合适的地方会考虑重构一下

PADDLE_ENFORCE_GT(
multiple, 0,
platform::errors::InvalidArgument(
"multiple should be a positive number,but it's (%d)", multiple));
return ((seq_len + multiple - 1) / multiple) * multiple;
}

template <typename T>
__global__ void broadcast(const T *src, T *dst, const int seq_len,
const int head_num) {
int batch_id = blockIdx.x / (head_num * seq_len);
int dst_offset = blockIdx.x * seq_len;
if (threadIdx.x < seq_len) {
dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len];
}
}

int QkvToContextPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
Expand All @@ -258,7 +276,21 @@ int QkvToContextPluginDynamic::enqueue(
auto *tptr = multihead_temp_data + scratch_size;

const float *input0_data = static_cast<const float *>(inputs[0]);
const float *input1_data = static_cast<const float *>(inputs[1]);
// fit to [batch, head_num, length, length] + [batch, 1, 1, length]
framework::Tensor temp_qk_bias_tensor;
float *qk_bias = const_cast<float *>(static_cast<const float *>(inputs[1]));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里const_cast要使用的理由是什么,需要解释下吗,这个输入为什么需要是const void *const *类型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里const_cast要使用的理由是什么,需要解释下吗,这个输入为什么需要是const void *const *类型
这个是由于基类设置的接口的原因,目前没办法,trt这边plugin都是这么写的,具体也和秋良沟通过

if (ProductDim(input_desc[1].dims) == (batch * seq_len)) {
temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len});
auto *temp_qk_bias = temp_qk_bias_tensor.mutable_data<float>(
platform::CUDAPlace(device_id));
int grid = batch * head_number_ * seq_len;
int block = round_up(seq_len);
broadcast<<<grid, block, 0, stream>>>(
static_cast<const float *>(inputs[1]), temp_qk_bias, seq_len,
head_number_);
qk_bias = temp_qk_bias;
}
const float *input1_data = static_cast<const float *>(qk_bias);
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV(batch, seq_len, head_size_, head_number_, input0_data, tptr,
stream);
Expand Down Expand Up @@ -290,7 +322,22 @@ int QkvToContextPluginDynamic::enqueue(
half *tptr = qkptr + scratch_size;

const half *input0_data = static_cast<const half *>(inputs[0]);
const half *input1_data = static_cast<const half *>(inputs[1]);
// fit to [batch, head_num, length, length] + [batch, 1, 1, length]
framework::Tensor temp_qk_bias_tensor;
half *qk_bias = const_cast<half *>(static_cast<const half *>(inputs[1]));
if (ProductDim(input_desc[1].dims) == (batch * seq_len)) {
temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len});
auto *temp_qk_bias =
reinterpret_cast<half *>(temp_qk_bias_tensor.mutable_data<int16_t>(
platform::CUDAPlace(device_id)));
int grid = batch * head_number_ * seq_len;
int block = round_up(seq_len);
broadcast<<<grid, block, 0, stream>>>(
static_cast<const half *>(inputs[1]), temp_qk_bias, seq_len,
head_number_);
qk_bias = temp_qk_bias;
}
const half *input1_data = static_cast<const half *>(qk_bias);
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV(batch, seq_len, head_size_, head_number_, input0_data, tptr,
stream);
Expand Down
32 changes: 30 additions & 2 deletions paddle/fluid/operators/fused/multihead_matmul_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,24 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
}
}

inline int round_up(int seq_len, int multiple = 32) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么不使用驼峰式命名,其他地方也一样
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么不使用驼峰式命名,其他地方也一样
image

我看这个文件里面很多地方都用下划线的方式,为了风格统一就延续了这种风格

PADDLE_ENFORCE_GT(
multiple, 0,
platform::errors::InvalidArgument(
"multiple should be a positive number,but it's (%d)", multiple));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个multiple需要标记一下吗?比如The input argument multiple,这个报错句子直接看语法是错的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个multiple需要标记一下吗?比如The input argument multiple,这个报错句子直接看语法是错的

这个可以修改一下

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

先合入,@fengxiaoshuai,develop提个PR改一下,或者下一个PR带一下。

return ((seq_len + multiple - 1) / multiple) * multiple;
}

template <typename T>
__global__ void broadcast(const T *src, T *dst, const int seq_len,
const int head_num) {
int batch_id = blockIdx.x / (head_num * seq_len);
int dst_offset = blockIdx.x * seq_len;
if (threadIdx.x < seq_len) {
dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len];
}
}

template <typename DeviceContext, typename T>
class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
public:
Expand All @@ -152,14 +170,25 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
int head_number = context.Attr<int>("head_number");
// compute q*k with eltadd
auto &device_ctx = context.template device_context<DeviceContext>();
auto stream = device_ctx.stream();
// should be (B * S * hidden)
auto input_dims = input->dims();
// shouble be (hidden * 3 * all_head_size)
auto w_dims = w->dims();
int batch = input_dims[0];
int seq_len = input_dims[1];
int hidden = input_dims[2];

Tensor temp_bias_tensor;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if (bias_qk.numel() == (batch * seq_len)) {
temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len});
auto *temp_qk_bias = temp_bias_tensor.mutable_data<T>(context.GetPlace());
int grid = batch * head_number * seq_len;
int block = round_up(seq_len);
broadcast<<<grid, block, 0, stream>>>(bias_qk_d, temp_qk_bias, seq_len,
head_number);
bias_qk_d = static_cast<const T *>(temp_qk_bias);
}
int all_head_size = w_dims[2];
int head_size = all_head_size / head_number;

Expand Down Expand Up @@ -196,7 +225,6 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
auto *qkptr = multihead_temp_data;
auto *tptr = multihead_temp_data + scratch_size;

auto stream = device_ctx.stream();
// Do the transpose with bias.
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransQKVWithBias(batch, seq_len, head_size, head_number, temp_out_data,
Expand Down