Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ if (CUDAToolkit_FOUND)
template-instances/fattn-vec-instance-f16-f16.cu
template-instances/fattn-vec-instance-q4_0-q4_0.cu
template-instances/fattn-vec-instance-q8_0-q8_0.cu
template-instances/fattn-vec-instance-bf16-bf16.cu)
template-instances/fattn-vec-instance-bf16-bf16.cu
template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu)
endif()

ggml_add_backend_library(ggml-cuda
Expand Down
9 changes: 9 additions & 0 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "convert.cuh"
#include "dequantize.cuh"
#include "turbo-quant.cuh"

#include <cstdint>

Expand Down Expand Up @@ -756,6 +757,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_NVFP4:
return dequantize_row_nvfp4_cuda;
case GGML_TYPE_TURBO3_0:
return dequantize_block_cont_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
case GGML_TYPE_BF16:
Expand Down Expand Up @@ -809,6 +812,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_NVFP4:
return dequantize_row_nvfp4_cuda;
case GGML_TYPE_TURBO3_0:
return dequantize_block_cont_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
case GGML_TYPE_F16:
return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16:
Expand All @@ -832,6 +837,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_TURBO3_0:
return dequantize_block_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
case GGML_TYPE_BF16:
return convert_unary_cuda<nv_bfloat16>;
default:
Expand Down Expand Up @@ -874,6 +881,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_TURBO3_0:
return dequantize_block_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
case GGML_TYPE_BF16:
return convert_unary_cuda<nv_bfloat16, float>;
default:
Expand Down
10 changes: 10 additions & 0 deletions ggml/src/ggml-cuda/dequantize.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "common.cuh"
#include "turbo-quant.cuh"

static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;
Expand Down Expand Up @@ -75,3 +76,12 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
v.x *= d;
v.y *= d;
}

// Turbo3: 3-bit PolarQuant (2-bit qs + 1-bit sign), block size 32
// iqs is the element index within the block (even), produces elements iqs and iqs+1
static __device__ __forceinline__ void dequantize_turbo3_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_turbo3_0 * x = (const block_turbo3_0 *) vx;
const float norm = __half2float(x[ib].norm);
v.x = turbo3_dequant_element(&x[ib], iqs + 0, norm);
v.y = turbo3_dequant_element(&x[ib], iqs + 1, norm);
}
129 changes: 129 additions & 0 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "common.cuh"
#include "convert.cuh"
#include "vecdotq.cuh"
#include "turbo-quant.cuh"

#include <cstdint>

Expand Down Expand Up @@ -288,6 +289,66 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0(
return sum;
}

// Turbo3 KQ dot product: dequantize K from turbo3 blocks, dot with Q (float2/half2)
// Uses float Q path (like f16), not q8_1 integer path.
// Q_v is half2[] or float2[] with D/2 pairs, partitioned nthreads-strided.
//
// Matches the f16 pattern: outer loop steps by nthreads*cpy_ne, inner loop
// processes cpy_ne pairs per thread per iteration so Q_v and K indices stay aligned.
// elem0 = 2*k_KQ is always even, so elem0 and elem0+1 always share the same
// turbo3 block (ib), qs byte, and signs byte — loaded once per pair.
template <int D, int nthreads>
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_turbo3_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {

const block_turbo3_0 * K_turbo = (const block_turbo3_0 *) K_c;
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);

constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;

float sum = 0.0f;

#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
const int k_KQ = k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne + k_KQ_1;

// elem0 is always even; elem0 and elem1 are always in the same block,
// the same qs byte (j0%4 ∈ {0,2}), and the same signs byte (j0%8 ∈ {0,2,4,6}).
const int elem0 = k_KQ * 2; // always even
const int ib = elem0 / QK_TURBO3; // shared block index
const int j0 = elem0 % QK_TURBO3; // always even, 0..30

// Single loads for the shared block fields
const float norm = __half2float(K_turbo[ib].norm);
const uint8_t qs_byte = K_turbo[ib].qs[j0 / 4]; // covers both j0 and j0+1
const uint8_t sgn_byte = K_turbo[ib].signs[j0 / 8]; // covers both j0 and j0+1

