-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathupdate_saved_model_class_names.py
More file actions
76 lines (55 loc) · 2.21 KB
/
update_saved_model_class_names.py
File metadata and controls
76 lines (55 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from zipfile import ZipFile
from glob import glob
import re
import argparse
'''
This script will update the class names in a saved model file
from older versions of nama to match the current class names.
'''
conversions = {
'embedding_similarity\nEmbeddingSimilarityModel': 'embedding_similarity.similarity_model\nSimilarityModel',
'embedding_similarity\nExponentWeights': 'embedding_similarity.similarity_model\nExponentWeights',
'embedding_similarity\nTransformerProjector': 'embedding_similarity.embedding_model\nEmbeddingModel',
'embedding_similarity\nExpCosSimilarity': 'embedding_similarity.scoring_model\nSimilarityScore',
}
def main(args):
model_file = args.model_file
if args.replace:
new_model_file = args.model_file
else:
new_model_file = model_file.parent/f'{model_file.stem}_converted.bin'
# Unzip model_file to a temporary directory
temp_dir = TemporaryDirectory()
with ZipFile(model_file, 'r') as model_zip:
model_zip.extractall(temp_dir.name)
# find data.pkl file
data_file = glob(temp_dir.name + '/**/data.pkl', recursive=True)[0]
with open(data_file, 'rb') as f:
data = f.read()
# Convert class names
for old, new in conversions.items():
old = old.encode()
new = new.encode()
count = data.count(old)
if count:
print(f'Replacing {count} instance of {repr(old)} with {repr(new)}')
data = data.replace(old,new)
# replace data.pkl with updated data
with open(data_file, 'wb') as f:
f.write(data)
# Re-zip model file and save as new_model_file
print(f'Saving converted model as {new_model_file}')
with ZipFile(new_model_file, 'w') as model_zip:
for f in glob(temp_dir.name + '/**/*', recursive=True):
model_zip.write(f, arcname=os.path.relpath(f, temp_dir.name))
# Clean up temporary directory
temp_dir.cleanup()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('model_file', type=Path)
parser.add_argument('--replace', action='store_true')
args = parser.parse_args()
main(args)