@@ -16,7 +16,7 @@ static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(
1616static bool IsSupportedDataType (const Node& node) {
1717 for (const auto & input_arg : node.InputDefs ()) {
1818 if (std::find (supported_data_types.begin (), supported_data_types.end (),
19- *(input_arg->Type ())) == supported_data_types.end ()) {
19+ *(input_arg->Type ())) == supported_data_types.end ()) {
2020 return false ;
2121 }
2222 }
@@ -56,7 +56,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
5656 Node& reduce_mean_node = *p_reduce_mean;
5757 ORT_RETURN_IF_ERROR (Recurse (reduce_mean_node, modified, graph_level, logger));
5858
59- if (!graph_utils::IsSupportedOptypeVersionAndDomain (reduce_mean_node, " ReduceMean" , {1 }) ||
59+ if (!graph_utils::IsSupportedOptypeVersionAndDomain (reduce_mean_node, " ReduceMean" , {1 , 11 }) ||
6060 !graph_utils::IsSupportedProvider (reduce_mean_node, GetCompatibleExecutionProviders ()) ||
6161 (reduce_mean_node.GetOutputEdgesCount () != 1 && reduce_mean_node.GetOutputEdgesCount () != 2 ) ||
6262 !IsSupportedDataType (reduce_mean_node)) {
@@ -95,7 +95,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
9595 continue ;
9696 }
9797 nodes_to_remove.push_back (sub_node);
98-
98+
9999 // Find the "Div" node after "Sub".
100100 const Node* p_div = nullptr ;
101101 p_div = graph_utils::FirstChildByType (sub_node, " Div" );
@@ -110,7 +110,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
110110 continue ;
111111 }
112112 nodes_to_remove.push_back (sub_node_dup);
113- // Find Div node after the duplicated sub node if it's not found after the first sub node.
113+ // Find Div node after the duplicated sub node if it's not found after the first sub node.
114114 if (p_div == nullptr ) {
115115 p_div = graph_utils::FirstChildByType (sub_node_dup, " Div" );
116116 }
@@ -138,7 +138,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
138138 if (!graph_utils::IsSupportedOptypeVersionAndDomain (sqrt_node, " Sqrt" , {6 }) ||
139139 sqrt_node.GetExecutionProviderType () != reduce_mean_node.GetExecutionProviderType () ||
140140 sqrt_node.GetOutputEdgesCount () != 1 ||
141- !IsSupportedDataType (sqrt_node) ||
141+ !IsSupportedDataType (sqrt_node) ||
142142 sqrt_node.GetInputEdgesCount () == 0 ) {
143143 continue ;
144144 }
@@ -162,10 +162,10 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
162162 continue ;
163163 }
164164 Node& reduce_mean2_node = *graph.GetNode (p_reduce_mean2->Index ());
165- if (!graph_utils::IsSupportedOptypeVersionAndDomain (reduce_mean2_node, " ReduceMean" , {1 }) ||
165+ if (!graph_utils::IsSupportedOptypeVersionAndDomain (reduce_mean2_node, " ReduceMean" , {1 , 11 }) ||
166166 reduce_mean2_node.GetExecutionProviderType () != reduce_mean_node.GetExecutionProviderType () ||
167167 reduce_mean2_node.GetOutputEdgesCount () != 1 ||
168- !IsSupportedDataType (reduce_mean2_node) ||
168+ !IsSupportedDataType (reduce_mean2_node) ||
169169 reduce_mean2_node.GetInputEdgesCount () == 0 ) {
170170 continue ;
171171 }
@@ -222,7 +222,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
222222 NodeArg* scale = nullptr ;
223223 NodeArg* bias = nullptr ;
224224 for (size_t i = 0 ; i < mul_node.MutableInputDefs ().size (); i++) {
225- if (graph_utils::NodeArgIsConstant (graph, *(mul_node.MutableInputDefs ()[i])) ||
225+ if (graph_utils::NodeArgIsConstant (graph, *(mul_node.MutableInputDefs ()[i])) ||
226226 graph_utils::IsGraphInput (graph, mul_node.MutableInputDefs ()[i])) {
227227 // Scale must be 1d.
228228 if (mul_node.MutableInputDefs ()[i]->Shape ()->dim_size () == 1 ) {
@@ -244,7 +244,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
244244 continue ;
245245 }
246246
247- // Scale and bias must have the same dimension.
247+ // Scale and bias must have the same dimension.
248248 if (scale->Shape ()->dim (0 ).dim_value () != bias->Shape ()->dim (0 ).dim_value ()) {
249249 continue ;
250250 }
@@ -267,4 +267,4 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
267267 }
268268 return Status::OK ();
269269}
270- } // namespace onnxruntime
270+ } // namespace onnxruntime
0 commit comments