-
Notifications
You must be signed in to change notification settings - Fork 5.9k
broadcast qkv_op #35780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
broadcast qkv_op #35780
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -233,6 +233,21 @@ __global__ void apply_scale(T *data, T scale, int n) { | |
| #endif | ||
| } | ||
|
|
||
| inline int round_up(int seq_len, int multiple = 32) { | ||
| assert(multiple); | ||
|
||
| return ((seq_len + multiple - 1) / multiple) * multiple; | ||
| } | ||
|
|
||
| template <typename T> | ||
| __global__ void broadcast(const T *src, T *dst, const int seq_len, | ||
| const int head_num) { | ||
| int batch_id = blockIdx.x / (head_num * seq_len); | ||
| int dst_offset = blockIdx.x * seq_len; | ||
| if (threadIdx.x < seq_len) { | ||
| dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len]; | ||
| } | ||
| } | ||
|
|
||
| int QkvToContextPluginDynamic::enqueue( | ||
| const nvinfer1::PluginTensorDesc *input_desc, | ||
| const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs, | ||
|
|
@@ -258,7 +273,21 @@ int QkvToContextPluginDynamic::enqueue( | |
| auto *tptr = multihead_temp_data + scratch_size; | ||
|
|
||
| const float *input0_data = static_cast<const float *>(inputs[0]); | ||
| const float *input1_data = static_cast<const float *>(inputs[1]); | ||
| // fit to [batch, head_num, length, length] + [batch, 1, 1, length] | ||
| framework::Tensor temp_qk_bias_tensor; | ||
| float *qk_bias = const_cast<float *>(static_cast<const float *>(inputs[1])); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里const_cast要使用的理由是什么,需要解释下吗,这个输入为什么需要是
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| if (ProductDim(input_desc[1].dims) == (batch * seq_len)) { | ||
| temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len}); | ||
| auto *temp_qk_bias = temp_qk_bias_tensor.mutable_data<float>( | ||
| platform::CUDAPlace(device_id)); | ||
| int grid = batch * head_number_ * seq_len; | ||
| int block = round_up(seq_len); | ||
| broadcast<<<grid, block, 0, stream>>>( | ||
| static_cast<const float *>(inputs[1]), temp_qk_bias, seq_len, | ||
| head_number_); | ||
| qk_bias = temp_qk_bias; | ||
| } | ||
| const float *input1_data = static_cast<const float *>(qk_bias); | ||
| // BxSx3xNxH => tptr: 3xBxNxSxH. | ||
| TransposeQKV(batch, seq_len, head_size_, head_number_, input0_data, tptr, | ||
| stream); | ||
|
|
@@ -290,7 +319,22 @@ int QkvToContextPluginDynamic::enqueue( | |
| half *tptr = qkptr + scratch_size; | ||
|
|
||
| const half *input0_data = static_cast<const half *>(inputs[0]); | ||
| const half *input1_data = static_cast<const half *>(inputs[1]); | ||
| // fit to [batch, head_num, length, length] + [batch, 1, 1, length] | ||
| framework::Tensor temp_qk_bias_tensor; | ||
| half *qk_bias = const_cast<half *>(static_cast<const half *>(inputs[1])); | ||
| if (ProductDim(input_desc[1].dims) == (batch * seq_len)) { | ||
| temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len}); | ||
| auto *temp_qk_bias = | ||
| reinterpret_cast<half *>(temp_qk_bias_tensor.mutable_data<int16_t>( | ||
| platform::CUDAPlace(device_id))); | ||
| int grid = batch * head_number_ * seq_len; | ||
| int block = round_up(seq_len); | ||
| broadcast<<<grid, block, 0, stream>>>( | ||
| static_cast<const half *>(inputs[1]), temp_qk_bias, seq_len, | ||
| head_number_); | ||
| qk_bias = temp_qk_bias; | ||
| } | ||
| const half *input1_data = static_cast<const half *>(qk_bias); | ||
| // BxSx3xNxH => tptr: 3xBxNxSxH. | ||
| TransposeQKV(batch, seq_len, head_size_, head_number_, input0_data, tptr, | ||
| stream); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -132,6 +132,21 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size, | |
| } | ||
| } | ||
|
|
||
| inline int round_up(int seq_len, int multiple = 32) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| assert(multiple); | ||
| return ((seq_len + multiple - 1) / multiple) * multiple; | ||
| } | ||
|
|
||
| template <typename T> | ||
| __global__ void broadcast(const T *src, T *dst, const int seq_len, | ||
| const int head_num) { | ||
| int batch_id = blockIdx.x / (head_num * seq_len); | ||
| int dst_offset = blockIdx.x * seq_len; | ||
| if (threadIdx.x < seq_len) { | ||
| dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len]; | ||
| } | ||
| } | ||
|
|
||
| template <typename DeviceContext, typename T> | ||
| class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { | ||
| public: | ||
|
|
@@ -152,14 +167,25 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { | |
| int head_number = context.Attr<int>("head_number"); | ||
| // compute q*k with eltadd | ||
| auto &device_ctx = context.template device_context<DeviceContext>(); | ||
| auto stream = device_ctx.stream(); | ||
| // should be (B * S * hidden) | ||
| auto input_dims = input->dims(); | ||
| // shouble be (hidden * 3 * all_head_size) | ||
| auto w_dims = w->dims(); | ||
| int batch = input_dims[0]; | ||
| int seq_len = input_dims[1]; | ||
| int hidden = input_dims[2]; | ||
|
|
||
| Tensor temp_bias_tensor; | ||
| // if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted | ||
| if (bias_qk.numel() == (batch * seq_len)) { | ||
| temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len}); | ||
| auto *temp_qk_bias = temp_bias_tensor.mutable_data<T>(context.GetPlace()); | ||
| int grid = batch * head_number * seq_len; | ||
| int block = round_up(seq_len); | ||
| broadcast<<<grid, block, 0, stream>>>(bias_qk_d, temp_qk_bias, seq_len, | ||
| head_number); | ||
| bias_qk_d = static_cast<const T *>(temp_qk_bias); | ||
| } | ||
| int all_head_size = w_dims[2]; | ||
| int head_size = all_head_size / head_number; | ||
|
|
||
|
|
@@ -196,7 +222,6 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { | |
| auto *qkptr = multihead_temp_data; | ||
| auto *tptr = multihead_temp_data + scratch_size; | ||
|
|
||
| auto stream = device_ctx.stream(); | ||
| // Do the transpose with bias. | ||
| // BxSx3xNxH => tptr: 3xBxNxSxH. | ||
| TransQKVWithBias(batch, seq_len, head_size, head_number, temp_out_data, | ||
|
|
||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两处代码是重复的吗?方便复用吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
考虑过,不过目前就用这两次,放到公共的头文件中发现这个函数和其他函数类型相比有点不伦不类,二者一个是trt,一个是cuda所以目前不太好放,后续如果常用或者有合适的地方会考虑重构一下