Skip to content

Commit c0289d1

Browse files
authored
Add api doc and update unittest. (#43)
* Add doc strings. * Update overlap_add op unittest
1 parent d2eebba commit c0289d1

File tree

3 files changed

+284
-26
lines changed

3 files changed

+284
-26
lines changed

paddle/fluid/operators/overlap_add_op.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class OverlapAddOp : public framework::OperatorWithKernel {
4444
"Attribute(hop_length) of OverlapAddOp should be greater "
4545
"than 0, but got %s.",
4646
hop_length));
47+
4748
PADDLE_ENFORCE_EQ(
4849
(axis == 0 || axis == -1), true,
4950
platform::errors::InvalidArgument(
@@ -68,6 +69,13 @@ class OverlapAddOp : public framework::OperatorWithKernel {
6869
end_axis = x_rank - 3;
6970
}
7071

72+
PADDLE_ENFORCE_LE(
73+
hop_length, frame_length,
74+
platform::errors::InvalidArgument(
75+
"Attribute(hop_length) of OverlapAddOp should be less or equal "
76+
"than frame_length, but got hop_length(%s) > frame_length(%s).",
77+
hop_length, frame_length));
78+
7179
const int seq_length = (n_frames - 1) * hop_length + frame_length;
7280

7381
// It won't go into for loop when x_rank == 2U.

python/paddle/fluid/tests/unittests/test_overlap_add_op.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ def setUp(self):
7979
self.outputs = {'Out': overlap_add(x=self.inputs['X'], **self.attrs)}
8080

8181
def initTestCase(self):
82-
input_shape = (150, 30)
82+
input_shape = (50, 3)
8383
input_type = 'float64'
8484
attrs = {
85-
'hop_length': 20,
85+
'hop_length': 4,
8686
'axis': -1,
8787
}
8888
return input_shape, input_type, attrs
@@ -100,54 +100,54 @@ def test_check_grad_normal(self):
100100

101101
class TestCase1(TestOverlapAddOp):
102102
def initTestCase(self):
103-
input_shape = (30, 150)
103+
input_shape = (3, 50)
104104
input_type = 'float64'
105105
attrs = {
106-
'hop_length': 15,
106+
'hop_length': 4,
107107
'axis': 0,
108108
}
109109
return input_shape, input_type, attrs
110110

111111

112112
class TestCase2(TestOverlapAddOp):
113113
def initTestCase(self):
114-
input_shape = (2, 250, 10)
114+
input_shape = (2, 40, 5)
115115
input_type = 'float64'
116116
attrs = {
117-
'hop_length': 50,
117+
'hop_length': 10,
118118
'axis': -1,
119119
}
120120
return input_shape, input_type, attrs
121121

122122

123123
class TestCase3(TestOverlapAddOp):
124124
def initTestCase(self):
125-
input_shape = (10, 250, 2)
125+
input_shape = (5, 40, 2)
126126
input_type = 'float64'
127127
attrs = {
128-
'hop_length': 30,
128+
'hop_length': 10,
129129
'axis': 0,
130130
}
131131
return input_shape, input_type, attrs
132132

133133

134134
class TestCase4(TestOverlapAddOp):
135135
def initTestCase(self):
136-
input_shape = (3, 5, 70, 20)
136+
input_shape = (3, 5, 12, 8)
137137
input_type = 'float64'
138138
attrs = {
139-
'hop_length': 27,
139+
'hop_length': 5,
140140
'axis': -1,
141141
}
142142
return input_shape, input_type, attrs
143143

144144

145145
class TestCase5(TestOverlapAddOp):
146146
def initTestCase(self):
147-
input_shape = (20, 70, 5, 3)
147+
input_shape = (8, 12, 5, 3)
148148
input_type = 'float64'
149149
attrs = {
150-
'hop_length': 33,
150+
'hop_length': 5,
151151
'axis': 0,
152152
}
153153
return input_shape, input_type, attrs

0 commit comments

Comments
 (0)