Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lite/api/paddle_use_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ USE_MIR_PASS(restrict_quantized_op_with_same_input_output_scale_pass);
USE_MIR_PASS(control_flow_op_unused_inputs_and_outputs_eliminate_pass)
USE_MIR_PASS(lite_scale_activation_fuse_pass);
USE_MIR_PASS(lite_instance_norm_activation_fuse_pass);
USE_MIR_PASS(lite_fc_prelu_fuse_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__resnet_d_fuse_pass);
USE_MIR_PASS(__xpu__resnet_cbam_fuse_pass);
Expand Down
101 changes: 55 additions & 46 deletions lite/backends/opencl/cl_kernel/buffer/fc_kernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,14 @@ void fc_gemv_naive(__global const CL_DTYPE* a,
// b: param.w {K, N}
// c: param.output {M, N}
__kernel
void fc_gemv_1x4(__global const CL_DTYPE* a,
__global const CL_DTYPE* b,
__global const CL_DTYPE* bias,
__global CL_DTYPE* c,
const int M, const int N, const int K) {
void fc_gemv_1x4(__global const float* a,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议使用CL_DTYPE类型

__global const float* b,
__global const float* bias,
__global float* c,
const int M, const int N, const int K,
__global const float* alpha) {
const int col = get_global_id(0) << 2; // gws[0]: [0, N >> 2) height of B == N

half alpha;
if (col + 3 < N) {
half4 c0 = 0.0f;
if (bias) {
Expand Down Expand Up @@ -306,35 +306,31 @@ void fc_gemv_1x4(__global const CL_DTYPE* a,
c0 += a0.y * b1;
c0 += a0.z * b2;

// store res
#ifdef RELU
if (col % 4 == 0) {
float4 act_res = convert_float4(fmax(c0, (half4)0.f));
vstore4(act_res, 0, c + col);
} else {
switch (col % 4) {
case 3:
c[col + 2] = activation(c0.z, alpha);
case 2:
c[col + 1] = activation(c0.y, alpha);
case 1:
c[col] = activation(c0.x, alpha);
}
}
half4 alpha0 = 0.0f;
#ifdef PRELU_MORE
alpha0.x = alpha[col];
alpha0.y = alpha[col+1];
alpha0.z = alpha[col+2];
alpha0.w = alpha[col+3];
#else
alpha0.x = alpha[0];
alpha0.y = alpha[0];
alpha0.z = alpha[0];
alpha0.w = alpha[0];
#endif
if (col % 4 == 0) {
vstore4(convert_float4(c0), 0, c + col);
float4 act_res = convert_float4(activation_type4(c0, alpha0));
vstore4(act_res, 0, c + col);
} else {
switch (col % 4) {
case 3:
c[col + 2] = c0.z;
c[col + 2] = activation(c0.z, alpha0.z);
case 2:
c[col + 1] = c0.y;
c[col + 1] = activation(c0.y, alpha0.y);
case 1:
c[col] = c0.x;
c[col] = activation(c0.x, alpha0.x);
}
}
#endif
} else {
const int left_col = N - col;
for (int col_offset = 0; col_offset < left_col; ++col_offset) {
Expand All @@ -344,11 +340,13 @@ void fc_gemv_1x4(__global const CL_DTYPE* a,
half a0 = *(a + p);
c0 += a0 * b0;
}
#ifdef RELU
c[col + col_offset] = activation(c0, alpha);
half alpha0 = 0.0f;
#ifdef PRELU_MORE
alpha0 = alpha[col];
#else
c[col + col_offset] = c0;
alpha0 = alpha[0];
#endif
c[col + col_offset] = activation(c0, alpha0);
}
}
}
Expand All @@ -359,11 +357,12 @@ void fc_gemv_1x4(__global const CL_DTYPE* a,
// b: param.w {K, N}
// c: param.output {M, N}
__kernel
void fc_gemm_4x4(__global const CL_DTYPE* a,
__global const CL_DTYPE* b,
__global const CL_DTYPE* bias,
__global CL_DTYPE* c,
const int M, const int N, const int K) {
void fc_gemm_4x4(__global const float* a,
__global const float* b,
__global const float* bias,
__global float* c,
const int M, const int N, const int K,
__global const float* alpha) {
const int row = get_global_id(0) << 2; // id: [0, M>>2) height of out == M
const int col = get_global_id(1) << 2; // id: [0, N>>2) width of out == N

Expand Down Expand Up @@ -395,17 +394,25 @@ void fc_gemm_4x4(__global const CL_DTYPE* a,
c20 += a20 * b00; c21 += a20 * b01; c22 += a20 * b02; c23 += a20 * b03;
c30 += a30 * b00; c31 += a30 * b01; c32 += a30 * b02; c33 += a30 * b03;
}
#if defined(RELU)
c[row*N+col] = fmax(c00, 0); c[row*N+(col+1)] = fmax(c01, 0); c[row*N+(col+2)] = fmax(c02, 0); c[row*N+(col+3)] = fmax(c03, 0);
c[(row+1)*N+col] = fmax(c10, 0); c[(row+1)*N+(col+1)] = fmax(c11, 0); c[(row+1)*N+(col+2)] = fmax(c12, 0); c[(row+1)*N+(col+3)] = fmax(c13, 0);
c[(row+2)*N+col] = fmax(c20, 0); c[(row+2)*N+(col+1)] = fmax(c21, 0); c[(row+2)*N+(col+2)] = fmax(c22, 0); c[(row+2)*N+(col+3)] = fmax(c23, 0);
c[(row+3)*N+col] = fmax(c30, 0); c[(row+3)*N+(col+1)] = fmax(c31, 0); c[(row+3)*N+(col+2)] = fmax(c32, 0); c[(row+3)*N+(col+3)] = fmax(c33, 0);
half alpha0 = 0.0f;
half alpha1 = 0.0f;
half alpha2 = 0.0f;
half alpha3 = 0.0f;
#ifdef PRELU_MORE
alpha0 = alpha[col];
alpha1 = alpha[col+1];
alpha2 = alpha[col+2];
alpha3 = alpha[col+3];
#else
c[row*N+col] = c00; c[row*N+(col+1)] = c01; c[row*N+(col+2)] = c02; c[row*N+(col+3)] = c03;
c[(row+1)*N+col] = c10; c[(row+1)*N+(col+1)] = c11; c[(row+1)*N+(col+2)] = c12; c[(row+1)*N+(col+3)] = c13;
c[(row+2)*N+col] = c20; c[(row+2)*N+(col+1)] = c21; c[(row+2)*N+(col+2)] = c22; c[(row+2)*N+(col+3)] = c23;
c[(row+3)*N+col] = c30; c[(row+3)*N+(col+1)] = c31; c[(row+3)*N+(col+2)] = c32; c[(row+3)*N+(col+3)] = c33;
alpha0 = alpha[0];
alpha1 = alpha[0];
alpha2 = alpha[0];
alpha3 = alpha[0];
#endif
c[row*N+col] = activation(c00, alpha0); c[row*N+(col+1)] = activation(c01, alpha1); c[row*N+(col+2)] = activation(c02, alpha2); c[row*N+(col+3)] = activation(c03, alpha3);
c[(row+1)*N+col] = activation(c10, alpha0); c[(row+1)*N+(col+1)] = activation(c11, alpha1); c[(row+1)*N+(col+2)] = activation(c12, alpha2); c[(row+1)*N+(col+3)] = activation(c13, alpha3);
c[(row+2)*N+col] = activation(c20, alpha0); c[(row+2)*N+(col+1)] = activation(c21, alpha1); c[(row+2)*N+(col+2)] = activation(c22, alpha2); c[(row+2)*N+(col+3)] = activation(c23, alpha3);
c[(row+3)*N+col] = activation(c30, alpha0); c[(row+3)*N+(col+1)] = activation(c31, alpha1); c[(row+3)*N+(col+2)] = activation(c32, alpha2); c[(row+3)*N+(col+3)] = activation(c33, alpha3);
} else {
for (int cidx = col; cidx < N; ++cidx) {
for (int ridx = row; ridx < M; ++ridx) {
Expand All @@ -417,11 +424,13 @@ void fc_gemm_4x4(__global const CL_DTYPE* a,
b0 = *(b + p * N + cidx),
c0 += a0 * b0;
}
#if defined(RELU)
c[ridx * N + cidx] = fmax(c0, 0);
half alpha0 = 0.0f;
#ifdef PRELU_MORE
alpha0 = alpha[cidx];
#else
c[ridx * N + cidx] = c0;
alpha0 = alpha[0];
#endif
c[ridx * N + cidx] = activation(c0, alpha0);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions lite/core/mir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ lite_cc_library(mir_passes
fusion/sequence_reverse_embedding_fuse_pass.cc
fusion/instance_norm_activation_fuse_pass.cc
fusion/elementwise_add_scale_fuse_pass.cc
fusion/fc_prelu_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc
Expand Down
4 changes: 4 additions & 0 deletions lite/core/mir/fusion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ lite_cc_library(fuse_instance_norm_activation
lite_cc_library(fuse_elementwise_add_scale
SRCS elementwise_add_scale_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_fc_prelu
SRCS fc_prelu_fuser.cc
DEPS pattern_matcher_high_api)

set(mir_fusers
fuse_reshape2_matmul
Expand All @@ -91,6 +94,7 @@ set(mir_fusers
fuse_sequence_reverse_embedding
fuse_instance_norm_activation
fuse_elementwise_add_scale
fuse_fc_prelu
fuse_conv_scale
CACHE INTERNAL "fusers")

Expand Down
36 changes: 36 additions & 0 deletions lite/core/mir/fusion/fc_prelu_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/core/mir/fusion/fc_prelu_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/fc_prelu_fuser.h"
#include "lite/core/mir/pass_registry.h"

namespace paddle {
namespace lite {
namespace mir {

void FcPreluFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::FcPreluFuser fuser("prelu");
fuser(graph.get());
}

} // namespace mir
} // namespace lite
} // namespace paddle

REGISTER_MIR_PASS(lite_fc_prelu_fuse_pass, paddle::lite::mir::FcPreluFusePass)
.BindTargets({TARGET(kOpenCL)})
.BindKernel("fc");
32 changes: 32 additions & 0 deletions lite/core/mir/fusion/fc_prelu_fuse_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include "lite/core/mir/pass.h"

namespace paddle {
namespace lite {
namespace mir {

class FcPreluFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};

} // namespace mir
} // namespace lite
} // namespace paddle
83 changes: 83 additions & 0 deletions lite/core/mir/fusion/fc_prelu_fuser.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/core/mir/fusion/fc_prelu_fuser.h"
#include <memory>
#include <vector>

namespace paddle {
namespace lite {
namespace mir {
namespace fusion {

void FcPreluFuser::BuildPattern() {
// create nodes
// fc
PMNode* input =
VarNode("input")->assert_is_op_input("fc", "Input")->AsInput();
PMNode* weights =
VarNode("weights")->assert_is_op_input("fc", "W")->AsInput();
PMNode* bias = VarNode("bias")->assert_is_op_input("fc", "Bias")->AsInput();
PMNode* fc = OpNode("fc", "fc")->AsIntermediate();
PMNode* fc_out = VarNode("fc_out")
->assert_is_op_output("fc", "Out")
->assert_is_op_input("prelu", "X")
->AsIntermediate();

// prelu
PMNode* alpha =
VarNode("alpha")->assert_is_op_input("prelu", "Alpha")->AsInput();
PMNode* prelu = OpNode("prelu", "prelu")->AsIntermediate();
PMNode* out =
VarNode("output")->assert_is_op_output("prelu", "Out")->AsOutput();

// create topology.
std::vector<PMNode*> fc_inputs{bias, weights, input};
fc_inputs >> *fc >> *fc_out >> *prelu >> *out;
*alpha >> *prelu;
}

void FcPreluFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto fc_op = LiteOpRegistry::Global().Create("fc");
auto fc_old = matched.at("fc")->stmt()->op();
auto* scope = fc_old->scope();
auto& valid_places = fc_old->valid_places();
fc_op->Attach(op_desc, scope);

auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places);

IR_NODE_LINK_TO(matched.at("input"), new_op_node);
IR_NODE_LINK_TO(matched.at("weights"), new_op_node);
IR_NODE_LINK_TO(matched.at("bias"), new_op_node);
IR_NODE_LINK_TO(matched.at("alpha"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}

cpp::OpDesc FcPreluFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("fc")->stmt()->op_info();
op_desc.SetInput("Alpha", {matched.at("alpha")->arg()->name});
op_desc.SetOutput("Out", {matched.at("output")->arg()->name});

cpp::OpDesc prelu_op_desc = *matched.at("prelu")->stmt()->op_info();
auto prelu_mode = prelu_op_desc.GetAttr<std::string>("mode");
op_desc.SetAttr("prelu_mode", prelu_mode);
op_desc.SetAttr("activation_type", std::string{"prelu"});
return op_desc;
}

} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
40 changes: 40 additions & 0 deletions lite/core/mir/fusion/fc_prelu_fuser.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"

namespace paddle {
namespace lite {
namespace mir {
namespace fusion {

class FcPreluFuser : public FuseBase {
public:
explicit FcPreluFuser(const std::string& act_type) {}

void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;

private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};

} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
1 change: 1 addition & 0 deletions lite/core/optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class Optimizer {
"lite_scale_activation_fuse_pass", //
"lite_elementwise_scale_fuse_pass", //
"lite_instance_norm_activation_fuse_pass", //
"lite_fc_prelu_fuse_pass", //
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \
(defined LITE_WITH_ARM)
"lite_elementwise_activation_fuse_pass", //
Expand Down
Loading