diff --git a/onnxruntime/core/optimizer/bias_gelu_fusion.cc b/onnxruntime/core/optimizer/bias_gelu_fusion.cc index 94b8c601dccbc..370e8b144bba6 100644 --- a/onnxruntime/core/optimizer/bias_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/bias_gelu_fusion.cc @@ -30,14 +30,28 @@ Status BiasGelu::ApplyImpl(Graph& graph, bool& modified, int graph_level, const } std::vector gelu_input; - const TensorShapeProto* add_input1_shape = node.MutableInputDefs()[0]->Shape(); - const TensorShapeProto* add_input2_shape = node.MutableInputDefs()[1]->Shape(); - if (add_input1_shape != nullptr && - add_input1_shape->dim_size() == 1) { + const TensorShapeProto* input1_shape = node.MutableInputDefs()[0]->Shape(); + const TensorShapeProto* input2_shape = node.MutableInputDefs()[1]->Shape(); + + if (input1_shape == nullptr || + input2_shape == nullptr || + input1_shape->dim_size() < 1 || + input2_shape->dim_size() < 1) { + continue; + } + + int last_dim_shape1 = input1_shape->dim_size() - 1; + int last_dim_shape2 = input2_shape->dim_size() - 1; + if (!utils::HasDimValue(input1_shape->dim(last_dim_shape1)) || + !utils::HasDimValue(input2_shape->dim(last_dim_shape2)) || + input1_shape->dim(last_dim_shape1).dim_value() != input2_shape->dim(last_dim_shape2).dim_value()) { + continue; + } + + if (input1_shape->dim_size() == 1) { gelu_input.push_back(node.MutableInputDefs()[1]); gelu_input.push_back(node.MutableInputDefs()[0]); - } else if (add_input2_shape != nullptr && - add_input2_shape->dim_size() == 1) { + } else if (input2_shape->dim_size() == 1) { gelu_input.push_back(node.MutableInputDefs()[0]); gelu_input.push_back(node.MutableInputDefs()[1]); } else { diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index 03449fdc69ce3..ebb565af4a4f6 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -43,7 +43,9 @@ static bool CheckFirstAdd(Node& add, ProviderType providertype) { // "Add" inputs have to be of same dimensions. bool is_valid_input = true; for (int i = 0; i < 3; i++) { - if (add_input1_shape->dim(i).dim_value() != add_input2_shape->dim(i).dim_value()) { + if (!utils::HasDimValue(add_input1_shape->dim(i)) || + !utils::HasDimValue(add_input2_shape->dim(i)) || + add_input1_shape->dim(i).dim_value() != add_input2_shape->dim(i).dim_value()) { is_valid_input = false; break; } @@ -68,7 +70,11 @@ static bool CheckSecondAdd(Node& add, ProviderType providertype) { return false; } - return add_input1_shape->dim_size() == 3 && add_input2_shape->dim_size() == 1; + return add_input1_shape->dim_size() == 3 && + add_input2_shape->dim_size() == 1 && + utils::HasDimValue(add_input1_shape->dim(2)) && + utils::HasDimValue(add_input2_shape->dim(0)) && + add_input1_shape->dim(2).dim_value() == add_input2_shape->dim(0).dim_value(); } /** diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1.onnx index a8b29909eed25..4e72ab0dd5f90 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1.onnx and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2.onnx index f4c7e8cb5e9f7..501bf2a5e9d1e 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2.onnx and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3.onnx index 4c60dd91b0592..8259df0b2b2b6 100644 Binary files a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3.onnx and b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py index c511935fc812a..1d8afa4b5803e 100644 --- a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py @@ -46,11 +46,11 @@ def GenerateModel(format, model_name): nodes, "SkipLayerNorm_format3", #name [ # inputs - helper.make_tensor_value_info('A', TensorProto.FLOAT, ['unk_1', 32, 4]), - helper.make_tensor_value_info('B', TensorProto.FLOAT, ['unk_1', 32, 4]), + helper.make_tensor_value_info('A', TensorProto.FLOAT, [16, 32, 4]), + helper.make_tensor_value_info('B', TensorProto.FLOAT, [16, 32, 4]), ], [ # outputs - helper.make_tensor_value_info('C', TensorProto.FLOAT, ['unk_3', 32, 4]), + helper.make_tensor_value_info('C', TensorProto.FLOAT, [16, 32, 4]), ], initializers )