From 1703c0ca5416cd03da442262d13cf69fde8e07d5 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sun, 6 Jul 2025 10:29:36 +0000 Subject: [PATCH 1/3] vulkan: increase LOAD_VEC_A to 8 (IQ1/IQ2) or 4 (IQ3) (#14485) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Commit taken from remyoudompheng's PR https://github.com/ggml-org/llama.cpp/pull/12260 Co-authored-by: Rémy Oudompheng --- ggml/src/vulkan-shaders/mul_mm.comp | 152 ++++++++++-------- .../src/vulkan-shaders/vulkan-shaders-gen.cpp | 4 +- 2 files changed, 83 insertions(+), 73 deletions(-) diff --git a/ggml/src/vulkan-shaders/mul_mm.comp b/ggml/src/vulkan-shaders/mul_mm.comp index 26163b167..888ce79f6 100644 --- a/ggml/src/vulkan-shaders/mul_mm.comp +++ b/ggml/src/vulkan-shaders/mul_mm.comp @@ -500,10 +500,9 @@ void main() { const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx % 128) / 4; - const int i8 = 2 * int(idx % 4); + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 32; const float d = float(data_a[ib].d); const uint qh = data_a[ib].qh[ib32]; @@ -512,22 +511,16 @@ void main() { const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); - const ivec2 gvec = ivec2( - bitfieldExtract(grid, 2 * (i8), 2), - bitfieldExtract(grid, 2 * (i8 + 1), 2) - ); - const vec2 v = dl * (vec2(gvec) + delta); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + [[unroll]] for (int k = 0; k < 8; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); + } #elif defined(DATA_A_IQ1_M) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib8 = (idx % 128) / 4; + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; const uint ib16 = ib8 / 2; - const int i8 = 2 * int(idx % 4); const uint16_t[4] scales = data_a[ib].scales; const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; @@ -538,21 +531,17 @@ void main() { const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); - const ivec2 gvec = ivec2( - bitfieldExtract(grid, 2 * (i8), 2), - bitfieldExtract(grid, 2 * (i8 + 1), 2) - ); - const vec2 v = dl * (vec2(gvec) + delta); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + [[unroll]] for (int k = 0; k < 8; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); + } #elif defined(DATA_A_IQ2_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx / 4) % 4; + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[8 * ib32 + ib8]; @@ -562,63 +551,81 @@ void main() { data_a[ib].qs[8*ib32 + 6], data_a[ib].qs[8*ib32 + 7] )); - const float db = d * 0.25 * (0.5 + (signs >> 28)); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28))); const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xxs_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_XS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx / 4) % 4; // 0..3 + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; // 0..3 const float d = float(data_a[ib].d); const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; - const float db = d * 0.25 * (0.5 + scale); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); const uint qs = data_a[ib].qs[4 * ib32 + ib8]; const uint sign7 = qs >> 9; - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xs_grid[qs & 511]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib8 = (idx % 128) / 4; // 0..31 - const uint ib32 = ib8 / 4; // 0..7 + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; const uint qs = data_a[ib].qs[ib8]; const uint qh = data_a[ib].qh[ib32]; const uint qhshift = 2 * (ib8 % 4); - const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8]; const float d = float(data_a[ib].d); - const float db = d * 0.25 * (0.5 + scale); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; - const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147 - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); + const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ3_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = (idx % 128) / 2; // 0..63 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values const float d = float(data_a[ib].d); @@ -631,33 +638,36 @@ void main() { )); const float db = d * 0.5 * (0.5 + (signs >> 28)); const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2)); + const uint grid = iq3xxs_grid[qs]; + const vec4 v = db * vec4(unpack8(grid)); + + buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ3_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = (idx % 128) / 2; // 0..63 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 const uint iqh = iqs / 8; const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[iqs]; const uint qh = data_a[ib].qh[iqh]; - const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4))); + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2))); const uint scale = data_a[ib].scales[iqs / 16]; const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); - const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ4_XS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp index 65dd82de0..1af71e165 100644 --- a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp @@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1")) + if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl")) load_vec_quant = "4"; if (tname == "bf16") { From d539037c680d0e4b9eea7c12c20736ec3a06bb65 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 8 Jul 2025 13:11:42 -0500 Subject: [PATCH 2/3] vulkan: optimize flash attention split_k_reduce (#14554) * vulkan: allow FA split_k with smaller KV values * vulkan: spread split_k_reduce work across more threads k_num can get rather large. Use the whole workgroup to reduce the M/L values. Launch a thread for each element in the HSV dimension of the output. Helps a lot for large HSV (like deepseek). # Conflicts: # ggml/src/ggml-vulkan.cpp --- ggml/src/ggml-vulkan.cpp | 8 ++-- .../flash_attn_split_k_reduce.comp | 46 +++++++++++++++---- 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index 4091c89fd..26b0568f9 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -2688,7 +2688,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { @@ -6197,13 +6197,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; // Try to use split_k when KV is large enough to be worth the overhead - if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { + if (workgroups_x == 1 && shader_core_count > 0) { // Try to run two workgroups per SM. split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); if (split_k > 1) { // Try to evenly split KV into split_k chunks, but it needs to be a multiple // of "align", so recompute split_k based on that. - split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align); + split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align); split_k = CEIL_DIV(KV, split_kv); workgroups_x = split_k; } @@ -6336,7 +6336,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, }, - pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 }); + pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, HSV, 1 }); } else { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { diff --git a/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp index a7e395685..90f8ad87e 100644 --- a/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -2,9 +2,9 @@ #extension GL_EXT_control_flow_attributes : enable -#define BLOCK_SIZE 32 +layout(constant_id = 0) const uint BLOCK_SIZE = 32; -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {float data_a[];}; layout (binding = 1) writeonly buffer D {float data_d[];}; @@ -15,6 +15,8 @@ layout (push_constant) uniform parameter { uint k_num; } p; +shared float tmpsh[BLOCK_SIZE]; + void main() { // Each workgroup handles a row const uint n = gl_WorkGroupID.x; @@ -30,23 +32,51 @@ void main() { // Compute the max m value for the row float m_max = -1.0/0.0; - [[unroll]] for (uint k = 0; k < k_num; ++k) { - float m = data_a[m_offset + k * lm_stride]; + for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) { + float m = data_a[m_offset + (k + tid) * lm_stride]; m_max = max(m_max, m); } + // reduce across the workgroup + tmpsh[tid] = m_max; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + m_max = max(m_max, tmpsh[tid + s]); + tmpsh[tid] = m_max; + } + barrier(); + } + m_max = tmpsh[0]; + + barrier(); + // Compute L based on m_max float L = 0; - [[unroll]] for (uint k = 0; k < k_num; ++k) { - float l = data_a[l_offset + k * lm_stride]; - float m = data_a[m_offset + k * lm_stride]; + for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) { + float l = data_a[l_offset + (k + tid) * lm_stride]; + float m = data_a[m_offset + (k + tid) * lm_stride]; L += exp(m - m_max) * l; } + // reduce across the workgroup + tmpsh[tid] = L; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + L += tmpsh[tid + s]; + tmpsh[tid] = L; + } + barrier(); + } + L = tmpsh[0]; + L = 1.0 / L; + // D dimension is split across workgroups in the y dimension + uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE; // Scale and sum the O contributions based on m_max and store the result to memory - for (uint d = tid; d < D; d += BLOCK_SIZE) { + if (d < D) { float O = 0.0; [[unroll]] for (uint k = 0; k < k_num; ++k) { uint o_offset = D * N * k + D * n + d; From 3ef6de294f34ae7a3259c7b919eae3a664783506 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 12 Jul 2025 04:51:58 -0500 Subject: [PATCH 3/3] vulkan: optimizations for deepseek prompt processing (#14555) * vulkan: allow unclamped loads in coopmat2 mul_mat_id shader * vulkan: increase coopmat2 mul_mat_id tile size * vulkan: optimize mat_mul_id row_ids search to batch loads, and port to coopmat1 path * vulkan: use smaller FA row size when head size is large. applies to both scalar and CM2 paths (CM1 isn't used due to shared memory limits) --- ggml/src/ggml-vulkan.cpp | 27 +++++++++----- ggml/src/vulkan-shaders/mul_mm.comp | 48 ++++++++++++++++++++++++- ggml/src/vulkan-shaders/mul_mm_cm2.comp | 47 +++++++++++++++++++----- 3 files changed, 104 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index 26b0568f9..bb8b9fd46 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -1717,7 +1717,14 @@ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) { // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static constexpr uint32_t scalar_flash_attention_num_large_rows = 8; + +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { + if (hsv >= 512) { + return 2; + } else { + return 8; + } +} // The FA coopmat1 shader assumes 16x16x16 matrix multiply support. // 128 threads split into four subgroups, each subgroup does 1/4 @@ -1742,7 +1749,7 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if (small_rows) { return {scalar_flash_attention_num_small_rows, 64}; } else { - return {scalar_flash_attention_num_large_rows, 32}; + return {get_fa_scalar_num_large_rows(hsv), 32}; } } @@ -1761,7 +1768,11 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 // small cols to reduce register count if (ggml_is_quantized(type) || hsk >= 256) { - return {64, 32}; + if (hsk >= 512) { + return {32, 32}; + } else { + return {64, 32}; + } } return {64, 64}; } @@ -1803,7 +1814,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec const uint32_t warps = warptile[0] / warptile[10]; const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; - const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0; + const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0; const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size; @@ -1928,10 +1939,10 @@ static void ggml_vk_load_shaders(vk_device& device) { s_mmq_wg_denoms_k = { 32, 32, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 64, 16, 0 }; + l_warptile_mmqid = { 256, 128, 128, 16, 0 }; m_warptile_mmqid = { 256, 128, 64, 16, 0 }; s_warptile_mmqid = { 256, 128, 64, 16, 0 }; - l_mmqid_wg_denoms = { 128, 64, 1 }; + l_mmqid_wg_denoms = { 128, 128, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; s_mmqid_wg_denoms = { 128, 64, 1 }; @@ -5994,7 +6005,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = scalar_flash_attention_num_large_rows; + const uint32_t Br = get_fa_scalar_num_large_rows(hsv); const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t tmpsh = wg_size * sizeof(float); @@ -6118,7 +6129,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = scalar_flash_attention_num_large_rows; + max_gqa = get_fa_scalar_num_large_rows(HSV); break; case FA_COOPMAT2: max_gqa = get_fa_num_small_rows(FA_COOPMAT2); diff --git a/ggml/src/vulkan-shaders/mul_mm.comp b/ggml/src/vulkan-shaders/mul_mm.comp index 888ce79f6..f48154991 100644 --- a/ggml/src/vulkan-shaders/mul_mm.comp +++ b/ggml/src/vulkan-shaders/mul_mm.comp @@ -18,6 +18,7 @@ #extension GL_KHR_cooperative_matrix : enable #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_ballot : enable #endif #ifdef MUL_MAT_ID @@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; #ifdef MUL_MAT_ID shared u16vec2 row_ids[4096]; +uint _ne1; +#ifdef COOPMAT +shared uint _ne1_sh; +#endif #endif // MUL_MAT_ID #define NUM_WARPS (BLOCK_SIZE / WARP) @@ -172,7 +177,47 @@ void main() { const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK; #ifdef MUL_MAT_ID - uint _ne1 = 0; +#ifdef COOPMAT + // Spread the search across all elements in the first subgroup + if (gl_SubgroupID == 0) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; + bool in_range = i < num_elements; + uint ii1 = i / p.nei0; + uint ii0 = i % p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_SubgroupInvocationID; + bool in_range = i < num_elements; + uint ii1 = i / p.nei0; + uint ii0 = i % p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + uint idx = subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx) { + row_ids[_ne1 + idx] = u16vec2(ii0, ii1); + } + _ne1 += subgroupBallotBitCount(ballot); + iter &= 15; + } + _ne1_sh = _ne1; + } + + barrier(); + + _ne1 = _ne1_sh; +#else + _ne1 = 0; for (uint ii1 = 0; ii1 < p.nei1; ii1++) { for (uint ii0 = 0; ii0 < p.nei0; ii0++) { if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { @@ -183,6 +228,7 @@ void main() { } barrier(); +#endif // Workgroup has no work if (ic * BN >= _ne1) return; diff --git a/ggml/src/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/vulkan-shaders/mul_mm_cm2.comp index 918465757..29e4b5c9c 100644 --- a/ggml/src/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/vulkan-shaders/mul_mm_cm2.comp @@ -162,17 +162,32 @@ void main() { _ne1 = 0; uint num_elements = p.nei1 * p.nei0; - for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) { + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; + bool in_range = i < num_elements; + uint ii1 = i / p.nei0; + uint ii0 = i % p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_SubgroupInvocationID; bool in_range = i < num_elements; - uint ii0 = i % p.nei0; uint ii1 = i / p.nei0; - uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + uint ii0 = i % p.nei0; + uint id = ids[iter++]; uvec4 ballot = subgroupBallot(in_range && id == expert_idx); uint idx = subgroupBallotExclusiveBitCount(ballot); if (in_range && id == expert_idx) { row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); } _ne1 += subgroupBallotBitCount(ballot); + iter &= 15; } _ne1_sh = _ne1; } @@ -414,17 +429,31 @@ void main() { fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); } - coopmat mat_a; - coopmat mat_b; + if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) { + coopmat mat_a; + coopmat mat_b; - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); #else - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif - sum = coopMatMulAdd(mat_a, mat_b, sum); + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); +#endif + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } } // Convert from ACC_TYPE to D_TYPE