Skip to content

Commit 6195385

Browse files
committed
vulkan: Preprocess FA mask to detect all-neg-inf and all-zero.
Write out a 2-bit code per block and avoid loading the mask when it matches these two common cases. Apply this optimization when the mask is relatively large (i.e. prompt processing).
1 parent 9f682fb commit 6195385

File tree

8 files changed

+362
-125
lines changed

8 files changed

+362
-125
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 89 additions & 20 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ void main() {
9494
}
9595
}
9696

97+
const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
98+
// mo_offset will point to the tile starting at row i*Br and col 0
99+
uint32_t mo_offset = mo_stride * i;
100+
97101
#if BLOCK_SIZE > 1
98102
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
99103
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
@@ -104,41 +108,41 @@ void main() {
104108
uint32_t m_offset = gqa_iq1*KV;
105109
if (p.nem2 != 1 || p.nem3 != 1) {
106110
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
111+
mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
107112
}
108113

114+
uint32_t mask_opt = 0;
115+
uint32_t mask_opt_idx = ~0;
116+
109117
[[dont_unroll]]
110118
for (uint32_t j = start_j; j < end_j; ++j) {
111119

112-
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
120+
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
121+
mask_opt_idx = j / 16;
122+
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
123+
}
124+
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
125+
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
126+
// skip this block
127+
continue;
128+
}
129+
// Only load if the block is not all zeros
130+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
113131
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
114132

115-
float max_mask = NEG_FLT_MAX_OVER_2;
116133
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
117134
uint32_t c = (idx + tid) % Bc;
118135
uint32_t r = (idx + tid) / Bc;
119136
if (idx + tid < Bc * Br) {
120137
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
121138
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
122139
masksh[c][r] = m;
123-
max_mask = max(max_mask, m);
124140
} else {
125141
masksh[c][r] = float(0);
126142
}
127143
}
128144
}
129-
// skip the block if the mask is entirely -inf
130-
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
131-
barrier();
132-
if (gl_SubgroupInvocationID == 0) {
133-
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
134-
}
135145
barrier();
136-
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
137-
max_mask = max(max_mask, tmpsh[s]);
138-
}
139-
if (max_mask <= NEG_FLT_MAX_OVER_2) {
140-
continue;
141-
}
142146
}
143147

144148
float Sf[Br][cols_per_thread];
@@ -185,7 +189,7 @@ void main() {
185189
}
186190
}
187191

188-
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
192+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
189193
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
190194
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
191195
float mvf = masksh[c * cols_per_iter + col_tid][r];
@@ -256,9 +260,6 @@ void main() {
256260
barrier();
257261
}
258262

259-
// prevent race on tmpsh
260-
barrier();
261-
262263
// reduce across threads
263264

264265
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ layout (constant_id = 5) const uint32_t Clamp = 0;
1010
layout (constant_id = 6) const uint32_t D_split = 16;
1111
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
1212
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
13+
layout (constant_id = 9) const bool USE_MASK_OPT = false;
1314

1415
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
1516
const uint32_t HSK_pad = (HSK + 15) & ~15;
@@ -66,6 +67,11 @@ layout (binding = 4) readonly buffer S {float data_s[];};
6667

6768
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
6869

70+
layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
71+
72+
#define MASK_OPT_ALL_NEG_INF 1
73+
#define MASK_OPT_ALL_ZERO 2
74+
6975
#define BINDING_IDX_K 0
7076
#define BINDING_IDX_V 1
7177
#if defined(DATA_A_F32)

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 60 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
4242
return elem;
4343
}
4444

45-
shared float tmpsh[row_split];
46-
4745
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
4846
shared f16vec4 Qf[Br * qstride];
4947

@@ -134,6 +132,10 @@ void main() {
134132
}
135133
}
136134

