Skip to content

Commit 1ec01d7

Browse files
committed
log_softmax_op_v1
1 parent f27a0dc commit 1ec01d7

File tree

2 files changed

+2
-24
lines changed

2 files changed

+2
-24
lines changed

paddle/fluid/operators/log_softmax_op_npu.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/operators/log_softmax_op.h"
16-
#include "paddle/fluid/framework/tensor_util.h"
1716
#include "paddle/fluid/operators/npu_op_runner.h"
1817
namespace paddle {
1918
namespace operators {
@@ -43,9 +42,4 @@ namespace plat = paddle::platform;
4342

4443
REGISTER_OP_NPU_KERNEL(
4544
log_softmax,
46-
ops::LogSoftmaxNPUKernel<paddle::platform::NPUDeviceContext, float>,
47-
ops::LogSoftmaxNPUKernel<paddle::platform::NPUDeviceContext, double>,
48-
// ops::LogSoftmaxNPUKernel<paddle::platform::NPUDeviceContext, int>, //
49-
// used to debug
50-
ops::LogSoftmaxNPUKernel<paddle::platform::NPUDeviceContext,
51-
paddle::platform::float16>);
45+
ops::LogSoftmaxNPUKernel<paddle::platform::NPUDeviceContext, float>);

python/paddle/fluid/tests/unittests/npu/test_log_softmax_op_npu.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,11 @@
2222
import paddle.fluid as fluid
2323
from paddle.fluid import core
2424
import paddle.nn.functional as F
25+
from test_log_softmax import ref_log_softmax, test_log_softmax
2526
paddle.enable_static()
2627
np.random.seed(10)
2728

2829

29-
def ref_log_softmax(x):
30-
shiftx = (x - np.max(x))
31-
out = shiftx - np.log(np.exp(shiftx).sum())
32-
return out
33-
34-
35-
def ref_log_softmax_grad(x, axis):
36-
if axis < 0:
37-
axis += len(x.shape)
38-
out = np.apply_along_axis(ref_log_softmax, axis, x)
39-
axis_dim = x.shape[axis]
40-
dout = np.full_like(x, fill_value=1. / x.size)
41-
dx = dout - np.exp(out) * dout.copy().sum(axis=axis, keepdims=True).repeat(
42-
axis_dim, axis=axis)
43-
return dx
44-
45-
4630
class TestLogSoftmaxNPUOp(OpTest):
4731
def setUp(self):
4832
self.set_npu()

0 commit comments

Comments
 (0)