Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/lite/api/cxx_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ USE_LITE_OP(elementwise_sub)
USE_LITE_OP(square)
USE_LITE_OP(softmax)
USE_LITE_OP(dropout)
USE_LITE_OP(concat)
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);

Expand All @@ -142,6 +143,7 @@ USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(concat, kX86, kFloat, kNCHW, def);
#endif

#ifdef LITE_WITH_CUDA
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/lite/kernels/x86/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cc_library(scale_compute_x86 SRCS scale_compute.cc DEPS ${lite_kernel_deps})
cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_sub_op elementwise_add_op)
cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} )
cc_library(concat_compute_x86 SRCS concat_compute.cc DEPS ${lite_kernel_deps} )

set(x86_kernels
activation_compute_x86
Expand All @@ -26,6 +27,7 @@ set(x86_kernels
scale_compute_x86
softmax_compute_x86
dropout_compute_x86
concat_compute_x86
)

set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels")
102 changes: 102 additions & 0 deletions paddle/fluid/lite/kernels/x86/concat_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <Eigen/Core>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/operators/strided_memcpy.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {

template <typename T>
class ConcatCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ConcatParam;

void Run() override {
auto& param = *param_.get_mutable<param_t>();
int64_t axis = static_cast<int64_t>(param.axis);
auto out = param.output;

if (axis == 0 && param.x.size() < 10) {
size_t output_offset = 0;
for (auto* in : param.x) {
if (!in || in->dims().production() == 0UL) {
continue;
}
auto in_stride = framework::stride_numel(in->dims().data());
auto out_stride = framework::stride_numel(out->dims().data());
paddle::operators::StridedNumelCopyWithAxis<T>(
platform::CPUDeviceContext(), axis,
out->mutable_data<T>() + output_offset, out_stride, in->data<T>(),
in_stride, in_stride[axis]);

output_offset += in_stride[axis];
}
} else {
std::vector<lite::Tensor> inputs;
for (size_t j = 0; j < param.x.size(); ++j) {
if (param.x[j] && param.x[j]->dims().production() > 0) {
inputs.push_back(*param.x[j]);
} else {
continue;
}
}

int num = inputs.size();
int rows = 1;
auto dim_0 = inputs[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;

std::vector<int64_t> input_cols(inputs.size());
for (int i = 0; i < num; ++i) {
int t_cols = inputs[i].dims().production() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
// computation
auto output_data = param.output->template mutable_data<T>();
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
auto input_data = inputs[j].data<float>();
for (int k = 0; k < out_rows; ++k) {
std::memcpy(output_data + k * out_cols + col_idx,
input_data + k * col_len, sizeof(T) * col_len);
}
col_idx += col_len;
}
}
}

virtual ~ConcatCompute() = default;
};

} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle

REGISTER_LITE_KERNEL(concat, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::ConcatCompute<float>, def)
.BindInput("X", {LiteType::GetTensorListTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
3 changes: 3 additions & 0 deletions paddle/fluid/lite/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS})
#cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS})
cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite)
cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS})
cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS})

set(ops_lite
fc_op_lite
Expand All @@ -32,6 +33,7 @@ set(ops_lite
fill_constant_op_lite
activation_ops_lite
dropout_op_lite
concat_op_lite
PARENT_SCOPE)

lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
Expand All @@ -41,3 +43,4 @@ lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite)
lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite)
lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite)
lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite)
75 changes: 75 additions & 0 deletions paddle/fluid/lite/operators/concat_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/lite/operators/concat_op.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace operators {

bool ConcatOpLite::CheckShape() const {
CHECK_GT_OR_FALSE(param_.x.size(), 1UL);
CHECK_OR_FALSE(param_.output);
return true;
}

bool ConcatOpLite::InferShape() const {
std::vector<framework::DDim> input_dims;
for (auto p : param_.x) {
input_dims.push_back(p->dims().data());
}
size_t axis = static_cast<size_t>(param_.axis);
const size_t n = input_dims.size();
CHECK_GT_OR_FALSE(n, 0);
auto &out_dims = input_dims[0];
size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) {
for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) {
out_dims[axis] += input_dims[i][j];
} else {
CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]);
}
}
}
if (out_dims[axis] < 0) {
out_dims[axis] = -1;
}
// Set output dims
param_.output->Resize(lite::DDim(out_dims));
return true;
}

// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool ConcatOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) {
auto inputs = op_desc.Input("X");
auto out = op_desc.Output("Out").front();

for (auto var : inputs) {
param_.x.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.axis = GetAttr<int>(op_desc.GetAttr("axis"));

return true;
}

} // namespace operators
} // namespace lite
} // namespace paddle

REGISTER_LITE_OP(concat, paddle::lite::operators::ConcatOpLite);
46 changes: 46 additions & 0 deletions paddle/fluid/lite/operators/concat_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/utils/all.h"

namespace paddle {
namespace lite {
namespace operators {

class ConcatOpLite : public OpLite {
public:
ConcatOpLite() {}
explicit ConcatOpLite(const std::string &op_type) : OpLite(op_type) {}

bool CheckShape() const override;

bool InferShape() const override;

bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override;

void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "concat"; }

private:
mutable ConcatParam param_;
};

} // namespace operators
} // namespace lite
} // namespace paddle
59 changes: 59 additions & 0 deletions paddle/fluid/lite/operators/concat_op_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/lite/operators/concat_op.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace operators {

TEST(concat_op_lite, test) {
// prepare variables
lite::Scope scope;
auto* x0 = scope.Var("x0")->GetMutable<lite::Tensor>();
auto* x1 = scope.Var("x1")->GetMutable<lite::Tensor>();
auto* output = scope.Var("output")->GetMutable<lite::Tensor>();
x0->Resize(lite::DDim(std::vector<int64_t>({10, 20})));
x1->Resize(lite::DDim(std::vector<int64_t>({10, 20})));
output->Resize(lite::DDim(std::vector<int64_t>{20, 20}));

// set data
for (int i = 0; i < 10 * 20; i++) {
x0->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 10 * 20; i++) {
x1->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 10 * 20; i++) {
output->mutable_data<float>()[i] = 0.;
}

// prepare op desc
lite::OpDesc desc;
desc.SetType("concat");
desc.SetInput("X", {"x0", "x1"});
desc.SetOutput("Out", {"output"});
desc.SetAttr("axis", static_cast<int>(0));

ConcatOpLite concat("concat");

concat.SetValidPlaces({Place{TARGET(kX86), PRECISION(kFloat)}});
concat.Attach(desc, &scope);
}

} // namespace operators
} // namespace lite
} // namespace paddle
7 changes: 7 additions & 0 deletions paddle/fluid/lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ struct ReshapeParam {

std::vector<int> shape{};
bool inplace{false};
}

// For Concat op
struct ConcatParam {
std::vector<lite::Tensor*> x{};
lite::Tensor* output{};
int axis{0};
};

// For Convolution op
Expand Down