Skip to content

Commit e74c821

Browse files
committed
add ffn python api test
1 parent 07e7e8e commit e74c821

File tree

3 files changed

+93
-9
lines changed

3 files changed

+93
-9
lines changed

python/paddle/fluid/tests/unittests/test_fused_ffn_op.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,6 @@ def FusedFFN(self):
153153
def test_fused_ffn(self):
154154
base_out, base_grad = self.Base()
155155
fused_out, fused_grad = self.FusedFFN()
156-
if base_grad is None:
157-
print("base grad is none")
158-
if fused_grad is None:
159-
print("fused grad is none")
160156

161157
np.testing.assert_allclose(
162158
base_out.numpy(), fused_out.numpy(), rtol=self.rtol, atol=self.atol)
@@ -205,5 +201,80 @@ def getShape(self):
205201
self.dim_feedforward = 8
206202

207203

204+
class TestFusedFFNOpApi(TestFusedFFNOp):
205+
def setUp(self):
206+
self.getDtype()
207+
self.getShape()
208+
self.getDiff()
209+
self.getActivation()
210+
self.getNormalizeBefore()
211+
self.weight_attr = None
212+
self.bias_attr = None
213+
214+
self.weight_attrs = fused_transformer._convert_param_attr_to_list(
215+
self.weight_attr, 2)
216+
self.bias_attrs = fused_transformer._convert_param_attr_to_list(
217+
self.bias_attr, 2)
218+
self.ffn_layer = fused_transformer.FusedFeedForward(
219+
self.d_model, self.dim_feedforward, 0.0, self.act_method, 0.0,
220+
self.normalize_before, self.weight_attrs[1], self.bias_attrs[1])
221+
222+
self.ln1_scale = self.ffn_layer._ln1_scale
223+
self.ln1_bias = self.ffn_layer._ln1_bias
224+
self.ln2_scale = self.ffn_layer._ln2_scale
225+
self.ln2_bias = self.ffn_layer._ln2_bias
226+
self.linear1_weight = self.ffn_layer._linear1_weight
227+
self.linear1_bias = self.ffn_layer._linear1_bias
228+
self.linear2_weight = self.ffn_layer._linear2_weight
229+
self.linear2_bias = self.ffn_layer._linear2_bias
230+
231+
self.src = np.random.random((self.batch_size, self.query_length,
232+
self.d_model)).astype(self.dtype)
233+
self.dout = np.random.random((self.batch_size, self.query_length,
234+
self.d_model)).astype(self.dtype)
235+
236+
self.dropout1 = Dropout(0.0, mode="upscale_in_train")
237+
self.dropout2 = Dropout(0.0, mode="upscale_in_train")
238+
self.activation = getattr(F, self.act_method)
239+
240+
def Base(self):
241+
tensor_src = paddle.to_tensor(self.src, stop_gradient=False)
242+
residual = paddle.to_tensor(self.src)
243+
if self.normalize_before:
244+
ln1_out = F.layer_norm(tensor_src,
245+
list([self.d_model]), self.ln1_scale,
246+
self.ln1_bias)
247+
linear1_out = F.linear(ln1_out, self.linear1_weight,
248+
self.linear1_bias)
249+
act_out = self.activation(linear1_out)
250+
dropout1_out = self.dropout1(act_out)
251+
linear2_out = F.linear(dropout1_out, self.linear2_weight,
252+
self.linear2_bias)
253+
dropout2_out = residual + self.dropout2(linear2_out)
254+
paddle.autograd.backward([dropout2_out],
255+
[paddle.to_tensor(self.dout)], True)
256+
return dropout2_out, tensor_src.grad
257+
else:
258+
linear1_out = F.linear(tensor_src, self.linear1_weight,
259+
self.linear1_bias)
260+
act_out = self.activation(linear1_out)
261+
dropout1_out = self.dropout1(act_out)
262+
linear2_out = F.linear(dropout1_out, self.linear2_weight,
263+
self.linear2_bias)
264+
dropout2_out = residual + self.dropout2(linear2_out)
265+
dropout2_out = F.layer_norm(dropout2_out,
266+
list([self.d_model]), self.ln2_scale,
267+
self.ln2_bias)
268+
paddle.autograd.backward([dropout2_out],
269+
[paddle.to_tensor(self.dout)], True)
270+
return dropout2_out, tensor_src.grad
271+
272+
def FusedFFN(self):
273+
tensor_src = paddle.to_tensor(self.src, stop_gradient=False)
274+
out = self.ffn_layer(tensor_src)
275+
paddle.autograd.backward([out], [paddle.to_tensor(self.dout)])
276+
return out, tensor_src.grad
277+
278+
208279
if __name__ == "__main__":
209280
unittest.main()

python/paddle/nn/functional/fused_ffn.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@
2323
__all__ = []
2424

2525

26+
def _verify_dropout_param(p, mode):
27+
if not isinstance(p, (float, int)):
28+
raise TypeError("p argument should be a number")
29+
if p < 0 or p > 1:
30+
raise ValueError("p argument should between 0 and 1")
31+
if mode not in ('downscale_in_infer', 'upscale_in_train'):
32+
raise ValueError(
33+
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
34+
35+
2636
def fused_ffn(x,
2737
linear1_weight,
2838
linear2_weight,
@@ -49,6 +59,8 @@ def fused_ffn(x,
4959
seed2=0,
5060
normalize_pre_or_post=False,
5161
name=None):
62+
_verify_dropout_param(dropout_prob1, dropout_implementation1)
63+
_verify_dropout_param(dropout_prob2, dropout_implementation2)
5264

5365
if in_dygraph_mode():
5466
out, _, _, _, _, _, _, _, _, _, _ = _C_ops.fused_ffn(

python/paddle/nn/layer/fused_transformer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .. import functional as F
1717
from paddle.nn import Layer
1818
from ...framework import ParamAttr
19+
from ..initializer import Constant
1920

2021
import collections
2122

@@ -188,9 +189,11 @@ def __init__(self,
188189
bias_attr=None):
189190

190191
super(FusedFeedForward, self).__init__()
191-
192-
#self._weight_attrs = _convert_param_attr_to_list(weight_attr, 2)
193-
#self._bias_attrs = _convert_param_attr_to_list(bias_attr, 2)
192+
assert d_model > 0, ("Expected d_model to be greater than 0, "
193+
"but recieved {}".format(d_model))
194+
assert dim_feedforward > 0, (
195+
"Expected dim_feedforward to be greater than 0, "
196+
"but recieved {}".format(dim_feedforward))
194197

195198
self._dtype = self._helper.get_default_dtype()
196199
self._d_model = d_model
@@ -199,8 +202,6 @@ def __init__(self,
199202
self._act_dropout = dropout if act_dropout is None else act_dropout
200203
self._act_method = activation
201204
self._normalize_before = normalize_before
202-
#self._weight_attr = self._weight_attrs[1]
203-
#self._bias_attr = self._bias_attrs[1]
204205

205206
self._linear1_weight = self.create_parameter(
206207
shape=[d_model, dim_feedforward],

0 commit comments

Comments
 (0)