@@ -1247,3 +1247,152 @@ def replacement_insert(self, codes, inserted=None):
12471247 return inserted
12481248
12491249 replace_method (the_class , 'insert' , replacement_insert )
1250+
1251+ ######################################################
1252+ # Syntatic sugar for NeuralNet classes
1253+ ######################################################
1254+
1255+ def handle_Tensor2D (the_class ):
1256+ the_class .original_init = the_class .__init__
1257+
1258+ def replacement_init (self , * args ):
1259+ if len (args ) == 1 :
1260+ array , = args
1261+ n , d = array .shape
1262+ self .original_init (n , d )
1263+ faiss .copy_array_to_vector (
1264+ np .ascontiguousarray (array ).ravel (), self .v )
1265+ else :
1266+ self .original_init (* args )
1267+
1268+ def numpy (self ):
1269+ shape = np .zeros (2 , dtype = np .int64 )
1270+ faiss .memcpy (faiss .swig_ptr (shape ), self .shape , shape .nbytes )
1271+ return faiss .vector_to_array (self .v ).reshape (shape [0 ], shape [1 ])
1272+
1273+ the_class .__init__ = replacement_init
1274+ the_class .numpy = numpy
1275+
1276+
1277+ def handle_Embedding (the_class ):
1278+ the_class .original_init = the_class .__init__
1279+
1280+ def replacement_init (self , * args ):
1281+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1282+ self .original_init (* args )
1283+ return
1284+ # assume it's a torch.Embedding
1285+ emb = args [0 ]
1286+ self .original_init (emb .num_embeddings , emb .embedding_dim )
1287+ self .from_torch (emb )
1288+
1289+ def from_torch (self , emb ):
1290+ """ copy weights from torch.Embedding """
1291+ assert emb .weight .shape == (self .num_embeddings , self .embedding_dim )
1292+ faiss .copy_array_to_vector (
1293+ np .ascontiguousarray (emb .weight .data ).ravel (), self .weight )
1294+
1295+ def from_array (self , array ):
1296+ """ copy weights from numpy array """
1297+ assert array .shape == (self .num_embeddings , self .embedding_dim )
1298+ faiss .copy_array_to_vector (
1299+ np .ascontiguousarray (array ).ravel (), self .weight )
1300+
1301+ the_class .from_array = from_array
1302+ the_class .from_torch = from_torch
1303+ the_class .__init__ = replacement_init
1304+
1305+
1306+ def handle_Linear (the_class ):
1307+ the_class .original_init = the_class .__init__
1308+
1309+ def replacement_init (self , * args ):
1310+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1311+ self .original_init (* args )
1312+ return
1313+ # assume it's a torch.Linear
1314+ linear = args [0 ]
1315+ bias = linear .bias is not None
1316+ self .original_init (linear .in_features , linear .out_features , bias )
1317+ self .from_torch (linear )
1318+
1319+ def from_torch (self , linear ):
1320+ """ copy weights from torch.Linear """
1321+ assert linear .weight .shape == (self .out_features , self .in_features )
1322+ faiss .copy_array_to_vector (linear .weight .data .numpy ().ravel (), self .weight )
1323+ if linear .bias is not None :
1324+ assert linear .bias .shape == (self .out_features ,)
1325+ faiss .copy_array_to_vector (linear .bias .data .numpy (), self .bias )
1326+
1327+ def from_array (self , array , bias = None ):
1328+ """ copy weights from numpy array """
1329+ assert array .shape == (self .out_features , self .in_features )
1330+ faiss .copy_array_to_vector (
1331+ np .ascontiguousarray (array ).ravel (), self .weight )
1332+ if bias is not None :
1333+ assert bias .shape == (self .out_features ,)
1334+ faiss .copy_array_to_vector (bias , self .bias )
1335+
1336+ the_class .from_array = from_array
1337+
1338+
1339+ the_class .__init__ = replacement_init
1340+ the_class .from_torch = from_torch
1341+
1342+ ######################################################
1343+ # Syntatic sugar for QINCo and QINCoStep
1344+ ######################################################
1345+
1346+ def handle_QINCoStep (the_class ):
1347+ the_class .original_init = the_class .__init__
1348+
1349+ def replacement_init (self , * args ):
1350+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1351+ self .original_init (* args )
1352+ return
1353+ step = args [0 ]
1354+ # assume it's a Torch QINCoStep
1355+ self .original_init (step .d , step .K , step .L , step .h )
1356+ self .from_torch (step )
1357+
1358+ def from_torch (self , step ):
1359+ """ copy weights from torch.QINCoStep """
1360+ assert (step .d , step .K , step .L , step .h ) == (self .d , self .K , self .L , self .h )
1361+ self .codebook .from_torch (step .codebook )
1362+ self .MLPconcat .from_torch (step .MLPconcat )
1363+
1364+ for l in range (step .L ):
1365+ src = step .residual_blocks [l ]
1366+ dest = self .get_residual_block (l )
1367+ dest .linear1 .from_torch (src [0 ])
1368+ dest .linear2 .from_torch (src [2 ])
1369+
1370+ the_class .__init__ = replacement_init
1371+ the_class .from_torch = from_torch
1372+
1373+
1374+ def handle_QINCo (the_class ):
1375+ the_class .original_init = the_class .__init__
1376+
1377+ def replacement_init (self , * args ):
1378+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1379+ self .original_init (* args )
1380+ return
1381+
1382+ # assume it's a Torch QINCo
1383+ qinco = args [0 ]
1384+ self .original_init (qinco .d , qinco .K , qinco .L , qinco .M , qinco .h )
1385+ self .from_torch (qinco )
1386+
1387+ def from_torch (self , qinco ):
1388+ """ copy weights from torch.QINCo """
1389+ assert (
1390+ (qinco .d , qinco .K , qinco .L , qinco .M , qinco .h ) ==
1391+ (self .d , self .K , self .L , self .M , self .h )
1392+ )
1393+ self .codebook0 .from_torch (qinco .codebook0 )
1394+ for m in range (qinco .M - 1 ):
1395+ self .get_step (m ).from_torch (qinco .steps [m ])
1396+
1397+ the_class .__init__ = replacement_init
1398+ the_class .from_torch = from_torch
0 commit comments