Skip to content

Commit efa6cb0

Browse files
authored
Merge pull request #16807 from Shixiaowei02/engine2-interface
update fc convert
2 parents a7b6291 + 4af8f6d commit efa6cb0

File tree

2 files changed

+19
-46
lines changed

2 files changed

+19
-46
lines changed

paddle/fluid/inference/anakin/convert/fc.cc

Lines changed: 13 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/inference/anakin/convert/fc.h"
16+
#include "paddle/fluid/inference/anakin/convert/helper.h"
1617
#include <algorithm>
1718
#include <string>
1819
#include <vector>
@@ -45,72 +46,39 @@ void FcBaseOpConverter<TargetT>::operator()(
4546
// get weights
4647
auto *y_v = scope.FindVar(op_desc.Input(w_name).front());
4748
PADDLE_ENFORCE_NOT_NULL(y_v);
48-
auto *y_t = y_v->GetMutable<framework::LoDTensor>();
49+
auto weight_tensor = tensor_from_var(*y_v, platform::CPUPlace());
50+
auto weight_shape = framework::vectorize2int(weight_tensor->dims());
51+
52+
int out_dim = weight_shape[1];
53+
const int w_m = weight_shape[0];
54+
const int w_k = weight_shape[1];
4955

5056
auto input_name = op_desc.Input(i_name).front();
5157
auto output_name = op_desc.Output("Out").front();
5258

5359
this->engine_->AddOp(op_name, "Dense", {input_name}, {output_name});
5460
this->engine_->AddOpAttr(op_name, "bias_term", with_bias);
5561
this->engine_->AddOpAttr(op_name, "axis", 1);
56-
57-
auto weight_shape = framework::vectorize2int(y_t->dims());
58-
int out_dim = weight_shape[1];
5962
this->engine_->AddOpAttr(op_name, "out_dim", out_dim);
60-
const int w_m = weight_shape[0];
61-
const int w_k = weight_shape[1];
6263

63-
if (weight_shape.size() < 4UL) {
64-
weight_shape.insert(weight_shape.begin(), 4UL - weight_shape.size(), 1);
65-
}
66-
Shape anakin_shape(weight_shape);
64+
auto *weight_data = weight_tensor->data<float>();
65+
PADDLE_ENFORCE(w_m * w_k == weight_tensor->numel());
6766

68-
framework::LoDTensor weight_tensor;
69-
weight_tensor.Resize(y_t->dims());
70-
TensorCopySync((*y_t), platform::CPUPlace(), &weight_tensor);
71-
auto *weight_data = weight_tensor.data<float>();
72-
PADDLE_ENFORCE(w_m * w_k == weight_tensor.numel());
73-
74-
std::vector<float> trans_weight_data(weight_tensor.numel());
67+
std::vector<float> trans_weight_data(weight_tensor->numel());
7568
for (int i = 0; i < w_m; i++) {
7669
for (int j = 0; j < w_k; j++) {
7770
trans_weight_data[i + j * w_m] = weight_data[i * w_k + j];
7871
}
7972
}
80-
auto *weight1 =
81-
GraphGlobalMem<TargetT>::Global().template new_block<AK_FLOAT>(
82-
anakin_shape);
83-
float *cpu_data = static_cast<float *>(weight1->h_tensor().mutable_data());
84-
std::copy_n(trans_weight_data.data(), weight_tensor.numel(), cpu_data);
85-
weight1->d_tensor().set_shape(anakin_shape);
86-
weight1->d_tensor().copy_from(weight1->h_tensor());
73+
74+
auto *weight1 = pblock_from_vector<TargetT>(trans_weight_data);
8775
this->engine_->AddOpAttr(op_name, "weight_1", *weight1);
8876

8977
// get bias
9078
if (with_bias) {
9179
auto *b_v = scope.FindVar(op_desc.Input("Bias").front());
9280
PADDLE_ENFORCE_NOT_NULL(b_v);
93-
auto *b_t = b_v->GetMutable<framework::LoDTensor>();
94-
95-
auto bias_shape = framework::vectorize2int(b_t->dims());
96-
framework::LoDTensor bias_tensor;
97-
bias_tensor.Resize(b_t->dims());
98-
TensorCopySync((*b_t), platform::CPUPlace(), &bias_tensor);
99-
auto *bias_data = bias_tensor.data<float>();
100-
bias_shape.insert(bias_shape.begin(), 1);
101-
bias_shape.insert(bias_shape.begin(), 1);
102-
bias_shape.insert(bias_shape.begin(), 1);
103-
// bias_shape.push_back(1);
104-
// bias_shape.push_back(1);
105-
Shape anakin_bias_shape(bias_shape);
106-
107-
auto *weight2 =
108-
GraphGlobalMem<TargetT>::Global().template new_block<AK_FLOAT>(
109-
anakin_bias_shape);
110-
float *cpu_data2 = static_cast<float *>(weight2->h_tensor().mutable_data());
111-
std::copy_n(bias_data, bias_tensor.numel(), cpu_data2);
112-
weight2->d_tensor().set_shape(anakin_bias_shape);
113-
weight2->d_tensor().copy_from(weight2->h_tensor());
81+
auto weight2 = pblock_from_var<TargetT>(*b_v);
11482
this->engine_->AddOpAttr(op_name, "weight_2", *weight2);
11583
}
11684
}

paddle/fluid/inference/anakin/convert/helper.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,17 @@ PBlock<T>* pblock_from_tensor(const framework::LoDTensor& tensor,
5555

5656
template <typename T>
5757
PBlock<T>* pblock_from_vector(const std::vector<float>& vec,
58-
const std::vector<int>& shape_vec) {
58+
std::vector<int> shape_vec) {
59+
while (shape_vec.size() < 4) {
60+
shape_vec.insert(shape_vec.begin(), 1);
61+
}
5962
Shape shape(shape_vec);
6063
auto *weight =
6164
GraphGlobalMem<T>::Global().template new_block<AK_FLOAT>(shape);
6265
auto *weight_data = static_cast<float *>(weight->h_tensor().mutable_data());
6366
std::copy(std::begin(vec), std::end(vec), weight_data);
67+
weight->d_tensor().set_shape(shape);
68+
weight->d_tensor().copy_from(weight->h_tensor());
6469
return weight;
6570
}
6671

0 commit comments

Comments
 (0)