Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/operators/array_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class ArrayOp : public framework::OperatorBase {

size_t offset;
if (platform::is_gpu_place(i_tensor.place()) ||
platform::is_xpu_place(i_tensor.place())) {
platform::is_xpu_place(i_tensor.place()) ||
platform::is_npu_place(i_tensor.place())) {
// FIXME: Avoid copy from GPU to CPU
framework::Tensor t;
framework::TensorCopy(i_tensor, platform::CPUPlace(), dev_ctx, &t);
Expand Down
11 changes: 7 additions & 4 deletions paddle/fluid/operators/controlflow/while_op_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,17 @@ bool GetCondData(const framework::LoDTensor &cond) {
if (platform::is_cpu_place(cond.place())) {
return cond.data<bool>()[0];
}
// when platform::is_gpu_place(cond.place()) is true
// when platform::is_gpu_place(cond.place()) or
// platform::is_npu_place(cond.place()) is true
std::unique_ptr<framework::LoDTensor> cpu_cond{new framework::LoDTensor()};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_ASCEND_CL)
framework::TensorCopySync(cond, platform::CPUPlace(), cpu_cond.get());
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"This version of PaddlePaddle does NOT support GPU but got GPU tensor "
"Cond in WhileOp. Please compile WITH_GPU option."));
"This version of PaddlePaddle does NOT support GPU/NPU but got GPU/NPU "
"tensor "
"Cond in WhileOp. Please compile WITH_GPU or WITH_ASCEND_CL option."));
#endif
return cpu_cond->data<bool>()[0];
}
Expand Down
97 changes: 72 additions & 25 deletions paddle/fluid/operators/sum_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,83 @@ using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class SumNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x = ctx.MultiInput<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
void Compute(const framework::ExecutionContext &ctx) const override {
auto out_var = ctx.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) {
auto *out = out_var->GetMutable<framework::LoDTensor>();
auto x = ctx.MultiInput<Tensor>("X");
out->mutable_data<T>(ctx.GetPlace());

auto place = ctx.GetPlace();
auto place = ctx.GetPlace();

int n = static_cast<int>(x.size());
if (n == 1) {
TensorCopy(*x[0], place, out);
return;
}
int n = static_cast<int>(x.size());
if (n == 1) {
TensorCopy(*x[0], place, out);
return;
}

std::vector<framework::Tensor> inputs;
std::vector<std::string> names;
for (int i = 0; i < n; ++i) {
if (x[i] && x[i]->numel() > 0) {
inputs.push_back(*x[i]);
names.push_back("x" + std::to_string(i));
} else {
continue;
std::vector<framework::Tensor> inputs;
std::vector<std::string> names;
for (int i = 0; i < n; ++i) {
if (x[i] && x[i]->numel() > 0) {
inputs.push_back(*x[i]);
names.push_back("x" + std::to_string(i));
} else {
continue;
}
}
}

auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
NpuOpRunner runner{"AddN", {inputs}, {*out}, {{"N", n}}};
runner.AddInputNames(names);
runner.Run(stream);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
NpuOpRunner runner{"AddN", {inputs}, {*out}, {{"N", n}}};
runner.AddInputNames(names);
runner.Run(stream);
} else if (out_var->IsType<framework::LoDTensorArray>()) {
auto in_vars = ctx.MultiInputVar("X");
bool in_place = out_var == in_vars[0];
auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
PADDLE_ENFORCE_EQ(in_vars[i]->IsType<framework::LoDTensorArray>(), true,
platform::errors::InvalidArgument(
"Only support all inputs are TensorArray, "
"but inputs[%d] is not TensorArray.",
i));
auto &in_array = in_vars[i]->Get<framework::LoDTensorArray>();

for (size_t i = 0; i < in_array.size(); ++i) {
if (in_array[i].IsInitialized() && (in_array[i].numel() != 0)) {
if (i >= out_array.size()) {
out_array.resize(i + 1);
}
if (!out_array[i].IsInitialized() || (out_array[i].numel() == 0)) {
framework::TensorCopy(in_array[i], in_array[i].place(),
ctx.device_context(), &out_array[i]);
out_array[i].set_lod(in_array[i].lod());
} else {
PADDLE_ENFORCE_EQ(
out_array[i].lod(), in_array[i].lod(),
platform::errors::InvalidArgument(
"The lod message between inputs[%d] and"
" outputs[%d] must be same, but now is not same.",
i, i));
auto stream = ctx.template device_context<
paddle::platform::NPUDeviceContext>()
.stream();
NpuOpRunner runner{
"Add", {out_array[i], in_array[i]}, {out_array[i]}, {}};
runner.Run(stream);
}
}
}
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Expected type of Output(out) must be Tensor or "
"LoDTensorArray. But got "
"unsupport type: %s.",
framework::ToTypeName(out_var->Type())));
}
}
};

