Skip to content

Commit 5e74409

Browse files
authored
kNCHW is deprecated, should use kLINEAR (#33777)
1 parent d91352c commit 5e74409

File tree

8 files changed

+10
-10
lines changed

8 files changed

+10
-10
lines changed

paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ bool GeluPlugin::supportsFormat(nvinfer1::DataType type,
4242
if (with_fp16_) {
4343
return ((type == nvinfer1::DataType::kFLOAT ||
4444
type == nvinfer1::DataType::kHALF) &&
45-
(format == nvinfer1::PluginFormat::kNCHW));
45+
(format == nvinfer1::PluginFormat::kLINEAR));
4646
} else {
4747
return ((type == nvinfer1::DataType::kFLOAT) &&
48-
(format == nvinfer1::PluginFormat::kNCHW));
48+
(format == nvinfer1::PluginFormat::kLINEAR));
4949
}
5050
}
5151

paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class InstanceNormPlugin : public PluginTensorRT {
112112
nvinfer1::PluginFormat format) const override {
113113
return ((type == nvinfer1::DataType::kFLOAT ||
114114
type == nvinfer1::DataType::kHALF) &&
115-
(format == nvinfer1::PluginFormat::kNCHW));
115+
(format == nvinfer1::PluginFormat::kLINEAR));
116116
}
117117
};
118118

paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ bool PoolPluginDynamic::supportsFormatCombination(
174174
(in_out && pos < (nb_inputs + nb_outputs));
175175

176176
return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
177-
in_out[pos].format == nvinfer1::PluginFormat::kNCHW);
177+
in_out[pos].format == nvinfer1::PluginFormat::kLINEAR);
178178
}
179179

180180
nvinfer1::DataType PoolPluginDynamic::getOutputDataType(

paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ bool PReluPluginDynamic::supportsFormatCombination(
129129
(in_out && pos < (nb_inputs + nb_outputs));
130130

131131
return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
132-
in_out[pos].format == nvinfer1::PluginFormat::kNCHW);
132+
in_out[pos].format == nvinfer1::PluginFormat::kLINEAR);
133133
}
134134

135135
nvinfer1::DataType PReluPluginDynamic::getOutputDataType(

paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ bool SlicePlugin::supportsFormat(nvinfer1::DataType type,
9090
if (with_fp16_) {
9191
return ((type == nvinfer1::DataType::kFLOAT ||
9292
type == nvinfer1::DataType::kHALF) &&
93-
(format == nvinfer1::PluginFormat::kNCHW));
93+
(format == nvinfer1::PluginFormat::kLINEAR));
9494
} else {
9595
return ((type == nvinfer1::DataType::kFLOAT) &&
96-
(format == nvinfer1::PluginFormat::kNCHW));
96+
(format == nvinfer1::PluginFormat::kLINEAR));
9797
}
9898
}
9999

paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ TEST(split_op_plugin, test_plugin) {
3333
input_dims.push_back(in_dims);
3434
sp_plugin.configurePlugin(input_dims.data(), 1, nullptr, 2,
3535
input_types.data(), nullptr, nullptr, nullptr,
36-
nvinfer1::PluginFormat::kNCHW, 4);
36+
nvinfer1::PluginFormat::kLINEAR, 4);
3737
sp_plugin.initialize();
3838
sp_plugin.getPluginType();
3939
sp_plugin.canBroadcastInputAcrossBatch(0);

paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ size_t PluginTensorRT::getBaseSerializationSize() {
6868
bool PluginTensorRT::supportsFormat(nvinfer1::DataType type,
6969
nvinfer1::PluginFormat format) const {
7070
return ((type == nvinfer1::DataType::kFLOAT) &&
71-
(format == nvinfer1::PluginFormat::kNCHW));
71+
(format == nvinfer1::PluginFormat::kLINEAR));
7272
}
7373

7474
void PluginTensorRT::configureWithFormat(

paddle/fluid/inference/tensorrt/plugin/trt_plugin.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class PluginTensorRTV2Ext : public nvinfer1::IPluginV2Ext {
181181
bool supportsFormat(nvinfer1::DataType type,
182182
nvinfer1::PluginFormat format) const override {
183183
return ((type == nvinfer1::DataType::kFLOAT) &&
184-
(format == nvinfer1::PluginFormat::kNCHW));
184+
(format == nvinfer1::PluginFormat::kLINEAR));
185185
}
186186
// Initialize the layer for execution.
187187
// This is called when the engine is created.

0 commit comments

Comments
 (0)