Skip to content

Commit 8f7d0b1

Browse files
authored
add param_attr for nets (#6509)
1 parent c8d4efb commit 8f7d0b1

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

python/paddle/v2/fluid/layers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,8 +1732,10 @@ def conv2d_transpose(input,
17321732

17331733
h_in = input.shape[2]
17341734
w_in = input.shape[3]
1735-
filter_size_h = output_size[0] - (h_in - 1) * stride[0] + 2 * padding[0]
1736-
filter_size_w = output_size[1] - (w_in - 1) * stride[1] + 2 * padding[1]
1735+
filter_size_h = output_size[0] - \
1736+
(h_in - 1) * stride[0] + 2 * padding[0]
1737+
filter_size_w = output_size[1] - \
1738+
(w_in - 1) * stride[1] + 2 * padding[1]
17371739
filter_size = [filter_size_h, filter_size_w]
17381740
elif isinstance(filter_size, int):
17391741
filter_size = [filter_size, filter_size]

python/paddle/v2/fluid/nets.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ def simple_img_conv_pool(input,
99
pool_size,
1010
pool_stride,
1111
act,
12+
param_attr=None,
1213
pool_type='max',
1314
main_program=None,
1415
startup_program=None):
1516
conv_out = layers.conv2d(
1617
input=input,
1718
num_filters=num_filters,
1819
filter_size=filter_size,
20+
param_attr=param_attr,
1921
act=act,
2022
main_program=main_program,
2123
startup_program=startup_program)
@@ -36,6 +38,7 @@ def img_conv_group(input,
3638
conv_padding=1,
3739
conv_filter_size=3,
3840
conv_act=None,
41+
param_attr=None,
3942
conv_with_batchnorm=False,
4043
conv_batchnorm_drop_rate=None,
4144
pool_stride=1,
@@ -57,6 +60,7 @@ def __extend_list__(obj):
5760

5861
conv_padding = __extend_list__(conv_padding)
5962
conv_filter_size = __extend_list__(conv_filter_size)
63+
param_attr = __extend_list__(param_attr)
6064
conv_with_batchnorm = __extend_list__(conv_with_batchnorm)
6165
conv_batchnorm_drop_rate = __extend_list__(conv_batchnorm_drop_rate)
6266

@@ -70,6 +74,7 @@ def __extend_list__(obj):
7074
num_filters=conv_num_filter[i],
7175
filter_size=conv_filter_size[i],
7276
padding=conv_padding[i],
77+
param_attr=param_attr[i],
7378
act=local_conv_act,
7479
main_program=main_program,
7580
startup_program=startup_program)
@@ -101,6 +106,7 @@ def __extend_list__(obj):
101106
def sequence_conv_pool(input,
102107
num_filters,
103108
filter_size,
109+
param_attr=None,
104110
act="sigmoid",
105111
pool_type="max",
106112
main_program=None,
@@ -109,6 +115,7 @@ def sequence_conv_pool(input,
109115
input=input,
110116
num_filters=num_filters,
111117
filter_size=filter_size,
118+
param_attr=param_attr,
112119
act=act,
113120
main_program=main_program,
114121
startup_program=startup_program)

0 commit comments

Comments
 (0)