-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_convert.py
More file actions
72 lines (57 loc) · 2.16 KB
/
model_convert.py
File metadata and controls
72 lines (57 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
import h5py
import argparse
import shutil
from pathlib import Path
from typing import TypeVar
T = TypeVar("T")
def get_typed_item(parent: h5py.Group, key: str, expected_type: type[T]) -> T:
item = parent[key]
assert isinstance(item, expected_type), (
f"Expected {expected_type.__name__}, got {type(item)}"
)
return item
def get_group(parent: h5py.Group, key: str) -> h5py.Group:
return get_typed_item(parent, key, h5py.Group)
def get_dataset(parent: h5py.Group, key: str) -> h5py.Dataset:
return get_typed_item(parent, key, h5py.Dataset)
def process_bias_dataset(group: h5py.Group) -> None:
arr: np.ndarray = np.zeros(150)
bias_dataset = get_dataset(group, "bias:0")
bias_dataset.read_direct(arr)
arr = arr.reshape(2, 75)
del group["bias:0"]
group.create_dataset("bias:0", data=arr, dtype="<f4")
def convert_file(file_path: Path) -> None:
with h5py.File(file_path, "r+") as f:
for layer_id in range(10, 15):
process_bias_dataset(
get_group(
get_group(
get_group(
get_group(f, "model_weights"), f"bidirectional_{layer_id}"
),
f"bidirectional_{layer_id}",
),
f"forward_cu_dnngru_{layer_id}",
)
)
process_bias_dataset(
get_group(
get_group(
get_group(
get_group(f, "model_weights"), f"bidirectional_{layer_id}"
),
f"bidirectional_{layer_id}",
),
f"backward_cu_dnngru_{layer_id}",
)
)
if __name__ == "__main__":
parser: argparse.ArgumentParser = argparse.ArgumentParser()
parser.add_argument("--file", type=Path, required=True)
args: argparse.Namespace = parser.parse_args()
original_file: Path = args.file
copy_file: Path = original_file.with_suffix(".modified" + original_file.suffix)
shutil.copy2(original_file, copy_file)
convert_file(copy_file)