Skip to content

Commit c2b4565

Browse files
joeldushouyushaobo.xie
authored andcommitted
ggml-hexagon: gelu operation (#17921)
* feat: inital support for gelu using sigmoid approximation * snapshot: faster gelu using polynomial approximation * test: disable l2-block prefetch in polynomail approximation * Revert "test: disable l2-block prefetch in polynomail approximation" This reverts commit 72339994d45b2bed887e79994403c378d90b62b5. * Revert "snapshot: faster gelu using polynomial approximation" This reverts commit 2a787a61d11f9e63e5943a2e6d134b2f0c402ace. * debug: temporarily disable unnecessary log message for debug purpose * Feat: optiized unaligned sigmoid_f32 * Feat: larger l2prefetch block * feat: apply unaligned-load optimization on mul and mul_scalar * Revert "debug: temporarily disable unnecessary log message for debug purpose" This reverts commit 84f2f23aa9f17e2fa826db969cd825d0ab192995. * refactor: cleanup commented unused code * chore: reformat code with clang-formatter to pass cli test * Revert "chore: reformat code with clang-formatter to pass cli test" This reverts commit 952877ec24732b12010c7fa7ed3fc8de4b74e718. * fix: fix loop overflow * chore: fix formating ci error
1 parent baa2a66 commit c2b4565

File tree

6 files changed

+256
-21
lines changed

6 files changed

+256
-21
lines changed

ggml/src/ggml-hexagon/ggml-hexagon.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2161,8 +2161,14 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
21612161
}
21622162

21632163
// src0, src1 & dst must be mapped to the same session
2164-
if (!hex_supported_buffer(sess, src0, src1, dst)) {
2165-
return false;
2164+
if(src1){
2165+
if (!hex_supported_buffer(sess, src0, src1, dst)) {
2166+
return false;
2167+
}
2168+
}else{
2169+
if (!hex_supported_buffer(sess, src0, dst)) {
2170+
return false;
2171+
}
21662172
}
21672173

21682174
return true;
@@ -2662,6 +2668,10 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
26622668
req.op = HTP_OP_UNARY_SILU;
26632669
supported = true;
26642670
}
2671+
else if (ggml_get_unary_op(dst) == GGML_UNARY_OP_GELU){
2672+
req.op = HTP_OP_UNARY_GELU;
2673+
supported = true;
2674+
}
26652675
break;
26662676

26672677
case GGML_OP_GLU:
@@ -2677,6 +2687,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
26772687
case GGML_OP_SOFT_MAX:
26782688
req.op = HTP_OP_SOFTMAX;
26792689
supported = true;
2690+
break;
26802691

26812692
default:
26822693
break;
@@ -2956,6 +2967,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
29562967
case GGML_OP_UNARY:
29572968
if (ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) {
29582969
ggml_hexagon_unary(node, flags);
2970+
} else if (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU) {
2971+
ggml_hexagon_unary(node, flags);
29592972
}
29602973
break;
29612974
case GGML_OP_GLU:
@@ -3254,7 +3267,6 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
32543267
auto sess = static_cast<ggml_hexagon_session *>(dev->context);
32553268

32563269
bool supp = false;
3257-
32583270
switch (op->op) {
32593271
case GGML_OP_NONE:
32603272
case GGML_OP_RESHAPE:
@@ -3294,6 +3306,9 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
32943306
if (ggml_get_unary_op(op) == GGML_UNARY_OP_SILU) {
32953307
supp = ggml_hexagon_supported_activations(sess, op);
32963308
}
3309+
else if (ggml_get_unary_op(op) == GGML_UNARY_OP_GELU){
3310+
supp = ggml_hexagon_supported_activations(sess, op);
3311+
}
32973312
break;
32983313

32993314
case GGML_OP_GLU:

ggml/src/ggml-hexagon/htp/act-ops.c

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,91 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
255255
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
256256
}
257257

