@@ -97,6 +97,8 @@ def test_to_static_program(self):
9797 main_program = dist_model ._engine ._pir_main_progs ["eval" ]
9898
9999 for op in main_program .global_block ().ops :
100+ if op .num_results () == 0 :
101+ continue
100102 tensor = op .result (0 )
101103 if op .name () == 'pd_op.data' :
102104 self .assertTrue (tensor .is_dist_dense_tensor_type ())
@@ -128,9 +130,22 @@ def test_to_static_program(self):
128130
129131 relu_idx = 0
130132 matmul_idx = 0
131-
132- for op in main_program .global_block ().ops :
133+ matmul_grad_idx = 0
134+ ops = main_program .global_block ().ops
135+ self .assertEqual (ops [- 1 ].name (), "pd_op.matmul_grad" )
136+ self .assertEqual (ops [- 2 ].name (), "pd_op.relu_grad" )
137+ self .assertEqual (ops [- 3 ].name (), "pd_op.matmul_grad" )
138+ self .assertEqual (ops [- 4 ].name (), "pd_op.relu_grad" )
139+ self .assertEqual (ops [- 5 ].name (), "pd_op.subtract_grad" )
140+ self .assertEqual (ops [- 6 ].name (), "pd_op.square_grad" )
141+ self .assertEqual (ops [- 7 ].name (), "pd_op.mean_grad" )
142+
143+ for op in ops :
144+ if op .num_results () == 0 :
145+ continue
133146 tensor = op .result (0 )
147+ if not tensor .initialized ():
148+ continue
134149 self .assertTrue (tensor .is_dist_dense_tensor_type ())
135150 self .assertEqual (tensor .dist_attr ().process_mesh .shape , [2 ])
136151 self .assertEqual (
@@ -143,8 +158,6 @@ def test_to_static_program(self):
143158 elif op .name () == 'builtin.parameter' :
144159 self .assertTrue (tensor .is_dense_tensor_type ())
145160 self .assertTrue (tensor .is_dist_dense_tensor_type ())
146- self .assertTrue (tensor .has_one_use ())
147-
148161 self .assertTrue (tensor .is_dist_dense_tensor_type ())
149162 self .assertEqual (tensor .dist_attr ().process_mesh .shape , [2 ])
150163 self .assertEqual (
@@ -189,6 +202,20 @@ def test_to_static_program(self):
189202 tensor ._local_shape , [BATCH_SIZE , CLASS_NUM ]
190203 )
191204 matmul_idx += 1
205+ if op .name () == 'pd_op.matmul_grad' :
206+ if matmul_grad_idx == 0 :
207+ self .assertEqual (tensor .dist_attr ().dims_mapping , [- 1 , 0 ])
208+ self .assertEqual (tensor .dist_attr ().partial_dims , set ())
209+ self .assertEqual (
210+ tensor ._local_shape , [BATCH_SIZE , CLASS_NUM ]
211+ )
212+ elif matmul_grad_idx == 1 :
213+ self .assertEqual (tensor .dist_attr ().dims_mapping , [- 1 , 0 ])
214+ self .assertEqual (tensor .dist_attr ().partial_dims , set ())
215+ self .assertEqual (
216+ tensor ._local_shape , [BATCH_SIZE , IMAGE_SIZE // 2 ]
217+ )
218+ matmul_grad_idx += 1
192219
193220 # dist_model.train()
194221 # for batch_id, (image, label) in enumerate(dist_loader()):
0 commit comments