Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions onnxruntime/core/optimizer/bias_gelu_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,28 @@ Status BiasGelu::ApplyImpl(Graph& graph, bool& modified, int graph_level, const
}

std::vector<NodeArg*> gelu_input;
const TensorShapeProto* add_input1_shape = node.MutableInputDefs()[0]->Shape();
const TensorShapeProto* add_input2_shape = node.MutableInputDefs()[1]->Shape();
if (add_input1_shape != nullptr &&
add_input1_shape->dim_size() == 1) {
const TensorShapeProto* input1_shape = node.MutableInputDefs()[0]->Shape();
const TensorShapeProto* input2_shape = node.MutableInputDefs()[1]->Shape();

if (input1_shape == nullptr ||
input2_shape == nullptr ||
input1_shape->dim_size() < 1 ||
input2_shape->dim_size() < 1) {
continue;
}

int last_dim_shape1 = input1_shape->dim_size() - 1;
int last_dim_shape2 = input2_shape->dim_size() - 1;
if (!utils::HasDimValue(input1_shape->dim(last_dim_shape1)) ||
!utils::HasDimValue(input2_shape->dim(last_dim_shape2)) ||
input1_shape->dim(last_dim_shape1).dim_value() != input2_shape->dim(last_dim_shape2).dim_value()) {
continue;
}

if (input1_shape->dim_size() == 1) {
gelu_input.push_back(node.MutableInputDefs()[1]);
gelu_input.push_back(node.MutableInputDefs()[0]);
} else if (add_input2_shape != nullptr &&
add_input2_shape->dim_size() == 1) {
} else if (input2_shape->dim_size() == 1) {
gelu_input.push_back(node.MutableInputDefs()[0]);
gelu_input.push_back(node.MutableInputDefs()[1]);
} else {
Expand Down
10 changes: 8 additions & 2 deletions onnxruntime/core/optimizer/skip_layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ static bool CheckFirstAdd(Node& add, ProviderType providertype) {
// "Add" inputs have to be of same dimensions.
bool is_valid_input = true;
for (int i = 0; i < 3; i++) {
if (add_input1_shape->dim(i).dim_value() != add_input2_shape->dim(i).dim_value()) {
if (!utils::HasDimValue(add_input1_shape->dim(i)) ||
!utils::HasDimValue(add_input2_shape->dim(i)) ||
add_input1_shape->dim(i).dim_value() != add_input2_shape->dim(i).dim_value()) {
is_valid_input = false;
break;
}
Expand All @@ -68,7 +70,11 @@ static bool CheckSecondAdd(Node& add, ProviderType providertype) {
return false;
}

return add_input1_shape->dim_size() == 3 && add_input2_shape->dim_size() == 1;
return add_input1_shape->dim_size() == 3 &&
add_input2_shape->dim_size() == 1 &&
utils::HasDimValue(add_input1_shape->dim(2)) &&
utils::HasDimValue(add_input2_shape->dim(0)) &&
add_input1_shape->dim(2).dim_value() == add_input2_shape->dim(0).dim_value();
}

/**
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def GenerateModel(format, model_name):
nodes,
"SkipLayerNorm_format3", #name
[ # inputs
helper.make_tensor_value_info('A', TensorProto.FLOAT, ['unk_1', 32, 4]),
helper.make_tensor_value_info('B', TensorProto.FLOAT, ['unk_1', 32, 4]),
helper.make_tensor_value_info('A', TensorProto.FLOAT, [16, 32, 4]),
helper.make_tensor_value_info('B', TensorProto.FLOAT, [16, 32, 4]),
],
[ # outputs
helper.make_tensor_value_info('C', TensorProto.FLOAT, ['unk_3', 32, 4]),
helper.make_tensor_value_info('C', TensorProto.FLOAT, [16, 32, 4]),
],
initializers
)
Expand Down