Skip to content

Commit 54896a2

Browse files
authored
Merge pull request #17 from LielinJiang/fix-save-error
Fix save error
2 parents b306aa7 + 0970de7 commit 54896a2

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

ppgan/models/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test(self):
9393
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
9494
It also calls <compute_visuals> to produce additional visualization results
9595
"""
96-
with paddle.imperative.no_grad():
96+
with paddle.no_grad():
9797
self.forward()
9898
self.compute_visuals()
9999

ppgan/models/pix2pix_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def forward(self):
8686
self.fake_B = self.netG(self.real_A) # G(A)
8787

8888
def forward_test(self, input):
89-
input = paddle.imperative.to_variable(input)
89+
input = paddle.to_tensor(input)
9090
return self.netG(input)
9191

9292
def backward_D(self):

ppgan/utils/filesystem.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
import pickle
44
import paddle
55

6+
67
def makedirs(dir):
78
if not os.path.exists(dir):
89
os.makedirs(dir)
910

10-
def save(state_dicts, file_name):
1111

12+
def save(state_dicts, file_name):
1213
def convert(state_dict):
1314
model_dict = {}
14-
15+
1516
for k, v in state_dict.items():
16-
if isinstance(v, (paddle.framework.Variable, paddle.imperative.core.VarBase)):
17+
if isinstance(
18+
v, (paddle.framework.Variable, paddle.fluid.core.VarBase)):
1719
model_dict[k] = v.numpy()
1820
else:
1921
model_dict[k] = v
@@ -22,14 +24,15 @@ def convert(state_dict):
2224

2325
final_dict = {}
2426
for k, v in state_dicts.items():
25-
if isinstance(v, (paddle.framework.Variable, paddle.imperative.core.VarBase)):
27+
if isinstance(v,
28+
(paddle.framework.Variable, paddle.fluid.core.VarBase)):
2629
final_dict = convert(state_dicts)
2730
break
2831
elif isinstance(v, dict):
2932
final_dict[k] = convert(v)
3033
else:
3134
final_dict[k] = v
32-
35+
3336
with open(file_name, 'wb') as f:
3437
pickle.dump(final_dict, f, protocol=2)
3538

@@ -39,7 +42,3 @@ def load(file_name):
3942
state_dicts = pickle.load(f) if six.PY2 else pickle.load(
4043
f, encoding='latin1')
4144
return state_dicts
42-
43-
44-
45-

0 commit comments

Comments
 (0)