Skip to content

Commit 46f88d5

Browse files
committed
remove program_guard
1 parent a9b0aae commit 46f88d5

File tree

1 file changed

+20
-18
lines changed
  • python/paddle/distributed/auto_parallel/static

1 file changed

+20
-18
lines changed

python/paddle/distributed/auto_parallel/static/pir_pass.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -165,22 +165,24 @@ def apply_reshard_pass(program):
165165
# training, the transpose is equal to reshape.
166166
# So, this pass is to haddle the specific case.
167167
def eliminate_transpose_by_reshape(program):
168-
with paddle.static.program_guard(program):
169-
for op in program.global_block().ops:
170-
if op.name() == 'pd_op.transpose' or op.name() == 'pd_op.transpose':
171-
var = op.operand(0).source()
172-
rank = len(var.shape)
173-
perm = op.attrs()['perm']
174-
perm = [p + rank if p < 0 else p for p in perm]
175-
# only support transpose dim 0 and dim 1
176-
expected_perm = [1, 0] + [i + 2 for i in range(rank - 2)]
177-
if perm == expected_perm and (
178-
var.shape[0] == 1 or var.shape[1] == 1
179-
):
180-
if var.shape == [1, 1024, 4096]:
181-
paddle.pir.set_insertion_point(op)
182-
transpose_var = op.result(0)
183-
reshape_var = paddle.reshape(var, transpose_var.shape)
184-
transpose_var.replace_all_uses_with(reshape_var)
185-
program.global_block().remove_op(op)
168+
for op in program.global_block().ops:
169+
if (
170+
op.name() == 'pd_op.transpose'
171+
or op.name() == 'pd_op.transpose_grad'
172+
):
173+
var = op.operand(0).source()
174+
rank = len(var.shape)
175+
perm = op.attrs()['perm']
176+
perm = [p + rank if p < 0 else p for p in perm]
177+
# only support transpose dim 0 and dim 1
178+
expected_perm = [1, 0] + [i + 2 for i in range(rank - 2)]
179+
if perm == expected_perm and (
180+
var.shape[0] == 1 or var.shape[1] == 1
181+
):
182+
print('elinimate', op)
183+
paddle.pir.set_insertion_point(op)
184+
transpose_var = op.result(0)
185+
reshape_var = paddle._C_ops.reshape(var, transpose_var.shape)
186+
transpose_var.replace_all_uses_with(reshape_var)
187+
program.global_block().remove_op(op)
186188
return program

0 commit comments

Comments
 (0)