// Extract 3-bit indices for elem0 and elem1 from shared bytes
const int shift = (j0 % 4) * 2; // 0 or 4
const uint8_t idx0 = ((qs_byte >> shift) & 0x3) | (((sgn_byte >> (j0 % 8)) & 0x1) << 2);
const uint8_t idx1 = ((qs_byte >> (shift+2)) & 0x3) | (((sgn_byte >> (j0 % 8 + 1)) & 0x1) << 2);

float2 kv;
kv.x = TURBO_CENTROIDS_3BIT[idx0] * norm;
kv.y = TURBO_CENTROIDS_3BIT[idx1] * norm;

#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 qv = ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1];
ggml_cuda_mad(sum, make_float2(kv.x, kv.y), __half22float2(qv));
#else
const float2 qv = ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1];
sum += kv.x * qv.x + kv.y * qv.y;
#endif // V_DOT2_F32_F16_AVAILABLE
}
}

return sum;
}

template <typename Tds, int ni>
static __device__ __forceinline__ void quantize_q8_1_to_shared(
const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
Expand Down Expand Up @@ -577,6 +638,70 @@ static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict
}
}

// Turbo3 V dequantize: extract `ne` float/half values at position i0.
//
// Optimised for the ne==4 path (used by the VEC kernel with turbo3 V):
// i0 is always a multiple of 4 from the VEC kernel access pattern, so all 4
// elements share one qs byte and one signs byte — we load each once.
template <typename T, int ne>
static __device__ __forceinline__ void dequantize_V_turbo3_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
const block_turbo3_0 * x = (const block_turbo3_0 *) vx;

const int64_t ib = i0 / QK_TURBO3;
const int j0 = i0 % QK_TURBO3;
const float norm = __half2float(x[ib].norm);

static_assert(ne == 2 || ne == 4, "bad ne");

