Skip to content

Commit 27dd6e3

Browse files
committed
make layernorm fusion to support opset 11
1 parent ace132f commit 27dd6e3

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

onnxruntime/core/optimizer/layer_norm_fusion.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(
1616
static bool IsSupportedDataType(const Node& node) {
1717
for (const auto& input_arg : node.InputDefs()) {
1818
if (std::find(supported_data_types.begin(), supported_data_types.end(),
19-
*(input_arg->Type())) == supported_data_types.end()) {
19+
*(input_arg->Type())) == supported_data_types.end()) {
2020
return false;
2121
}
2222
}
@@ -56,7 +56,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
5656
Node& reduce_mean_node = *p_reduce_mean;
5757
ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger));
5858

59-
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1}) ||
59+
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11}) ||
6060
!graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) ||
6161
(reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) ||
6262
!IsSupportedDataType(reduce_mean_node)) {
@@ -95,7 +95,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
9595
continue;
9696
}
9797
nodes_to_remove.push_back(sub_node);
98-
98+
9999
// Find the "Div" node after "Sub".
100100
const Node* p_div = nullptr;
101101
p_div = graph_utils::FirstChildByType(sub_node, "Div");
@@ -110,7 +110,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
110110
continue;
111111
}
112112
nodes_to_remove.push_back(sub_node_dup);
113-
// Find Div node after the duplicated sub node if it's not found after the first sub node.
113+
// Find Div node after the duplicated sub node if it's not found after the first sub node.
114114
if (p_div == nullptr) {
115115
p_div = graph_utils::FirstChildByType(sub_node_dup, "Div");
116116
}
@@ -138,7 +138,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
138138
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sqrt_node, "Sqrt", {6}) ||
139139
sqrt_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
140140
sqrt_node.GetOutputEdgesCount() != 1 ||
141-
!IsSupportedDataType(sqrt_node) ||
141+
!IsSupportedDataType(sqrt_node) ||
142142
sqrt_node.GetInputEdgesCount() == 0) {
143143
continue;
144144
}
@@ -162,10 +162,10 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
162162
continue;
163163
}
164164
Node& reduce_mean2_node = *graph.GetNode(p_reduce_mean2->Index());
165-
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1}) ||
165+
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11}) ||
166166
reduce_mean2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
167167
reduce_mean2_node.GetOutputEdgesCount() != 1 ||
168-
!IsSupportedDataType(reduce_mean2_node) ||
168+
!IsSupportedDataType(reduce_mean2_node) ||
169169
reduce_mean2_node.GetInputEdgesCount() == 0) {
170170
continue;
171171
}
@@ -222,7 +222,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
222222
NodeArg* scale = nullptr;
223223
NodeArg* bias = nullptr;
224224
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
225-
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
225+
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
226226
graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
227227
// Scale must be 1d.
228228
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
@@ -244,7 +244,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
244244
continue;
245245
}
246246

247-
// Scale and bias must have the same dimension.
247+
// Scale and bias must have the same dimension.
248248
if (scale->Shape()->dim(0).dim_value() != bias->Shape()->dim(0).dim_value()) {
249249
continue;
250250
}
@@ -267,4 +267,4 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
267267
}
268268
return Status::OK();
269269
}
270-
} // namespace onnxruntime
270+
} // namespace onnxruntime

0 commit comments

Comments
 (0)