Expand Down
130 changes: 130 additions & 0 deletions python/paddle/fluid/tests/unittests/npu/test_while_op_npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import paddle
import paddle.fluid.layers as layers
from paddle.fluid.executor import Executor
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.backward import append_backward
import numpy
from paddle.fluid import compiler, Program, program_guard

paddle.enable_static()


class TestWhileOp(unittest.TestCase):
def simple_net(self):
d0 = layers.data(
"d0", shape=[10], append_batch_size=False, dtype='float32')
d1 = layers.data(
"d1", shape=[10], append_batch_size=False, dtype='float32')
d2 = layers.data(
"d2", shape=[10], append_batch_size=False, dtype='float32')
# fill_constant npu op doesn't support int64
i = layers.zeros(shape=[1], dtype='int32')
i = layers.cast(i, 'int64')
i.stop_gradient = True
init = layers.zeros(shape=[10], dtype='float32')
mem_array = layers.array_write(x=init, i=i)
data_array = layers.array_write(x=d0, i=i)
i = layers.increment(i)
layers.array_write(d1, i, array=data_array)
i = layers.increment(i)
layers.array_write(d2, i, array=data_array)
i = layers.zeros(shape=[1], dtype='int32')
i = layers.cast(i, 'int64')
i.stop_gradient = True
array_len = layers.fill_constant(shape=[1], dtype='int32', value=5)
array_len = layers.cast(array_len, 'int64')
array_len.stop_gradient = True
cond = layers.ones(shape=[1], dtype='int32')
cond = layers.cast(cond, 'bool')
j = layers.fill_constant(shape=[1], dtype='int32', value=1)
j = layers.cast(j, 'int64')
j.stop_gradient = True
array_len2 = layers.fill_constant(shape=[1], dtype='int32', value=3)
array_len2 = layers.cast(array_len2, 'int64')
array_len2.stop_gradient = True
cond2 = layers.logical_or(x=j, y=array_len2)
cond2 = layers.ones(shape=[1], dtype='int32')
cond2 = layers.cast(cond2, 'bool')
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
with while_op.block():
d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i)
result = layers.sums(input=[d, prev])

i = layers.increment(x=i, in_place=True)
layers.array_write(result, i=i, array=mem_array)
layers.less_than(x=i, y=array_len, cond=cond)

with while_op2.block():
d2 = layers.array_read(array=data_array, i=j)
prev2 = layers.array_read(array=mem_array, i=j)
result2 = layers.sums(input=[d2, prev2])

j = layers.increment(x=j, in_place=True)
layers.array_write(result2, i=j, array=mem_array)
layers.less_than(x=j, y=array_len2, cond=cond2)
sum_result = layers.array_read(array=mem_array, i=j)
loss = layers.mean(sum_result)
return loss, sum_result

def test_simple_net(self):
paddle.enable_static()
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
loss, sum_result = self.simple_net()

append_backward(loss)

npu_place = paddle.NPUPlace(0)
exe = Executor(npu_place)
d = []

for i in range(3):
d.append(numpy.random.random(size=[10]).astype('float32'))

outs = exe.run(feed={'d0': d[0],
'd1': d[1],
'd2': d[2]},
fetch_list=[sum_result])
self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01)

def test_simple_net_forward(self):
paddle.enable_static()
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
self.simple_net()

npu_place = paddle.NPUPlace(0)
exe = Executor(npu_place)
d = []

for i in range(3):
d.append(numpy.random.random(size=[10]).astype('float32'))

for _ in range(2):
exe.run(main_program, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]})


if __name__ == '__main__':
unittest.main()