Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions paddle/fluid/inference/api/api_anakin_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ bool PaddleInferenceAnakinPredictor<T, P, R>::RunImpl(
LOG(FATAL) << " input " << input.name
<< "'s shape size should be equal to that of net";
}
#ifndef ANAKIN_MLU_PLACE
int sum = 1;
for_each(input.shape.begin(), input.shape.end(), [&](int n) { sum *= n; });
if (sum > net_shape.count()) {
Expand All @@ -221,6 +222,7 @@ bool PaddleInferenceAnakinPredictor<T, P, R>::RunImpl(
"memory.";
}
}
#endif
std::vector<int> tmp_shape;
for (auto s : input.shape) {
tmp_shape.push_back(s);
Expand All @@ -229,8 +231,9 @@ bool PaddleInferenceAnakinPredictor<T, P, R>::RunImpl(
anakin::saber::Tensor<typename anakin::DefaultHostType<T>::Host_type>
h_tensor(data, typename anakin::DefaultHostType<T>::Host_type(), 0,
tmp_shape);
#ifndef ANAKIN_MLU_PLACE
d_tensor_p->reshape(tmp_shape);

#endif
if (input.lod.size() > 0) {
if (input.lod.size() > 1) {
LOG(FATAL) << " input lod first dim should <=1, but you set "
Expand All @@ -256,14 +259,18 @@ bool PaddleInferenceAnakinPredictor<T, P, R>::RunImpl(
LOG(FATAL) << output.name << " is not in the outputs of the graph.";
}
auto *d_tensor_p = this->executor_p_->get_out(output.name);
output.shape = d_tensor_p->valid_shape();
if (output.data.length() < d_tensor_p->valid_size() * sizeof(float)) {
output.data.Resize(d_tensor_p->valid_size() * sizeof(float));
auto tmp_shape = d_tensor_p->valid_shape();
#ifdef ANAKIN_MLU_PLACE
tmp_shape.set_num(batch_size);
#endif
output.shape = tmp_shape;
if (output.data.length() < tmp_shape.count() * sizeof(float)) {
output.data.Resize(tmp_shape.count() * sizeof(float));
}
auto *data = static_cast<float *>(output.data.data());
anakin::saber::Tensor<typename anakin::DefaultHostType<T>::Host_type>
h_tensor(data, typename anakin::DefaultHostType<T>::Host_type(), 0,
d_tensor_p->valid_shape());
tmp_shape);
h_tensor.copy_from(*d_tensor_p);
}
return true;
Expand Down Expand Up @@ -316,6 +323,7 @@ void PaddleInferenceAnakinMLUPredictor<P, R>::SetContext() {
this->config_.device_id, this->config_.data_stream_id,
this->config_.compute_stream_id);
this->ctx_p_->set_model_parallel(this->config_.model_parallel);
this->ctx_p_->set_fusion(this->config_.op_fuse);
this->ctx_p_->enable_batch_changable();
this->ctx_p_->enable_channel_duplicate();
}
Expand Down