Skip to content

Commit ba105b5

Browse files
committed
fix simple_rnn_cell, gru_cell and lstm_cell zero_div_error
1 parent 4cc3d9a commit ba105b5

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

python/paddle/fluid/tests/unittests/rnn/test_rnn_cells.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,16 @@ def test_with_zero_state(self):
6060
y2, h2 = rnn2(paddle.to_tensor(x))
6161
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
6262

63+
def test_errors(self):
64+
def test_zero_hidden_size():
65+
cell = paddle.nn.SimpleRNNCell(-1, 0)
66+
67+
self.assertRaises(ValueError, test_zero_hidden_size)
68+
6369
def runTest(self):
6470
self.test_with_initial_state()
6571
self.test_with_zero_state()
72+
self.test_errors()
6673

6774

6875
class TestGRUCell(unittest.TestCase):
@@ -103,9 +110,16 @@ def test_with_zero_state(self):
103110
y2, h2 = rnn2(paddle.to_tensor(x))
104111
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
105112

113+
def test_errors(self):
114+
def test_zero_hidden_size():
115+
cell = paddle.nn.GRUCell(-1, 0)
116+
117+
self.assertRaises(ValueError, test_zero_hidden_size)
118+
106119
def runTest(self):
107120
self.test_with_initial_state()
108121
self.test_with_zero_state()
122+
self.test_errors()
109123

110124

111125
class TestLSTMCell(unittest.TestCase):
@@ -150,9 +164,16 @@ def test_with_zero_state(self):
150164
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
151165
np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5)
152166

167+
def test_errors(self):
168+
def test_zero_hidden_size():
169+
cell = paddle.nn.LSTMCell(-1, 0)
170+
171+
self.assertRaises(ValueError, test_zero_hidden_size)
172+
153173
def runTest(self):
154174
self.test_with_initial_state()
155175
self.test_with_zero_state()
176+
self.test_errors()
156177

157178

158179
def load_tests(loader, tests, pattern):

python/paddle/nn/layer/rnn.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ def __init__(self,
332332
bias_hh_attr=None,
333333
name=None):
334334
super(SimpleRNNCell, self).__init__()
335+
if hidden_size <= 0:
336+
raise ValueError(
337+
"hidden_size of {} must be greater than 0, but now equals to {}".
338+
format(self.__class__.__name__, hidden_size))
335339
std = 1.0 / math.sqrt(hidden_size)
336340
self.weight_ih = self.create_parameter(
337341
(hidden_size, input_size),
@@ -480,6 +484,10 @@ def __init__(self,
480484
bias_hh_attr=None,
481485
name=None):
482486
super(LSTMCell, self).__init__()
487+
if hidden_size <= 0:
488+
raise ValueError(
489+
"hidden_size of {} must be greater than 0, but now equals to {}".
490+
format(self.__class__.__name__, hidden_size))
483491
std = 1.0 / math.sqrt(hidden_size)
484492
self.weight_ih = self.create_parameter(
485493
(4 * hidden_size, input_size),
@@ -627,6 +635,10 @@ def __init__(self,
627635
bias_hh_attr=None,
628636
name=None):
629637
super(GRUCell, self).__init__()
638+
if hidden_size <= 0:
639+
raise ValueError(
640+
"hidden_size of {} must be greater than 0, but now equals to {}".
641+
format(self.__class__.__name__, hidden_size))
630642
std = 1.0 / math.sqrt(hidden_size)
631643
self.weight_ih = self.create_parameter(
632644
(3 * hidden_size, input_size),

0 commit comments

Comments
 (0)