-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathconvert_pytorch2tf.py
More file actions
168 lines (140 loc) · 9.81 KB
/
convert_pytorch2tf.py
File metadata and controls
168 lines (140 loc) · 9.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import numpy as np
import tensorflow.compat.v1 as tf
import torch
import argparse
from model import LiteFlowNet
tf.disable_eager_execution()
parser = argparse.ArgumentParser()
parser.add_argument('--input_model', default='network-default.pytorch')
parser.add_argument('--output_model', default='model')
args = parser.parse_args()
def ToTensor(sample, transfrm=lambda x: np.transpose(x, [0, 3, 1, 2])):
return torch.from_numpy(transfrm(sample))
weights_mapping = {
'moduleFeatures.moduleOne.0.weight': 'flownet/feature_extractor/sequential/conv2d/kernel',
'moduleFeatures.moduleOne.0.bias': 'flownet/feature_extractor/sequential/conv2d/bias',
'moduleFeatures.moduleTwo.0.weight': 'flownet/feature_extractor/sequential_1/conv2d_1/kernel',
'moduleFeatures.moduleTwo.0.bias': 'flownet/feature_extractor/sequential_1/conv2d_1/bias',
'moduleFeatures.moduleTwo.2.weight': 'flownet/feature_extractor/sequential_1/conv2d_2/kernel',
'moduleFeatures.moduleTwo.2.bias': 'flownet/feature_extractor/sequential_1/conv2d_2/bias',
'moduleFeatures.moduleTwo.4.weight': 'flownet/feature_extractor/sequential_1/conv2d_3/kernel',
'moduleFeatures.moduleTwo.4.bias': 'flownet/feature_extractor/sequential_1/conv2d_3/bias',
'moduleFeatures.moduleThr.0.weight': 'flownet/feature_extractor/sequential_2/conv2d_4/kernel',
'moduleFeatures.moduleThr.0.bias': 'flownet/feature_extractor/sequential_2/conv2d_4/bias',
'moduleFeatures.moduleThr.2.weight': 'flownet/feature_extractor/sequential_2/conv2d_5/kernel',
'moduleFeatures.moduleThr.2.bias': 'flownet/feature_extractor/sequential_2/conv2d_5/bias',
'moduleFeatures.moduleFou.0.weight': 'flownet/feature_extractor/sequential_3/conv2d_6/kernel',
'moduleFeatures.moduleFou.0.bias': 'flownet/feature_extractor/sequential_3/conv2d_6/bias',
'moduleFeatures.moduleFou.2.weight': 'flownet/feature_extractor/sequential_3/conv2d_7/kernel',
'moduleFeatures.moduleFou.2.bias': 'flownet/feature_extractor/sequential_3/conv2d_7/bias',
'moduleFeatures.moduleFiv.0.weight': 'flownet/feature_extractor/sequential_4/conv2d_8/kernel',
'moduleFeatures.moduleFiv.0.bias': 'flownet/feature_extractor/sequential_4/conv2d_8/bias',
'moduleFeatures.moduleSix.0.weight': 'flownet/feature_extractor/sequential_5/conv2d_9/kernel',
'moduleFeatures.moduleSix.0.bias': 'flownet/feature_extractor/sequential_5/conv2d_9/bias',
}
pytorch_model_path = args.input_model
pytorch_state_dict = torch.load(pytorch_model_path)
model = LiteFlowNet()
frame1 = tf.placeholder(tf.float32, shape=[None, None, None, 3])
frame2 = tf.placeholder(tf.float32, shape=[None, None, None, 3])
out = model(frame1, frame2)
c = 10
m_weights = {}
for j in [-1, -2, -3, -4, -5]:
i = abs(j)
lvls = [2, 3, 4, 5, 6][j]
if lvls < 6:
m_weights['moduleMatching.%i.moduleUpflow.weight' % (5 - i)] = 'flownet/matching_%i/moduleUpflow/filter_w' % i
m_weights['moduleMatching.%i.moduleUpcorr.weight' % (5 - i)] = 'flownet/matching_%i/moduleUpcorr/filter_w' % i
if lvls == 2:
m_weights['moduleMatching.%i.moduleFeat.0.weight' % (5 - i)] = 'flownet/matching_%i/module_feat/conv2d_%i/kernel' % (i, c)
m_weights['moduleMatching.%i.moduleFeat.0.bias' % (5 - i)] = 'flownet/matching_%i/module_feat/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleMatching.%i.moduleMain.0.weight' % (5 - i)] = 'flownet/matching_%i/module_main/conv2d_%i/kernel' % (i, c)
m_weights['moduleMatching.%i.moduleMain.0.bias' % (5 - i)] = 'flownet/matching_%i/module_main/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleMatching.%i.moduleMain.2.weight' % (5 - i)] = 'flownet/matching_%i/module_main/conv2d_%i/kernel' % (i, c)
m_weights['moduleMatching.%i.moduleMain.2.bias' % (5 - i)] = 'flownet/matching_%i/module_main/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleMatching.%i.moduleMain.4.weight' % (5 - i)] = 'flownet/matching_%i/module_main/conv2d_%i/kernel' % (i, c)
m_weights['moduleMatching.%i.moduleMain.4.bias' % (5 - i)] = 'flownet/matching_%i/module_main/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleMatching.%i.moduleMain.6.weight' % (5 - i)] = 'flownet/matching_%i/module_main/conv2d_%i/kernel' % (i, c)
m_weights['moduleMatching.%i.moduleMain.6.bias' % (5 - i)] = 'flownet/matching_%i/module_main/conv2d_%i/bias' % (i, c)
c += 1
if lvls == 2:
m_weights['moduleSubpixel.%i.moduleFeat.0.weight' % (5 - i)] = 'flownet/subpixel_%i/module_feat/conv2d_%i/kernel' % (i, c)
m_weights['moduleSubpixel.%i.moduleFeat.0.bias' % (5 - i)] = 'flownet/subpixel_%i/module_feat/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleSubpixel.%i.moduleMain.0.weight' % (5 - i)] = 'flownet/subpixel_%i/module_main/conv2d_%i/kernel' % (i, c)
m_weights['moduleSubpixel.%i.moduleMain.0.bias' % (5 - i)] = 'flownet/subpixel_%i/module_main/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleSubpixel.%i.moduleMain.2.weight' % (5 - i)] = 'flownet/subpixel_%i/module_main/conv2d_%i/kernel' % (i, c)
m_weights['moduleSubpixel.%i.moduleMain.2.bias' % (5 - i)] = 'flownet/subpixel_%i/module_main/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleSubpixel.%i.moduleMain.4.weight' % (5 - i)] = 'flownet/subpixel_%i/module_main/conv2d_%i/kernel' % (i, c)
m_weights['moduleSubpixel.%i.moduleMain.4.bias' % (5 - i)] = 'flownet/subpixel_%i/module_main/conv2d_%i/bias' % (
i, c)
c += 1
m_weights['moduleSubpixel.%i.moduleMain.6.weight' % (5 - i)] = 'flownet/subpixel_%i/module_main/conv2d_%i/kernel' % (i, c)
m_weights['moduleSubpixel.%i.moduleMain.6.bias' % (5 - i)] = 'flownet/subpixel_%i/module_main/conv2d_%i/bias' % (
i, c)
c += 1
if lvls < 5:
m_weights['moduleRegularization.%i.moduleFeat.0.weight' % (5 - i)] = 'flownet/regularization_%i/module_feat/conv2d_%i/kernel' % (i, c)
m_weights['moduleRegularization.%i.moduleFeat.0.bias' % (5 - i)] = 'flownet/regularization_%i/module_feat/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleRegularization.%i.moduleMain.0.weight' % (5 - i)] = 'flownet/regularization_%i/module_main/conv2d_%i/kernel' % (i, c)
m_weights['moduleRegularization.%i.moduleMain.0.bias' % (5 - i)] = 'flownet/regularization_%i/module_main/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleRegularization.%i.moduleMain.2.weight' % (5 - i)] = 'flownet/regularization_%i/module_main/conv2d_%i/kernel' % (i, c)
m_weights['moduleRegularization.%i.moduleMain.2.bias' % (5 - i)] = 'flownet/regularization_%i/module_main/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleRegularization.%i.moduleMain.4.weight' % (5 - i)] = 'flownet/regularization_%i/module_main/conv2d_%i/kernel' % (i, c)
m_weights['moduleRegularization.%i.moduleMain.4.bias' % (5 - i)] = 'flownet/regularization_%i/module_main/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleRegularization.%i.moduleMain.6.weight' % (5 - i)] = 'flownet/regularization_%i/module_main/conv2d_%i/kernel' % (i, c)
m_weights['moduleRegularization.%i.moduleMain.6.bias' % (5 - i)] = 'flownet/regularization_%i/module_main/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleRegularization.%i.moduleMain.8.weight' % (5 - i)] = 'flownet/regularization_%i/module_main/conv2d_%i/kernel' % (i, c)
m_weights['moduleRegularization.%i.moduleMain.8.bias' % (5 - i)] = 'flownet/regularization_%i/module_main/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleRegularization.%i.moduleMain.10.weight' % (5 - i)] = 'flownet/regularization_%i/module_main/conv2d_%i/kernel' % (i, c)
m_weights['moduleRegularization.%i.moduleMain.10.bias' % (5 - i)] = 'flownet/regularization_%i/module_main/conv2d_%i/bias' % (i, c)
c += 1
if lvls < 5:
m_weights['moduleRegularization.%i.moduleDist.0.weight' % (5 - i)] = 'flownet/regularization_%i/module_dist/conv2d_%i/kernel' % (i, c)
m_weights['moduleRegularization.%i.moduleDist.0.bias' % (5 - i)] = 'flownet/regularization_%i/module_dist/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleRegularization.%i.moduleDist.1.weight' % (5 - i)] = 'flownet/regularization_%i/module_dist/conv2d_%i/kernel' % (i, c)
m_weights['moduleRegularization.%i.moduleDist.1.bias' % (5 - i)] = 'flownet/regularization_%i/module_dist/conv2d_%i/bias' % (i, c)
c += 1
else:
m_weights['moduleRegularization.%i.moduleDist.0.weight' % (5 - i)] = 'flownet/regularization_%i/module_dist/conv2d_%i/kernel' % (i, c)
m_weights['moduleRegularization.%i.moduleDist.0.bias' % (5 - i)] = 'flownet/regularization_%i/module_dist/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleRegularization.%i.moduleScaleX.weight' % (5 - i)] = 'flownet/regularization_%i/moduleScaleX/conv2d_%i/kernel' % (i, c)
m_weights['moduleRegularization.%i.moduleScaleX.bias' % (5 - i)] = 'flownet/regularization_%i/moduleScaleX/conv2d_%i/bias' % (i, c)
c += 1
m_weights['moduleRegularization.%i.moduleScaleY.weight' % (5 - i)] = 'flownet/regularization_%i/moduleScaleY/conv2d_%i/kernel' % (i, c)
m_weights['moduleRegularization.%i.moduleScaleY.bias' % (5 - i)] = 'flownet/regularization_%i/moduleScaleY/conv2d_%i/bias' % (i, c)
c += 1
weights_mapping.update(m_weights)
for v in sorted(weights_mapping.values()):
print v
sess = tf.Session()
tfvarsg = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='flownet')
tfvars = {v.name[:-2]: v for v in tfvarsg}
for v in sorted(tfvarsg, key=lambda x: x.name):
print v.name
for state in pytorch_state_dict:
pytorch_data = pytorch_state_dict[state].cpu().detach().numpy()
if len(pytorch_data.shape) > 3:
shapes = pytorch_data.shape
pytorch_data = np.transpose(pytorch_data, [2, 3, 1, 0])
if state in weights_mapping:
print ("Assing: " + state + " ====> " + weights_mapping[state])
sess.run(tf.assign(tfvars[weights_mapping[state]], pytorch_data))
# save model
saver = tf.train.Saver(tfvars)
saver.save(sess, args.output_model)