@@ -84,8 +84,12 @@ class TestMKLDNNSigmoidBF16Op(TestActivation):
8484 @OpTestTool .skip_if_not_cpu_bf16 ()
8585 def config (self ):
8686 self .op_type = "sigmoid"
87- self .op_func = lambda x : (1 / (1 + np .exp (- x )))
88- self .op_grad_func = lambda dout , x : (dout * self .op_func (x )) * (1 - self .op_func (x ))
87+
88+ def op_func (self , x ):
89+ return 1 / (1 + np .exp (- x ))
90+
91+ def op_grad_func (self , dout , x ):
92+ return dout * self .op_func (x ) * (1 - self .op_func (x ))
8993
9094 def set_attrs (self ):
9195 self .attrs = {"use_mkldnn" : True }
@@ -121,8 +125,14 @@ def test_check_grad(self):
121125class TestMKLDNNGeluErfBF16Op (TestMKLDNNSigmoidBF16Op ):
122126 def config (self ):
123127 self .op_type = "gelu"
124- self .op_func = lambda x : gelu (x , False )
125- self .op_grad_func = lambda dout , x : (dout * (0.5 + 0.5 * erf (x / np .sqrt (2 )) + (x / np .sqrt (2 * np .pi ) * np .exp (- 0.5 * np .power (x , 2 )))))
128+
129+ def op_func (self , x ):
130+ return gelu (x , False )
131+
132+ def op_grad_func (self , dout , x ):
133+ return (dout *
134+ (0.5 + 0.5 * erf (x / np .sqrt (2 )) +
135+ (x / np .sqrt (2 * np .pi ) * np .exp (- 0.5 * np .power (x , 2 )))))
126136
127137
128138class TestMKLDNNGeluErfDim2BF16Op (TestMKLDNNGeluErfBF16Op ):
@@ -133,8 +143,16 @@ def init_data(self):
133143class TestMKLDNNGeluTanhBF16Op (TestMKLDNNSigmoidBF16Op ):
134144 def config (self ):
135145 self .op_type = "gelu"
136- self .op_func = lambda x : gelu (x , True )
137- self .op_grad_func = lambda dout , x : (dout * 0.5 * (1 + np .tanh (np .sqrt (2 / np .pi ) * (x + 0.044715 * np .power (x , 3 )))) * (1 + np .sqrt (2 / np .pi ) * (x + 0.134145 * np .power (x , 3 )) * (1 - np .tanh (np .sqrt (2 / np .pi ) * (x + 0.044715 * np .power (x , 3 ))))))
146+
147+ def op_func (self , x ):
148+ return gelu (x , True )
149+
150+ def op_grad_func (self , dout , x ):
151+ grad_part = np .tanh (
152+ np .sqrt (2 / np .pi ) * (x + 0.044715 * np .power (x , 3 )))
153+ return dout * 0.5 * (1 + grad_part ) * (1 + np .sqrt (2 / np .pi ) *
154+ (x + 0.134145 * np .power (x , 3 )) *
155+ (1 - grad_part ))
138156
139157 def set_attrs (self ):
140158 self .attrs = {"use_mkldnn" : True , "approximate" : True }
0 commit comments