@@ -66,9 +66,9 @@ def __init__(
6666 self .dtype = self .data .dtype
6767 self .shape = self .data .shape
6868 else :
69- assert (
70- shape is not None
71- ), "While data_gen is not defined, shape must not be None"
69+ assert shape is not None , (
70+ "While data_gen is not defined, shape must not be None"
71+ )
7272 self .data = np .random .normal (0.0 , 1.0 , shape ).astype (np .float32 )
7373 self .shape = shape
7474 self .dtype = self .data .dtype
@@ -291,9 +291,9 @@ def __repr__(self):
291291 return log_str
292292
293293 def set_input_type (self , _type : np .dtype ) -> None :
294- assert (
295- _type in self . supported_cast_type or _type is None
296- ), "PaddleTRT only supports FP32 / FP16 IO"
294+ assert _type in self . supported_cast_type or _type is None , (
295+ "PaddleTRT only supports FP32 / FP16 IO"
296+ )
297297
298298 ver = paddle .inference .get_trt_compile_version ()
299299 trt_version = ver [0 ] * 1000 + ver [1 ] * 100 + ver [2 ] * 10
@@ -629,9 +629,9 @@ def create_quant_model(
629629
630630 def _get_op_output_var_names (op ):
631631 """ """
632- assert isinstance (
633- op , ( IrNode , Operator )
634- ), "The input op should be IrNode or Operator."
632+ assert isinstance (op , ( IrNode , Operator )), (
633+ "The input op should be IrNode or Operator."
634+ )
635635 var_names = []
636636 op_name = op .name () if isinstance (op , IrNode ) else op .type
637637 if op_name not in op_real_in_out_name :
0 commit comments