Skip to content

Commit dbcee05

Browse files
authored
[NNAdapter][HuaweiAscendNPU] Support transformer model (#8594)
1 parent 2c5feaa commit dbcee05

19 files changed

Lines changed: 239 additions & 22 deletions

lite/backends/nnadapter/nnadapter/src/driver/huawei_ascend_npu/converter/range.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ namespace huawei_ascend_npu {
2323

2424
int ConvertRange(Converter* converter, core::Operation* operation) {
2525
RANGE_OPERATION_EXTRACT_INPUTS_OUTPUTS
26-
NNADAPTER_CHECK(IsOperationWithAllInputConstantOperands(operation))
27-
<< "Range input operands only support constant!";
2826

2927
// Convert to GE operators
3028
auto start_operator = converter->GetMappedOperator(start_operand);

lite/backends/nnadapter/nnadapter/src/operation/cast.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "operation/cast.h"
1616
#include "core/types.h"
1717
#include "utility/debug.h"
18+
#include "utility/hints.h"
1819
#include "utility/logging.h"
1920
#include "utility/modeling.h"
2021
#include "utility/utility.h"
@@ -28,6 +29,7 @@ int PrepareCast(core::Operation* operation) {
2829
// Infer the shape and type of output operands
2930
CopyOperandTypeExceptQuantParams(&output_operand->type, input_operand->type);
3031
output_operand->type.precision = dtype;
32+
SetTemporaryShape(output_operand, input_operand->type.dimensions);
3133
NNADAPTER_VLOG(5) << "output: " << OperandToString(output_operand);
3234
return NNADAPTER_NO_ERROR;
3335
}

lite/backends/nnadapter/nnadapter/src/operation/range.cc

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,28 @@
1515
#include "operation/range.h"
1616
#include "core/types.h"
1717
#include "utility/debug.h"
18+
#include "utility/hints.h"
1819
#include "utility/logging.h"
1920
#include "utility/modeling.h"
2021
#include "utility/utility.h"
2122

2223
namespace nnadapter {
2324
namespace operation {
2425

26+
void GetRangeOperandValue(core::Operand* operand, int64_t* data) { // NOLINT
27+
if (IsConstantOperand(operand)) {
28+
*data = reinterpret_cast<int64_t*>(operand->buffer)[0];
29+
} else if (IsTemporaryShapeOperand(operand)) {
30+
auto& temporary_shape = *(GetTemporaryShape(operand));
31+
NNADAPTER_CHECK_EQ(temporary_shape.count, 1);
32+
*data = temporary_shape.data[0];
33+
} else {
34+
NNADAPTER_LOG(FATAL) << "Unsupported operand precision"
35+
<< OperandPrecisionCodeToString(
36+
operand->type.precision);
37+
}
38+
}
39+
2540
int PrepareRange(core::Operation* operation) {
2641
RANGE_OPERATION_EXTRACT_INPUTS_OUTPUTS
2742

@@ -31,19 +46,22 @@ int PrepareRange(core::Operation* operation) {
3146
NNADAPTER_CHECK_EQ(limit_operand->type.dimensions.count, 1);
3247
NNADAPTER_CHECK_EQ(delta_operand->type.dimensions.count, 1);
3348

34-
if (IsConstantOperand(start_operand) && IsConstantOperand(limit_operand) &&
35-
IsConstantOperand(delta_operand)) {
36-
auto start_data = reinterpret_cast<float*>(start_operand->buffer)[0];
37-
auto limit_data = reinterpret_cast<float*>(limit_operand->buffer)[0];
38-
auto delta_data = reinterpret_cast<float*>(delta_operand->buffer)[0];
49+
int64_t start_data, limit_data, delta_data;
50+
start_data = limit_data = delta_data = -1;
51+
GetRangeOperandValue(start_operand, &start_data);
52+
GetRangeOperandValue(limit_operand, &limit_data);
53+
GetRangeOperandValue(delta_operand, &delta_data);
54+
55+
output_type.dimensions.count = 1;
56+
output_type.precision = start_operand->type.precision;
57+
output_type.lifetime = NNADAPTER_TEMPORARY_VARIABLE;
58+
if (start_data == NNADAPTER_UNKNOWN || limit_data == NNADAPTER_UNKNOWN ||
59+
delta_data == NNADAPTER_UNKNOWN) {
60+
output_type.dimensions.data[0] = NNADAPTER_UNKNOWN;
61+
} else {
3962
output_type.dimensions.data[0] =
4063
GetSpanCount(start_data, limit_data, delta_data);
41-
} else {
42-
output_type.dimensions.data[0] = NNADAPTER_UNKNOWN;
4364
}
44-
output_type.precision = start_operand->type.precision;
45-
output_type.lifetime = NNADAPTER_TEMPORARY_VARIABLE;
46-
output_type.dimensions.count = 1;
4765

4866
NNADAPTER_VLOG(5) << "output: " << OperandToString(output_operand);
4967
return NNADAPTER_NO_ERROR;

lite/core/optimizer/mir/elimination/assign_value_calc_offline_pass.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "lite/core/optimizer/mir/pass.h"
2323
#include "lite/core/optimizer/mir/pass_registry.h"
2424
#include "lite/core/optimizer/mir/pattern_matcher.h"
25+
#include "lite/core/optimizer/mir/ssa_graph_utils.h"
2526
#include "lite/model_parser/cpp_desc.h"
2627

2728
namespace paddle {
@@ -36,6 +37,20 @@ void AssignValueCalcOfflinePass::RemoveAssignValuePattern(
3637
const std::unique_ptr<SSAGraph>& graph) {
3738
for (auto& node : graph->StmtTopologicalOrder()) {
3839
if (node->AsStmt().op_type() != "assign_value") continue;
40+
auto outlinks = node->outlinks;
41+
bool has_extra_producers = false;
42+
for (auto& out_link : outlinks) {
43+
if (HasExtraProducers(
44+
graph.get(), out_link->arg()->name, {"assign_value"})) {
45+
has_extra_producers = true;
46+
break;
47+
}
48+
}
49+
if (has_extra_producers) {
50+
LOG(WARNING)
51+
<< "The output var of op is not supported with multiple producers";
52+
continue;
53+
}
3954

4055
std::set<const Node*> nodes2rm_;
4156
auto& assign_value_instruct = node->AsStmt();

lite/core/optimizer/mir/elimination/fill_constant_calc_offline_pass.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "lite/core/optimizer/mir/pass.h"
2323
#include "lite/core/optimizer/mir/pass_registry.h"
2424
#include "lite/core/optimizer/mir/pattern_matcher.h"
25+
#include "lite/core/optimizer/mir/ssa_graph_utils.h"
2526
#include "lite/model_parser/cpp_desc.h"
2627

2728
namespace paddle {
@@ -45,7 +46,20 @@ void FillConstantCalcOfflinePass::RemoveFillConstantPattern(
4546
const std::unique_ptr<SSAGraph>& graph) {
4647
for (auto& node : graph->StmtTopologicalOrder()) {
4748
if (node->AsStmt().op_type() != "fill_constant") continue;
48-
49+
auto outlinks = node->outlinks;
50+
bool has_extra_producers = false;
51+
for (auto& out_link : outlinks) {
52+
if (HasExtraProducers(
53+
graph.get(), out_link->arg()->name, {"fill_constant"})) {
54+
has_extra_producers = true;
55+
break;
56+
}
57+
}
58+
if (has_extra_producers) {
59+
LOG(WARNING)
60+
<< "Unsupported for op output var containing multiple producers";
61+
continue;
62+
}
4963
std::set<const Node*> nodes2rm_;
5064
auto& fill_constant_instruct = node->AsStmt();
5165
auto* scope = fill_constant_instruct.op()->scope();

lite/core/optimizer/mir/elimination/range_calc_offline_pass.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "lite/core/optimizer/mir/pass.h"
2323
#include "lite/core/optimizer/mir/pass_registry.h"
2424
#include "lite/core/optimizer/mir/pattern_matcher.h"
25+
#include "lite/core/optimizer/mir/ssa_graph_utils.h"
2526
#include "lite/model_parser/cpp_desc.h"
2627

2728
namespace paddle {
@@ -43,6 +44,19 @@ void RangeCalcOfflinePass::RemoveRangePattern(
4344
const std::unique_ptr<SSAGraph>& graph) {
4445
for (auto& node : graph->StmtTopologicalOrder()) {
4546
if (node->AsStmt().op_type() != "range") continue;
47+
auto outlinks = node->outlinks;
48+
bool has_extra_producers = false;
49+
for (auto& out_link : outlinks) {
50+
if (HasExtraProducers(graph.get(), out_link->arg()->name, {"range"})) {
51+
has_extra_producers = true;
52+
break;
53+
}
54+
}
55+
if (has_extra_producers) {
56+
LOG(WARNING)
57+
<< "Unsupported for op output var containing multiple producers";
58+
continue;
59+
}
4660

4761
std::set<const Node*> nodes2rm_;
4862
auto& range_instruct = node->AsStmt();

lite/core/optimizer/mir/elimination/scale_calc_offline_pass.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "lite/core/optimizer/mir/pass.h"
2323
#include "lite/core/optimizer/mir/pass_registry.h"
2424
#include "lite/core/optimizer/mir/pattern_matcher.h"
25+
#include "lite/core/optimizer/mir/ssa_graph_utils.h"
2526
#include "lite/model_parser/cpp_desc.h"
2627

2728
namespace paddle {
@@ -36,6 +37,19 @@ void ScaleCalcOfflinePass::RemoveScalePattern(
3637
const std::unique_ptr<SSAGraph>& graph) {
3738
for (auto& node : graph->StmtTopologicalOrder()) {
3839
if (node->AsStmt().op_type() != "scale") continue;
40+
auto outlinks = node->outlinks;
41+
bool has_extra_producers = false;
42+
for (auto& out_link : outlinks) {
43+
if (HasExtraProducers(graph.get(), out_link->arg()->name, {"scale"})) {
44+
has_extra_producers = true;
45+
break;
46+
}
47+
}
48+
if (has_extra_producers) {
49+
LOG(WARNING)
50+
<< "Unsupported for op output var containing multiple producers";
51+
continue;
52+
}
3953

4054
std::set<const Node*> nodes2rm_;
4155
auto& scale_instruct = node->AsStmt();

lite/core/optimizer/mir/elimination/ssd_boxes_calc_offline_pass.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "lite/core/optimizer/mir/pass.h"
2323
#include "lite/core/optimizer/mir/pass_registry.h"
2424
#include "lite/core/optimizer/mir/pattern_matcher.h"
25+
#include "lite/core/optimizer/mir/ssa_graph_utils.h"
2526
#include "lite/model_parser/cpp_desc.h"
2627

2728
namespace paddle {
@@ -41,6 +42,21 @@ void SSDBoxesCalcOfflinePass::RemovePriorboxPattern(
4142
if (node->AsStmt().op_type() != "prior_box" &&
4243
node->AsStmt().op_type() != "density_prior_box")
4344
continue;
45+
auto outlinks = node->outlinks;
46+
bool has_extra_producers = false;
47+
for (auto& out_link : outlinks) {
48+
if (HasExtraProducers(graph.get(),
49+
out_link->arg()->name,
50+
{"prior_box", "density_prior_box"})) {
51+
has_extra_producers = true;
52+
break;
53+
}
54+
}
55+
if (has_extra_producers) {
56+
LOG(WARNING)
57+
<< "Unsupported for op output var containing multiple producers";
58+
continue;
59+
}
4460

4561
std::set<const Node*> nodes2rm_;
4662
auto& priorbox_instruct = node->AsStmt();

lite/core/optimizer/mir/elimination/unsqueeze_calc_offline_pass.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "lite/core/optimizer/mir/pass.h"
2323
#include "lite/core/optimizer/mir/pass_registry.h"
2424
#include "lite/core/optimizer/mir/pattern_matcher.h"
25+
#include "lite/core/optimizer/mir/ssa_graph_utils.h"
2526
#include "lite/model_parser/cpp_desc.h"
2627

2728
namespace paddle {
@@ -38,6 +39,21 @@ void UnsqueezeCalcOfflinePass::RemoveUnsqueezePattern(
3839
if (node->AsStmt().op_type() != "unsqueeze" &&
3940
node->AsStmt().op_type() != "unsqueeze2")
4041
continue;
42+
auto outlinks = node->outlinks;
43+
bool has_extra_producers = false;
44+
for (auto& out_link : outlinks) {
45+
if (HasExtraProducers(graph.get(),
46+
out_link->arg()->name,
47+
{"unsqueeze", "unsqueeze2"})) {
48+
has_extra_producers = true;
49+
break;
50+
}
51+
}
52+
if (has_extra_producers) {
53+
LOG(WARNING)
54+
<< "Unsupported for op output var containing multiple producers";
55+
continue;
56+
}
4157

4258
std::set<const Node*> nodes2rm_;
4359
auto& unsqueeze_instruct = node->AsStmt();
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/core/optimizer/mir/ssa_graph_utils.h"
16+
17+
namespace paddle {
18+
namespace lite {
19+
namespace mir {
20+
21+
bool HasExtraProducers(mir::SSAGraph *graph,
22+
const std::string &var_name,
23+
const std::set<std::string> &exclude_op_list,
24+
const std::set<std::string> &candidate_op) {
25+
for (auto &op_node : graph->StmtTopologicalOrder()) {
26+
if (!op_node->IsStmt()) continue;
27+
auto op_info = op_node->AsStmt().op_info();
28+
auto op_type = op_info->Type();
29+
if (exclude_op_list.count(op_type)) continue;
30+
if (candidate_op.empty() || candidate_op.count(op_type)) {
31+
for (auto &var_node : op_node->outlinks) {
32+
if (var_name == var_node->AsArg().name ||
33+
var_node->AsArg().name.find(std::string(var_name + "__Mangled_")) !=
34+
std::string::npos) {
35+
return true;
36+
}
37+
}
38+
}
39+
}
40+
return false;
41+
}
42+
43+
} // namespace mir
44+
} // namespace lite
45+
} // namespace paddle

0 commit comments

Comments
 (0)