diff --git a/paddle/fluid/inference/api/api_anakin_engine.cc b/paddle/fluid/inference/api/api_anakin_engine.cc index 86e6fdaabf2dd4..4c51c239f6d444 100644 --- a/paddle/fluid/inference/api/api_anakin_engine.cc +++ b/paddle/fluid/inference/api/api_anakin_engine.cc @@ -208,6 +208,7 @@ bool PaddleInferenceAnakinPredictor::RunImpl( LOG(FATAL) << " input " << input.name << "'s shape size should be equal to that of net"; } +#ifndef ANAKIN_MLU_PLACE int sum = 1; for_each(input.shape.begin(), input.shape.end(), [&](int n) { sum *= n; }); if (sum > net_shape.count()) { @@ -221,6 +222,7 @@ bool PaddleInferenceAnakinPredictor::RunImpl( "memory."; } } +#endif std::vector tmp_shape; for (auto s : input.shape) { tmp_shape.push_back(s); @@ -229,8 +231,9 @@ bool PaddleInferenceAnakinPredictor::RunImpl( anakin::saber::Tensor::Host_type> h_tensor(data, typename anakin::DefaultHostType::Host_type(), 0, tmp_shape); +#ifndef ANAKIN_MLU_PLACE d_tensor_p->reshape(tmp_shape); - +#endif if (input.lod.size() > 0) { if (input.lod.size() > 1) { LOG(FATAL) << " input lod first dim should <=1, but you set " @@ -256,14 +259,18 @@ bool PaddleInferenceAnakinPredictor::RunImpl( LOG(FATAL) << output.name << " is not in the outputs of the graph."; } auto *d_tensor_p = this->executor_p_->get_out(output.name); - output.shape = d_tensor_p->valid_shape(); - if (output.data.length() < d_tensor_p->valid_size() * sizeof(float)) { - output.data.Resize(d_tensor_p->valid_size() * sizeof(float)); + auto tmp_shape = d_tensor_p->valid_shape(); +#ifdef ANAKIN_MLU_PLACE + tmp_shape.set_num(batch_size); +#endif + output.shape = tmp_shape; + if (output.data.length() < tmp_shape.count() * sizeof(float)) { + output.data.Resize(tmp_shape.count() * sizeof(float)); } auto *data = static_cast(output.data.data()); anakin::saber::Tensor::Host_type> h_tensor(data, typename anakin::DefaultHostType::Host_type(), 0, - d_tensor_p->valid_shape()); + tmp_shape); h_tensor.copy_from(*d_tensor_p); } return true; @@ -316,6 +323,7 @@ void PaddleInferenceAnakinMLUPredictor::SetContext() { this->config_.device_id, this->config_.data_stream_id, this->config_.compute_stream_id); this->ctx_p_->set_model_parallel(this->config_.model_parallel); + this->ctx_p_->set_fusion(this->config_.op_fuse); this->ctx_p_->enable_batch_changable(); this->ctx_p_->enable_channel_duplicate(); }