Skip to content

Commit cf40894

Browse files
authored
[NPU] Add norm_grad kernel (#35237)
* [NPU] fix for test_norm_op_npu * [NPU] add norm_grad * [NPU] add CheckAxis for axis * [NPU] delete debug codes * norm can not use L2Normalize, norm_grad can use L2NormalizeGrad * [NPU] delete useless codes * [NPU] optimize norm_grad OpMaker * Update python import path
1 parent e928274 commit cf40894

File tree

3 files changed

+93
-34
lines changed

3 files changed

+93
-34
lines changed

paddle/fluid/operators/norm_op.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,11 @@ class NormOpGradOpMaker : public framework::SingleGradOpMaker<T> {
8888
op->SetAttrMap(this->Attrs());
8989
op->SetInput("X", this->Input("X"));
9090
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
91+
#ifndef PADDLE_WITH_ASCEND_CL
9192
op->SetInput("Norm", this->Output("Norm"));
93+
#else
94+
op->SetInput("Out", this->Output("Out"));
95+
#endif
9296
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
9397
}
9498
};

paddle/fluid/operators/norm_op_npu.cc

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,26 @@ limitations under the License. */
1515
namespace paddle {
1616
namespace operators {
1717

18+
using DDim = framework::DDim;
19+
using Tensor = framework::Tensor;
20+
21+
void CheckAxis(int axis, int rank) {
22+
// check the axis is in [-rank, rank-1]
23+
if (axis <= rank - 1 && axis >= -rank) return;
24+
PADDLE_THROW(platform::errors::InvalidArgument(
25+
"axis in norm operator must between (%d) and (%d)"
26+
"but got (%d).",
27+
-rank, rank - 1, axis));
28+
}
29+
1830
template <typename DeviceContext, typename T>
1931
class NormNPUKernel : public framework::OpKernel<T> {
20-
private:
21-
void CheckAxis(int axis, int rank) const {
22-
// check the axis is in [-rank, rank-1]
23-
if (axis <= rank - 1 && axis >= -rank) return;
24-
PADDLE_THROW(platform::errors::InvalidArgument(
25-
"axis in norm operator must between (%d) and (%d)"
26-
"but got (%d).",
27-
-rank, rank - 1, axis));
28-
}
29-
3032
public:
31-
void Compute(const framework::ExecutionContext& ctx) const override {
33+
void Compute(const framework::ExecutionContext &ctx) const override {
3234
VLOG(4) << "Launch Norm Op Kernel on NPU." << std::endl;
33-
auto* in_x = ctx.Input<framework::Tensor>("X");
34-
auto* out_y = ctx.Output<framework::Tensor>("Out");
35-
auto* out_norm = ctx.Output<framework::Tensor>("Norm");
35+
auto *in_x = ctx.Input<framework::Tensor>("X");
36+
auto *out_y = ctx.Output<framework::Tensor>("Out");
37+
auto *out_norm = ctx.Output<framework::Tensor>("Norm");
3638
out_y->mutable_data<T>(ctx.GetPlace());
3739
out_norm->mutable_data<T>(ctx.GetPlace());
3840
auto xdim = in_x->dims();
@@ -46,7 +48,7 @@ class NormNPUKernel : public framework::OpKernel<T> {
4648
attr_input_norm["p"] = 2;
4749
attr_input_norm["keepdim"] = true;
4850
attr_input_norm["epsilon"] = eps;
49-
const auto& runner =
51+
const auto &runner =
5052
NpuOpRunner("LpNorm", {*in_x}, {*out_norm}, attr_input_norm);
5153
auto stream =
5254
ctx.template device_context<paddle::platform::NPUDeviceContext>()
@@ -56,12 +58,48 @@ class NormNPUKernel : public framework::OpKernel<T> {
5658
}
5759
};
5860

61+
template <typename DeviceContext, typename T>
62+
class NormGradNPUKernel : public framework::OpKernel<T> {
63+
public:
64+
void Compute(const framework::ExecutionContext &ctx) const override {
65+
float epsilon = ctx.Attr<float>("epsilon");
66+
int axis = ctx.Attr<int>("axis");
67+
68+
auto *x = ctx.Input<Tensor>("X");
69+
auto *y = ctx.Input<framework::Tensor>("Out");
70+
auto *dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
71+
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
72+
73+
auto xdim = x->dims();
74+
CheckAxis(axis, xdim.size());
75+
76+
auto place = ctx.GetPlace();
77+
78+
dx->mutable_data<T>(place);
79+
80+
framework::NPUAttributeMap attr_input_norm;
81+
attr_input_norm["dim"] = std::vector<int>({axis});
82+
attr_input_norm["eps"] = epsilon;
83+
const auto &runner =
84+
NpuOpRunner("L2NormalizeGrad", {*x, *y, *dy}, {*dx}, attr_input_norm);
85+
auto stream =
86+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
87+
.stream();
88+
runner.Run(stream);
89+
}
90+
};
91+
5992
} // namespace operators
6093
} // namespace paddle
6194

6295
namespace ops = paddle::operators;
6396
namespace plat = paddle::platform;
97+
6498
REGISTER_OP_NPU_KERNEL(
6599
norm, ops::NormNPUKernel<paddle::platform::NPUDeviceContext, float>,
66100
ops::NormNPUKernel<paddle::platform::NPUDeviceContext,
67101
paddle::platform::float16>)
102+
103+
REGISTER_OP_NPU_KERNEL(
104+
norm_grad, ops::NormGradNPUKernel<plat::NPUDeviceContext, float>,
105+
ops::NormGradNPUKernel<plat::NPUDeviceContext, plat::float16>);

python/paddle/fluid/tests/unittests/npu/test_norm_op_npu.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,18 @@
2020
import numpy as np
2121
import paddle
2222
import paddle.fluid as fluid
23-
from op_test import OpTest, skip_check_grad_ci
23+
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci
24+
from paddle.fluid.tests.unittests.test_norm_op import l2_norm
2425

25-
SEED = 2021
2626

27-
28-
def l2_norm(x, axis, epsilon):
29-
x2 = x**2
30-
s = np.sum(x2, axis=axis, keepdims=True)
31-
r = np.sqrt(s) + epsilon
32-
y = x / np.broadcast_to(r, x.shape)
33-
return y, r
34-
35-
36-
class TestNorm(OpTest):
27+
class TestNPUNormOp(OpTest):
3728
def setUp(self):
3829
paddle.enable_static()
3930
self.set_npu()
4031
self.place = paddle.NPUPlace(0)
4132
self.op_type = "norm"
4233
self.init_dtype()
34+
self.init_test_case()
4335

4436
x = np.random.random(self.shape).astype(self.dtype)
4537
y, norm = l2_norm(x, self.axis, self.epsilon)
@@ -52,36 +44,59 @@ def set_npu(self):
5244

5345
def init_dtype(self):
5446
self.dtype = np.float32
47+
48+
def init_test_case(self):
5549
self.axis = 1
5650
self.epsilon = 1e-10
5751
self.shape = (2, 3, 4, 5)
5852

5953
def test_check_output(self):
6054
self.check_output_with_place(self.place)
6155

56+
def test_check_grad(self):
57+
if self.dtype == np.float16:
58+
return
6259

63-
class TestNormOp2(TestNorm):
60+
self.check_grad_with_place(
61+
self.place, ['X'], 'Out', max_relative_error=0.006)
62+
63+
64+
class TestNPUNormOp2(TestNPUNormOp):
6465
def init_test_case(self):
6566
self.shape = [5, 3, 9, 7]
6667
self.axis = 0
6768
self.epsilon = 1e-8
68-
self.dtype = np.float32
6969

7070

71-
class TestNormOp3(TestNorm):
71+
class TestNPUNormOp3(TestNPUNormOp):
7272
def init_test_case(self):
7373
self.shape = [5, 3, 2, 7]
7474
self.axis = -1
7575
self.epsilon = 1e-8
76-
self.dtype = np.float32
7776

7877

79-
class TestNormOp4(TestNorm):
78+
@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
79+
"however it is desirable to cover the forward pass")
80+
class TestNPUNormOp4(TestNPUNormOp):
8081
def init_test_case(self):
8182
self.shape = [128, 1024, 14, 14]
8283
self.axis = 2
8384
self.epsilon = 1e-8
84-
self.dtype = np.float32
85+
86+
def test_check_grad(self):
87+
pass
88+
89+
90+
@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
91+
"however it is desirable to cover the forward pass")
92+
class TestNPUNormOp5(TestNPUNormOp):
93+
def init_test_case(self):
94+
self.shape = [2048, 2048]
95+
self.axis = 1
96+
self.epsilon = 1e-8
97+
98+
def test_check_grad(self):
99+
pass
85100

86101

87102
class API_NormTest(unittest.TestCase):
@@ -96,13 +111,15 @@ def test_norm_x_type():
96111
self.assertRaises(TypeError, test_norm_x_type)
97112

98113

99-
class TestNormFP16(TestNorm):
114+
class TestNPUNormOpFP16(TestNPUNormOp):
100115
def set_npu(self):
101116
self.__class__.use_npu = True
102117
self.__class__.no_need_check_grad = True
103118

104119
def init_dtype(self):
105120
self.dtype = np.float16
121+
122+
def init_test_case(self):
106123
self.axis = -1
107124
self.epsilon = 1e-10
108125
self.shape = (2, 3, 100)

0 commit comments

Comments
 (0)