Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ namespace nnadapter {

void ConstantFoldShapeAndAssociatedOperations(core::Model* model);

void FoldShapeSliceConcatTile(core::Model* model);

} // namespace nnadapter
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class NCHW2NHWCDataLayoutConverter {
void ConvertSplit(core::Operation* operation);
void ConvertSqueeze(core::Operation* operation);
void ConvertStack(core::Operation* operation);
void ConvertTile(core::Operation* operation);
void ConvertTranspose(core::Operation* operation);
void ConvertMatMul(core::Operation* operation);
void ConvertUnsqueeze(core::Operation* operation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class PatternMatcher {
bool MatchAllConditions(const Node* node) const;
// Utility conditions
Pattern* IsOperand();
Pattern* IsOperation(NNAdapterOperationType type);
Pattern* IsOperation(NNAdapterOperationType type = NNADAPTER_UNKNOWN);
Pattern* IsConstantOperand();
Pattern* IsVariableOperand();
Pattern* IsConstantCopyOperand();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "core/types.h"

namespace nnadapter {

void RemoveUselessCast(core::Model *model);

void RemoveUselessMul(core::Model *model);

} // namespace nnadapter
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,8 @@ void FillLikeCumSumSubGatherAddFuser::BuildPattern() {
CreatePattern("cum_sum_axis")
->IsOperationInputOperand(NNADAPTER_CUM_SUM, 1)
->MatchCondition([](const Node* node) -> bool {
auto operand = node->operand;
return operand != nullptr &&
operand->type.precision == NNADAPTER_INT32 &&
operand->length == sizeof(int32_t) &&
*reinterpret_cast<int32_t*>(operand->buffer) == 1;
int32_t axis = *reinterpret_cast<int32_t*>(node->operand->buffer);
return axis == 1 || axis == -1;
})
->IsIntermediate();
auto cum_sum_exclusive =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

#include "optimizer/constant_fold_shape_and_associated_operations.h"
#include <set>
#include "optimizer/pattern_matcher.h"
#include "utility/debug.h"
#include "utility/hints.h"
#include "utility/logging.h"
#include "utility/micros.h"
#include "utility/modeling.h"
#include "utility/utility.h"

namespace nnadapter {

Expand Down Expand Up @@ -132,4 +134,116 @@ NNADAPTER_EXPORT void ConstantFoldShapeAndAssociatedOperations(
}
}

/*
before:
shape fill_like fill_like
| | |
slice | |
\----------|------------/
concat
|
tile

after:
tile
*/
class FoldShapeSliceConcatTileFuser : public PatternMatcher {
public:
FoldShapeSliceConcatTileFuser() {}
void BuildPattern() override;
bool HandleMatchedResults(
core::Model *model, const std::map<std::string, Node *> &nodes) override;
};

void FoldShapeSliceConcatTileFuser::BuildPattern() {
// Create patterns
auto shape_in = CreatePattern("shape_in")
->IsOperationInputOperand(NNADAPTER_SHAPE, 0)
->MatchCondition([](const Node *node) -> bool {
auto operand = node->operand;
return operand != nullptr &&
!IsDynamicShapeOperandType(operand->type);
});
auto shape_dtype = CreatePattern("shape_dtype")
->IsOperationInputOperand(NNADAPTER_SHAPE, 1)
->IsIntermediate();
auto shape = CreatePattern("shape", NNADAPTER_SHAPE)->IsIntermediate();
auto shape_out = CreatePattern("shape_out")
->IsOperationOutputOperand(NNADAPTER_SHAPE, 0)
->IsOperationInputOperand(NNADAPTER_SLICE, 0)
->IsIntermediate();

auto slice_axes = CreatePattern("slice_axes")
->IsOperationInputOperand(NNADAPTER_SLICE, 1)
->IsIntermediate();
auto slice_starts = CreatePattern("slice_starts")
->IsOperationInputOperand(NNADAPTER_SLICE, 2)
->IsIntermediate();
auto slice_ends = CreatePattern("slice_ends")
->IsOperationInputOperand(NNADAPTER_SLICE, 3)
->IsIntermediate();
auto slice_steps = CreatePattern("slice_steps")
->IsOperationInputOperand(NNADAPTER_SLICE, 4)
->IsIntermediate();
auto slice = CreatePattern("slice", NNADAPTER_SLICE)->IsIntermediate();
auto slice_out = CreatePattern("slice_out")
->IsOperationOutputOperand(NNADAPTER_SLICE, 0)
->IsOperationInputOperand(NNADAPTER_CONCAT, 1)
->IsIntermediate();

auto concat_in0 = CreatePattern("concat_in0")
->IsConstantOperand()
->IsOperationInputOperand(NNADAPTER_CONCAT, 0)
->IsIntermediate();
auto concat_in2 = CreatePattern("concat_in2")
->IsConstantOperand()
->IsOperationInputOperand(NNADAPTER_CONCAT, 2)
->IsIntermediate();
auto concat_axis = CreatePattern("concat_axis")
->IsOperationInputOperand(NNADAPTER_CONCAT, 3)
->IsIntermediate();
auto concat = CreatePattern("concat", NNADAPTER_CONCAT)->IsIntermediate();
auto concat_out = CreatePattern("concat_out")
->IsOperationOutputOperand(NNADAPTER_CONCAT, 0)
->IsOperationInputOperand(NNADAPTER_TILE, 1)
->IsIntermediate();

auto tile = CreatePattern("tile", NNADAPTER_TILE);

// Create the topological connections for the above patterns
std::vector<Pattern *> shape_ins{shape_in, shape_dtype};
shape_ins >> *shape >> *shape_out;
std::vector<Pattern *> slice_ins{
shape_out, slice_axes, slice_starts, slice_ends, slice_steps};
slice_ins >> *slice >> *slice_out;
std::vector<Pattern *> concat_ins{
concat_in0, slice_out, concat_in2, concat_axis};
concat_ins >> *concat >> *concat_out >> *tile;
}

bool FoldShapeSliceConcatTileFuser::HandleMatchedResults(
core::Model *model, const std::map<std::string, Node *> &nodes) {
auto shape_in_data = nodes.at("shape_in")->operand->type.dimensions.data;
auto start_index = *reinterpret_cast<int32_t *>(
nodes.at("slice")->operation->input_operands[2]->buffer);
std::vector<int32_t> repeat_times{
*reinterpret_cast<int32_t *>(nodes.at("concat_in0")->operand->buffer),
shape_in_data[start_index],
*reinterpret_cast<int32_t *>(nodes.at("concat_in2")->operand->buffer),
};
std::vector<int32_t> dims{3};
auto tile_in1 = AddInt32ConstantOperand(model, repeat_times.data(), dims);
nodes.at("tile")->operation->input_operands[1] = tile_in1;
return true;
}

NNADAPTER_EXPORT void FoldShapeSliceConcatTile(core::Model *model) {
NNADAPTER_VLOG(5) << "Apply FoldShapeSliceConcatTileFuser";
bool stop;
do {
FoldShapeSliceConcatTileFuser fold_shape_slice_concat_tile_fuser;
stop = fold_shape_slice_concat_tile_fuser.Apply(model) == 0;
} while (!stop);
}

} // namespace nnadapter
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,31 @@ void NCHW2NHWCDataLayoutConverter::ConvertStack(core::Operation* operation) {
}
}

void NCHW2NHWCDataLayoutConverter::ConvertTile(core::Operation* operation) {
auto& input_operands = operation->input_operands;
auto& output_operands = operation->output_operands;
auto input_count = input_operands.size();
auto output_count = output_operands.size();
NNADAPTER_CHECK_EQ(input_count, 2);
NNADAPTER_CHECK_EQ(output_count, 1);
auto input_operand = input_operands[0];
auto output_operand = output_operands[0];
// The input and output operands share the same dimorder vector
NNADAPTER_CHECK(IsConstantOperand(input_operands[1]));
int32_t* repeat_data = reinterpret_cast<int32_t*>(input_operands[1]->buffer);
int32_t repeat_count = input_operands[1]->length / sizeof(int32_t);
std::vector<int32_t> repeat(repeat_count);
auto input_permutation = GetPermutation(input_operand);
NNADAPTER_CHECK_EQ(repeat_count,
static_cast<int32_t>(input_permutation.size()));
for (int i = 0; i < repeat_count; i++) {
repeat[i] = repeat_data[input_permutation[i]];
}
memcpy(repeat_data, repeat.data(), input_operands[1]->length);
TransposeOperand(output_operand, input_permutation);
SetPermutation(output_operand, input_permutation);
}

void NCHW2NHWCDataLayoutConverter::ConvertTranspose(
core::Operation* operation) {
auto& input_operands = operation->input_operands;
Expand Down Expand Up @@ -1295,6 +1320,9 @@ void NCHW2NHWCDataLayoutConverter::Apply(core::Model* model) {
case NNADAPTER_STACK:
ConvertStack(operation);
break;
case NNADAPTER_TILE:
ConvertTile(operation);
break;
case NNADAPTER_TRANSPOSE:
ConvertTranspose(operation);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ NNADAPTER_EXPORT void ConvertQuantizationSymmToAsymm(core::Model* model) {
case NNADAPTER_SQUEEZE:
case NNADAPTER_SWISH:
case NNADAPTER_TANH:
case NNADAPTER_TILE:
case NNADAPTER_TRANSPOSE:
case NNADAPTER_UNSQUEEZE: {
ConvertOperandSymmToAsymm(input_operands[0], 128);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ PatternMatcher::Pattern::IsOperationOutputOperand(NNAdapterOperationType type,
NNADAPTER_EXPORT PatternMatcher::Pattern *PatternMatcher::Pattern::IsOperation(
NNAdapterOperationType type) {
conditions.emplace_back([type](const Node *node) {
return node && node->IsOperation() && node->operation->type == type;
return node && node->IsOperation() &&
(node->operation->type == type || type == NNADAPTER_UNKNOWN);
});
return this;
}
Expand Down Expand Up @@ -214,8 +215,11 @@ NNADAPTER_EXPORT size_t PatternMatcher::Apply(core::Model *model) {
return 0;
}
auto subgraphs = DetectPatterns();
NNADAPTER_VLOG(5) << subgraphs.size() << " subgraphs detected!";
UniquePatterns(&subgraphs);
NNADAPTER_VLOG(5) << subgraphs.size() << " subgraphs unique!";
ValidatePatterns(&subgraphs);
NNADAPTER_VLOG(5) << subgraphs.size() << " subgraphs valid!";
RemoveOverlappedPatterns(&subgraphs);
NNADAPTER_VLOG(5) << subgraphs.size() << " subgraphs matched!";
// Notify to handle the matched subgraphs, and collect the intermediate
Expand Down
Loading