Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/operators/norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ class NormOpGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
#ifndef PADDLE_WITH_ASCEND_CL
op->SetInput("Norm", this->Output("Norm"));
#else
op->SetInput("Out", this->Output("Out"));
#endif
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
Expand Down
68 changes: 53 additions & 15 deletions paddle/fluid/operators/norm_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,26 @@ limitations under the License. */
namespace paddle {
namespace operators {

using DDim = framework::DDim;
using Tensor = framework::Tensor;

void CheckAxis(int axis, int rank) {
// check the axis is in [-rank, rank-1]
if (axis <= rank - 1 && axis >= -rank) return;
PADDLE_THROW(platform::errors::InvalidArgument(
"axis in norm operator must between (%d) and (%d)"
"but got (%d).",
-rank, rank - 1, axis));
}

template <typename DeviceContext, typename T>
class NormNPUKernel : public framework::OpKernel<T> {
private:
void CheckAxis(int axis, int rank) const {
// check the axis is in [-rank, rank-1]
if (axis <= rank - 1 && axis >= -rank) return;
PADDLE_THROW(platform::errors::InvalidArgument(
"axis in norm operator must between (%d) and (%d)"
"but got (%d).",
-rank, rank - 1, axis));
}

public:
void Compute(const framework::ExecutionContext& ctx) const override {
void Compute(const framework::ExecutionContext &ctx) const override {
VLOG(4) << "Launch Norm Op Kernel on NPU." << std::endl;
auto* in_x = ctx.Input<framework::Tensor>("X");
auto* out_y = ctx.Output<framework::Tensor>("Out");
auto* out_norm = ctx.Output<framework::Tensor>("Norm");
auto *in_x = ctx.Input<framework::Tensor>("X");
auto *out_y = ctx.Output<framework::Tensor>("Out");
auto *out_norm = ctx.Output<framework::Tensor>("Norm");
out_y->mutable_data<T>(ctx.GetPlace());
out_norm->mutable_data<T>(ctx.GetPlace());
auto xdim = in_x->dims();
Expand All @@ -46,7 +48,7 @@ class NormNPUKernel : public framework::OpKernel<T> {
attr_input_norm["p"] = 2;
attr_input_norm["keepdim"] = true;
attr_input_norm["epsilon"] = eps;
const auto& runner =
const auto &runner =
NpuOpRunner("LpNorm", {*in_x}, {*out_norm}, attr_input_norm);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
Expand All @@ -56,12 +58,48 @@ class NormNPUKernel : public framework::OpKernel<T> {
}
};

template <typename DeviceContext, typename T>
class NormGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
float epsilon = ctx.Attr<float>("epsilon");
int axis = ctx.Attr<int>("axis");

auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Input<framework::Tensor>("Out");
auto *dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));

auto xdim = x->dims();
CheckAxis(axis, xdim.size());

auto place = ctx.GetPlace();

dx->mutable_data<T>(place);

framework::NPUAttributeMap attr_input_norm;
attr_input_norm["dim"] = std::vector<int>({axis});
attr_input_norm["eps"] = epsilon;
const auto &runner =
NpuOpRunner("L2NormalizeGrad", {*x, *y, *dy}, {*dx}, attr_input_norm);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_NPU_KERNEL(
norm, ops::NormNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::NormNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>)

REGISTER_OP_NPU_KERNEL(
norm_grad, ops::NormGradNPUKernel<plat::NPUDeviceContext, float>,
ops::NormGradNPUKernel<plat::NPUDeviceContext, plat::float16>);
55 changes: 36 additions & 19 deletions python/paddle/fluid/tests/unittests/npu/test_norm_op_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,18 @@
import numpy as np
import paddle
import paddle.fluid as fluid
from op_test import OpTest, skip_check_grad_ci
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci
from paddle.fluid.tests.unittests.test_norm_op import l2_norm

SEED = 2021


def l2_norm(x, axis, epsilon):
x2 = x**2
s = np.sum(x2, axis=axis, keepdims=True)
r = np.sqrt(s) + epsilon
y = x / np.broadcast_to(r, x.shape)
return y, r


class TestNorm(OpTest):
class TestNPUNormOp(OpTest):
def setUp(self):
paddle.enable_static()
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "norm"
self.init_dtype()
self.init_test_case()

x = np.random.random(self.shape).astype(self.dtype)
y, norm = l2_norm(x, self.axis, self.epsilon)
Expand All @@ -52,36 +44,59 @@ def set_npu(self):

def init_dtype(self):
self.dtype = np.float32

def init_test_case(self):
self.axis = 1
self.epsilon = 1e-10
self.shape = (2, 3, 4, 5)

def test_check_output(self):
self.check_output_with_place(self.place)

def test_check_grad(self):
if self.dtype == np.float16:
return

class TestNormOp2(TestNorm):
self.check_grad_with_place(
self.place, ['X'], 'Out', max_relative_error=0.006)


class TestNPUNormOp2(TestNPUNormOp):
def init_test_case(self):
self.shape = [5, 3, 9, 7]
self.axis = 0
self.epsilon = 1e-8
self.dtype = np.float32


class TestNormOp3(TestNorm):
class TestNPUNormOp3(TestNPUNormOp):
def init_test_case(self):
self.shape = [5, 3, 2, 7]
self.axis = -1
self.epsilon = 1e-8
self.dtype = np.float32


class TestNormOp4(TestNorm):
@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
"however it is desirable to cover the forward pass")
class TestNPUNormOp4(TestNPUNormOp):
def init_test_case(self):
self.shape = [128, 1024, 14, 14]
self.axis = 2
self.epsilon = 1e-8
self.dtype = np.float32

def test_check_grad(self):
pass


@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " +
"however it is desirable to cover the forward pass")
class TestNPUNormOp5(TestNPUNormOp):
def init_test_case(self):
self.shape = [2048, 2048]
self.axis = 1
self.epsilon = 1e-8

def test_check_grad(self):
pass


class API_NormTest(unittest.TestCase):
Expand All @@ -96,13 +111,15 @@ def test_norm_x_type():
self.assertRaises(TypeError, test_norm_x_type)


class TestNormFP16(TestNorm):
class TestNPUNormOpFP16(TestNPUNormOp):
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True

def init_dtype(self):
self.dtype = np.float16

def init_test_case(self):
self.axis = -1
self.epsilon = 1e-10
self.shape = (2, 3, 100)
Expand Down