@@ -1247,3 +1247,133 @@ def replacement_insert(self, codes, inserted=None):
12471247 return inserted
12481248
12491249 replace_method (the_class , 'insert' , replacement_insert )
1250+
1251+ ######################################################
1252+ # Syntatic sugar for NeuralNet classes
1253+ ######################################################
1254+
1255+ def handle_Tensor2D (the_class ):
1256+ the_class .original_init = the_class .__init__
1257+
1258+ def replacement_init (self , * args ):
1259+ if len (args ) == 1 :
1260+ array , = args
1261+ n , d = array .shape
1262+ self .original_init (n , d )
1263+ faiss .copy_array_to_vector (
1264+ np .ascontiguousarray (array ).ravel (), self .v )
1265+ else :
1266+ self .original_init (* args )
1267+
1268+ def numpy (self ):
1269+ shape = np .zeros (2 , dtype = np .int64 )
1270+ faiss .memcpy (faiss .swig_ptr (shape ), self .shape , shape .nbytes )
1271+ return faiss .vector_to_array (self .v ).reshape (shape [0 ], shape [1 ])
1272+
1273+ the_class .__init__ = replacement_init
1274+ the_class .numpy = numpy
1275+
1276+
1277+ def handle_Embedding (the_class ):
1278+ the_class .original_init = the_class .__init__
1279+
1280+ def replacement_init (self , * args ):
1281+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1282+ self .original_init (* args )
1283+ return
1284+ # assume it's a torch.Embedding
1285+ emb = args [0 ]
1286+ self .original_init (emb .num_embeddings , emb .embedding_dim )
1287+ self .from_torch (emb )
1288+
1289+ def from_torch (self , emb ):
1290+ """ copy weights from torch.Embedding """
1291+ assert emb .weight .shape == (self .num_embeddings , self .embedding_dim )
1292+ faiss .copy_array_to_vector (
1293+ np .ascontiguousarray (emb .weight .data ).ravel (), self .weight )
1294+
1295+ the_class .from_torch = from_torch
1296+ the_class .__init__ = replacement_init
1297+
1298+
1299+ def handle_Linear (the_class ):
1300+ the_class .original_init = the_class .__init__
1301+
1302+ def replacement_init (self , * args ):
1303+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1304+ self .original_init (* args )
1305+ return
1306+ # assume it's a torch.Linear
1307+ linear = args [0 ]
1308+ bias = linear .bias is not None
1309+ self .original_init (linear .in_features , linear .out_features , bias )
1310+ self .from_torch (linear )
1311+
1312+ def from_torch (self , linear ):
1313+ """ copy weights from torch.Linear """
1314+ assert linear .weight .shape == (self .out_features , self .in_features )
1315+ faiss .copy_array_to_vector (linear .weight .data .numpy ().ravel (), self .weight )
1316+ if linear .bias is not None :
1317+ assert linear .bias .shape == (self .out_features ,)
1318+ faiss .copy_array_to_vector (linear .bias .data .numpy (), self .bias )
1319+
1320+ the_class .__init__ = replacement_init
1321+ the_class .from_torch = from_torch
1322+
1323+ ######################################################
1324+ # Syntatic sugar for QINCo and QINCoStep
1325+ ######################################################
1326+
1327+ def handle_QINCoStep (the_class ):
1328+ the_class .original_init = the_class .__init__
1329+
1330+ def replacement_init (self , * args ):
1331+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1332+ self .original_init (* args )
1333+ return
1334+ step = args [0 ]
1335+ # assume it's a Torch QINCoStep
1336+ self .original_init (step .d , step .K , step .L , step .h )
1337+ self .from_torch (step )
1338+
1339+ def from_torch (self , step ):
1340+ """ copy weights from torch.QINCoStep """
1341+ assert (step .d , step .K , step .L , step .h ) == (self .d , self .K , self .L , self .h )
1342+ self .codebook .from_torch (step .codebook )
1343+ self .MLPconcat .from_torch (step .MLPconcat )
1344+
1345+ for l in range (step .L ):
1346+ src = step .residual_blocks [l ]
1347+ dest = self .get_residual_block (l )
1348+ dest .linear1 .from_torch (src [0 ])
1349+ dest .linear2 .from_torch (src [2 ])
1350+
1351+ the_class .__init__ = replacement_init
1352+ the_class .from_torch = from_torch
1353+
1354+
1355+ def handle_QINCo (the_class ):
1356+ the_class .original_init = the_class .__init__
1357+
1358+ def replacement_init (self , * args ):
1359+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1360+ self .original_init (* args )
1361+ return
1362+
1363+ # assume it's a Torch QINCo
1364+ qinco = args [0 ]
1365+ self .original_init (qinco .d , qinco .K , qinco .L , qinco .M , qinco .h )
1366+ self .from_torch (qinco )
1367+
1368+ def from_torch (self , qinco ):
1369+ """ copy weights from torch.QINCo """
1370+ assert (
1371+ (qinco .d , qinco .K , qinco .L , qinco .M , qinco .h ) ==
1372+ (self .d , self .K , self .L , self .M , self .h )
1373+ )
1374+ self .codebook0 .from_torch (qinco .codebook0 )
1375+ for m in range (qinco .M - 1 ):
1376+ self .get_step (m ).from_torch (qinco .steps [m ])
1377+
1378+ the_class .__init__ = replacement_init
1379+ the_class .from_torch = from_torch
0 commit comments