@@ -275,6 +275,7 @@ def generate_weight():
275275 self .outputs = outputs
276276 self .input_type = input_type
277277 self .no_cast_list = [] if no_cast_list is None else no_cast_list
278+ self .supported_cast_type = [np .float32 , np .float16 ]
278279
279280 def __repr__ (self ):
280281 log_str = ''
@@ -292,11 +293,9 @@ def __repr__(self):
292293 return log_str
293294
294295 def set_input_type (self , _type : np .dtype ) -> None :
295- assert _type in [
296- np .float32 ,
297- np .float16 ,
298- None ,
299- ], "PaddleTRT only supports FP32 / FP16 IO"
296+ assert (
297+ _type in self .supported_cast_type or _type is None
298+ ), "PaddleTRT only supports FP32 / FP16 IO"
300299
301300 ver = paddle .inference .get_trt_compile_version ()
302301 trt_version = ver [0 ] * 1000 + ver [1 ] * 100 + ver [2 ] * 10
@@ -309,15 +308,14 @@ def set_input_type(self, _type: np.dtype) -> None:
309308 def get_feed_data (self ) -> Dict [str , Dict [str , Any ]]:
310309 feed_data = {}
311310 for name , tensor_config in self .inputs .items ():
312- do_casting = (
313- self .input_type is not None and name not in self .no_cast_list
314- )
311+ data = tensor_config .data
315312 # Cast to target input_type
316- data = (
317- tensor_config .data .astype (self .input_type )
318- if do_casting
319- else tensor_config .data
320- )
313+ if (
314+ self .input_type is not None
315+ and name not in self .no_cast_list
316+ and data .dtype in self .supported_cast_type
317+ ):
318+ data = data .astype (self .input_type )
321319 # Truncate FP32 tensors to FP16 precision for FP16 test stability
322320 if data .dtype == np .float32 and name not in self .no_cast_list :
323321 data = data .astype (np .float16 ).astype (np .float32 )
@@ -334,10 +332,14 @@ def _cast(self) -> None:
334332 for name , inp in self .inputs .items ():
335333 if name in self .no_cast_list :
336334 continue
335+ if inp .dtype not in self .supported_cast_type :
336+ continue
337337 inp .convert_type_inplace (self .input_type )
338338 for name , weight in self .weights .items ():
339339 if name in self .no_cast_list :
340340 continue
341+ if weight .dtype not in self .supported_cast_type :
342+ continue
341343 weight .convert_type_inplace (self .input_type )
342344 return self
343345
0 commit comments