Skip to content

Commit f48159a

Browse files
authored
Optimizer use init program (#5275)
* optimizer use init_program * create persistable variable * add create_persistable_var to block * optimizer use create_persistable_var * fix prefix * move create_global_persistable_var from Block to LayerHelper * Polish Optimizer initialization code. * Using the LayerHelper to create initialize operator and variables * add_accumulator should use an independent data type * default use param data type for accumulator
1 parent 90f4d5e commit f48159a

File tree

10 files changed

+213
-158
lines changed

10 files changed

+213
-158
lines changed

python/paddle/v2/framework/framework.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
__all__ = ['Block', 'Variable', 'Program', 'Operator']
88

99

10+
def unique_name(prefix):
11+
uid = core.unique_integer(prefix) # unique during whole process.
12+
return "_".join([prefix, str(uid)])
13+
14+
1015
class Variable(object):
1116
def __init__(self,
1217
block,

python/paddle/v2/framework/layer_helper.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
11
import copy
22
import itertools
33

4-
import paddle.v2.framework.core as core
5-
64
from paddle.v2.framework.framework import Variable, g_program, \
7-
g_init_program
5+
g_init_program, unique_name, Program
86
from paddle.v2.framework.initializer import ConstantInitializer, \
97
UniformInitializer
108

119

12-
def unique_name(prefix):
13-
uid = core.unique_integer(prefix) # unique during whole process.
14-
return "_".join([prefix, str(uid)])
15-
16-
1710
class LayerHelper(object):
1811
def __init__(self, layer_type, **kwargs):
1912
self.kwargs = kwargs
@@ -138,9 +131,19 @@ def create_tmp_variable(self, dtype):
138131
def create_variable(self, *args, **kwargs):
139132
return self.program.current_block().create_var(*args, **kwargs)
140133

141-
def create_global_variable(self, *args, **kwargs):
134+
def create_global_variable(self, persistable=False, *args, **kwargs):
142135
return self.program.global_block().create_var(
143-
*args, persistable=False, **kwargs)
136+
*args, persistable=persistable, **kwargs)
137+
138+
def set_variable_initializer(self, var, initializer):
139+
assert isinstance(var, Variable)
140+
self.init_program.global_block().create_var(
141+
name=var.name,
142+
type=var.type,
143+
dtype=var.data_type,
144+
shape=var.shape,
145+
persistable=True,
146+
initializer=initializer)
144147

145148
def append_bias_op(self, input_var, num_flatten_dims=None):
146149
"""

0 commit comments

Comments
 (0)