Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions lite/core/op_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,19 @@ class OpLite : public Registry {
std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::unique_ptr<OpInfo> op_info_;
// Infer Shape according to memory, if current input shapes are consistent
// with that of previous inputs, output shapes of last time will be reused.
std::vector<const Tensor *> input_tensor_ptrs_cache_{};
std::vector<Tensor *> output_tensor_ptrs_cache_{};

private:
// todo: it's prefered to combine last_input_shapes and
// last_input_lods into a single hash value to decrease
// memory usage.
private:
std::vector<DDimLite> last_input_shapes_{};
std::vector<LoD> last_input_lods_{};
std::vector<DDimLite> last_output_shapes_{};
std::vector<LoD> last_output_lods_{};
std::vector<const Tensor *> input_tensor_ptrs_cache_{};
std::vector<Tensor *> output_tensor_ptrs_cache_{};
// Infer Shape according to memory, if current input shapes are consistent
// with that of previous inputs, output shapes of last time will be reused.
};

/*
Expand Down
2 changes: 2 additions & 0 deletions lite/operators/batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.variance =
scope->FindVar(op_desc.Input("Variance").front())->GetMutable<Tensor>();
param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>();
input_tensor_ptrs_cache_.push_back(param_.x);
output_tensor_ptrs_cache_.push_back(param_.y);

auto is_test_type = op_desc.GetAttrType("is_test");
switch (is_test_type) {
Expand Down
3 changes: 3 additions & 0 deletions lite/operators/box_coder_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ bool BoxCoderOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
if (opdesc.HasAttr("variance")) {
param_.variance = opdesc.GetAttr<std::vector<float>>("variance");
}
input_tensor_ptrs_cache_.push_back(param_.prior_box);
input_tensor_ptrs_cache_.push_back(param_.target_box);
output_tensor_ptrs_cache_.push_back(param_.proposals);
return true;
}

Expand Down
3 changes: 3 additions & 0 deletions lite/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,13 @@ bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.x.clear();
for (auto var : inputs) {
param_.x.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
input_tensor_ptrs_cache_.push_back(
scope->FindVar(var)->GetMutable<lite::Tensor>());
}
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.axis = op_desc.GetAttr<int>("axis");
output_tensor_ptrs_cache_.push_back(param_.output);

std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "AxisTensor") !=
Expand Down
2 changes: 2 additions & 0 deletions lite/operators/conv_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class ConvOpLite : public OpLite {
CHECK(param_.x);
CHECK(param_.filter);
CHECK(param_.output);
input_tensor_ptrs_cache_.push_back(param_.x);
output_tensor_ptrs_cache_.push_back(param_.output);

param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
std::vector<int> paddings = op_desc.GetAttr<std::vector<int>>("paddings");
Expand Down
2 changes: 2 additions & 0 deletions lite/operators/deformable_conv_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class DeformableConvOpLite : public OpLite {
param_.conv_param.dilations = std::make_shared<std::vector<int>>(dilations);
std::vector<int> paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.conv_param.paddings = std::make_shared<std::vector<int>>(paddings);
input_tensor_ptrs_cache_.push_back(param_.x);
output_tensor_ptrs_cache_.push_back(param_.output);

// optional params
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
Expand Down
3 changes: 3 additions & 0 deletions lite/operators/elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
param_.alpha = opdesc.GetAttr<float>("alpha");
param_.bias = opdesc.GetAttr<float>("bias");
}
input_tensor_ptrs_cache_.push_back(param_.X);
input_tensor_ptrs_cache_.push_back(param_.Y);
output_tensor_ptrs_cache_.push_back(param_.Out);

return true;
}
Expand Down
2 changes: 2 additions & 0 deletions lite/operators/fc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.in_num_col_dims = op_desc.GetAttr<int>("in_num_col_dims");
input_tensor_ptrs_cache_.push_back(param_.input);
output_tensor_ptrs_cache_.push_back(param_.output);

if (op_desc.HasAttr("activation_type")) {
param_.activation_type = op_desc.GetAttr<std::string>("activation_type");
Expand Down
3 changes: 3 additions & 0 deletions lite/operators/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ bool MatMulOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.transpose_X = op_desc.GetAttr<bool>("transpose_X");
param_.transpose_Y = op_desc.GetAttr<bool>("transpose_Y");
param_.alpha = op_desc.GetAttr<float>("alpha");
input_tensor_ptrs_cache_.push_back(param_.X);
input_tensor_ptrs_cache_.push_back(param_.Y);
output_tensor_ptrs_cache_.push_back(param_.Out);

const OpInfo *op_info = static_cast<const OpInfo *>(&op_desc);
if (op_info != nullptr && op_info->HasAttr("enable_int8")) {
Expand Down
3 changes: 3 additions & 0 deletions lite/operators/matmul_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ bool MatMulV2OpLite::AttachImpl(const cpp::OpDesc &op_desc,
if (op_desc.HasAttr("alpha")) {
param_.alpha = op_desc.GetAttr<float>("alpha");
}
input_tensor_ptrs_cache_.push_back(param_.X);
input_tensor_ptrs_cache_.push_back(param_.Y);
output_tensor_ptrs_cache_.push_back(param_.Out);
return true;
}

Expand Down
3 changes: 3 additions & 0 deletions lite/operators/mul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class MulOpLite : public OpLite {
if (op_info->HasOutputScale(out_scale_name, true))
param_.output_scale = op_info->GetOutputScale(out_scale_name, true)[0];
}
input_tensor_ptrs_cache_.push_back(param_.x);
input_tensor_ptrs_cache_.push_back(param_.y);
output_tensor_ptrs_cache_.push_back(param_.output);

return true;
}
Expand Down
2 changes: 2 additions & 0 deletions lite/operators/pool_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class PoolOpLite : public OpLite {
CHECK(scope->FindVar(out));
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
input_tensor_ptrs_cache_.push_back(param_.x);
output_tensor_ptrs_cache_.push_back(param_.output);

param_.pooling_type = op_desc.GetAttr<std::string>("pooling_type");
param_.ksize = op_desc.GetAttr<std::vector<int>>("ksize");
Expand Down
2 changes: 2 additions & 0 deletions lite/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ bool ReshapeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
CHECK(param_.output);
input_tensor_ptrs_cache_.push_back(param_.x);
output_tensor_ptrs_cache_.push_back(param_.output);

// prority: input(ShapeTensor) > input(Shape) > attr(shape)
param_.shape_tensor_vct.clear();
Expand Down
3 changes: 3 additions & 0 deletions lite/operators/split_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
auto outs_name = opdesc.Output("Out");
for (auto name : outs_name) {
param_.output.push_back(scope->FindMutableTensor(name));
output_tensor_ptrs_cache_.push_back(scope->FindMutableTensor(name));
}
input_tensor_ptrs_cache_.push_back(param_.x);

return true;
}

Expand Down
2 changes: 2 additions & 0 deletions lite/operators/squeeze_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ bool SqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
if (opdesc.HasAttr("inplace")) {
param_.inplace = opdesc.GetAttr<bool>("inplace");
}
input_tensor_ptrs_cache_.push_back(param_.X);
output_tensor_ptrs_cache_.push_back(param_.Out);
return true;
}

Expand Down
2 changes: 2 additions & 0 deletions lite/operators/transpose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ bool TransposeOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
if (op_desc.HasAttr("data_format")) {
param_.data_format = op_desc.GetAttr<std::string>("data_format");
}
input_tensor_ptrs_cache_.push_back(param_.x);
output_tensor_ptrs_cache_.push_back(param_.output);
return true;
}

Expand Down
2 changes: 2 additions & 0 deletions lite/operators/unsqueeze_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
if (opdesc.HasAttr("inplace")) {
param_.inplace = opdesc.GetAttr<bool>("inplace");
}
input_tensor_ptrs_cache_.push_back(param_.X);
output_tensor_ptrs_cache_.push_back(param_.Out);
return true;
}

Expand Down