Skip to content

Commit 94a702b

Browse files
committed
switched from lambdas to local functions
1 parent 5da3813 commit 94a702b

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,12 @@ class TestMKLDNNSigmoidBF16Op(TestActivation):
8484
@OpTestTool.skip_if_not_cpu_bf16()
8585
def config(self):
8686
self.op_type = "sigmoid"
87-
self.op_func = lambda x: (1 / (1 + np.exp(-x)))
88-
self.op_grad_func = lambda dout, x: (dout * self.op_func(x)) * (1 - self.op_func(x))
87+
88+
def op_func(self, x):
89+
return 1 / (1 + np.exp(-x))
90+
91+
def op_grad_func(self, dout, x):
92+
return dout * self.op_func(x) * (1 - self.op_func(x))
8993

9094
def set_attrs(self):
9195
self.attrs = {"use_mkldnn": True}
@@ -121,8 +125,14 @@ def test_check_grad(self):
121125
class TestMKLDNNGeluErfBF16Op(TestMKLDNNSigmoidBF16Op):
122126
def config(self):
123127
self.op_type = "gelu"
124-
self.op_func = lambda x: gelu(x, False)
125-
self.op_grad_func = lambda dout, x: (dout * (0.5 + 0.5 * erf(x / np.sqrt(2)) + (x / np.sqrt(2 * np.pi) * np.exp(-0.5 * np.power(x, 2)))))
128+
129+
def op_func(self, x):
130+
return gelu(x, False)
131+
132+
def op_grad_func(self, dout, x):
133+
return (dout *
134+
(0.5 + 0.5 * erf(x / np.sqrt(2)) +
135+
(x / np.sqrt(2 * np.pi) * np.exp(-0.5 * np.power(x, 2)))))
126136

127137

128138
class TestMKLDNNGeluErfDim2BF16Op(TestMKLDNNGeluErfBF16Op):
@@ -133,8 +143,16 @@ def init_data(self):
133143
class TestMKLDNNGeluTanhBF16Op(TestMKLDNNSigmoidBF16Op):
134144
def config(self):
135145
self.op_type = "gelu"
136-
self.op_func = lambda x: gelu(x, True)
137-
self.op_grad_func = lambda dout, x: (dout * 0.5 * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) * (1 + np.sqrt(2 / np.pi) * (x + 0.134145 * np.power(x, 3)) * (1 - np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))))
146+
147+
def op_func(self, x):
148+
return gelu(x, True)
149+
150+
def op_grad_func(self, dout, x):
151+
grad_part = np.tanh(
152+
np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))
153+
return dout * 0.5 * (1 + grad_part) * (1 + np.sqrt(2 / np.pi) *
154+
(x + 0.134145 * np.power(x, 3)) *
155+
(1 - grad_part))
138156

139157
def set_attrs(self):
140158
self.attrs = {"use_mkldnn": True, "approximate": True}

0 commit comments

Comments
 (0)