@@ -165,22 +165,24 @@ def apply_reshard_pass(program):
165165# training, the transpose is equal to reshape.
166166# So, this pass is to haddle the specific case.
167167def eliminate_transpose_by_reshape (program ):
168- with paddle .static .program_guard (program ):
169- for op in program .global_block ().ops :
170- if op .name () == 'pd_op.transpose' or op .name () == 'pd_op.transpose' :
171- var = op .operand (0 ).source ()
172- rank = len (var .shape )
173- perm = op .attrs ()['perm' ]
174- perm = [p + rank if p < 0 else p for p in perm ]
175- # only support transpose dim 0 and dim 1
176- expected_perm = [1 , 0 ] + [i + 2 for i in range (rank - 2 )]
177- if perm == expected_perm and (
178- var .shape [0 ] == 1 or var .shape [1 ] == 1
179- ):
180- if var .shape == [1 , 1024 , 4096 ]:
181- paddle .pir .set_insertion_point (op )
182- transpose_var = op .result (0 )
183- reshape_var = paddle .reshape (var , transpose_var .shape )
184- transpose_var .replace_all_uses_with (reshape_var )
185- program .global_block ().remove_op (op )
168+ for op in program .global_block ().ops :
169+ if (
170+ op .name () == 'pd_op.transpose'
171+ or op .name () == 'pd_op.transpose_grad'
172+ ):
173+ var = op .operand (0 ).source ()
174+ rank = len (var .shape )
175+ perm = op .attrs ()['perm' ]
176+ perm = [p + rank if p < 0 else p for p in perm ]
177+ # only support transpose dim 0 and dim 1
178+ expected_perm = [1 , 0 ] + [i + 2 for i in range (rank - 2 )]
179+ if perm == expected_perm and (
180+ var .shape [0 ] == 1 or var .shape [1 ] == 1
181+ ):
182+ print ('elinimate' , op )
183+ paddle .pir .set_insertion_point (op )
184+ transpose_var = op .result (0 )
185+ reshape_var = paddle ._C_ops .reshape (var , transpose_var .shape )
186+ transpose_var .replace_all_uses_with (reshape_var )
187+ program .global_block ().remove_op (op )
186188 return program
0 commit comments