Skip to content

Commit 4f711db

Browse files
authored
[Layers] Improve the first-token perf. in w8a8(f32 InT) by enabling flashAttn (#43)
1 parent 3729f2a commit 4f711db

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

src/layers/attention.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ class Attention {
372372

373373
TimeLine t4("MHA");
374374
if (pastSeqLen == 0) {
375-
if (ctx->inputSeqLen > Env::getInstance().getFlashAttnThreshold()) {
375+
if (Env::getInstance().getFlashAttnEnabled<InT>(ctx->inputSeqLen)) {
376376
flashAttention(ctx, query, key, value, attnSplit, presentKey, presentValue, attnMask, pastSeqLen);
377377
} else if constexpr ((std::is_same_v<InT, bfloat16_t> && std::is_same_v<OutT, bfloat16_t>)
378378
#if defined(AMX_FP16_WEIGHT_ONLY_FP16)
@@ -581,7 +581,7 @@ class Attention {
581581

582582
TimeLine t4("MHA");
583583
if (seqs[0]->getStep() == 0) { // First token generation
584-
if (totInSeqLen > Env::getInstance().getFlashAttnThreshold() * seqs.size()) {
584+
if (Env::getInstance().getFlashAttnEnabled<InT>(totInSeqLen / seqs.size())) {
585585
flashAttention(ctx, query, key, value, attnSplit, keyCaches, valueCaches, seqs);
586586
} else if constexpr ((std::is_same_v<InT, bfloat16_t> && std::is_same_v<OutT, bfloat16_t>)
587587
#if defined(AMX_FP16_WEIGHT_ONLY_FP16)

src/utils/environment.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,14 @@ class Env {
4848
int getAMXThresholdM() { return AMXThresholdMValue; }
4949

5050
// get FLASH_ATTN_THRESHOLD
51-
int getFlashAttnThreshold() { return FlashAttnThresholdValue; }
51+
template <typename T>
52+
bool getFlashAttnEnabled(int inputLen) {
53+
if (FlashAttnThresholdValue >= 0 &&
54+
(std::is_same_v<T, float> || inputLen > FlashAttnThresholdValue))
55+
return true;
56+
else
57+
return false;
58+
}
5259

5360
// get ENABLE_CAT_MLP
5461
bool getMlpCatEnabled() { return MlpCatEnabled; }
@@ -245,13 +252,12 @@ class Env {
245252
// > threshold to enable flash attention, default 8192
246253
char *flashAttnThresholdValue = getenv("FLASH_ATTN_THRESHOLD");
247254
if (flashAttnThresholdValue != NULL) {
248-
int value = atoi(flashAttnThresholdValue);
249-
if (value >= 0)
250-
FlashAttnThresholdValue = value;
251-
else
252-
printf("[ERROR] FLASH_ATTN_THRESHOLD value need to be greater than or equal to 0.\n");
255+
FlashAttnThresholdValue = atoi(flashAttnThresholdValue);
253256
}
254-
printf("[INFO] SeqLen > FLASH_ATTN_THRESHOLD(%d) will enable FlashAttn.\n", FlashAttnThresholdValue);
257+
if (FlashAttnThresholdValue < 0)
258+
printf("[INFO] FlashAttn is disabled (FLASH_ATTN_THRESHOLD = %d).\n", FlashAttnThresholdValue);
259+
else
260+
printf("[INFO] SeqLen > FLASH_ATTN_THRESHOLD(%d) will enable FlashAttn.\n", FlashAttnThresholdValue);
255261
}
256262

257263
// ENABLE_CAT_MLP

0 commit comments

Comments
 (0)