258+
259+
static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
260+
struct htp_tensor * dst,
261+
const int32_t * op_params,
262+
struct htp_spad * src0_spad,
263+
struct htp_spad * dst_spad,
264+
uint32_t nth,
265+
uint32_t ith,
266+
uint32_t src0_nrows_per_thread) {
267+
htp_act_preamble2;
268+
269+
uint64_t t1, t2;
270+
t1 = HAP_perf_get_qtimer_count();
271+
272+
const size_t src0_row_size = nb01;
273+
const size_t dst_row_size = nb1;
274+
275+
const uint32_t src0_nrows = ne01 * ne02 * ne03;
276+
277+
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
278+
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
279+
280+
// no work for this thread
281+
if (src0_start_row >= src0_end_row) {
282+
return;
283+
}
284+
285+
int is_aligned = 1;
286+
int opt_path = 0;
287+
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
288+
is_aligned = 0;
289+
FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
290+
}
291+
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
292+
opt_path = 1;
293+
}
294+
295+
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
296+
uint8_t * restrict data_dst = (uint8_t *) dst->data;
297+
298+
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
299+
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
300+
301+
const int BLOCK = 8;
302+
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
303+
const uint32_t block_end = MIN(ir + BLOCK, src0_end_row);
304+
305+
// Prefetch next block
306+
if (block_end < src0_end_row) {
307+
const float * restrict prefetch_ptr = (float *) (data_src0 + (block_end * src0_row_size));
308+
htp_l2fetch(prefetch_ptr, 1, block_end * src0_row_size, src0_row_size);
309+
}
310+
311+
// Process rows in current block
312+
for (uint32_t ib = ir; ib < block_end; ib++) {
313+
const float * restrict src0 = (float *) (data_src0 + (ib * src0_row_size));
314+
float * restrict dst = (float *) (data_dst + (ib * dst_row_size));
315+
316+
// gelu = x * sigmoid(1.702 * x) // current implementation
317+
if (1 == opt_path) {
318+
hvx_mul_scalar_f32((const uint8_t *) src0, (float) 1.702, (uint8_t *) src0_spad_data, ne0);
319+
hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_data, (uint8_t *) src0_spad_data, ne0);
320+
hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
321+
} else {
322+
hvx_mul_scalar_f32( (const uint8_t *) src0, (float)1.702, (uint8_t *) src0_spad_data, ne0);
323+
hvx_sigmoid_f32((const uint8_t *) src0_spad_data, (uint8_t *) src0_spad_data, ne0);
324+
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
325+
}
326+
}
327+
}
328+
329+
t2 = HAP_perf_get_qtimer_count();
330+
331+
FARF(HIGH, "gelu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02,
332+
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
333+
}
334+
335+
static void unary_gelu_fp32(unsigned int n, unsigned int i, void * data) {
336+
struct htp_ops_context * octx = (struct htp_ops_context *) data;
337+
unary_gelu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
338+
octx->src0_nrows_per_thread);
339+
}
340+
341+
342+
258343
static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
259344
struct htp_tensor * dst,
260345
const int32_t * op_params,
@@ -371,7 +456,10 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) {
371456
act_op_func = glu_swiglu_oai_fp32;
372457
op_type = "swiglu-oai-f32";
373458
break;
374-
459+
case HTP_OP_UNARY_GELU:
460+
act_op_func = unary_gelu_fp32;
461+
op_type = "gelu-f32";
462+
break;
375463
default:
376464
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
377465
return HTP_STATUS_NO_SUPPORT;

ggml/src/ggml-hexagon/htp/htp-msg.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@ enum htp_op {
5151
HTP_OP_MUL_MAT_ID = 5,
5252
HTP_OP_RMS_NORM = 6,
5353
HTP_OP_UNARY_SILU = 7,
54-
HTP_OP_GLU_SWIGLU = 8,
55-
HTP_OP_GLU_SWIGLU_OAI = 9,
56-
HTP_OP_SOFTMAX = 10,
57-
HTP_OP_ADD_ID = 11,
58-
HTP_OP_ROPE = 12,
54+
HTP_OP_UNARY_GELU = 8,
55+
HTP_OP_GLU_SWIGLU = 9,
56+
HTP_OP_GLU_SWIGLU_OAI = 10,
57+
HTP_OP_SOFTMAX = 11,
58+
HTP_OP_ADD_ID = 12,
59+
HTP_OP_ROPE = 13,
5960
INVALID
6061
};
6162

ggml/src/ggml-hexagon/htp/hvx-utils.c

Lines changed: 85 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ void hvx_mul_f32(const uint8_t * restrict src0,
4949
FARF(HIGH, "hvx_mul_f32: unaligned loop in hvx op, possibly slower execution\n");
5050
}
5151

52+
53+
bool handled_leftover = false;
5254
if (0 == unaligned_loop) {
5355
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
5456
HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
@@ -60,18 +62,59 @@ void hvx_mul_f32(const uint8_t * restrict src0,
6062
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
6163
}
6264
} else {
65+
int step_of_1 = num_elems_whole >> 5; // divby 32, because 32 float = 128 bytes per HVX vector
66+
int leftover_size = left_over * sizeof(float);
67+
68+
69+
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
70+
HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
71+
HVX_UVector * restrict vec_out = (HVX_UVector *) dst;
72+
73+
HVX_Vector slinep;
74+
HVX_Vector slinec;
75+
HVX_Vector sline;
76+
HVX_Vector sline2p;
77+
HVX_Vector sline2c;
78+
HVX_Vector sline2;
79+
80+
slinep = *vec_in1++;
81+
sline2p = *vec_in2++;
6382
#pragma unroll(4)
64-
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
65-
HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32);
66-
HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32);
83+
for (int i = step_of_1 - 1; i > 0; i--) {
84+
slinec = *vec_in1++;
85+
sline2c = *vec_in2++;
86+
sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src0);
87+
sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1);
88+
89+
*((HVX_UVector *) (vec_out++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, sline2));
90+
slinep = slinec;
91+
sline2p = sline2c;
92+
}
93+
if (step_of_1 > 1) {
94+
slinec = htp_is_aligned(vec_in1, VLEN) && left_over == 0 ? slinep : *vec_in1++;
95+
sline2c = htp_is_aligned(vec_in2, VLEN) && left_over == 0 ? sline2p : *vec_in2++;
96+
97+
sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src0);
98+
sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1);
99+
*((HVX_UVector *) (vec_out++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, sline2));
100+
slinep = slinec;
101+
sline2p = sline2c;
102+
}
103+
if (left_over > 0) {
104+
slinec = (is_in_one_chunk(vec_in1, leftover_size, VLEN) ? slinep : *vec_in1++);
67105

68-
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in1, in2);
106+
sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src0);
107+
sline2c = (is_in_one_chunk(vec_in2, leftover_size, VLEN) ? sline2p : *vec_in2++);
108+
sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1);
69109

