From 866ca2eb4f4b9fdc7e75252ef1dbfeaceea29640 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Thu, 1 Aug 2019 13:18:56 +0000 Subject: [PATCH 1/9] fix con2d transpose bias by create and init it in build_onee --- paddle/fluid/API.spec | 2 +- python/paddle/fluid/dygraph/nn.py | 39 +++++++++++++++++++++++++++++-- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 8880da2e1ae6b4..74de1fb9869b0a 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -1043,4 +1043,4 @@ paddle.reader.Fake ('paddle.reader.decorator.Fake', ('document', '0d8f4847b99bed paddle.reader.Fake.__init__ (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.reader.creator.np_array (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '28d457fbc9a71efa4ac91a3be179cada')) paddle.reader.creator.text_file (ArgSpec(args=['path'], varargs=None, keywords=None, defaults=None), ('document', 'f45fcb7add066c8e042c6774fc7c3db2')) -paddle.reader.creator.recordio (ArgSpec(args=['paths', 'buf_size'], varargs=None, keywords=None, defaults=(100,)), ('document', 'b4a94ee0e2cefb495619275c2f8c61d2')) +paddle.reader.creator.recordio (ArgSpec(args=['paths', 'buf_size'], varargs=None, keywords=None, defaults=(100,)), ('document', 'b4a94ee0e2cefb495619275c2f8c61d2')) \ No newline at end of file diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index f933e22ddfa552..9be8150d356d0f 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -2162,6 +2162,12 @@ def _build_once(self, input): self._img_filter = self.create_parameter( dtype=input.dtype, shape=filter_shape, attr=self._param_attr) + self._bias_param = self.create_parameter( + attr=self._bias_attr, + shape=[self._num_filters], + dtype=self._dtype, + is_bias=True) + def forward(self, input): pre_bias = self._helper.create_variable_for_type_inference( dtype=input.dtype) @@ -2179,7 +2185,18 @@ def forward(self, input): 'use_cudnn': self._use_cudnn }) - pre_act = self._helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) + if self._bias_param is not None: + pre_act = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], + 'Y': [self._bias_param]}, + outputs={'Out': [pre_act]}, + attrs={'axis': 1}) + else: + pre_act = pre_bias + out = self._helper.append_activation(pre_act) return out @@ -2237,6 +2254,12 @@ def _build_once(self, input): self._filter_param = self.create_parameter( attr=self._param_attr, shape=filter_shape, dtype=self._dtype) + self._bias_param = self.create_parameter( + attr=self._bias_attr, + shape=[self._num_filters], + dtype=self._dtype, + is_bias=True) + def forward(self, input): pre_bias = self._helper.create_variable_for_type_inference(self._dtype) self._helper.append_op( @@ -2251,7 +2274,19 @@ def forward(self, input): 'contextStart': -int(self._filter_size // 2), 'contextLength': self._filter_size }) - pre_act = self._helper.append_bias_op(pre_bias) + + if self._bias_param is not None: + pre_act = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], + 'Y': [self._bias_param]}, + outputs={'Out': [pre_act]}, + attrs={'axis': 1}) + else: + pre_act = pre_bias + return self._helper.append_activation(pre_act) From 8897328553e0acb16055d4c68536f4b76c83f859 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Thu, 1 Aug 2019 13:26:17 +0000 Subject: [PATCH 2/9] fix API spec --- paddle/fluid/API.spec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 74de1fb9869b0a..8880da2e1ae6b4 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -1043,4 +1043,4 @@ paddle.reader.Fake ('paddle.reader.decorator.Fake', ('document', '0d8f4847b99bed paddle.reader.Fake.__init__ (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.reader.creator.np_array (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '28d457fbc9a71efa4ac91a3be179cada')) paddle.reader.creator.text_file (ArgSpec(args=['path'], varargs=None, keywords=None, defaults=None), ('document', 'f45fcb7add066c8e042c6774fc7c3db2')) -paddle.reader.creator.recordio (ArgSpec(args=['paths', 'buf_size'], varargs=None, keywords=None, defaults=(100,)), ('document', 'b4a94ee0e2cefb495619275c2f8c61d2')) \ No newline at end of file +paddle.reader.creator.recordio (ArgSpec(args=['paths', 'buf_size'], varargs=None, keywords=None, defaults=(100,)), ('document', 'b4a94ee0e2cefb495619275c2f8c61d2')) From de85dcb7bba846b7c74e029abade3179c9ccf4de Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Thu, 1 Aug 2019 13:39:58 +0000 Subject: [PATCH 3/9] test=develop, invoke ci --- python/paddle/fluid/dygraph/nn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 9be8150d356d0f..26028b6aab63ba 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -2649,6 +2649,7 @@ def forward(self, nodes_vector, edge_set): out = self.create_variable( name=self._name, dtype=self._dtype, persistable=False) else: + out = self._helper.create_variable_for_type_inference( dtype=self._dtype) From 534f801c177b3f23420117e10f2498fbbedb116c Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Thu, 15 Aug 2019 14:03:42 +0000 Subject: [PATCH 4/9] fix bias_attr and act has no effect error on layer norm, conv2dTranpose, billinearTensorProduct, sequece_conv. fix original_mode not used error on GRUunit. fix sample_weight not set error on NCE. Add ut for all thoese layer --- paddle/fluid/API.spec | 2 +- python/paddle/fluid/dygraph/nn.py | 25 +++-- .../fluid/tests/unittests/test_layers.py | 98 ++++++++++++++----- 3 files changed, 93 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index aae345a5369692..5027c97d60ca2f 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -684,7 +684,7 @@ paddle.fluid.dygraph.LayerNorm.state_dict (ArgSpec(args=['self', 'destination', paddle.fluid.dygraph.LayerNorm.sublayers (ArgSpec(args=['self', 'include_sublayers'], varargs=None, keywords=None, defaults=(True,)), ('document', '00a881005ecbc96578faf94513bf0d62')) paddle.fluid.dygraph.LayerNorm.train (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.dygraph.NCE ('paddle.fluid.dygraph.nn.NCE', ('document', '47eb439a5568468fad70235f1e61ead9')) -paddle.fluid.dygraph.NCE.__init__ (ArgSpec(args=['self', 'name_scope', 'num_total_classes', 'param_attr', 'bias_attr', 'num_neg_samples', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, 'uniform', None, 0, False)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.fluid.dygraph.NCE.__init__ (ArgSpec(args=['self', 'name_scope', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, 'uniform', None, 0, False)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.dygraph.NCE.add_parameter (ArgSpec(args=['self', 'name', 'parameter'], varargs=None, keywords=None, defaults=None), ('document', 'f35ab374c7d5165c3daf3bd64a5a2ec1')) paddle.fluid.dygraph.NCE.add_sublayer (ArgSpec(args=['self', 'name', 'sublayer'], varargs=None, keywords=None, defaults=None), ('document', '839ff3c0534677ba6ad8735c3fd4e995')) paddle.fluid.dygraph.NCE.backward (ArgSpec(args=['self'], varargs='inputs', keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 26028b6aab63ba..7f0f80f823db35 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -23,6 +23,7 @@ from ..param_attr import ParamAttr from ..initializer import Normal, Constant, NumpyArrayInitializer import numpy as np +import logging __all__ = [ 'Conv2D', 'Conv3D', 'Pool2D', 'FC', 'BatchNorm', 'Embedding', 'GRUUnit', @@ -1374,6 +1375,10 @@ def _build_once(self, input): shape=param_shape, dtype=self._dtype, default_initializer=Constant(1.0)) + else: + if self._param_attr: + logging.warn("param_attr are only avaliable with scale is True") + if self._shift: assert self._bias_attr is not False self._bias_w = self.create_parameter( @@ -1381,6 +1386,9 @@ def _build_once(self, input): shape=param_shape, dtype=self._dtype, is_bias=True) + else: + if self._bias_attr: + logging.warn("bias_attr are only avaliable with shift is True") def forward(self, input): inputs = dict() @@ -1410,7 +1418,7 @@ def forward(self, input): "begin_norm_axis": self._begin_norm_axis }) - return self._helper.append_activation(layer_norm_out) + return self._helper.append_activation(layer_norm_out, act=self._act) class GRUUnit(layers.Layer): @@ -1648,6 +1656,7 @@ class NCE(layers.Layer): def __init__(self, name_scope, num_total_classes, + sample_weight=None, param_attr=None, bias_attr=None, num_neg_samples=None, @@ -1661,7 +1670,7 @@ def __init__(self, self._num_total_classes = num_total_classes self._inputs = dict() - + self._inputs['SampleWeight'] = sample_weight if sample_weight is not None else [] if sampler == "uniform": sampler = 0 elif sampler == "log_uniform": @@ -1941,15 +1950,15 @@ def _build_once(self, x, y): if self._bias_attr: bias_size = [1, self._size] - bias = self.create_parameter( + self._bias_param = self.create_parameter( attr=self._bias_attr, shape=bias_size, dtype=self._dtype, is_bias=True) - self._inputs["Bias"] = bias def forward(self, x, y): self._inputs = {"X": x, "Y": y, "Weight": self._w} + self._inputs["Bias"] = self._bias_param if self._name is not None: out = self._helper.create_variable( name=".".join([self.full_name(), self._name]), @@ -1964,7 +1973,7 @@ def forward(self, x, y): outputs={"Out": out}) # add activation - return self._helper.append_activation(out) + return self._helper.append_activation(out, act=self._act) class Conv2DTranspose(layers.Layer): @@ -2099,6 +2108,7 @@ def __init__(self, assert param_attr is not False, "param_attr should not be False in conv2d_transpose." self._param_attr = param_attr self._bias_attr = bias_attr + self._act = act self._groups = groups self._num_filters = num_filters self._use_cudnn = use_cudnn @@ -2197,7 +2207,7 @@ def forward(self, input): else: pre_act = pre_bias - out = self._helper.append_activation(pre_act) + out = self._helper.append_activation(pre_act, act=self._act) return out @@ -2247,6 +2257,7 @@ def __init__(self, self._padding = padding self._bias_attr = bias_attr self._param_attr = param_attr + self._act = act def _build_once(self, input): self._dtype = self._helper.input_dtype(input) @@ -2287,7 +2298,7 @@ def forward(self, input): else: pre_act = pre_bias - return self._helper.append_activation(pre_act) + return self._helper.append_activation(pre_act, act=self._act) class RowConv(layers.Layer): diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 9a5baa9c42401c..46138cb5f7e020 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -124,7 +124,10 @@ def test_layer_norm(self): shape=[3, 32, 32], dtype='float32', append_batch_size=False) - ret = layers.layer_norm(t) + ret = layers.layer_norm( + t, + bias_attr=fluid.initializer.ConstantInitializer(value=1), + act='sigmoid') static_ret = self.get_static_graph_result( feed={'data': inp}, fetch_list=[ret])[0] with self.static_graph(): @@ -133,16 +136,32 @@ def test_layer_norm(self): shape=[3, 32, 32], dtype='float32', append_batch_size=False) - lm = nn.LayerNorm('layer_norm') + lm = nn.LayerNorm( + 'layer_norm', + bias_attr=fluid.initializer.ConstantInitializer(value=1), + act='sigmoid') ret = lm(t) static_ret2 = self.get_static_graph_result( feed={'data': inp}, fetch_list=[ret])[0] with self.dynamic_graph(): - lm = nn.LayerNorm('layer_norm') + lm = nn.LayerNorm( + 'layer_norm', + bias_attr=fluid.initializer.ConstantInitializer(value=1), + act='sigmoid') dy_ret = lm(base.to_variable(inp)) + with self.dynamic_graph(): + lm = nn.LayerNorm( + 'layer_norm', + shift=False, + scale=False, + param_attr=fluid.initializer.ConstantInitializer(value=1), + bias_attr=fluid.initializer.ConstantInitializer(value=1), + act='sigmoid') + self.assertIsNone(lm._scale_w) + self.assertIsNone(lm._bias_w) - self.assertTrue(np.allclose(static_ret, static_ret2)) - self.assertTrue(np.allclose(dy_ret.numpy(), static_ret2)) + self.assertTrue(np.array_equal(static_ret, static_ret2)) + self.assertTrue(np.array_equal(dy_ret.numpy(), static_ret2)) def test_relu(self): with self.static_graph(): @@ -313,7 +332,7 @@ def test_sequence_conv(self): dtype='float32', lod_level=1, append_batch_size=False) - out = layers.sequence_conv(seq, 2) + out = layers.sequence_conv(seq, 2, act='sigmoid') static_rlt = self.get_static_graph_result( feed={ "seq_in": fluid.create_lod_tensor( @@ -331,7 +350,7 @@ def test_sequence_conv(self): dtype='float32', lod_level=1, append_batch_size=False) - seq_conv = nn.SequenceConv('seq_conv', num_filters=2) + seq_conv = nn.SequenceConv('seq_conv', num_filters=2, act='sigmoid') out = seq_conv(seq) static_rlt2 = self.get_static_graph_result( feed={ @@ -343,29 +362,41 @@ def test_sequence_conv(self): fetch_list=[out], with_lod=True)[0] self.assertTrue( - np.allclose(np.array(static_rlt), np.array(static_rlt2))) + np.array_equal(np.array(static_rlt), np.array(static_rlt2))) def test_conv2d_transpose(self): inp_np = np.arange(0, 24).reshape([2, 3, 2, 2]).astype('float32') with self.static_graph(): img = layers.data(name='pixel', shape=[3, 2, 2], dtype='float32') out = layers.conv2d_transpose( - input=img, num_filters=10, output_size=28) + input=img, + num_filters=10, + output_size=28, + act='sigmoid', + bias_attr=fluid.initializer.ConstantInitializer(value=1)) static_rlt = self.get_static_graph_result( feed={'pixel': inp_np}, fetch_list=[out])[0] with self.static_graph(): img = layers.data(name='pixel', shape=[3, 2, 2], dtype='float32') conv2d_transpose = nn.Conv2DTranspose( - 'conv2d_transpose', num_filters=10, output_size=28) + 'conv2d_transpose', + num_filters=10, + output_size=28, + act='sigmoid', + bias_attr=fluid.initializer.ConstantInitializer(value=1)) out = conv2d_transpose(img) static_rlt2 = self.get_static_graph_result( feed={'pixel': inp_np}, fetch_list=[out])[0] with self.dynamic_graph(): conv2d_transpose = nn.Conv2DTranspose( - 'conv2d_transpose', num_filters=10, output_size=28) + 'conv2d_transpose', + num_filters=10, + output_size=28, + act='sigmoid', + bias_attr=fluid.initializer.ConstantInitializer(value=1)) dy_rlt = conv2d_transpose(base.to_variable(inp_np)) - self.assertTrue(np.allclose(static_rlt2, static_rlt)) - self.assertTrue(np.allclose(dy_rlt.numpy(), static_rlt)) + self.assertTrue(np.array_equal(static_rlt2, static_rlt)) + self.assertTrue(np.array_equal(dy_rlt.numpy(), static_rlt)) def test_bilinear_tensor_product(self): inp_np_x = np.array([[1, 2, 3]]).astype('float32') @@ -382,7 +413,12 @@ def test_bilinear_tensor_product(self): shape=[1, 3], dtype="float32", append_batch_size=False) - out = layers.bilinear_tensor_product(data_x, data_y, 6) + out = layers.bilinear_tensor_product( + data_x, + data_y, + 6, + bias_attr=fluid.initializer.ConstantInitializer(value=1), + act='sigmoid') static_rlt = self.get_static_graph_result( feed={'x': inp_np_x, @@ -398,17 +434,25 @@ def test_bilinear_tensor_product(self): shape=[1, 3], dtype="float32", append_batch_size=False) - btp = nn.BilinearTensorProduct('btp', 6) + btp = nn.BilinearTensorProduct( + 'btp', + 6, + bias_attr=fluid.initializer.ConstantInitializer(value=1), + act='sigmoid') out = btp(data_x, data_y) static_rlt2 = self.get_static_graph_result( feed={'x': inp_np_x, 'y': inp_np_y}, fetch_list=[out])[0] with self.dynamic_graph(): - btp = nn.BilinearTensorProduct('btp', 6) + btp = nn.BilinearTensorProduct( + 'btp', + 6, + bias_attr=fluid.initializer.ConstantInitializer(value=1), + act='sigmoid') dy_rlt = btp(base.to_variable(inp_np_x), base.to_variable(inp_np_y)) - self.assertTrue(np.allclose(static_rlt2, static_rlt)) - self.assertTrue(np.allclose(dy_rlt.numpy(), static_rlt)) + self.assertTrue(np.array_equal(static_rlt2, static_rlt)) + self.assertTrue(np.array_equal(dy_rlt.numpy(), static_rlt)) def test_prelu(self): inp_np = np.ones([5, 200, 100, 100]).astype('float32') @@ -497,7 +541,8 @@ def test_nce(self): words.append( layers.data( name='word_{0}'.format(i), shape=[1], dtype='int64')) - + sample_weights = layers.fill_constant( + shape=[5, 1], dtype='float32', value=1) embs = [] for i in range(window_size): if i == label_word: @@ -519,7 +564,8 @@ def test_nce(self): custom_dist=nid_freq_arr.tolist(), seed=seed, param_attr='nce.w', - bias_attr='nce.b') + bias_attr='nce.b', + sample_weight=sample_weights) feed_dict = dict() for i in range(window_size): feed_dict['word_{0}'.format(i)] = inp_word[i] @@ -531,7 +577,8 @@ def test_nce(self): words.append( layers.data( name='word_{0}'.format(i), shape=[1], dtype='int64')) - + sample_weights = layers.fill_constant( + shape=[5, 1], dtype='float32', value=1) emb = nn.Embedding( 'embedding', size=[dict_size, 32], @@ -554,7 +601,8 @@ def test_nce(self): custom_dist=nid_freq_arr.tolist(), seed=seed, param_attr='nce.w', - bias_attr='nce.b') + bias_attr='nce.b', + sample_weight=sample_weights) nce_loss2 = nce(embs2, words[label_word]) feed_dict = dict() @@ -568,7 +616,8 @@ def test_nce(self): words = [] for i in range(window_size): words.append(base.to_variable(inp_word[i])) - + sample_weights = layers.fill_constant( + shape=[5, 1], dtype='float32', value=1) emb = nn.Embedding( 'embedding', size=[dict_size, 32], @@ -591,7 +640,8 @@ def test_nce(self): custom_dist=nid_freq_arr.tolist(), seed=seed, param_attr='nce.w', - bias_attr='nce.b') + bias_attr='nce.b', + sample_weight=sample_weights) nce_loss3 = nce(embs3, words[label_word]) From d94d4a96fcb1ed219f7a0da3e97d34905ec2c58a Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 20 Aug 2019 11:13:12 +0000 Subject: [PATCH 5/9] test=develop, change success standard for conv2dTranspose --- python/paddle/fluid/tests/unittests/test_layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 46138cb5f7e020..8c129598085063 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -395,8 +395,8 @@ def test_conv2d_transpose(self): act='sigmoid', bias_attr=fluid.initializer.ConstantInitializer(value=1)) dy_rlt = conv2d_transpose(base.to_variable(inp_np)) - self.assertTrue(np.array_equal(static_rlt2, static_rlt)) - self.assertTrue(np.array_equal(dy_rlt.numpy(), static_rlt)) + self.assertTrue(np.allclose(static_rlt2, static_rlt)) + self.assertTrue(np.allclose(dy_rlt.numpy(), static_rlt2)) def test_bilinear_tensor_product(self): inp_np_x = np.array([[1, 2, 3]]).astype('float32') From b6d4251644cfd65cd362bd3821387fdef1ef35a3 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Thu, 22 Aug 2019 12:34:59 +0000 Subject: [PATCH 6/9] test=develop, fix test_layers to invoke some error branch --- python/paddle/fluid/tests/unittests/test_layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 8c129598085063..2175c8bd01af41 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -157,6 +157,7 @@ def test_layer_norm(self): param_attr=fluid.initializer.ConstantInitializer(value=1), bias_attr=fluid.initializer.ConstantInitializer(value=1), act='sigmoid') + lm(base.to_variable(inp)) self.assertIsNone(lm._scale_w) self.assertIsNone(lm._bias_w) From b0a352d27fb94fade6db8e2625a27a6a3e991641 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Mon, 26 Aug 2019 09:57:21 +0000 Subject: [PATCH 7/9] test=develop, fix sample code --- python/paddle/fluid/dygraph/nn.py | 3 ++- .../fluid/tests/unittests/test_layers.py | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 7f0f80f823db35..abc5c25f9c7846 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1958,7 +1958,8 @@ def _build_once(self, x, y): def forward(self, x, y): self._inputs = {"X": x, "Y": y, "Weight": self._w} - self._inputs["Bias"] = self._bias_param + if self._bias_attr: + self._inputs["Bias"] = self._bias_param if self._name is not None: out = self._helper.create_variable( name=".".join([self.full_name(), self._name]), diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 2175c8bd01af41..e6f070fbf31542 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -452,6 +452,30 @@ def test_bilinear_tensor_product(self): act='sigmoid') dy_rlt = btp(base.to_variable(inp_np_x), base.to_variable(inp_np_y)) + with self.dynamic_graph(): + btp2 = nn.BilinearTensorProduct('btp', 6, act='sigmoid') + dy_rlt2 = btp2( + base.to_variable(inp_np_x), base.to_variable(inp_np_y)) + + with self.static_graph(): + data_x2 = layers.data( + name='x', + shape=[1, 3], + dtype="float32", + append_batch_size=False) + data_y2 = layers.data( + name='y', + shape=[1, 3], + dtype="float32", + append_batch_size=False) + out2 = layers.bilinear_tensor_product( + data_x2, data_y2, 6, act='sigmoid') + + static_rlt3 = self.get_static_graph_result( + feed={'x': inp_np_x, + 'y': inp_np_y}, fetch_list=[out2])[0] + + self.assertTrue(np.array_equal(dy_rlt2.numpy(), static_rlt3)) self.assertTrue(np.array_equal(static_rlt2, static_rlt)) self.assertTrue(np.array_equal(dy_rlt.numpy(), static_rlt)) From ffee1176324aa770b0bf355333d7b9e63f09094c Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Mon, 26 Aug 2019 17:46:21 +0000 Subject: [PATCH 8/9] test=develop, fix BilinearTensorProduct failed in dygraph mode --- python/paddle/fluid/dygraph/nn.py | 15 +++++++-------- .../paddle/fluid/tests/unittests/test_layers.py | 1 + 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index abc5c25f9c7846..d94b33ee481d05 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -1948,17 +1948,16 @@ def _build_once(self, x, y): dtype=self._dtype, is_bias=False) - if self._bias_attr: - bias_size = [1, self._size] - self._bias_param = self.create_parameter( - attr=self._bias_attr, - shape=bias_size, - dtype=self._dtype, - is_bias=True) + bias_size = [1, self._size] + self._bias_param = self.create_parameter( + attr=self._bias_attr, + shape=bias_size, + dtype=self._dtype, + is_bias=True) def forward(self, x, y): self._inputs = {"X": x, "Y": y, "Weight": self._w} - if self._bias_attr: + if self._bias_param: self._inputs["Bias"] = self._bias_param if self._name is not None: out = self._helper.create_variable( diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index e6f070fbf31542..eb83814d77bea4 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -424,6 +424,7 @@ def test_bilinear_tensor_product(self): static_rlt = self.get_static_graph_result( feed={'x': inp_np_x, 'y': inp_np_y}, fetch_list=[out])[0] + with self.static_graph(): data_x = layers.data( name='x', From fa7270c766a1f2e721c6721be4c3e60e6c975789 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 27 Aug 2019 04:15:34 +0000 Subject: [PATCH 9/9] test=develop, fix test_layers segment fault error --- python/paddle/fluid/tests/unittests/test_layers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index eb83814d77bea4..9fc27f4bdeb3a6 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -158,8 +158,9 @@ def test_layer_norm(self): bias_attr=fluid.initializer.ConstantInitializer(value=1), act='sigmoid') lm(base.to_variable(inp)) - self.assertIsNone(lm._scale_w) - self.assertIsNone(lm._bias_w) + + self.assertFalse(hasattr(lm, "_scale_w")) + self.assertFalse(hasattr(lm, "_bias_w")) self.assertTrue(np.array_equal(static_ret, static_ret2)) self.assertTrue(np.array_equal(dy_ret.numpy(), static_ret2))