Skip to content

Commit 2778d45

Browse files
committed
added detailed error message for unregistered tensorrt_subgrah_pass
1 parent aafb48e commit 2778d45

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,13 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
100100
// compute the channel wise abs max of the weight tensor
101101
int quant_axis =
102102
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis"));
103+
104+
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
105+
platform::errors::InvalidArgument(
106+
"'quant_axis' should be 0 or 1, but "
107+
"the received is %d",
108+
quant_axis));
109+
103110
const int64_t channel = w_dims[quant_axis];
104111
weight_scale.resize(channel, 0);
105112
if (quant_axis == 0) {

paddle/fluid/framework/ir/pass.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,19 @@ class PassRegistry {
206206
}
207207

208208
std::unique_ptr<Pass> Get(const std::string &pass_type) const {
209-
PADDLE_ENFORCE_EQ(Has(pass_type), true,
210-
platform::errors::InvalidArgument(
211-
"Pass %s has not been registered.", pass_type));
209+
if (pass_type == "tensorrt_subgraph_pass") {
210+
PADDLE_ENFORCE_EQ(Has(pass_type), true,
211+
platform::errors::InvalidArgument(
212+
"Pass %s has not been registered. Please "
213+
"use the paddle inference library "
214+
"compiled with tensorrt or disable "
215+
"the tensorrt engine in inference configuration! ",
216+
pass_type));
217+
} else {
218+
PADDLE_ENFORCE_EQ(Has(pass_type), true,
219+
platform::errors::InvalidArgument(
220+
"Pass %s has not been registered.", pass_type));
221+
}
212222
return map_.at(pass_type)();
213223
}
214224

0 commit comments

Comments
 (0)