Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6a84086
add fused_multi_transformer_int8_op and unittest
minghaoBD Aug 18, 2022
139ec0d
fuse kernel
minghaoBD Aug 19, 2022
51315d6
fix build error
RichardWooSJTU Aug 19, 2022
5674acc
Merge branch 'fused_multi_transformrt_int8' of https://github.com/Ric…
RichardWooSJTU Aug 19, 2022
ce00ad9
fix fuse kernel
RichardWooSJTU Aug 19, 2022
d5b00da
Merge branch 'PaddlePaddle:develop' into fused_multi_transformrt_int8
minghaoBD Aug 30, 2022
df2d3a0
rename kernels and skip UT on non-GPU platforms
minghaoBD Sep 5, 2022
30b2c3d
Merge branch 'fused_multi_transformrt_int8' of https://github.com/Ric…
minghaoBD Sep 5, 2022
44e1e46
add layer API for create model
RichardWooSJTU Sep 6, 2022
ef3ef70
Merge branch 'fused_multi_transformrt_int8' of https://github.com/Ric…
RichardWooSJTU Sep 6, 2022
c42512a
clean debug code
RichardWooSJTU Sep 6, 2022
f26e42b
code clean
minghaoBD Sep 8, 2022
cb31f82
Merge pull request #1 from RichardWooSJTU/tmp_branch
minghaoBD Sep 8, 2022
1422b17
Merge branch 'develop' into fused_multi_transformrt_int8
minghaoBD Sep 8, 2022
2a967fd
resolve conflicts and fix UT bugs.
minghaoBD Sep 14, 2022
07fc384
skip unnit test in cpu and skip cast when tensor has been casted in s…
RichardWooSJTU Sep 14, 2022
117f588
modify input_scale and output_scale to align fake quant/dequant op
RichardWooSJTU Sep 14, 2022
4392807
skip windows unittest
RichardWooSJTU Sep 14, 2022
dfa79a3
modify mutable_data to gpucontext.Alloc
RichardWooSJTU Sep 14, 2022
9323569
fix CI-ROCM error and add quantization argument description
RichardWooSJTU Sep 15, 2022
a6038ba
add branch to decoupling roundtype and clip type
RichardWooSJTU Sep 15, 2022
9672af1
Delete .python-version
RichardWooSJTU Sep 15, 2022
4794a24
fix dyload error and disable algo select with cuda10.2
RichardWooSJTU Sep 16, 2022
2524c63
Merge branch 'fused_multi_transformrt_int8' of https://github.com/Ric…
RichardWooSJTU Sep 16, 2022
d134702
fix ci problem
RichardWooSJTU Sep 16, 2022
991fff6
fix unittest timeout setting in ci-rocm
RichardWooSJTU Sep 16, 2022
c1c22c6
delete api related codes in fp16utils
RichardWooSJTU Sep 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
auto var_data_type = var_node->Var()->GetDataType();
VLOG(5) << "var_name is " << var_name << ", data type is "
<< var_data_type;
if (var_data_type == paddle::framework::proto::VarType::FP16) {
if (var_data_type == paddle::framework::proto::VarType::FP16 &&
t->dtype() != paddle::experimental::DataType::FLOAT16) {
framework::Tensor half_tensor;
half_tensor.set_type(paddle::experimental::DataType::FLOAT16);
half_tensor.Resize(t->dims());
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ register_operators(
fused_transformer_op
fused_feedforward_op
fused_multi_transformer_op
fused_multi_transformer_int8_op
fused_bias_dropout_residual_layer_norm_op
resnet_unit_op
fused_gemm_epilogue_op
Expand Down Expand Up @@ -118,6 +119,7 @@ if(WITH_GPU OR WITH_ROCM)
# fused_attention_op
op_library(fused_attention_op)
op_library(fused_multi_transformer_op)
op_library(fused_multi_transformer_int8_op)
op_library(fused_bias_dropout_residual_layer_norm_op)
endif()
# resnet_unit needs cudnn 8.0 above
Expand Down
30 changes: 24 additions & 6 deletions paddle/fluid/operators/fused/attention_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ limitations under the License. */
namespace paddle {
namespace operators {

template <typename T>
// NOTE: T must be the same as OutType in ComputeBackward
template <typename T, typename InType = T, typename OutType = T>
class AttnLayerNorm {
public:
AttnLayerNorm(const phi::GPUContext& dev_ctx,
Expand All @@ -33,25 +34,42 @@ class AttnLayerNorm {

~AttnLayerNorm() {}

void ComputeForward(const T* x_data,
void ComputeForward(const InType* x_data,
const LayerNormParamType<T>* scale_data,
const LayerNormParamType<T>* bias_data,
T* y_data,
OutType* y_data,
LayerNormParamType<T>* mean_data,
LayerNormParamType<T>* var_data) {
LayerNormParamType<T>* var_data,
const float* dequant_out_scale_data = nullptr,
const int quant_out_scale_offset = 0,
const float quant_in_scale = 1.0,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
auto stream = dev_ctx_.stream();

switch (GetDesiredBlockDim(feature_size_)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, LayerNormParamType<T>, kBlockDim>
LayerNormForward<T,
LayerNormParamType<T>,
kBlockDim,
false,
InType,
OutType>
<<<batch_size_, kBlockDim, 0, stream>>>(x_data,
scale_data,
bias_data,
y_data,
mean_data,
var_data,
epsilon_,
feature_size_));
feature_size_,
dequant_out_scale_data,
quant_out_scale_offset,
quant_in_scale,
quant_round_type,
quant_max_bound,
quant_min_bound));
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Feature_size must be larger than 1"));
Expand Down
189 changes: 189 additions & 0 deletions paddle/fluid/operators/fused/attn_gemm_int8.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <iostream>
#include <vector>
#include "paddle/fluid/operators/fused/cublaslt.h"
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
class AttnMatmulINT8 {
public:
AttnMatmulINT8(
const phi::GPUContext& dev_ctx, int m, int n, int k, bool compute_bias)
: dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) {
auto helper = std::make_shared<CublasLtHelper>(m, k, n);
helpers_.emplace_back(helper);
}
~AttnMatmulINT8() {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里命名INT8的话,上面Q命名也改成INT8

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DONE.

As for the fused layernorm-quantization kernel, I am still trying to git rid of the redundant code.


// This function is used to execute GEMM, with input and output's types are
// both T.
void ComputeForward(const framework::Tensor* weight,
const framework::Tensor* input,
framework::Tensor* input_tmp,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* output_tmp,
framework::Tensor* bias_out,
const float quant_in_scale,
const framework::Tensor* dequant_out_scale,
const int quant_out_scale_offset,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
quantize_kernel_launcher<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());

helpers_[0]->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
dev_ctx_.stream());

dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
quant_in_scale,
dequant_out_scale->data<float>(),
quant_out_scale_offset);

if (compute_bias_) {
// bias_out = output + bias
std::vector<const framework::Tensor*> ins = {output, bias};
std::vector<framework::Tensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess,
platform::errors::Fatal(
"cuda error occured after computing bias. "
"But it does not mean this error is caused by "
"bias computing"));
}
}

// This function is used to execute GEMM, with input and output's types are
// both INT8.
void ComputeForwardINT8ToINT8(const framework::Tensor* weight,
framework::Tensor* input,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* bias_out) {
helpers_[0]->GEMM(input->data<int8_t>(),
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx_.stream());
}

// This function is used to execute GEMM, with input and output's types are
// INT8 and T.
void ComputeForwardINT8ToT(const framework::Tensor* weight,
const float quant_in_scale,
framework::Tensor* input,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* output_tmp,
framework::Tensor* bias_out,
const framework::Tensor* dequant_out_scale,
const int quant_out_scale_offset) {
helpers_[0]->GEMM(input->data<int8_t>(),
weight->data<int8_t>(),
output_tmp->data<int32_t>(),
dev_ctx_.stream());

dequantize_kernel_launcher<T>(output_tmp->data<int32_t>(),
output->data<T>(),
m_,
n_,
dev_ctx_.stream(),
quant_in_scale,
dequant_out_scale->data<float>(),
quant_out_scale_offset);

if (compute_bias_) {
// bias_out = output + bias
std::vector<const framework::Tensor*> ins = {output, bias};
std::vector<framework::Tensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess,
platform::errors::Fatal(
"cuda error occured after computing bias. "
"But it does not mean this error is caused by "
"bias computing"));
}
}

// This function is used to execute GEMM, with input and output's types are T
// and INT8.
void ComputeForwardTToINT8(const framework::Tensor* weight,
const float quant_in_scale,
const framework::Tensor* input,
framework::Tensor* input_tmp,
const framework::Tensor* bias,
framework::Tensor* output,
framework::Tensor* bias_out,
const int quant_round_type = 1,
const float quant_max_bound = 127.0,
const float quant_min_bound = -127.0) {
quantize_kernel_launcher<T>(input->data<T>(),
input_tmp->data<int8_t>(),
quant_in_scale,
m_,
k_,
quant_round_type,
quant_max_bound,
quant_min_bound,
dev_ctx_.stream());

helpers_[0]->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx_.stream());
}

private:
const phi::GPUContext& dev_ctx_;

int m_; // m
int n_; // n
int k_; // k

int compute_bias_;
std::vector<std::shared_ptr<CublasLtHelper>> helpers_;
};

} // namespace operators
} // namespace paddle
Loading