70-
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
110+
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(sline, sline2);
111+
hvx_vec_store_u(vec_out, leftover_size, Q6_Vsf_equals_Vqf32(out));
112+
handled_leftover = true;
71113
}
72114
}
73115

74-
if (left_over > 0) {
116+
117+
if (left_over > 0 && !handled_leftover) {
75118
const float * src0f = (const float *) src0 + num_elems_whole;
76119
const float * src1f = (const float *) src1 + num_elems_whole;
77120
float * dstf = (float *) dst + num_elems_whole;
@@ -464,7 +507,7 @@ void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
464507
}
465508

466509
HVX_Vector val_vec = hvx_vec_splat_fp32(val);
467-
510+
bool handled_leftover = false;
468511
if (0 == unaligned_loop) {
469512
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
470513
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
@@ -475,17 +518,47 @@ void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
475518
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
476519
}
477520
} else {
521+
int step_of_1 = num_elems >> 5; // divby 32, because 32 float = 128 bytes per HVX vector
522+
int leftover_size = left_over * sizeof(float);
523+
524+
HVX_Vector * input_v_ptr = (HVX_Vector *) src;
525+
HVX_UVector * output_v_ptr = (HVX_UVector *) dst;
526+
527+
HVX_Vector slinep;
528+
HVX_Vector slinec;
529+
HVX_Vector sline;
530+
531+
slinep = *input_v_ptr++;
532+
478533
#pragma unroll(4)
479-
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
480-
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
534+
for (int i = step_of_1 - 1; i > 0; i--) {
535+
slinec = *input_v_ptr++;
536+
sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src);
537+
*((HVX_UVector *) (output_v_ptr++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec));
538+
/* Prepare slinep for next iteration */
539+
slinep = slinec;
540+
}
481541

