Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions paddle/fluid/inference/analysis/argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,15 @@ struct Argument {
DECL_ARGUMENT_FIELD(tensorrt_disabled_ops,
TensorRtDisabledOPs,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(trt_parameter_run_fp16,
TRTParameterRunFp16,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(trt_parameter_run_int8,
TRTParameterRunInt8,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(trt_parameter_run_bfp16,
TRTParameterRunBfp16,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode, int);
DECL_ARGUMENT_FIELD(tensorrt_use_static_engine,
TensorRtUseStaticEngine,
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/inference/analysis/ir_pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set(
"trt_exclude_var_names",
new std::vector<std::string>(argument->trt_exclude_var_names()));
pass->Set(
"trt_parameter_run_fp16",
new std::vector<std::string>(argument->trt_parameter_run_fp16()));
pass->Set(
"trt_parameter_run_int8",
new std::vector<std::string>(argument->trt_parameter_run_int8()));
pass->Set(
"trt_parameter_run_bfp16",
new std::vector<std::string>(argument->trt_parameter_run_bfp16()));
pass->Set("forbid_dynamic_op",
new bool(argument->trt_forbid_dynamic_op()));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
// limitations under the License.

#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"

#include <fcntl.h>
#include <cstddef>
#include <memory>
Expand Down Expand Up @@ -476,9 +475,47 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
}
auto precision_mode =
static_cast<phi::DataType>(Get<int>("trt_precision_mode"));
auto trt_params_run_fp16 =
Get<std::vector<std::string>>("trt_parameter_run_fp16");
auto trt_params_run_int8 =
Get<std::vector<std::string>>("trt_parameter_run_int8");
auto trt_params_run_bfp16 =
Get<std::vector<std::string>>("trt_parameter_run_bfp16");

for (const auto &para : parameters) {
if (std::find(trt_params_run_fp16.begin(),
trt_params_run_fp16.end(),
para) != trt_params_run_fp16.end()) {
precision_mode = phi::DataType::FLOAT16;
break;
}
}

bool enable_fp16 = false;
if (precision_mode == phi::DataType::FLOAT16) enable_fp16 = true;
auto enable_int8 = Get<bool>("enable_int8");

for (const auto &para : parameters) {
if (std::find(trt_params_run_int8.begin(),
trt_params_run_int8.end(),
para) != trt_params_run_int8.end()) {
enable_int8 = true;
precision_mode = phi::DataType::INT8;
break;
}
}

for (const auto &para : parameters) {
if (std::find(trt_params_run_bfp16.begin(),
trt_params_run_bfp16.end(),
para) != trt_params_run_bfp16.end()) {
precision_mode = phi::DataType::BFLOAT16;
break;
}
}
bool enable_bfp16 = false;
if (precision_mode == phi::DataType::BFLOAT16) enable_bfp16 = true;

auto use_calib_mode = Get<bool>("use_calib_mode");
auto &subgraph_nodes = *framework::ir::Agent(node).subgraph();
auto min_input_shape =
Expand Down Expand Up @@ -724,6 +761,7 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
op_desc->SetAttr("calibration_data", calibration_data);
op_desc->SetAttr("enable_int8", enable_int8);
op_desc->SetAttr("enable_fp16", enable_fp16);
op_desc->SetAttr("enbale_bfp16", enable_bfp16);
op_desc->SetAttr("use_calib_mode", use_calib_mode);
op_desc->SetAttr("engine_key", engine_key);
op_desc->SetAttr("calibration_engine_key", calibration_engine_key);
Expand Down
24 changes: 24 additions & 0 deletions paddle/fluid/inference/api/analysis_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(tensorrt_min_subgraph_size_);
CP_MEMBER(tensorrt_precision_mode_);
CP_MEMBER(trt_mark_output_);
CP_MEMBER(trt_parameters_run_fp16_);
CP_MEMBER(trt_parameters_run_int8_);
CP_MEMBER(trt_parameters_run_bfp16_);
CP_MEMBER(trt_forbid_dynamic_op_)
CP_MEMBER(trt_output_tensor_names_);
CP_MEMBER(trt_disabled_ops_);
Expand Down Expand Up @@ -880,6 +883,21 @@ void AnalysisConfig::Exp_DisableTensorRtSubgraph(
var_name_not_trt.end());
}

void AnalysisConfig::Exp_SpecifyTensorRTSubgraphPrecision(
const std::vector<std::string> &trt_parameters_run_fp16,
const std::vector<std::string> &trt_parameters_run_int8,
const std::vector<std::string> &trt_parameters_run_bfp16) {
trt_parameters_run_fp16_.insert(trt_parameters_run_fp16_.end(),
trt_parameters_run_fp16.begin(),
trt_parameters_run_fp16.end());
trt_parameters_run_int8_.insert(trt_parameters_run_int8_.end(),
trt_parameters_run_int8.begin(),
trt_parameters_run_int8.end());
trt_parameters_run_bfp16_.insert(trt_parameters_run_bfp16_.end(),
trt_parameters_run_bfp16.begin(),
trt_parameters_run_bfp16.end());
}

void AnalysisConfig::EnableVarseqlen() { trt_use_varseqlen_ = true; }

void AnalysisConfig::SetTensorRtOptimizationLevel(int level) {
Expand Down Expand Up @@ -1135,6 +1153,12 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << tensorrt_max_batchsize_;
ss << tensorrt_min_subgraph_size_;
ss << trt_mark_output_;
for (auto &name : trt_parameters_run_fp16_) ss << name.c_str();
ss << ";";
for (auto &name : trt_parameters_run_int8_) ss << name.c_str();
ss << ";";
for (auto &name : trt_parameters_run_bfp16_) ss << name.c_str();
ss << ";";
ss << trt_forbid_dynamic_op_;

ss << use_dlnne_;
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,9 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_);
argument_->SetTRTMarkOutput(config_.trt_mark_output_);
argument_->SetTRTOutputTensorNames(config_.trt_output_tensor_names_);
argument_->SetTRTParameterRunFp16(config_.trt_parameters_run_fp16_);
argument_->SetTRTParameterRunInt8(config_.trt_parameters_run_int8_);
argument_->SetTRTParameterRunBfp16(config_.trt_parameters_run_bfp16_);
argument_->SetTensorRtDisabledOPs(config_.trt_disabled_ops_);
argument_->SetTRTExcludeVarNames(config_.trt_exclude_var_names_);
argument_->SetTRTForbidDynamicOp(config_.trt_forbid_dynamic_op_);
Expand Down
22 changes: 22 additions & 0 deletions paddle/fluid/inference/api/paddle_analysis_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -810,9 +810,27 @@ struct PD_INFER_DECL AnalysisConfig {
///
void Exp_DisableTensorRtOPs(const std::vector<std::string>& ops);

///
/// \brief Prevent TensorRtSubgraph running in Paddle-TRT
/// NOTE: just experimental, not an official stable API, easy to be broken.
///
void Exp_DisableTensorRtSubgraph(
const std::vector<std::string>& var_name_not_trt);

///
/// \brief Specify TensorRT subgraph precision,fp16, int8 or bfp16(TensorRT
/// Version>=9.0) NOTE: just experimental, not an official stable API, easy to
/// be broken.
///
void Exp_SpecifyTensorRTSubgraphPrecision(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR把这几个api都加上注释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这最后一个PR了,我改下吧

const std::vector<std::string>& trt_parameters_fp16,
const std::vector<std::string>& trt_parameters_int8,
const std::vector<std::string>& trt_parameters_bfp16);

///
/// \brief Prevent DynamicShape OPs running in Paddle-TRT
/// NOTE: just experimental, not an official stable API, easy to be broken.
///
void Exp_DisableTensorRTDynamicShapeOPs(bool trt_forbid_dynamic_op);

///
Expand Down Expand Up @@ -1289,6 +1307,10 @@ struct PD_INFER_DECL AnalysisConfig {

std::vector<std::string> trt_output_tensor_names_{};
std::vector<std::string> trt_exclude_var_names_{};
std::vector<std::string> trt_parameters_run_fp16_{};
std::vector<std::string> trt_parameters_run_int8_{};
std::vector<std::string> trt_parameters_run_bfp16_{};

std::string tensorrt_transformer_posid_{""};
std::string tensorrt_transformer_maskid_{""};
bool trt_use_dla_{false};
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/inference_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,8 @@ void BindAnalysisConfig(py::module *m) {
.def("exp_disable_tensorrt_ops", &AnalysisConfig::Exp_DisableTensorRtOPs)
.def("exp_disable_tensorrt_subgraph",
&AnalysisConfig::Exp_DisableTensorRtSubgraph)
.def("exp_specify_tensorrt_subgraph_precision",
&AnalysisConfig::Exp_SpecifyTensorRTSubgraphPrecision)
.def("exp_disable_tensorrt_dynamic_shape_ops",
&AnalysisConfig::Exp_DisableTensorRTDynamicShapeOPs)
.def("enable_tensorrt_dla",
Expand Down
144 changes: 144 additions & 0 deletions test/ir/inference/test_trt_ops_fp16_mix_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import unittest

import numpy as np

import paddle
from paddle import nn, static
from paddle.inference import Config, PrecisionType, create_predictor

paddle.enable_static()


class SimpleNet(nn.Layer):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2D(
in_channels=4,
out_channels=4,
kernel_size=3,
stride=2,
padding=0,
)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2D(
in_channels=4,
out_channels=2,
kernel_size=3,
stride=2,
padding=0,
)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2D(
in_channels=2,
out_channels=1,
kernel_size=3,
stride=2,
padding=0,
)
self.relu3 = nn.ReLU()
self.flatten = nn.Flatten()
self.fc = nn.Linear(729, 10)
self.softmax = nn.Softmax()

def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.relu3(x)
x = self.flatten(x)
x = self.fc(x)
x = self.softmax(x)
return x


class TestTRTOptimizationLevel(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.temp_dir = tempfile.TemporaryDirectory()
self.path = os.path.join(self.temp_dir.name, 'optimization_level', '')
self.model_prefix = self.path + 'infer_model'

def tearDown(self):
shutil.rmtree(self.path)

def build_model(self):
image = static.data(
name='img', shape=[None, 4, 224, 224], dtype='float32'
)
predict = SimpleNet()(image)
exe = paddle.static.Executor(self.place)
exe.run(paddle.static.default_startup_program())
paddle.static.save_inference_model(
self.model_prefix, [image], [predict], exe
)

def init_predictor(self):
config = Config(
self.model_prefix + '.pdmodel', self.model_prefix + '.pdiparams'
)
config.enable_use_gpu(256, 0, PrecisionType.Float32)
config.exp_disable_tensorrt_ops(["relu_1.tmp_0"])
config.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=3,
precision_mode=PrecisionType.Float32,
use_static=False,
use_calib_mode=False,
)

config.exp_specify_tensorrt_subgraph_precision(
["conv2d_1.w_0"], [""], ["conv2d_2.w_0"]
)

config.enable_memory_optim()
# config.disable_glog_info()
config.set_tensorrt_optimization_level(0)
self.assertEqual(config.tensorrt_optimization_level(), 0)
predictor = create_predictor(config)
return predictor

def infer(self, predictor, img):
input_names = predictor.get_input_names()
for i, name in enumerate(input_names):
input_tensor = predictor.get_input_handle(name)
input_tensor.reshape(img[i].shape)
input_tensor.copy_from_cpu(img[i].copy())

predictor.run()
results = []
output_names = predictor.get_output_names()
for i, name in enumerate(output_names):
output_tensor = predictor.get_output_handle(name)
output_data = output_tensor.copy_to_cpu()
results.append(output_data)
return results

def test_optimization_level(self):
self.build_model()
predictor = self.init_predictor()
img = np.ones((1, 4, 224, 224), dtype=np.float32)
results = self.infer(predictor, img=[img])


if __name__ == '__main__':
unittest.main()