-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add INT8 support for fused_multi_transformer_op #45284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add INT8 support for fused_multi_transformer_op #45284
Conversation
…hardWooSJTU/Paddle into fused_multi_transformrt_int8 merge fuse kernel
|
你的PR提交成功,感谢你对开源项目的贡献! |
| } | ||
|
|
||
| template <typename T> | ||
| void quantize_kernelLauncher(const T* input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quantize_kernel_launcher
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DONE
| tmp.w = | ||
| __float2int_rn(static_cast<float>(input[m_id * n + n_id + 3]) * scale); | ||
| output[(m_id * n + n_id) >> 2] = tmp; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m,n较大时一个线程只计算4个可能效率会低,最好这里考虑通用性
| } | ||
|
|
||
| template <typename T> | ||
| void quantize_kernelLauncher(const T* input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名不规范
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DONE
| dim3 block(32, 32); | ||
|
|
||
| quantize_kernel<<<grid, block, 0, stream>>>( | ||
| input, (char4*)output, scale, m, n); // NOLINT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上m、n较大时这里不高效
| if (check) { | ||
| float out_scale = quant_out_scale_data[layer_offset + m_id]; | ||
| output[n_id * m + m_id] = | ||
| static_cast<T>(static_cast<float>(input[n_id * m + m_id]) * out_scale); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
| const int m, // hidden | ||
| const int n, // batch size | ||
| const float* quant_out_scale_data, | ||
| const int layer_offset) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
layer_offset命名不直观
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename to: quant_out_scale_offset
| hidden_units, | ||
| batch_size, | ||
| quant_out_scale_data, | ||
| layer_offset); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
问题同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DONE
| auto helper = std::make_shared<CublasLtHelper>(m, k, n); | ||
| helpers_.emplace_back(helper); | ||
| } | ||
| ~AttnMatmulINT8() {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里命名INT8的话,上面Q命名也改成INT8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DONE.
As for the fused layernorm-quantization kernel, I am still trying to git rid of the redundant code.
|
|
||
| void ComputeForward( | ||
| const framework::Tensor* | ||
| weight, // [int8] which has been transformed in pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里格式有点乱
review了部分,后续清理后在review~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed some useless comments.
…hardWooSJTU/Paddle into fused_multi_transformrt_int8
…hardWooSJTU/Paddle into fused_multi_transformrt_int8 merge minghao
fix error
0c59ac2 to
2a967fd
Compare
…ync_params_among_devices pass
| namespace operators { | ||
|
|
||
| template <typename T> | ||
| __forceinline__ __device__ int8_t clip_round(const T input, const float scale) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么不直接叫"quant"呢?clip_round并不能完整代表该方法的功能
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改为quant_helper
| float quant_value = 127.0f * (1.0f / scale) * static_cast<float>(input); | ||
| quant_value = static_cast<float>(round(quant_value)); | ||
| quant_value = quant_value > 127.0f ? 127.0f : quant_value; | ||
| quant_value = quant_value < -127.0f ? -127.0f : quant_value; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里与fake_quant_abs_max_op对齐,确认是-127
参考:
| } else { |
同时为了部分解耦,clip不直接hard code 127.0f而是使用max_bound/min_bound参数,该参数为op的属性
| const float quant_in_scale, | ||
| const float* quant_out_scale_data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dequant理论上乘一个dequant_scale就行了,其中,dequant_scale = intput_scale * weight_scale
quant_in_scale和quant_out_scale各是什么意思?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quant_in_scale 为op 属性中的input scale,与PTQ导出的scale意义相同
quant_out_scale现改名为dequant_out_scale,为op输入中的output scale,定义与fake_dequant_range_abs_max op中的max_range属性意义相同。
op的属性、输入的定义在fused_multi_transformer_int8_op.cc中进行了文字说明。
| float out_scale = quant_out_scale_data[quant_out_scale_offset + m_id]; | ||
| output[n_id * m + m_id] = | ||
| static_cast<T>(static_cast<float>(input[n_id * m + m_id]) * | ||
| quant_in_scale / out_scale); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看起来是dequant+quant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了与fake_dequant_abs_max op对齐以提高精度的暂时的操作,参考:
| out[i] = in[i] * scale[0] / max_range; |
| const int hidden_units, // n | ||
| cudaStream_t stream, | ||
| const float quant_in_scale, | ||
| const float* quant_out_scale_data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么quant_out_scale_data会有多个数呢?channel-wise dequant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是channel-wise dequant,也可以兼容layer-wise dequant,每个channel值相同即可
| LayerNormParamType<T>* var_data, | ||
| const float* quant_out_scale_data = nullptr, | ||
| const int quant_out_scale_offset = 0, | ||
| const float quant_in_scale_data = 1.0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果不是指针,直接命名为quant_in_scale?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
| auto ffn2_in_scale = ctx.Attr<std::vector<float>>("ffn2_in_scale"); | ||
|
|
||
| // output scales, tensor, size = [num_layers, n], n is gemm output size | ||
| auto *qkv_out_scale = ctx.Input<Tensor>("QKVOutScale"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个scale是1/weights_scale么?
- 在量化旧格式中,将weight scale存在了out scale中,这么做不太合理
- 不应该将格式的约束带入推理实现中,换句话说,推理的实现应该独立于量化模型格式。
- 应该通过pass来解耦量化格式和推理实现。在Pass中,拿到dequant需要的所有信息,并计算出dequant scales。只需要向推理Deqaunt Operator传递dequant scales。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已沟通
| const float quant_in_scale_data, | ||
| const framework::Tensor* quant_out_scale, | ||
| const int quant_out_scale_offset) { | ||
| int m = m_, k = k_, n = n_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的m n k好像没有被用到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
| const int8_t* B_dev, | ||
| int32_t* C_dev, | ||
| cudaStream_t stream) { | ||
| // PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确认下是否需要,不需要的话可以删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
…hardWooSJTU/Paddle into fused_multi_transformrt_int8
Aurelius84
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for data registeration
XieYunshen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
单测时间设置
XiaoguangHu01
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Co-authored-by: RichardWooSJTU <[email protected]>
PR types
New features
PR changes
OPs
Describe
Add fused_multi_transformer_int8 op to support quantization inference without trt. The reason of quantization inference using native inference instead of trt is two tensorRt engines introduced by while op cannot share weights which cause double GPU memory. With native inference, we can manage weights flexibly, but the inference performance is slightly inferior compared with trt. To gain better performance, we made the following attempts:
a. We define the above 5 arguements with default values, which means we didn't need to modify the existed references.
b. The above changes only apply to dropout-rate != 1.0 and the 3 classes.
c. Only pre-layernorm is fully tested.
Some limitations by now:
a. Batched GEMM are not quantized.
b. quant/dequant is explictly called in the QKV GEMM. Fusion into prev/after might be useful.