Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 169 additions & 1 deletion ggml/src/ggml-cpu/repack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,11 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR

size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);

return true;
}
case GGML_OP_GET_ROWS:
{
size = 0;
return true;
}
default:
Expand All @@ -1593,6 +1598,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
case GGML_OP_MUL_MAT_ID:
forward_mul_mat_id(params, op);
return true;
case GGML_OP_GET_ROWS:
forward_get_rows(params, op);
return true;
default:
// GGML_ABORT("fatal error");
break;
Expand Down Expand Up @@ -1801,6 +1809,155 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
#undef MMID_MATRIX_ROW
}

void forward_get_rows(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_Q6_K:
ggml_compute_forward_get_rows_q6_Kx8(params, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
}

static void ggml_compute_forward_get_rows_q6_Kx8(const ggml_compute_params * params, ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

GGML_TENSOR_BINARY_OP_LOCALS
const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);

assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == ggml_type_size(src0->type));
assert(ggml_nrows(dst) == nr);

const int ith = params->ith;
const int nth = params->nth;

// rows per thread
const int dr = (nr + nth - 1) / nth;

// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);

constexpr int nrows_interleaved = 8;
const size_t sizeof_one_repacked_block = sizeof(block_q6_Kx8);

const int num_repacked_blocks_per_row_width = nc / QK_K;

const size_t stride_between_actual_row_groups = num_repacked_blocks_per_row_width * sizeof_one_repacked_block;

for (int64_t i = ir0; i < ir1; ++i) {
const int64_t i12 = i / (ne11 * ne10);
const int64_t i11 = (i - i12 * ne11 * ne10) / ne10;
const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10);
const int64_t i01 = *(int32_t *)((char *)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); // original logical row

GGML_ASSERT(i01 >= 0 && i01 < ne01);

const int row_group_idx = i01 / nrows_interleaved;
const int row_idx_in_group = i01 % nrows_interleaved;

const char * base_ptr_for_higher_dims_in_src0 = (const char *)src0->data + i11 * nb02 + i12 * nb03;

// Pointer to the first block_q6_Kx8 of the identified row_group_idx
const block_q6_Kx8 * p_first_repacked_block_of_group_x8 = (const block_q6_Kx8 *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups);

dequantize_row_q6_Kx8(
p_first_repacked_block_of_group_x8,
(float *)((char *)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc, row_idx_in_group);
}
}


/**
* Dequantizes a single logical row from data repacked with quant interleaving for repacked block_q6_Kx8
*
* @param p_repacked_group_column_blocks Pointer to the start of 'block_q6_Kx8' for the row group.
* @param y Output buffer for the dequantized float values.
* @param k Total number of elements (columns) in the logical row.
* @param row_idx_in_group Index (0-7) of the logical row to dequantize.
*/

static void dequantize_row_q6_Kx8(
const void * GGML_RESTRICT p_repacked_blocks,
float * GGML_RESTRICT y,
int64_t k,
int row_idx_in_group) {

assert(k % QK_K == 0);
assert(row_idx_in_group >= 0 && row_idx_in_group < 8);

const int nb = k / QK_K;
const block_q6_Kx8 * blocks = (const block_q6_Kx8 *)p_repacked_blocks;

for (int i = 0; i < nb; i++) {
const block_q6_Kx8 * current_block = &blocks[i];

const float d_super_block = GGML_FP16_TO_FP32(current_block->d[row_idx_in_group]);

const uint8_t * ptr_ql_base = current_block->ql;
const uint8_t * ptr_qh_base = current_block->qh;
uint8_t * ptr_repacked_scales = (uint8_t *) current_block->scales; // 16 * 8 scales repacked
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;

// get the 4 scales needed for q1, q2, q3 and q4
const int8_t sc0 = read_scale_from_repacked(ptr_repacked_scales, row_idx_in_group, is + 0);
const int8_t sc1 = read_scale_from_repacked(ptr_repacked_scales, row_idx_in_group, is + 2);
const int8_t sc2 = read_scale_from_repacked(ptr_repacked_scales, row_idx_in_group, is + 4);
const int8_t sc3 = read_scale_from_repacked(ptr_repacked_scales, row_idx_in_group, is + 6);

// get the right ql & qh values from the interleaved data
const uint8_t ql_l0 = read_ql_qh_from_repacked(ptr_ql_base, row_idx_in_group, n/2 + l + 0);
const uint8_t ql_l32 = read_ql_qh_from_repacked(ptr_ql_base, row_idx_in_group, n/2 + l + 32);
const uint8_t qh_l = read_ql_qh_from_repacked(ptr_qh_base, row_idx_in_group, n/4 + l);

const int8_t q1 = (int8_t) ((ql_l0 & 0xF) | (((qh_l >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t) ((ql_l32 & 0xF) | (((qh_l >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t) ((ql_l0 >> 4) | (((qh_l >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t) ((ql_l32 >> 4) | (((qh_l >> 6) & 3) << 4)) - 32;

y[l] = d_super_block * sc0 * q1;
y[l + 32] = d_super_block * sc1 * q2;
y[l + 64] = d_super_block * sc2 * q3;
y[l + 96] = d_super_block * sc3 * q4;
}
y += 128;
ptr_repacked_scales = (uint8_t *) current_block->scales + 64;
}
}
}


/**
* Read the scales from the repacked ptr_repacked_scales
*/
static inline int8_t read_scale_from_repacked(const uint8_t* ptr_repacked_scales, int row_idx_in_group, int scale_idx) {
const int pair_group_idx = scale_idx / 2;
const int sub_idx_in_pair = scale_idx % 2;
const int offset = pair_group_idx * 16 + row_idx_in_group * 2 + sub_idx_in_pair;
return ptr_repacked_scales[offset];
}

/**
* Read the qh / ql from the repacked ptr_qh_ql_base
*/
static inline uint8_t read_ql_qh_from_repacked(const uint8_t* ptr_qh_ql_base, int row_idx_in_group, int ql_0_idx) {
const int block_size_interleave = 8;
const int chunk_idx = ql_0_idx / block_size_interleave;
const int offset_in_chunk = ql_0_idx % block_size_interleave;
const int offset = chunk_idx * (8 * block_size_interleave) + row_idx_in_group * block_size_interleave + offset_in_chunk;
return ptr_qh_ql_base[offset];
}

int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
(int) NB_COLS, (int) INTER_SIZE);
Expand Down Expand Up @@ -1949,12 +2106,23 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
//if (op->src[1]->type == GGML_TYPE_Q8_0) {
// return true;
//}
} else if (op->op == GGML_OP_GET_ROWS
&& op->src[0]->buffer
&& (ggml_n_dims(op->src[0]) == 2)
&& op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()
&& ggml_repack_get_optimal_repack_type(op->src[0])) {
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
if (op->src[0]->type == GGML_TYPE_Q6_K) {
return true;
}
}
return false;
}

ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID || op->op == GGML_OP_GET_ROWS) {
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
}
Expand Down