Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(
static bool IsSupportedDataType(const Node& node) {
for (const auto& input_arg : node.InputDefs()) {
if (std::find(supported_data_types.begin(), supported_data_types.end(),
*(input_arg->Type())) == supported_data_types.end()) {
*(input_arg->Type())) == supported_data_types.end()) {
return false;
}
}
Expand Down Expand Up @@ -56,7 +56,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Node& reduce_mean_node = *p_reduce_mean;
ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger));

if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11}) ||
!graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) ||
(reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) ||
!IsSupportedDataType(reduce_mean_node)) {
Expand Down Expand Up @@ -95,7 +95,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
nodes_to_remove.push_back(sub_node);

// Find the "Div" node after "Sub".
const Node* p_div = nullptr;
p_div = graph_utils::FirstChildByType(sub_node, "Div");
Expand All @@ -110,7 +110,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
nodes_to_remove.push_back(sub_node_dup);
// Find Div node after the duplicated sub node if it's not found after the first sub node.
// Find Div node after the duplicated sub node if it's not found after the first sub node.
if (p_div == nullptr) {
p_div = graph_utils::FirstChildByType(sub_node_dup, "Div");
}
Expand Down Expand Up @@ -138,7 +138,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sqrt_node, "Sqrt", {6}) ||
sqrt_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
sqrt_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(sqrt_node) ||
!IsSupportedDataType(sqrt_node) ||
sqrt_node.GetInputEdgesCount() == 0) {
continue;
}
Expand All @@ -162,10 +162,10 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
Node& reduce_mean2_node = *graph.GetNode(p_reduce_mean2->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11}) ||
reduce_mean2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
reduce_mean2_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(reduce_mean2_node) ||
!IsSupportedDataType(reduce_mean2_node) ||
reduce_mean2_node.GetInputEdgesCount() == 0) {
continue;
}
Expand Down Expand Up @@ -222,7 +222,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
NodeArg* scale = nullptr;
NodeArg* bias = nullptr;
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
// Scale must be 1d.
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
Expand All @@ -244,7 +244,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}

// Scale and bias must have the same dimension.
// Scale and bias must have the same dimension.
if (scale->Shape()->dim(0).dim_value() != bias->Shape()->dim(0).dim_value()) {
continue;
}
Expand All @@ -267,4 +267,4 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
return Status::OK();
}
} // namespace onnxruntime
} // namespace onnxruntime