@@ -97,6 +97,8 @@ def test_to_static_program(self):
9797 main_program = dist_model ._engine ._pir_main_progs ["eval" ]
9898
9999 for op in main_program .global_block ().ops :
100+ if op .num_results () == 0 :
101+ continue
100102 tensor = op .result (0 )
101103 if op .name () == 'pd_op.data' :
102104 self .assertTrue (tensor .is_dist_dense_tensor_type ())
@@ -128,9 +130,24 @@ def test_to_static_program(self):
128130
129131 relu_idx = 0
130132 matmul_idx = 0
131-
132- for op in main_program .global_block ().ops :
133+ matmul_grad_idx = 0
134+ ops = main_program .global_block ().ops
135+ self .assertEqual (ops [- 1 ].name (), "pd_op.matmul_grad" )
136+ self .assertEqual (ops [- 2 ].name (), "pd_op.relu_grad" )
137+ self .assertEqual (ops [- 3 ].name (), "pd_op.matmul_grad" )
138+ self .assertEqual (ops [- 4 ].name (), "pd_op.relu_grad" )
139+ self .assertEqual (ops [- 5 ].name (), "pd_op.subtract_grad" )
140+ self .assertEqual (ops [- 6 ].name (), "pd_op.square_grad" )
141+ self .assertEqual (ops [- 7 ].name (), "pd_op.mean_grad" )
142+
143+ for op in ops :
144+ # skip shadow_output
145+ if op .num_results () == 0 :
146+ continue
133147 tensor = op .result (0 )
148+ # while tensor's stop_gradient is true, the corresponding grad tensor is initialized.
149+ if not tensor .initialized ():
150+ continue
134151 self .assertTrue (tensor .is_dist_dense_tensor_type ())
135152 self .assertEqual (tensor .dist_attr ().process_mesh .shape , [2 ])
136153 self .assertEqual (
@@ -143,8 +160,6 @@ def test_to_static_program(self):
143160 elif op .name () == 'builtin.parameter' :
144161 self .assertTrue (tensor .is_dense_tensor_type ())
145162 self .assertTrue (tensor .is_dist_dense_tensor_type ())
146- self .assertTrue (tensor .has_one_use ())
147-
148163 self .assertTrue (tensor .is_dist_dense_tensor_type ())
149164 self .assertEqual (tensor .dist_attr ().process_mesh .shape , [2 ])
150165 self .assertEqual (
@@ -189,6 +204,20 @@ def test_to_static_program(self):
189204 tensor ._local_shape , [BATCH_SIZE , CLASS_NUM ]
190205 )
191206 matmul_idx += 1
207+ if op .name () == 'pd_op.matmul_grad' :
208+ if matmul_grad_idx == 0 :
209+ self .assertEqual (tensor .dist_attr ().dims_mapping , [- 1 , 0 ])
210+ self .assertEqual (tensor .dist_attr ().partial_dims , set ())
211+ self .assertEqual (
212+ tensor ._local_shape , [BATCH_SIZE , CLASS_NUM ]
213+ )
214+ elif matmul_grad_idx == 1 :
215+ self .assertEqual (tensor .dist_attr ().dims_mapping , [- 1 , 0 ])
216+ self .assertEqual (tensor .dist_attr ().partial_dims , set ())
217+ self .assertEqual (
218+ tensor ._local_shape , [BATCH_SIZE , IMAGE_SIZE // 2 ]
219+ )
220+ matmul_grad_idx += 1
192221
193222 # dist_model.train()
194223 # for batch_id, (image, label) in enumerate(dist_loader()):
0 commit comments