From d847f77cd46ae7d0ddedc8121b440f6f150c8d64 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Thu, 5 Dec 2019 16:10:37 -0800 Subject: [PATCH 1/3] Add more check on SkipLayerNorm fusion --- onnxruntime/core/optimizer/skip_layer_norm_fusion.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index 03449fdc69ce3..ebb565af4a4f6 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -43,7 +43,9 @@ static bool CheckFirstAdd(Node& add, ProviderType providertype) { // "Add" inputs have to be of same dimensions. bool is_valid_input = true; for (int i = 0; i < 3; i++) { - if (add_input1_shape->dim(i).dim_value() != add_input2_shape->dim(i).dim_value()) { + if (!utils::HasDimValue(add_input1_shape->dim(i)) || + !utils::HasDimValue(add_input2_shape->dim(i)) || + add_input1_shape->dim(i).dim_value() != add_input2_shape->dim(i).dim_value()) { is_valid_input = false; break; } @@ -68,7 +70,11 @@ static bool CheckSecondAdd(Node& add, ProviderType providertype) { return false; } - return add_input1_shape->dim_size() == 3 && add_input2_shape->dim_size() == 1; + return add_input1_shape->dim_size() == 3 && + add_input2_shape->dim_size() == 1 && + utils::HasDimValue(add_input1_shape->dim(2)) && + utils::HasDimValue(add_input2_shape->dim(0)) && + add_input1_shape->dim(2).dim_value() == add_input2_shape->dim(0).dim_value(); } /** From 8db36f4c4c80f344acc21e28f4dca6948066ec14 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Fri, 6 Dec 2019 09:51:35 -0800 Subject: [PATCH 2/3] add more check on gelu --- .../core/optimizer/bias_gelu_fusion.cc | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/optimizer/bias_gelu_fusion.cc b/onnxruntime/core/optimizer/bias_gelu_fusion.cc index 94b8c601dccbc..370e8b144bba6 100644 --- a/onnxruntime/core/optimizer/bias_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/bias_gelu_fusion.cc @@ -30,14 +30,28 @@ Status BiasGelu::ApplyImpl(Graph& graph, bool& modified, int graph_level, const } std::vector gelu_input; - const TensorShapeProto* add_input1_shape = node.MutableInputDefs()[0]->Shape(); - const TensorShapeProto* add_input2_shape = node.MutableInputDefs()[1]->Shape(); - if (add_input1_shape != nullptr && - add_input1_shape->dim_size() == 1) { + const TensorShapeProto* input1_shape = node.MutableInputDefs()[0]->Shape(); + const TensorShapeProto* input2_shape = node.MutableInputDefs()[1]->Shape(); + + if (input1_shape == nullptr || + input2_shape == nullptr || + input1_shape->dim_size() < 1 || + input2_shape->dim_size() < 1) { + continue; + } + + int last_dim_shape1 = input1_shape->dim_size() - 1; + int last_dim_shape2 = input2_shape->dim_size() - 1; + if (!utils::HasDimValue(input1_shape->dim(last_dim_shape1)) || + !utils::HasDimValue(input2_shape->dim(last_dim_shape2)) || + input1_shape->dim(last_dim_shape1).dim_value() != input2_shape->dim(last_dim_shape2).dim_value()) { + continue; + } + + if (input1_shape->dim_size() == 1) { gelu_input.push_back(node.MutableInputDefs()[1]); gelu_input.push_back(node.MutableInputDefs()[0]); - } else if (add_input2_shape != nullptr && - add_input2_shape->dim_size() == 1) { + } else if (input2_shape->dim_size() == 1) { gelu_input.push_back(node.MutableInputDefs()[0]); gelu_input.push_back(node.MutableInputDefs()[1]); } else { From 1faaf04b769391de62ffbf5daa28843523b4488f Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Fri, 6 Dec 2019 10:15:12 -0800 Subject: [PATCH 3/3] change the test data --- .../fusion/skip_layer_norm_format1.onnx | Bin 779 -> 764 bytes .../fusion/skip_layer_norm_format2.onnx | Bin 779 -> 764 bytes .../fusion/skip_layer_norm_format3.onnx | Bin 710 -> 695 bytes .../transform/fusion/skip_layer_norm_gen.py | 6 +++--- 4 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format1.onnx index a8b29909eed255ec074b9d98af9b49dc076501b3..4e72ab0dd5f90c554581658e40af3619321f1d7d 100644 GIT binary patch delta 93 zcmeBX`@_o0!D{uHbtCIaCO>g5Mn@qbE&&ckAs#L!4gnxm0AiLXpo9}Pi6o$eGd2k) HCIK!0MOg_; delta 108 zcmeyv+Retw!D`jSwvlxuQ>+XZqoa^Gmna9LkRTVk5Nl~(cDx}M6Nds2vqS-9op8z~ Q0cD-hWsPBaotOl;0GC+~&Hw-a diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format2.onnx index f4c7e8cb5e9f7c4cca028a50bc2e276056c678f4..501bf2a5e9d1eba9164932db7dcbae181a4ab5ab 100644 GIT binary patch delta 93 zcmeBX`@_o0!D{uHbtCIaCO>g5Mn@qbE&&ckAs#L!4gnxm0AiLXpo9}Pi6o$eGd2k) HCIK!0MOg_; delta 108 zcmeyv+Retw!D`jSwvlxuQ>+XZqoa^Gmna9LkRTVk5Nl~(cDx}M6Nds2vqS-9op8z~ Q0cD-hWsPBaotOl;0GC+~&Hw-a diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3.onnx b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_format3.onnx index 4c60dd91b0592e9a28a5a6190e04fad732d4d80f..8259df0b2b2b6407ffbcfcc8f8263ace56c4a25b 100644 GIT binary patch delta 92 zcmX@cx}BAUgVk!?MwV12KXEQbM