1616import re
1717import argparse
1818import os
19+ import logging
1920from codegen_utils import core_ops_returns_info , core_ops_args_info , core_ops_args_type_info
2021from codegen_utils import yaml_types_mapping
2122from codegen_utils import ReadFwdFile , ReadBwdFile
3031from codegen_utils import ParseYamlForward , ParseYamlBackward
3132from codegen_utils import FunctionGeneratorBase , YamlGeneratorBase
3233from codegen_utils import ops_to_fill_zero_for_empty_grads
34+ from codegen_utils import AssertMessage
3335
3436
3537###########
@@ -398,14 +400,21 @@ def DygraphYamlValidationCheck(self):
398400 forward_api_contents = self .forward_api_contents
399401 grad_api_contents = self .grad_api_contents
400402
401- assert 'api' in forward_api_contents .keys ()
402- assert 'args' in forward_api_contents .keys ()
403- assert 'output' in forward_api_contents .keys ()
404- assert 'backward' in forward_api_contents .keys ()
405-
406- assert 'args' in grad_api_contents .keys ()
407- assert 'output' in grad_api_contents .keys ()
408- assert 'forward' in grad_api_contents .keys ()
403+ assert 'api' in forward_api_contents .keys (
404+ ), "Unable to find \" api\" in api.yaml"
405+ assert 'args' in forward_api_contents .keys (
406+ ), "Unable to find \" args\" in api.yaml"
407+ assert 'output' in forward_api_contents .keys (
408+ ), "Unable to find \" output\" in api.yaml"
409+ assert 'backward' in forward_api_contents .keys (
410+ ), "Unable to find \" backward\" in api.yaml"
411+
412+ assert 'args' in grad_api_contents .keys (
413+ ), "Unable to find \" args\" in backward.yaml"
414+ assert 'output' in grad_api_contents .keys (
415+ ), "Unable to find \" output\" in backward.yaml"
416+ assert 'forward' in grad_api_contents .keys (
417+ ), "Unable to find \" forward\" in backward.yaml"
409418
410419 def ForwardsValidationCheck (self ):
411420 forward_inputs_list = self .forward_inputs_list
@@ -424,8 +433,10 @@ def ForwardsValidationCheck(self):
424433 orig_input_type = orig_forward_inputs_list [i ][1 ]
425434 orig_input_pos = orig_forward_inputs_list [i ][2 ]
426435
427- assert forward_input_type == orig_input_type
428- assert forward_input_pos == orig_input_pos
436+ assert forward_input_type == orig_input_type , AssertMessage (
437+ forward_input_type , orig_input_type )
438+ assert forward_input_pos == orig_input_pos , AssertMessage (
439+ forward_input_pos , orig_input_pos )
429440
430441 for i in range (len (forward_attrs_list )):
431442 orig_attr_name = orig_forward_attrs_list [i ][0 ]
@@ -436,18 +447,23 @@ def ForwardsValidationCheck(self):
436447 forward_attr_type = forward_attrs_list [i ][1 ]
437448 forward_attr_default = forward_attrs_list [i ][2 ]
438449 forward_attr_pos = forward_attrs_list [i ][3 ]
439- assert orig_attr_type == forward_attr_type
440- assert orig_attr_default == forward_attr_default
441- assert orig_attr_pos == forward_attr_pos
450+ assert orig_attr_type == forward_attr_type , AssertMessage (
451+ orig_attr_type , forward_attr_type )
452+ assert orig_attr_default == forward_attr_default , AssertMessage (
453+ orig_attr_default , forward_attr_default )
454+ assert orig_attr_pos == forward_attr_pos , AssertMessage (
455+ orig_attr_pos , forward_attr_pos )
442456
443457 for i in range (len (forward_returns_list )):
444458 orig_return_type = orig_forward_returns_list [i ][1 ]
445459 orig_return_pos = orig_forward_returns_list [i ][2 ]
446460 forward_return_type = forward_returns_list [i ][1 ]
447461 forward_return_pos = forward_returns_list [i ][2 ]
448462
449- assert orig_return_type == forward_return_type
450- assert orig_return_pos == forward_return_pos
463+ assert orig_return_type == forward_return_type , AssertMessage (
464+ orig_return_type , forward_return_type )
465+ assert orig_return_pos == forward_return_pos , AssertMessage (
466+ orig_return_pos , forward_return_pos )
451467
452468 # Check Order: Inputs, Attributes
453469 max_input_position = - 1
@@ -456,7 +472,8 @@ def ForwardsValidationCheck(self):
456472
457473 max_attr_position = - 1
458474 for _ , _ , _ , pos in forward_attrs_list :
459- assert pos > max_input_position
475+ assert pos > max_input_position , AssertMessage (pos ,
476+ max_input_position )
460477 max_attr_position = max (max_attr_position , pos )
461478
462479 def BackwardValidationCheck (self ):
@@ -471,12 +488,14 @@ def BackwardValidationCheck(self):
471488
472489 max_grad_tensor_position = - 1
473490 for _ , (_ , _ , pos ) in backward_grad_inputs_map .items ():
474- assert pos > max_fwd_input_position
491+ assert pos > max_fwd_input_position , AssertMessage (
492+ pos , max_grad_tensor_position )
475493 max_grad_tensor_position = max (max_grad_tensor_position , pos )
476494
477495 max_attr_position = - 1
478496 for _ , _ , _ , pos in backward_attrs_list :
479- assert pos > max_grad_tensor_position
497+ assert pos > max_grad_tensor_position , AssertMessage (
498+ pos , max_grad_tensor_position )
480499 max_attr_position = max (max_attr_position , pos )
481500
482501 def IntermediateValidationCheck (self ):
@@ -491,7 +510,8 @@ def IntermediateValidationCheck(self):
491510 len (forward_returns_list ))
492511 for ret_name , _ , pos in forward_returns_list :
493512 if ret_name in intermediate_outputs :
494- assert pos in intermediate_positions
513+ assert pos in intermediate_positions , AssertMessage (
514+ pos , intermediate_positions )
495515
496516 def CollectBackwardInfo (self ):
497517 forward_api_contents = self .forward_api_contents
@@ -505,9 +525,12 @@ def CollectBackwardInfo(self):
505525
506526 self .backward_inputs_list , self .backward_attrs_list , self .backward_returns_list = ParseYamlBackward (
507527 backward_args_str , backward_returns_str )
508- print ("Parsed Backward Inputs List: " , self .backward_inputs_list )
509- print ("Prased Backward Attrs List: " , self .backward_attrs_list )
510- print ("Parsed Backward Returns List: " , self .backward_returns_list )
528+
529+ logging .info (
530+ f"Parsed Backward Inputs List: { self .backward_inputs_list } " )
531+ logging .info (f"Prased Backward Attrs List: { self .backward_attrs_list } " )
532+ logging .info (
533+ f"Parsed Backward Returns List: { self .backward_returns_list } " )
511534
512535 def CollectForwardInfoFromBackwardContents (self ):
513536
@@ -530,7 +553,9 @@ def SlotNameMatching(self):
530553 backward_fwd_name = FindForwardName (backward_input_name )
531554 if backward_fwd_name :
532555 # Grad Input
533- assert backward_fwd_name in forward_outputs_position_map .keys ()
556+ assert backward_fwd_name in forward_outputs_position_map .keys (
557+ ), AssertMessage (backward_fwd_name ,
558+ forward_outputs_position_map .keys ())
534559 matched_forward_output_type = forward_outputs_position_map [
535560 backward_fwd_name ][0 ]
536561 matched_forward_output_pos = forward_outputs_position_map [
@@ -556,17 +581,18 @@ def SlotNameMatching(self):
556581 backward_input_type , False , backward_input_pos
557582 ]
558583 else :
559- assert False , backward_input_name
584+ assert False , f"Cannot find { backward_input_name } in forward position map"
560585
561586 for backward_output in backward_returns_list :
562587 backward_output_name = backward_output [0 ]
563588 backward_output_type = backward_output [1 ]
564589 backward_output_pos = backward_output [2 ]
565590
566591 backward_fwd_name = FindForwardName (backward_output_name )
567- assert backward_fwd_name is not None
592+ assert backward_fwd_name is not None , f"Detected { backward_fwd_name } = None"
568593 assert backward_fwd_name in forward_inputs_position_map .keys (
569- ), f"Unable to find { backward_fwd_name } in forward inputs"
594+ ), AssertMessage (backward_fwd_name ,
595+ forward_inputs_position_map .keys ())
570596
571597 matched_forward_input_type = forward_inputs_position_map [
572598 backward_fwd_name ][0 ]
@@ -577,12 +603,15 @@ def SlotNameMatching(self):
577603 backward_output_type , matched_forward_input_pos ,
578604 backward_output_pos
579605 ]
580- print ("Generated Backward Fwd Input Map: " ,
581- self .backward_forward_inputs_map )
582- print ("Generated Backward Grad Input Map: " ,
583- self .backward_grad_inputs_map )
584- print ("Generated Backward Grad Output Map: " ,
585- self .backward_grad_outputs_map )
606+ logging .info (
607+ f"Generated Backward Fwd Input Map: { self .backward_forward_inputs_map } "
608+ )
609+ logging .info (
610+ f"Generated Backward Grad Input Map: { self .backward_grad_inputs_map } "
611+ )
612+ logging .info (
613+ f"Generated Backward Grad Output Map: { self .backward_grad_outputs_map } "
614+ )
586615
587616 def GenerateNodeDeclaration (self ):
588617 forward_op_name = self .forward_api_name
@@ -642,7 +671,7 @@ def GenerateNodeDeclaration(self):
642671 set_tensor_wrapper_methods_str , set_attribute_methods_str ,
643672 tensor_wrapper_members_str , attribute_members_str )
644673
645- print ( "Generated Node Declaration: " , self .node_declaration_str )
674+ logging . info ( f "Generated Node Declaration: { self .node_declaration_str } " )
646675
647676 def GenerateNodeDefinition (self ):
648677 namespace = self .namespace
@@ -710,7 +739,7 @@ def GenerateNodeDefinition(self):
710739 grad_node_name , fill_zero_str , grad_node_name , grad_api_namespace ,
711740 backward_api_name , grad_api_args_str , returns_str )
712741
713- print ( "Generated Node Definition: " , self .node_definition_str )
742+ logging . info ( f "Generated Node Definition: { self .node_definition_str } " )
714743
715744 def GenerateForwardDefinition (self , is_inplaced ):
716745 namespace = self .namespace
@@ -813,8 +842,10 @@ def GenerateForwardDefinition(self, is_inplaced):
813842 dygraph_event_str , node_creation_str , returns_str )
814843 self .forward_declaration_str += f"{ returns_type_str } { forward_function_name } ({ inputs_args_declaration_str } );\n "
815844
816- print ("Generated Forward Definition: " , self .forward_definition_str )
817- print ("Generated Forward Declaration: " , self .forward_declaration_str )
845+ logging .info (
846+ f"Generated Forward Definition: { self .forward_definition_str } " )
847+ logging .info (
848+ f"Generated Forward Declaration: { self .forward_declaration_str } " )
818849
819850 def GenerateNodeCreationCodes (self , forward_call_str ):
820851 forward_api_name = self .forward_api_name
@@ -921,7 +952,8 @@ def GenerateNodeCreationCodes(self, forward_call_str):
921952 else :
922953 if num_fwd_outputs > 1 :
923954 # Aligned with forward output position
924- assert name in forward_outputs_position_map .keys ()
955+ assert name in forward_outputs_position_map .keys (
956+ ), AssertMessage (name , forward_outputs_position_map .keys ())
925957 fwd_output_pos = forward_outputs_position_map [name ][1 ]
926958 tw_name = f"std::get<{ fwd_output_pos } >(api_result)"
927959 else :
@@ -1114,7 +1146,8 @@ def GetBackwardAPIContents(self, forward_api_contents):
11141146 if 'backward' not in forward_api_contents .keys (): return None
11151147
11161148 backward_api_name = forward_api_contents ['backward' ]
1117- assert backward_api_name in grad_api_dict .keys ()
1149+ assert backward_api_name in grad_api_dict .keys (), AssertMessage (
1150+ backward_api_name , grad_api_dict .keys ())
11181151 backward_api_contents = grad_api_dict [backward_api_name ]
11191152
11201153 return backward_api_contents
0 commit comments