Skip to content

Commit 1f093ac

Browse files
[Metal] fix_relu and relu6 Add_exp (#8082)
* fix_relu_exp_relu6 * fix_relu_exp_relu6 * swish
1 parent 1dcc5b4 commit 1f093ac

4 files changed

Lines changed: 22 additions & 4 deletions

File tree

lite/backends/metal/metal_kernel/texture/ActivationKernel.metal

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,17 @@ kernel void hard_sigmoid(texture2d_array<ftype, access::sample> inTexture[[textu
107107
outTexture.write(output, gid.xy, gid.z);
108108
}
109109

110+
// activation function: swish
110111
kernel void swish(texture2d_array<ftype, access::sample> inTexture[[texture(0)]],
111112
texture2d_array<ftype, access::write> outTexture[[texture(1)]],
113+
constant SwishParam& param[[buffer(0)]],
112114
uint3 gid[[thread_position_in_grid]]) {
113115
if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() ||
114116
gid.z >= outTexture.get_array_size())
115117
return;
116118
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
117119
const ftype4 input = inTexture.read(gid.xy, gid.z);
118-
const ftype4 output = input / (1.0 + exp(-input));
120+
const ftype4 output = input / (1.0 + exp(-(input * param.beta)));
119121
outTexture.write(output, gid.xy, gid.z);
120122
}
121123

lite/backends/metal/metal_kernel/texture/Common.metal

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ struct HardSigmoidParam {
196196
float offset;
197197
};
198198

199+
struct SwishParam {
200+
float beta;
201+
};
202+
199203
struct HardSwishParam {
200204
float offset;
201205
float threshold;

lite/kernels/metal/image_op/activation_image_compute.mm

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,16 @@
6464
case 5:
6565
function_name_ = "sigmoid";
6666
break;
67-
case 7:
67+
case 7: {
68+
SwishMetalParam metal_param{param.Swish_beta};
69+
param_buffer_ =
70+
std::make_shared<MetalBuffer>(metal_context_, sizeof(metal_param), &metal_param);
71+
}
6872
function_name_ = "swish";
6973
break;
74+
case 8:
75+
function_name_ = "exp";
76+
break;
7077
case 10: {
7178
HardSwishMetalParam metal_param{
7279
param.hard_swish_offset, param.hard_swish_threshold, param.hard_swish_scale};
@@ -106,8 +113,9 @@
106113
auto encoder = [backend commandEncoder];
107114
[encoder setTexture:input_buffer_->image() atIndex:(0)];
108115
[encoder setTexture:output_buffer_->image() atIndex:(1)];
109-
if (function_name_ == "relu" || function_name_ == "leaky_relu" ||
110-
function_name_ == "hard_swish" || function_name_ == "hard_sigmoid") {
116+
if (function_name_ == "leaky_relu" || function_name_ == "relu6" ||
117+
function_name_ == "hard_swish" || function_name_ == "hard_sigmoid" ||
118+
function_name_ == "swish") {
111119
[encoder setBuffer:param_buffer_->buffer() offset:(0) atIndex:(0)];
112120
}
113121

lite/kernels/metal/image_op/metal_params.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ struct HardSigmoidMetalParam {
184184
float offset;
185185
};
186186

187+
struct SwishMetalParam {
188+
float beta;
189+
};
190+
187191
struct HardSwishMetalParam {
188192
float offset;
189193
float threshold;

0 commit comments

Comments
 (0)