diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 15363bde8cf74..e57ada81f53e7 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -16,7 +16,7 @@ static std::vector supported_data_types{"tensor(float16)", "tensor( static bool IsSupportedDataType(const Node& node) { for (const auto& input_arg : node.InputDefs()) { if (std::find(supported_data_types.begin(), supported_data_types.end(), - *(input_arg->Type())) == supported_data_types.end()) { + *(input_arg->Type())) == supported_data_types.end()) { return false; } } @@ -56,7 +56,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& reduce_mean_node = *p_reduce_mean; ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger)); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11}) || !graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) || (reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) || !IsSupportedDataType(reduce_mean_node)) { @@ -95,7 +95,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } nodes_to_remove.push_back(sub_node); - + // Find the "Div" node after "Sub". const Node* p_div = nullptr; p_div = graph_utils::FirstChildByType(sub_node, "Div"); @@ -110,7 +110,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } nodes_to_remove.push_back(sub_node_dup); - // Find Div node after the duplicated sub node if it's not found after the first sub node. + // Find Div node after the duplicated sub node if it's not found after the first sub node. if (p_div == nullptr) { p_div = graph_utils::FirstChildByType(sub_node_dup, "Div"); } @@ -138,7 +138,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, if (!graph_utils::IsSupportedOptypeVersionAndDomain(sqrt_node, "Sqrt", {6}) || sqrt_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || sqrt_node.GetOutputEdgesCount() != 1 || - !IsSupportedDataType(sqrt_node) || + !IsSupportedDataType(sqrt_node) || sqrt_node.GetInputEdgesCount() == 0) { continue; } @@ -162,10 +162,10 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } Node& reduce_mean2_node = *graph.GetNode(p_reduce_mean2->Index()); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1}) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11}) || reduce_mean2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || reduce_mean2_node.GetOutputEdgesCount() != 1 || - !IsSupportedDataType(reduce_mean2_node) || + !IsSupportedDataType(reduce_mean2_node) || reduce_mean2_node.GetInputEdgesCount() == 0) { continue; } @@ -222,7 +222,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, NodeArg* scale = nullptr; NodeArg* bias = nullptr; for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || + if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) { // Scale must be 1d. if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) { @@ -244,7 +244,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, continue; } - // Scale and bias must have the same dimension. + // Scale and bias must have the same dimension. if (scale->Shape()->dim(0).dim_value() != bias->Shape()->dim(0).dim_value()) { continue; } @@ -267,4 +267,4 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } return Status::OK(); } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime