@@ -1086,5 +1086,146 @@ def test_dygraph_forward(self):
10861086 )
10871087
10881088
1089+ class TestIndexPutPrim (unittest .TestCase ):
1090+ def __int__ (self ):
1091+ self ().__init__ ()
1092+
1093+ def test_prim (self ):
1094+ try :
1095+ paddle .framework .core ._set_prim_all_enabled (True )
1096+ for accumulate in [False , True ]:
1097+ for x_shape , indices_shape , value_shape in [
1098+ ([16 ], [10 ], [10 ]),
1099+ ([16 , 16 ], [20 , 2 ], [20 ]),
1100+ ([12 , 13 , 14 ], [88 , 1 ], [88 , 13 , 14 ]),
1101+ ([12 , 13 , 14 ], [88 , 2 ], [88 , 14 ]),
1102+ ([12 , 13 , 14 ], [88 , 3 ], [88 ]),
1103+ ([12 , 13 , 14 ], [12 * 13 * 14 , 3 ], [12 * 13 * 14 ]),
1104+ ]:
1105+ n_indices = indices_shape [0 ]
1106+ index_dim_size = (
1107+ indices_shape [1 ] if len (indices_shape ) > 1 else 1
1108+ )
1109+
1110+ x_np = np .random .randn (* x_shape )
1111+ indices_np = tuple (
1112+ [
1113+ np .random .randint (
1114+ - x_shape [i ], x_shape [i ], [n_indices ]
1115+ )
1116+ for i in range (max (index_dim_size , 1 ))
1117+ ]
1118+ )
1119+ value_np = np .random .randn (* value_shape ).astype ("float32" )
1120+
1121+ # run paddle
1122+ x_pd = paddle .to_tensor (
1123+ x_np .copy (),
1124+ "float32" ,
1125+ stop_gradient = False ,
1126+ )
1127+ indices_pd = tuple (
1128+ [
1129+ paddle .to_tensor (
1130+ indice .copy (),
1131+ "int64" ,
1132+ stop_gradient = True ,
1133+ )
1134+ for indice in indices_np
1135+ ]
1136+ )
1137+ value_pd = paddle .to_tensor (
1138+ value_np .copy (),
1139+ "float32" ,
1140+ stop_gradient = False ,
1141+ )
1142+
1143+ out_pd = paddle .index_put (
1144+ x_pd , indices_pd , value_pd , accumulate = accumulate
1145+ )
1146+ # out_pd = paddle.tanh(out_pd) #
1147+ dout_np = np .random .randn (* out_pd .shape )
1148+
1149+ dout_pd = paddle .to_tensor (
1150+ dout_np .copy (),
1151+ "float32" ,
1152+ stop_gradient = False ,
1153+ )
1154+ dout_pd .stop_gradient = False
1155+
1156+ if accumulate :
1157+
1158+ def compute_dx_dv (x , indices , v , dy , accumulate = True ):
1159+ y = paddle .index_put (x , indices , v , True )
1160+ return paddle .grad (y , [x , v ], dy , create_graph = True )
1161+
1162+ else :
1163+
1164+ def compute_dx_dv (x , indices , v , dy , accumulate = False ):
1165+ y = paddle .index_put (x , indices , v , False )
1166+ return paddle .grad (y , [x , v ], dy , create_graph = True )
1167+
1168+ # eager
1169+ dx_ref , dv_ref = compute_dx_dv (
1170+ x_pd , indices_pd , value_pd , dout_pd
1171+ )
1172+
1173+ # static dynamic shape
1174+ st_func1 = paddle .jit .to_static (
1175+ compute_dx_dv ,
1176+ input_spec = [
1177+ paddle .static .InputSpec (
1178+ shape = [- 1 , - 1 ], dtype = 'float32'
1179+ ),
1180+ tuple (
1181+ paddle .static .InputSpec (
1182+ shape = [- 1 ], dtype = 'int64'
1183+ )
1184+ for _ in range (len (indices_pd ))
1185+ ),
1186+ paddle .static .InputSpec (
1187+ shape = [- 1 , - 1 ], dtype = 'float32'
1188+ ),
1189+ paddle .static .InputSpec (
1190+ shape = [- 1 , - 1 ], dtype = 'float32'
1191+ ),
1192+ ],
1193+ full_graph = True ,
1194+ backend = None ,
1195+ )
1196+ dx_1 , dv_1 = st_func1 (x_pd , indices_pd , value_pd , dout_pd )
1197+
1198+ # static fixed shape
1199+ st_func2 = paddle .jit .to_static (
1200+ compute_dx_dv ,
1201+ full_graph = True ,
1202+ backend = None ,
1203+ )
1204+ dx_2 , dv_2 = st_func2 (x_pd , indices_pd , value_pd , dout_pd )
1205+
1206+ np .testing .assert_allclose (
1207+ dx_1 .numpy (),
1208+ dx_ref .numpy (),
1209+ err_msg = f"accumulate={ accumulate } \n x_np:\n { x_np } \n indices_np:\n { indices_np } \n value_np:\n { value_np } \n out_np:{ out_pd .numpy ()} \n " ,
1210+ )
1211+ np .testing .assert_allclose (
1212+ dv_1 .numpy (),
1213+ dv_ref .numpy (),
1214+ err_msg = f"accumulate={ accumulate } \n x_np:\n { x_np } \n indices_np:\n { indices_np } \n value_np:\n { value_np } \n out_np:{ out_pd .numpy ()} \n " ,
1215+ )
1216+ np .testing .assert_allclose (
1217+ dx_2 .numpy (),
1218+ dx_ref .numpy (),
1219+ err_msg = f"accumulate={ accumulate } \n x_np:\n { x_np } \n indices_np:\n { indices_np } \n value_np:\n { value_np } \n out_np:{ out_pd .numpy ()} \n " ,
1220+ )
1221+ np .testing .assert_allclose (
1222+ dv_2 .numpy (),
1223+ dv_ref .numpy (),
1224+ err_msg = f"accumulate={ accumulate } \n x_np:\n { x_np } \n indices_np:\n { indices_np } \n value_np:\n { value_np } \n out_np:{ out_pd .numpy ()} \n " ,
1225+ )
1226+ finally :
1227+ paddle .framework .core ._set_prim_all_enabled (False )
1228+
1229+
10891230if __name__ == '__main__' :
10901231 unittest .main ()
0 commit comments