if constexpr (ne == 4) {
// When j0 % 4 == 0 (always true from VEC kernel), all 4 elements share one
// qs byte (4 elements per byte) and one signs byte (8 elements per byte).
const uint8_t qs_byte = x[ib].qs[j0 / 4];
const uint8_t sgn_byte = x[ib].signs[j0 / 8];
const int shift_s = j0 % 8; // 0 or 4

const uint8_t idx0 = ((qs_byte >> 0) & 0x3) | (((sgn_byte >> (shift_s+0)) & 0x1) << 2);
const uint8_t idx1 = ((qs_byte >> 2) & 0x3) | (((sgn_byte >> (shift_s+1)) & 0x1) << 2);
const uint8_t idx2 = ((qs_byte >> 4) & 0x3) | (((sgn_byte >> (shift_s+2)) & 0x1) << 2);
const uint8_t idx3 = ((qs_byte >> 6) & 0x3) | (((sgn_byte >> (shift_s+3)) & 0x1) << 2);

#ifdef FP16_AVAILABLE
if constexpr (std::is_same_v<T, half>) {
((half2 *) dst)[0] = make_half2(
__float2half(TURBO_CENTROIDS_3BIT[idx0] * norm),
__float2half(TURBO_CENTROIDS_3BIT[idx1] * norm));
((half2 *) dst)[1] = make_half2(
__float2half(TURBO_CENTROIDS_3BIT[idx2] * norm),
__float2half(TURBO_CENTROIDS_3BIT[idx3] * norm));
} else
#endif // FP16_AVAILABLE
if constexpr (std::is_same_v<T, float>) {
((float2 *) dst)[0] = make_float2(
TURBO_CENTROIDS_3BIT[idx0] * norm,
TURBO_CENTROIDS_3BIT[idx1] * norm);
((float2 *) dst)[1] = make_float2(
TURBO_CENTROIDS_3BIT[idx2] * norm,
TURBO_CENTROIDS_3BIT[idx3] * norm);
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
} else { // ne == 2
#ifdef FP16_AVAILABLE
if constexpr (std::is_same_v<T, half>) {
float v0 = turbo3_dequant_element(&x[ib], j0, norm);
float v1 = turbo3_dequant_element(&x[ib], j0+1, norm);
((half2 *) dst)[0] = make_half2(__float2half(v0), __float2half(v1));
} else
#endif // FP16_AVAILABLE
if constexpr (std::is_same_v<T, float>) {
((float *) dst)[0] = turbo3_dequant_element(&x[ib], j0, norm);
((float *) dst)[1] = turbo3_dequant_element(&x[ib], j0+1, norm);
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
}
}

template <ggml_type type_K, int D, int nthreads>
constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
if constexpr (type_K == GGML_TYPE_F16) {
Expand All @@ -593,6 +718,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_BF16) {
return vec_dot_fattn_vec_KQ_bf16<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_TURBO3_0) {
return vec_dot_fattn_vec_KQ_turbo3_0<D, nthreads>;
} else {
static_assert(type_K == -1, "bad type");
return nullptr;
Expand All @@ -615,6 +742,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
return dequantize_V_q8_0<T, ne>;
} else if constexpr (type_V == GGML_TYPE_BF16) {
return dequantize_V_bf16<float, ne>;
} else if constexpr (type_V == GGML_TYPE_TURBO3_0) {
return dequantize_V_turbo3_0<T, ne>;
} else {
static_assert(type_V == -1, "bad type");
return nullptr;
Expand Down
16 changes: 12 additions & 4 deletions ggml/src/ggml-cuda/fattn-vec.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,20 @@ static __global__ void flash_attn_ext_vec(
#endif // GGML_USE_HIP

constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q;
constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q;
// Turbo3 uses the float Q path (like f16/bf16), not q8_1 integer path
constexpr bool K_is_unquantized = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16 || type_K == GGML_TYPE_TURBO3_0);
constexpr bool V_is_unquantized = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16 || type_V == GGML_TYPE_TURBO3_0);
constexpr int nthreads_KQ = K_is_unquantized ? 128 / cpy_nb : nthreads_KQ_q;
constexpr int nthreads_V = V_is_unquantized ? (type_V == GGML_TYPE_TURBO3_0 ? nthreads_V_q : 128 / cpy_nb) : nthreads_V_q;

static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");

constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4;
constexpr int V_rows_per_thread = V_is_unquantized ? (type_V == GGML_TYPE_TURBO3_0 ? 4 : 2*cpy_ne) : 4;
constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;

constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16;
constexpr bool Q_q8_1 = !K_is_unquantized;
#ifdef V_DOT2_F32_F16_AVAILABLE
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
#else
Expand Down Expand Up @@ -598,3 +601,8 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16)

// TurboQuant3 — turbo3 K + turbo3 V (KV cache uses same type)
extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0);
extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0);
extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0);
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
#endif // GGML_CUDA_FA_ALL_QUANTS

// TurboQuant3 KV cache types (always enabled)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0)

GGML_ABORT("fatal error");
}

Expand Down Expand Up @@ -371,6 +374,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_BF16:
case GGML_TYPE_TURBO3_0:
break;
default:
return BEST_FATTN_KERNEL_NONE;
Expand Down
10 changes: 9 additions & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "ggml-cuda/gated_delta_net.cuh"
#include "ggml-cuda/set.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/turbo-wht.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
#include "ggml-cuda/solve_tri.cuh"
#include "ggml-cuda/tri.cuh"
Expand Down Expand Up @@ -2510,6 +2511,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SET_ROWS:
ggml_cuda_op_set_rows(ctx, dst);
break;
case GGML_OP_TURBO_WHT:
ggml_cuda_turbo_wht(ctx, dst);
break;
case GGML_OP_SET:
ggml_cuda_op_set(ctx, dst);
break;
Expand Down Expand Up @@ -4837,7 +4841,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
{
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL ||
op->type == GGML_TYPE_TURBO3_0) &&
op->src[0]->type == GGML_TYPE_F32 &&
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
} break;
Expand Down Expand Up @@ -4964,6 +4969,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CLAMP:
case GGML_OP_LOG:
return true;
case GGML_OP_TURBO_WHT:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
ggml_nelements(op->src[0]) % 128 == 0;
case GGML_OP_SSM_SCAN: {
if (op->src[3]->ne[0] == 1) {
// Mamba2
Expand Down
Loading