@@ -351,7 +351,7 @@ def test_eval_symbolic(self):
351351 np .testing .assert_equal (
352352 sym_shape_str_list [j ].find (self .expected [i ][j ]),
353353 0 ,
354- f'in case i,j = { i } ,{ j } : output shape ({ sym_shape_str_list [0 ]} ) is not expected { (self .expected [i ][j ])} ' ,
354+ f'in case i,j = { i } ,{ j } : output shape ({ sym_shape_str_list [j ]} ) is not expected { (self .expected [i ][j ])} ' ,
355355 )
356356
357357 return True
@@ -403,7 +403,7 @@ def test_eval_symbolic(self):
403403 np .testing .assert_equal (
404404 sym_shape_str_list [j ].find (self .expected [i ][j ]),
405405 0 ,
406- f'in case i,j = { i } ,{ j } : output shape ({ sym_shape_str_list [0 ]} ) is not expected { (self .expected [i ][j ])} ' ,
406+ f'in case i,j = { i } ,{ j } : output shape ({ sym_shape_str_list [j ]} ) is not expected { (self .expected [i ][j ])} ' ,
407407 )
408408
409409 return True
@@ -453,7 +453,7 @@ def test_eval_symbolic(self):
453453 np .testing .assert_equal (
454454 sym_shape_str_list [j ].find (self .expected [i ][j ]),
455455 0 ,
456- f'in case i,j = { i } ,{ j } : output shape ({ sym_shape_str_list [0 ]} ) is not expected { (self .expected [i ][j ])} ' ,
456+ f'in case i,j = { i } ,{ j } : output shape ({ sym_shape_str_list [j ]} ) is not expected { (self .expected [i ][j ])} ' ,
457457 )
458458
459459 return True
@@ -512,11 +512,84 @@ def test_eval_symbolic(self):
512512 np .testing .assert_equal (
513513 sym_shape_str_list [j ].find (self .expected [i ][j ]),
514514 0 ,
515- f'in case i,j = { i } ,{ j } : output shape ({ sym_shape_str_list [0 ]} ) is not expected { (self .expected [i ][j ])} ' ,
515+ f'in case i,j = { i } ,{ j } : output shape ({ sym_shape_str_list [j ]} ) is not expected { (self .expected [i ][j ])} ' ,
516516 )
517517
518518 return True
519519
520520
521+ class SplitNet (paddle .nn .Layer ):
522+ def __init__ (self ):
523+ super ().__init__ ()
524+
525+ def forward (self , x ):
526+ out = paddle .split (x , [- 1 ], axis = 1 )
527+ out = paddle .split (x , [1 , 2 , - 1 ], axis = 1 )
528+ out = paddle .split (x , [1 , - 1 ], axis = 1 )
529+ out = paddle .split (x , [1 , 2 , 3 ], axis = 1 )
530+ out = paddle .split (x , [1 , 2 , x .shape [1 ]], axis = 1 )
531+
532+ out = x .split ([- 1 ], axis = 1 )
533+ out = x .split ([1 , 2 , - 1 ], axis = 1 )
534+ out = x .split ([1 , - 1 ], axis = 1 )
535+ out = x .split ([1 , 2 , 3 ], axis = 1 )
536+ out = x .split ([1 , 2 , x .shape [1 ]], axis = 1 )
537+
538+ return out
539+
540+
541+ class TestSplitOpInferSymbolicShape (TestBase ):
542+ def prepare_data (self ):
543+ self .cases = [np .random .rand (4 , 6 , 5 )]
544+
545+ self .expected = [
546+ [
547+ 'shape[S0, S1, S2], data[NULL]' ,
548+ 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, Add(S1, -3), S2], data[NULL]' ,
549+ 'shape[S0, 1, S2], data[NULL], shape[S0, Add(S1, -1), S2], data[NULL]' ,
550+ 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]' ,
551+ 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, S1, S2], data[NULL]' ,
552+ 'shape[S0, S1, S2], data[NULL]' ,
553+ 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, Add(S1, -3), S2], data[NULL]' ,
554+ 'shape[S0, 1, S2], data[NULL], shape[S0, Add(S1, -1), S2], data[NULL]' ,
555+ 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]' ,
556+ 'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, S1, S2], data[NULL]' ,
557+ ]
558+ ]
559+
560+ def test_eval_symbolic (self ):
561+ net = SplitNet ()
562+
563+ for i in range (len (self .cases )):
564+ x = self .cases [i ]
565+ x_spec = InputSpec (
566+ shape = [None for index in range (len (x .shape ))], dtype = 'float32'
567+ )
568+
569+ input_spec = [x_spec ]
570+ net = apply_to_static (net , False , input_spec )
571+ net .eval ()
572+
573+ # check the infer result
574+ sym_shape_str_list = get_sym_shape_str_for_op (
575+ net , input_spec , 'pd_op.split'
576+ )
577+ np .testing .assert_equal (
578+ len (sym_shape_str_list ), len (self .expected [i ])
579+ )
580+ for j in range (len (sym_shape_str_list )):
581+ np .testing .assert_equal (
582+ sym_shape_str_list [j ].find (self .expected [i ][j ]),
583+ 0 ,
584+ f'in case i,j = { i } ,{ j } : output shape ({ sym_shape_str_list [j ]} ) is not expected { (self .expected [i ][j ])} ' ,
585+ )
586+
587+ # TODO(fty1777): Add builtin.split op infer symbolic shape test
588+ # Not added because attribute `sym_shape_str` does not support multi-output op now.
589+ # See also: paddle/fluid/pir/transforms/shape_optimization_pass.cc:144.
590+
591+ return True
592+
593+
521594if __name__ == '__main__' :
522595 unittest .main ()
0 commit comments