@@ -61,99 +61,90 @@ def reshard_combine_value(op, operand, attr):
6161
6262
6363def apply_partition_pass (program ):
64- with paddle .static .program_guard (program ):
65- for op in program .global_block ().ops :
66- if op .name () in partition_skip_op_list :
67- continue
68- assert len (op .operands ()) == len (
69- op .dist_attr .operands ()
70- ), f"The number of operands and the number of op_dist_attr's operands are not equal in op: { op } "
71-
72- for operand , attr in zip (op .operands (), op .dist_attr .operands ()):
73- prev_var = operand .source ()
74- if prev_var .is_combine ():
75- operand .set_source (reshard_combine_value (op , operand , attr ))
76- else :
77- operand .set_source (reshard_single_value (op , operand , attr ))
78- prev_op = prev_var .get_defining_op ()
79- if (
80- prev_op
81- and prev_op .num_results () == 1
82- and prev_var .use_empty ()
83- ):
84- prev_op .erase ()
85-
86- for var , attr in zip (op .results (), op .dist_attr .results ()):
87- if (
88- var .initialized ()
89- and var .is_dist ()
90- and var .dist_attr () != attr
91- ):
92- paddle .pir .set_insertion_point_after (op )
93- old_dist_attr = var .dist_attr ()
94- var .update_dist_attr (attr .as_tensor_dist_attr ())
95- # insert reshard
96- reshard_var = paddle ._C_ops .reshard_v2 (var , old_dist_attr )
97- var .replace_all_uses_with (reshard_var )
98- reshard_var .get_defining_op ().operand (0 ).set_source (var )
99-
100- # pruning op and value not belong to cur rank
101- cur_rank = paddle .distributed .get_rank ()
102- for op in program .global_block ().ops [::- 1 ]:
103- if cur_rank not in op .dist_attr .process_mesh .process_ids :
104- program .global_block ().remove_op (op )
64+ for op in program .global_block ().ops :
65+ if op .name () in partition_skip_op_list :
66+ continue
67+ assert len (op .operands ()) == len (
68+ op .dist_attr .operands ()
69+ ), f"The number of operands and the number of op_dist_attr's operands are not equal in op: { op } "
70+
71+ for operand , attr in zip (op .operands (), op .dist_attr .operands ()):
72+ prev_var = operand .source ()
73+ if prev_var .is_combine ():
74+ operand .set_source (reshard_combine_value (op , operand , attr ))
10575 else :
106- # set the operand as null when it is not belong to cur rank
107- if (
108- op .name () == 'dist_op.reshard'
109- and cur_rank
110- not in op .operand (0 )
111- .source ()
112- .dist_attr ()
113- .process_mesh .process_ids
114- ):
115- op .operand (0 ).set_source (None )
116-
117- # merge pd.data ops for
118- lr_ops = []
119- for op in program .global_block ().ops [::- 1 ]:
76+ operand .set_source (reshard_single_value (op , operand , attr ))
77+ prev_op = prev_var .get_defining_op ()
78+ if prev_op and prev_op .num_results () == 1 and prev_var .use_empty ():
79+ prev_op .erase ()
80+
81+ for var , attr in zip (op .results (), op .dist_attr .results ()):
82+ if var .initialized () and var .is_dist () and var .dist_attr () != attr :
83+ paddle .pir .set_insertion_point_after (op )
84+ old_dist_attr = var .dist_attr ()
85+ var .update_dist_attr (attr .as_tensor_dist_attr ())
86+ # insert reshard
87+ reshard_var = paddle ._C_ops .reshard_v2 (var , old_dist_attr )
88+ var .replace_all_uses_with (reshard_var )
89+ reshard_var .get_defining_op ().operand (0 ).set_source (var )
90+
91+ # pruning op and value not belong to cur rank
92+ cur_rank = paddle .distributed .get_rank ()
93+ for op in program .global_block ().ops [::- 1 ]:
94+ if cur_rank not in op .dist_attr .process_mesh .process_ids :
95+ op .erase ()
96+ else :
97+ # set the operand as null when it is not belong to cur rank
12098 if (
121- op .name () == 'pd_op.data'
122- and "learning_rate" in op .attrs ()["name" ]
99+ op .name () == 'dist_op.reshard'
100+ and cur_rank
101+ not in op .operand (0 )
102+ .source ()
103+ .dist_attr ()
104+ .process_mesh .process_ids
123105 ):
124- lr_ops .append (op )
125-
126- if len (lr_ops ) > 1 :
127- lr_value = lr_ops [0 ].result (0 )
128- for op in lr_ops [1 :]:
129- lr = op .result (0 )
130- lr .replace_all_uses_with (lr_value )
131- program .global_block ().remove_op (op )
132- return program
106+ op .operand (0 ).set_source (None )
107+
108+ # merge pd.data ops for
109+ lr_ops = []
110+ for op in program .global_block ().ops [::- 1 ]:
111+ if op .name () == 'pd_op.data' and "learning_rate" in op .attrs ()["name" ]:
112+ lr_ops .append (op )
113+
114+ if len (lr_ops ) > 1 :
115+ lr_value = lr_ops [0 ].result (0 )
116+ for op in lr_ops [1 :]:
117+ lr = op .result (0 )
118+ lr .replace_all_uses_with (lr_value )
119+ op .erase ()
133120
134121
135122def apply_reshard_pass (program ):
136- new_program = program .clone ()
137- with paddle .base .program_guard (new_program ):
138- for op in new_program .global_block ().ops :
139- if op .name () == 'dist_op.reshard' :
140- var = op .operand_source (0 )
141- op_dist_attr = op .dist_attr
142- src_dist_attr = op_dist_attr .operand (0 ).as_tensor_dist_attr ()
143- dst_dist_attr = op_dist_attr .result (0 ).as_tensor_dist_attr ()
144- assert (
145- not var .initialized () or var .dist_attr () == src_dist_attr
146- ), f"The dist_attr of reshard op's input and operand should be equal, but got { var .dist_attr ()} and { src_dist_attr } "
147-
148- reshard_func = choose_reshard_func (src_dist_attr , dst_dist_attr )
149- assert (
150- reshard_func is not None
151- ), f'There is no reshard function that matches src_dist_attr: { src_dist_attr } and dst_dist_attr: { dst_dist_attr } '
152- reshard_func .reshard (
153- new_program , op , src_dist_attr , dst_dist_attr
154- )
155-
156- return new_program
123+ for op in program .global_block ().ops :
124+ if op .name () == 'dist_op.reshard' :
125+ var = op .operand_source (0 )
126+ op_dist_attr = op .dist_attr
127+ src_dist_attr = op_dist_attr .operand (0 ).as_tensor_dist_attr ()
128+ dst_dist_attr = op_dist_attr .result (0 ).as_tensor_dist_attr ()
129+ assert (
130+ not var .initialized () or var .dist_attr () == src_dist_attr
131+ ), f"The dist_attr of reshard op's input and operand should be equal, but got { var .dist_attr ()} and { src_dist_attr } "
132+
133+ reshard_func = choose_reshard_func (src_dist_attr , dst_dist_attr )
134+ assert (
135+ reshard_func is not None
136+ ), f'There is no reshard function that matches src_dist_attr: { src_dist_attr } and dst_dist_attr: { dst_dist_attr } '
137+ paddle .pir .set_insertion_point_after (op )
138+ out_value = reshard_func .reshard (
139+ src_dist_attr ,
140+ dst_dist_attr ,
141+ op .operand_source (0 ),
142+ op .result (0 ).type (),
143+ )
144+ if out_value is not None :
145+ op .result (0 ).replace_all_uses_with (out_value )
146+ if op .result (0 ).use_empty ():
147+ op .erase ()
157148
158149
159150# In sequence_parallel, we need to transpose hidden_states
@@ -183,5 +174,5 @@ def eliminate_transpose_by_reshape(program):
183174 transpose_var = op .result (0 )
184175 reshape_var = paddle ._C_ops .reshape (var , transpose_var .shape )
185176 transpose_var .replace_all_uses_with (reshape_var )
186- program . global_block (). remove_op ( op )
177+ op . erase ( )
187178 return program
0 commit comments