-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add INT8 support for fused_multi_transformer_op #45284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
wanghaoshuang
merged 27 commits into
PaddlePaddle:develop
from
RichardWooSJTU:fused_multi_transformrt_int8
Sep 18, 2022
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
6a84086
add fused_multi_transformer_int8_op and unittest
minghaoBD 139ec0d
fuse kernel
minghaoBD 51315d6
fix build error
RichardWooSJTU 5674acc
Merge branch 'fused_multi_transformrt_int8' of https://github.com/Ric…
RichardWooSJTU ce00ad9
fix fuse kernel
RichardWooSJTU d5b00da
Merge branch 'PaddlePaddle:develop' into fused_multi_transformrt_int8
minghaoBD df2d3a0
rename kernels and skip UT on non-GPU platforms
minghaoBD 30b2c3d
Merge branch 'fused_multi_transformrt_int8' of https://github.com/Ric…
minghaoBD 44e1e46
add layer API for create model
RichardWooSJTU ef3ef70
Merge branch 'fused_multi_transformrt_int8' of https://github.com/Ric…
RichardWooSJTU c42512a
clean debug code
RichardWooSJTU f26e42b
code clean
minghaoBD cb31f82
Merge pull request #1 from RichardWooSJTU/tmp_branch
minghaoBD 1422b17
Merge branch 'develop' into fused_multi_transformrt_int8
minghaoBD 2a967fd
resolve conflicts and fix UT bugs.
minghaoBD 07fc384
skip unnit test in cpu and skip cast when tensor has been casted in s…
RichardWooSJTU 117f588
modify input_scale and output_scale to align fake quant/dequant op
RichardWooSJTU 4392807
skip windows unittest
RichardWooSJTU dfa79a3
modify mutable_data to gpucontext.Alloc
RichardWooSJTU 9323569
fix CI-ROCM error and add quantization argument description
RichardWooSJTU a6038ba
add branch to decoupling roundtype and clip type
RichardWooSJTU 9672af1
Delete .python-version
RichardWooSJTU 4794a24
fix dyload error and disable algo select with cuda10.2
RichardWooSJTU 2524c63
Merge branch 'fused_multi_transformrt_int8' of https://github.com/Ric…
RichardWooSJTU d134702
fix ci problem
RichardWooSJTU 991fff6
fix unittest timeout setting in ci-rocm
RichardWooSJTU c1c22c6
delete api related codes in fp16utils
RichardWooSJTU File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,189 @@ | ||
| /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <iostream> | ||
| #include <vector> | ||
| #include "paddle/fluid/operators/fused/cublaslt.h" | ||
| #include "paddle/fluid/operators/fused/quant_dequant_kernel.h" | ||
| #include "paddle/fluid/platform/device/gpu/gpu_info.h" | ||
| #include "paddle/fluid/platform/float16.h" | ||
| #include "paddle/phi/kernels/funcs/broadcast_function.h" | ||
| #include "paddle/phi/kernels/funcs/elementwise_functor.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| using Tensor = framework::Tensor; | ||
|
|
||
| template <typename T> | ||
| class AttnMatmulINT8 { | ||
| public: | ||
| AttnMatmulINT8( | ||
| const phi::GPUContext& dev_ctx, int m, int n, int k, bool compute_bias) | ||
| : dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) { | ||
| auto helper = std::make_shared<CublasLtHelper>(m, k, n); | ||
| helpers_.emplace_back(helper); | ||
| } | ||
| ~AttnMatmulINT8() {} | ||
|
|
||
| // This function is used to execute GEMM, with input and output's types are | ||
| // both T. | ||
| void ComputeForward(const framework::Tensor* weight, | ||
| const framework::Tensor* input, | ||
| framework::Tensor* input_tmp, | ||
| const framework::Tensor* bias, | ||
| framework::Tensor* output, | ||
| framework::Tensor* output_tmp, | ||
| framework::Tensor* bias_out, | ||
| const float quant_in_scale, | ||
| const framework::Tensor* dequant_out_scale, | ||
| const int quant_out_scale_offset, | ||
| const int quant_round_type = 1, | ||
| const float quant_max_bound = 127.0, | ||
| const float quant_min_bound = -127.0) { | ||
| quantize_kernel_launcher<T>(input->data<T>(), | ||
| input_tmp->data<int8_t>(), | ||
| quant_in_scale, | ||
| m_, | ||
| k_, | ||
| quant_round_type, | ||
| quant_max_bound, | ||
| quant_min_bound, | ||
| dev_ctx_.stream()); | ||
|
|
||
| helpers_[0]->GEMM(input_tmp->data<int8_t>(), | ||
| weight->data<int8_t>(), | ||
| output_tmp->data<int32_t>(), | ||
| dev_ctx_.stream()); | ||
|
|
||
| dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(), | ||
| output->data<T>(), | ||
| m_, | ||
| n_, | ||
| dev_ctx_.stream(), | ||
| quant_in_scale, | ||
| dequant_out_scale->data<float>(), | ||
| quant_out_scale_offset); | ||
|
|
||
| if (compute_bias_) { | ||
| // bias_out = output + bias | ||
| std::vector<const framework::Tensor*> ins = {output, bias}; | ||
| std::vector<framework::Tensor*> outs = {bias_out}; | ||
| phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( | ||
| dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>()); | ||
| PADDLE_ENFORCE_EQ(cudaGetLastError(), | ||
| cudaSuccess, | ||
| platform::errors::Fatal( | ||
| "cuda error occured after computing bias. " | ||
| "But it does not mean this error is caused by " | ||
| "bias computing")); | ||
| } | ||
| } | ||
|
|
||
| // This function is used to execute GEMM, with input and output's types are | ||
| // both INT8. | ||
| void ComputeForwardINT8ToINT8(const framework::Tensor* weight, | ||
| framework::Tensor* input, | ||
| const framework::Tensor* bias, | ||
| framework::Tensor* output, | ||
| framework::Tensor* bias_out) { | ||
| helpers_[0]->GEMM(input->data<int8_t>(), | ||
| weight->data<int8_t>(), | ||
| output->data<int32_t>(), | ||
| dev_ctx_.stream()); | ||
| } | ||
|
|
||
| // This function is used to execute GEMM, with input and output's types are | ||
| // INT8 and T. | ||
| void ComputeForwardINT8ToT(const framework::Tensor* weight, | ||
| const float quant_in_scale, | ||
| framework::Tensor* input, | ||
| const framework::Tensor* bias, | ||
| framework::Tensor* output, | ||
| framework::Tensor* output_tmp, | ||
| framework::Tensor* bias_out, | ||
| const framework::Tensor* dequant_out_scale, | ||
| const int quant_out_scale_offset) { | ||
| helpers_[0]->GEMM(input->data<int8_t>(), | ||
| weight->data<int8_t>(), | ||
| output_tmp->data<int32_t>(), | ||
| dev_ctx_.stream()); | ||
|
|
||
| dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(), | ||
| output->data<T>(), | ||
| m_, | ||
| n_, | ||
| dev_ctx_.stream(), | ||
| quant_in_scale, | ||
| dequant_out_scale->data<float>(), | ||
| quant_out_scale_offset); | ||
|
|
||
| if (compute_bias_) { | ||
| // bias_out = output + bias | ||
| std::vector<const framework::Tensor*> ins = {output, bias}; | ||
| std::vector<framework::Tensor*> outs = {bias_out}; | ||
| phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>( | ||
| dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>()); | ||
| PADDLE_ENFORCE_EQ(cudaGetLastError(), | ||
| cudaSuccess, | ||
| platform::errors::Fatal( | ||
| "cuda error occured after computing bias. " | ||
| "But it does not mean this error is caused by " | ||
| "bias computing")); | ||
| } | ||
| } | ||
|
|
||
| // This function is used to execute GEMM, with input and output's types are T | ||
| // and INT8. | ||
| void ComputeForwardTToINT8(const framework::Tensor* weight, | ||
| const float quant_in_scale, | ||
| const framework::Tensor* input, | ||
| framework::Tensor* input_tmp, | ||
| const framework::Tensor* bias, | ||
| framework::Tensor* output, | ||
| framework::Tensor* bias_out, | ||
| const int quant_round_type = 1, | ||
| const float quant_max_bound = 127.0, | ||
| const float quant_min_bound = -127.0) { | ||
| quantize_kernel_launcher<T>(input->data<T>(), | ||
| input_tmp->data<int8_t>(), | ||
| quant_in_scale, | ||
| m_, | ||
| k_, | ||
| quant_round_type, | ||
| quant_max_bound, | ||
| quant_min_bound, | ||
| dev_ctx_.stream()); | ||
|
|
||
| helpers_[0]->GEMM(input_tmp->data<int8_t>(), | ||
| weight->data<int8_t>(), | ||
| output->data<int32_t>(), | ||
| dev_ctx_.stream()); | ||
| } | ||
|
|
||
| private: | ||
| const phi::GPUContext& dev_ctx_; | ||
|
|
||
| int m_; // m | ||
| int n_; // n | ||
| int k_; // k | ||
|
|
||
| int compute_bias_; | ||
| std::vector<std::shared_ptr<CublasLtHelper>> helpers_; | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里命名INT8的话,上面Q命名也改成INT8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DONE.
As for the fused layernorm-quantization kernel, I am still trying to git rid of the redundant code.