@@ -218,15 +218,15 @@ def __init__(
218218 forward_range = None ,
219219 backward_range = None ,
220220 ):
221- assert isinstance (
222- in_out_values , tuple
223- ), "in_out_values must be tuple with len == 3"
224- assert (
225- len ( in_out_values ) == 3
226- ), "in_out_values must be tuple with len == 3"
227- assert isinstance (
228- in_out_values [ 0 ], list
229- ), "in_out_values must be tuple with len == 3"
221+ assert isinstance (in_out_values , tuple ), (
222+ " in_out_values must be tuple with len == 3"
223+ )
224+ assert len ( in_out_values ) == 3 , (
225+ " in_out_values must be tuple with len == 3"
226+ )
227+ assert isinstance (in_out_values [ 0 ], list ), (
228+ " in_out_values must be tuple with len == 3"
229+ )
230230 self .program = program
231231 self .x_names = self .convert_name (in_out_values [0 ])
232232 self .param_names = self .convert_name (in_out_values [1 ])
@@ -310,9 +310,9 @@ def clone(self):
310310 )
311311
312312 def split_forward_backward (self ):
313- assert (
314- self . has_splited is False
315- ), "Please ensure only split once! don't call split_forward_backward manually."
313+ assert self . has_splited is False , (
314+ "Please ensure only split once! don't call split_forward_backward manually."
315+ )
316316 self .has_splited = True
317317 self .update_op_range ()
318318 (
@@ -406,9 +406,9 @@ def _forward_backward_program(self):
406406
407407 @cached_property # shouldn't changed when call this once.
408408 def program_attr (self ):
409- assert (
410- self . finish_pass is False
411- ), "program_attr() is called by PartialProgramLayer, don't call it manually, use program_name_attr instead."
409+ assert self . finish_pass is False , (
410+ "program_attr() is called by PartialProgramLayer, don't call it manually, use program_name_attr instead."
411+ )
412412 # can't apply pass after call this function.
413413 self .finish_pass = True
414414 fwd_map = RunnableProgram ._get_name_value_map_from_program (
@@ -445,9 +445,9 @@ def program_attr(self):
445445 program_attr [f"{ k } _names" ] = ns
446446
447447 # Restore stop_gradient for output values
448- assert len (program_attr ["fo_values" ]) == len (
449- self . out_stop_gradients
450- ), "Output values and stop gradients length mismatch"
448+ assert len (program_attr ["fo_values" ]) == len (self . out_stop_gradients ), (
449+ "Output values and stop gradients length mismatch"
450+ )
451451 for v , stop_gradient in zip (
452452 program_attr ["fo_values" ], self .out_stop_gradients
453453 ):
@@ -474,9 +474,9 @@ def unify_value_names(
474474 # Get all values again because some values has been erased.
475475 for value in RunnableProgram ._get_program_all_values (program ):
476476 if value .has_name :
477- assert (
478- value . _has_only_one_name ()
479- ), f"Expected all values in Program have only one name, but { value } has multiple names: { value . _names } "
477+ assert value . _has_only_one_name (), (
478+ f"Expected all values in Program have only one name, but { value } has multiple names: { value . _names } "
479+ )
480480 return rename_mapping
481481
482482 @staticmethod
0 commit comments