Skip to content

Commit 0438b60

Browse files
authored
[Paddle-TRT] upgrade test_tensorrt to trt8 (#34294)
* upgrade test_tensorrt to trt8 * format
1 parent 6fc33a0 commit 0438b60

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

paddle/fluid/inference/tensorrt/test_tensorrt.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ limitations under the License. */
1616
#include <glog/logging.h>
1717
#include <gtest/gtest.h>
1818
#include "NvInfer.h"
19+
#include "paddle/fluid/inference/tensorrt/helper.h"
1920
#include "paddle/fluid/platform/dynload/tensorrt.h"
2021

2122
namespace dy = paddle::platform::dynload;
2223

2324
class Logger : public nvinfer1::ILogger {
2425
public:
25-
void log(nvinfer1::ILogger::Severity severity, const char* msg) override {
26+
void log(nvinfer1::ILogger::Severity severity,
27+
const char* msg) TRT_NOEXCEPT override {
2628
switch (severity) {
2729
case Severity::kINFO:
2830
LOG(INFO) << msg;
@@ -74,10 +76,11 @@ nvinfer1::IHostMemory* CreateNetwork() {
7476
Logger logger;
7577
// Create the engine.
7678
nvinfer1::IBuilder* builder = createInferBuilder(&logger);
79+
auto config = builder->createBuilderConfig();
7780
ScopedWeights weights(2.);
7881
ScopedWeights bias(3.);
7982

80-
nvinfer1::INetworkDefinition* network = builder->createNetwork();
83+
nvinfer1::INetworkDefinition* network = builder->createNetworkV2(0U);
8184
// Add the input
8285
auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT,
8386
nvinfer1::Dims3{1, 1, 1});
@@ -91,8 +94,8 @@ nvinfer1::IHostMemory* CreateNetwork() {
9194
network->markOutput(*output);
9295
// Build the engine.
9396
builder->setMaxBatchSize(1);
94-
builder->setMaxWorkspaceSize(1 << 10);
95-
auto engine = builder->buildCudaEngine(*network);
97+
config->setMaxWorkspaceSize(1 << 10);
98+
auto engine = builder->buildEngineWithConfig(*network, *config);
9699
EXPECT_NE(engine, nullptr);
97100
// Serialize the engine to create a model, then close.
98101
nvinfer1::IHostMemory* model = engine->serialize();

0 commit comments

Comments
 (0)