@@ -42,8 +42,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
4242 return elem;
4343}
4444
45- shared float tmpsh[row_split];
46-
4745const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
4846shared f16vec4 Qf[Br * qstride];
4947
@@ -134,6 +132,10 @@ void main() {
134132 }
135133 }
136134
135+ const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
136+ // mo_offset will point to the tile starting at row i*Br and col 0
137+ uint32_t mo_offset = mo_stride * i;
138+
137139#if BLOCK_SIZE > 1
138140 uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
139141 uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
@@ -144,66 +146,74 @@ void main() {
144146 uint32_t m_offset = gqa_iq1*KV;
145147 if (p.nem2 != 1 || p.nem3 != 1) {
146148 m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
149+ mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
147150 }
148151
152+ uint32_t mask_opt = 0;
153+ uint32_t mask_opt_idx = ~0;
154+
149155 [[dont_unroll]]
150156 for (uint32_t j = start_j; j < end_j; ++j) {
151157
152158 f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
159+ [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
160+ mask_cache[idx] = f16vec4(0);
161+ }
162+
153163 if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
154- bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
155164
156- float max_mask = NEG_FLT_MAX_OVER_2;
157- [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
158- uint32_t c = (idx + tid) / (Br / 4);
159- uint32_t r = (idx + tid) % (Br / 4);
160- if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
161- if ((!KV_bounds_check || j * Bc + c < KV)) {
162- f16vec4 m;
163- if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
164- m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
165- data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
166- data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
167- data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
168- max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
169- } else if (i * Br + r * 4 + 2 < p.nem1) {
170- m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
171- data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
172- data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
173- 0.0);
174- max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
175- } else if (i * Br + r * 4 + 1 < p.nem1) {
176- m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
177- data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
178- 0.0,
179- 0.0);
180- max_mask = max(max(max_mask, float(m[0])), float(m[1]));
181- } else if (i * Br + r * 4 < p.nem1) {
182- m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
183- 0.0,
184- 0.0,
185- 0.0);
186- max_mask = max(max_mask, float(m[0]));
187- } else {
188- m = f16vec4(0.0);
165+ if (USE_MASK_OPT && mask_opt_idx != j / 16) {
166+ mask_opt_idx = j / 16;
167+ mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
168+ }
169+ uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
170+ if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
171+ // skip this block
172+ continue;
173+ }
174+ // Only load if the block is not all zeros
175+ if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
176+ bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
177+
178+ float max_mask = NEG_FLT_MAX_OVER_2;
179+ [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
180+ uint32_t c = (idx + tid) / (Br / 4);
181+ uint32_t r = (idx + tid) % (Br / 4);
182+ if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
183+ if ((!KV_bounds_check || j * Bc + c < KV)) {
184+ f16vec4 m;
185+ if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
186+ m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
187+ data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
188+ data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
189+ data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
190+ max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
191+ } else if (i * Br + r * 4 + 2 < p.nem1) {
192+ m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
193+ data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
194+ data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
195+ 0.0);
196+ max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
197+ } else if (i * Br + r * 4 + 1 < p.nem1) {
198+ m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
199+ data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
200+ 0.0,
201+ 0.0);
202+ max_mask = max(max(max_mask, float(m[0])), float(m[1]));
203+ } else if (i * Br + r * 4 < p.nem1) {
204+ m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)],
205+ 0.0,
206+ 0.0,
207+ 0.0);
208+ max_mask = max(max_mask, float(m[0]));
209+ } else {
210+ m = f16vec4(0.0);
211+ }
212+ mask_cache[idx / WorkGroupSize] = m;
189213 }
190- mask_cache[idx / WorkGroupSize] = m;
191214 }
192215 }
193216 }
194- // skip the block if the mask is entirely -inf
195- bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
196- barrier();
197- if (gl_SubgroupInvocationID == 0) {
198- tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
199- }
200- barrier();
201- [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
202- max_mask = max(max_mask, tmpsh[s]);
203- }
204- if (max_mask <= NEG_FLT_MAX_OVER_2) {
205- continue;
206- }
207217 }
208218
209219 if (K_LOAD_SHMEM != 0) {
0 commit comments