@@ -1247,3 +1247,152 @@ def replacement_insert(self, codes, inserted=None):
12471247 return inserted
12481248
12491249 replace_method (the_class , 'insert' , replacement_insert )
1250+
1251+ ######################################################
1252+ # Syntatic sugar for NeuralNet classes
1253+ ######################################################
1254+
1255+
1256+ def handle_Tensor2D (the_class ):
1257+ the_class .original_init = the_class .__init__
1258+
1259+ def replacement_init (self , * args ):
1260+ if len (args ) == 1 :
1261+ array , = args
1262+ n , d = array .shape
1263+ self .original_init (n , d )
1264+ faiss .copy_array_to_vector (
1265+ np .ascontiguousarray (array ).ravel (), self .v )
1266+ else :
1267+ self .original_init (* args )
1268+
1269+ def numpy (self ):
1270+ shape = np .zeros (2 , dtype = np .int64 )
1271+ faiss .memcpy (faiss .swig_ptr (shape ), self .shape , shape .nbytes )
1272+ return faiss .vector_to_array (self .v ).reshape (shape [0 ], shape [1 ])
1273+
1274+ the_class .__init__ = replacement_init
1275+ the_class .numpy = numpy
1276+
1277+
1278+ def handle_Embedding (the_class ):
1279+ the_class .original_init = the_class .__init__
1280+
1281+ def replacement_init (self , * args ):
1282+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1283+ self .original_init (* args )
1284+ return
1285+ # assume it's a torch.Embedding
1286+ emb = args [0 ]
1287+ self .original_init (emb .num_embeddings , emb .embedding_dim )
1288+ self .from_torch (emb )
1289+
1290+ def from_torch (self , emb ):
1291+ """ copy weights from torch.Embedding """
1292+ assert emb .weight .shape == (self .num_embeddings , self .embedding_dim )
1293+ faiss .copy_array_to_vector (
1294+ np .ascontiguousarray (emb .weight .data ).ravel (), self .weight )
1295+
1296+ def from_array (self , array ):
1297+ """ copy weights from numpy array """
1298+ assert array .shape == (self .num_embeddings , self .embedding_dim )
1299+ faiss .copy_array_to_vector (
1300+ np .ascontiguousarray (array ).ravel (), self .weight )
1301+
1302+ the_class .from_array = from_array
1303+ the_class .from_torch = from_torch
1304+ the_class .__init__ = replacement_init
1305+
1306+
1307+ def handle_Linear (the_class ):
1308+ the_class .original_init = the_class .__init__
1309+
1310+ def replacement_init (self , * args ):
1311+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1312+ self .original_init (* args )
1313+ return
1314+ # assume it's a torch.Linear
1315+ linear = args [0 ]
1316+ bias = linear .bias is not None
1317+ self .original_init (linear .in_features , linear .out_features , bias )
1318+ self .from_torch (linear )
1319+
1320+ def from_torch (self , linear ):
1321+ """ copy weights from torch.Linear """
1322+ assert linear .weight .shape == (self .out_features , self .in_features )
1323+ faiss .copy_array_to_vector (
1324+ linear .weight .data .numpy ().ravel (), self .weight )
1325+ if linear .bias is not None :
1326+ assert linear .bias .shape == (self .out_features ,)
1327+ faiss .copy_array_to_vector (linear .bias .data .numpy (), self .bias )
1328+
1329+ def from_array (self , array , bias = None ):
1330+ """ copy weights from numpy array """
1331+ assert array .shape == (self .out_features , self .in_features )
1332+ faiss .copy_array_to_vector (
1333+ np .ascontiguousarray (array ).ravel (), self .weight )
1334+ if bias is not None :
1335+ assert bias .shape == (self .out_features ,)
1336+ faiss .copy_array_to_vector (bias , self .bias )
1337+
1338+ the_class .__init__ = replacement_init
1339+ the_class .from_array = from_array
1340+ the_class .from_torch = from_torch
1341+
1342+ ######################################################
1343+ # Syntatic sugar for QINCo and QINCoStep
1344+ ######################################################
1345+
1346+ def handle_QINCoStep (the_class ):
1347+ the_class .original_init = the_class .__init__
1348+
1349+ def replacement_init (self , * args ):
1350+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1351+ self .original_init (* args )
1352+ return
1353+ step = args [0 ]
1354+ # assume it's a Torch QINCoStep
1355+ self .original_init (step .d , step .K , step .L , step .h )
1356+ self .from_torch (step )
1357+
1358+ def from_torch (self , step ):
1359+ """ copy weights from torch.QINCoStep """
1360+ assert (step .d , step .K , step .L , step .h ) == (self .d , self .K , self .L , self .h )
1361+ self .codebook .from_torch (step .codebook )
1362+ self .MLPconcat .from_torch (step .MLPconcat )
1363+
1364+ for l in range (step .L ):
1365+ src = step .residual_blocks [l ]
1366+ dest = self .get_residual_block (l )
1367+ dest .linear1 .from_torch (src [0 ])
1368+ dest .linear2 .from_torch (src [2 ])
1369+
1370+ the_class .__init__ = replacement_init
1371+ the_class .from_torch = from_torch
1372+
1373+
1374+ def handle_QINCo (the_class ):
1375+ the_class .original_init = the_class .__init__
1376+
1377+ def replacement_init (self , * args ):
1378+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1379+ self .original_init (* args )
1380+ return
1381+
1382+ # assume it's a Torch QINCo
1383+ qinco = args [0 ]
1384+ self .original_init (qinco .d , qinco .K , qinco .L , qinco .M , qinco .h )
1385+ self .from_torch (qinco )
1386+
1387+ def from_torch (self , qinco ):
1388+ """ copy weights from torch.QINCo """
1389+ assert (
1390+ (qinco .d , qinco .K , qinco .L , qinco .M , qinco .h ) ==
1391+ (self .d , self .K , self .L , self .M , self .h )
1392+ )
1393+ self .codebook0 .from_torch (qinco .codebook0 )
1394+ for m in range (qinco .M - 1 ):
1395+ self .get_step (m ).from_torch (qinco .steps [m ])
1396+
1397+ the_class .__init__ = replacement_init
1398+ the_class .from_torch = from_torch
0 commit comments