482-
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, val_vec);
542+
if (step_of_1 > 0) {
543+
slinec = htp_is_aligned(input_v_ptr, VLEN) && left_over == 0 ? slinep : *input_v_ptr++;
544+
sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src);
545+
*((HVX_UVector *) (output_v_ptr++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec));
483546

484-
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
547+
slinep = slinec;
548+
}
549+
550+
if (leftover_size > 0) {
551+
slinec = (is_in_one_chunk(input_v_ptr, leftover_size, VLEN) ? slinep : *input_v_ptr++);
552+
553+
sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src);
554+
555+
HVX_Vector sout = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec));
556+
hvx_vec_store_u(output_v_ptr, leftover_size, sout);
557+
handled_leftover = true;
485558
}
486559
}
487560

488-
if (left_over > 0) {
561+
if (left_over > 0 && !handled_leftover) {
489562
const float * srcf = (const float *) src + num_elems_whole;
490563
float * dstf = (float *) dst + num_elems_whole;
491564

ggml/src/ggml-hexagon/htp/hvx-utils.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,16 @@ static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t
265265
}
266266
}
267267

268+
269+
/* Return whether 'n' elements from vector are in the one chunk of 'chunk_size'. */
268270
static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
269271
uint32_t left_off = (size_t) addr & (chunk_size - 1);
270272
uint32_t right_off = left_off + n;
271273
return right_off <= chunk_size;
272274
}
273275

276+
277+
274278
static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
275279
HVX_VectorAlias u = { .v = v };
276280

@@ -994,6 +998,59 @@ static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t *
994998
}
995999
}
9961000

1001+
1002+
static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems){
1003+
int step_of_1 = num_elems >> 5; // divby 32, because 32 float = 128 bytes per HVX vector
1004+
int leftover = num_elems - (step_of_1 * VLEN_FP32);
1005+
1006+
int32_t leftover_size = leftover * sizeof(float);
1007+
1008+
static const float kMinExp = -87.f; // 0
1009+
static const float kMaxExp = 87.f; // 1
1010+
1011+
const HVX_Vector one = hvx_vec_splat_fp32(1.f);
1012+
const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
1013+
const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
1014+
1015+
const float *input = (float *)src;
1016+
float *output = (float *)dst;
1017+
1018+
HVX_Vector * input_v_ptr = (HVX_Vector *) input;
1019+
HVX_UVector * output_v_ptr = (HVX_UVector *) output;
1020+
1021+
HVX_Vector slinep;
1022+
HVX_Vector slinec;
1023+
HVX_Vector sline;
1024+
1025+
slinep = *input_v_ptr++;
1026+
#pragma unroll(4)
1027+
for (int i = step_of_1 - 1; i > 0; i--) {
1028+
slinec = *input_v_ptr++;
1029+
sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
1030+
*((HVX_UVector *) (output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);
1031+
/* Prepare slinep for next iteration */
1032+
slinep = slinec;
1033+
}
1034+
1035+
if (step_of_1 > 0) {
1036+
slinec = htp_is_aligned(input_v_ptr, 128) && leftover == 0 ? slinep : *input_v_ptr++;
1037+
sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
1038+
*((HVX_UVector *) (output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);
1039+
;
1040+
1041+
slinep = slinec;
1042+
}
1043+
if (leftover > 0) {
1044+
slinec = (is_in_one_chunk(input_v_ptr, leftover_size, 128) ? slinep : *input_v_ptr++);
1045+
1046+
sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
1047+
1048+
HVX_Vector sout = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);
1049+
hvx_vec_store_u(output_v_ptr, leftover_size, sout);
1050+
}
1051+
}
1052+
1053+
9971054
float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
9981055
void hvx_mul_f32(const uint8_t * restrict src0,
9991056
const uint8_t * restrict src1,

0 commit comments

Comments
 (0)