135+
const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
136+
// mo_offset will point to the tile starting at row i*Br and col 0
137+
uint32_t mo_offset = mo_stride * i;
138+
137139
#if BLOCK_SIZE > 1
138140
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
139141
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
@@ -144,66 +146,74 @@ void main() {
144146
uint32_t m_offset = gqa_iq1*KV;
145147
if (p.nem2 != 1 || p.nem3 != 1) {
146148
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
149+
mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
147150
}
148151

152+
uint32_t mask_opt = 0;
153+
uint32_t mask_opt_idx = ~0;
154+
149155
[[dont_unroll]]
150156
for (uint32_t j = start_j; j < end_j; ++j) {
151157

152158
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
159+
[[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
160+
mask_cache[idx] = f16vec4(0);
161+
}
162+
153163
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
154-
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
155164

156-
float max_mask = NEG_FLT_MAX_OVER_2;
157-
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
158-
uint32_t c = (idx + tid) / (Br / 4);
159-
uint32_t r = (idx + tid) % (Br / 4);
160-
if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
161-
if ((!KV_bounds_check || j * Bc + c < KV)) {
162-
f16vec4 m;
163-
if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
164-
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
165-
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
166-
data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
167-
data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
168-
max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
169-
} else if (i * Br + r * 4 + 2 < p.nem1) {
170-
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
171-
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
172-
data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
173-
0.0);
174-
max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
175-
} else if (i * Br + r * 4 + 1 < p.nem1) {
176-
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
177-
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
178-
0.0,
179-
0.0);
180-
max_mask = max(max(max_mask, float(m[0])), float(m[1]));
181-
} else if (i * Br + r * 4 < p.nem1) {
182-
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
183-
0.0,
184-
0.0,
185-
0.0);
186-
max_mask = max(max_mask, float(m[0]));
187-
} else {
188-
m = f16vec4(0.0);
165+
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
166+
mask_opt_idx = j / 16;
167+
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
168+
}
169+
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
170+
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
171+
// skip this block
172+
continue;
173+
}
174+
// Only load if the block is not all zeros
175+
if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
176+
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
177+
178+
float max_mask = NEG_FLT_MAX_OVER_2;
179+
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
180+
uint32_t c = (idx + tid) / (Br / 4);
181+
uint32_t r = (idx + tid) % (Br / 4);
182+
if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
183+
if ((!KV_bounds_check || j * Bc + c < KV)) {
184+
f16vec4 m;
185+
if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
186+
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
187+
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
188+
data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
189+
data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
190+
max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
191+
} else if (i * Br + r * 4 + 2 < p.nem1) {
192+
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
193+
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
194+
data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
195+
0.0);
196+
max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
197+
} else if (i * Br + r * 4 + 1 < p.nem1) {
198+
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
199+
data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
200+
0.0,
201+
0.0);
202+
max_mask = max(max(max_mask, float(m[0])), float(m[1]));
203+
} else if (i * Br + r * 4 < p.nem1) {
204+
m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
205+
0.0,
206+
0.0,
207+
0.0);
208+
max_mask = max(max_mask, float(m[0]));
209+
} else {
210+
m = f16vec4(0.0);
211+
}
212+
mask_cache[idx / WorkGroupSize] = m;
189213
}
190-
mask_cache[idx / WorkGroupSize] = m;
191214
}
192215
}
193216
}
194-
// skip the block if the mask is entirely -inf
195-
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
196-
barrier();
197-
if (gl_SubgroupInvocationID == 0) {
198-
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
199-
}
200-
barrier();
201-
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
202-
max_mask = max(max_mask, tmpsh[s]);
203-
}
204-
if (max_mask <= NEG_FLT_MAX_OVER_2) {
205-
continue;
206-
}
207217
}
208218

209219
if (K_LOAD_SHMEM != 0) {

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -138,48 +138,53 @@ void main() {
138138
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
139139
}
140140

141+
const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
142+
// mo_offset will point to the tile starting at row i*Br and col 0
143+
uint32_t mo_offset = mo_stride * i;
144+
141145
uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/;
142146
if (p.nem2 != 1 || p.nem3 != 1) {
143147
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
148+
mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
144149
}
145150

151+
uint32_t mask_opt = 0;
152+
uint32_t mask_opt_idx = ~0;
153+
146154
[[dont_unroll]]
147155
for (uint32_t j = start_j; j < end_j; ++j) {
148156

149-
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
157+
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
150158
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
151-
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
152-
153-
if (nem1_bounds_check) {
154-
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
155-
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
156-
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
157-
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
158-
159-
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
160-
161-
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
162-
163-
// skip the block if the mask is entirely -inf
164-
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
165-
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
166-
continue;
167-
}
168-
} else {
169-
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
170-
// Don't clamp against nem1 when GQA is enabled
171-
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
172-
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
173-
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
174-
175-
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
176159

177-
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
178-
179-
// skip the block if the mask is entirely -inf
180-
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
181-
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
182-
continue;
160+
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
161+
mask_opt_idx = j / 16;
162+
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
163+
}
164+
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
165+
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
166+
// skip this block
167+
continue;
168+
}
169+
// Only load if the block is not all zeros
170+
if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
171+
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
172+
173+
if (nem1_bounds_check) {
174+
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
175+
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
176+
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
177+
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
178+
179+
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
180+
} else {
181+
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
182+
// Don't clamp against nem1 when GQA is enabled
183+
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
184+
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
185+
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
186+
187+
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
183188
}
184189
}
185190
}

0 commit comments

Comments
 (0)