Skip to content

Commit c7577d0

Browse files
pwilkinshaobo.xie
authored andcommitted
Add DIAG for CUDA (#17873)
* Add DIAG for CUDA * Refactor parameters
1 parent 63e2534 commit c7577d0

File tree

4 files changed

+116
-0
lines changed

4 files changed

+116
-0
lines changed

ggml/src/ggml-cuda/diag.cu

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include "convert.cuh"
2+
#include "diag.cuh"
3+
#include "ggml.h"
4+
5+
template <typename T>
6+
static __global__ void diag_kernel(T * __restrict__ dst,
7+
const T * __restrict__ src,
8+
const int64_t ne0,
9+
const int64_t ne1,
10+
const int64_t ne2,
11+
const int64_t ne3,
12+
const int64_t total_elements) {
13+
const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
14+
15+
if (global_idx >= total_elements) {
16+
return;
17+
}
18+
19+
const int64_t i0 = global_idx % ne0;
20+
const int64_t i1 = (global_idx / ne0) % ne1;
21+
const int64_t i2 = (global_idx / (ne0 * ne1)) % ne2;
22+
const int64_t i3 = global_idx / (ne0 * ne1 * ne2);
23+
24+
const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0;
25+
26+
if (i0 == i1) {
27+
const int64_t batch_idx = i3 * ne2 + i2;
28+
const int64_t src_idx = batch_idx * ne0 + i0;
29+
dst[dst_idx] = src[src_idx];
30+
} else {
31+
dst[dst_idx] = ggml_cuda_cast<T>(0);
32+
}
33+
GGML_UNUSED_VARS(ne3);
34+
}
35+
36+
void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
37+
const ggml_tensor * src0 = dst->src[0];
38+
39+
void * dst_d = dst->data;
40+
const void * src0_d = src0->data;
41+
42+
cudaStream_t stream = ctx.stream();
43+
44+
GGML_ASSERT(ggml_is_contiguous(dst));
45+
GGML_ASSERT(ggml_is_contiguous(src0));
46+
47+
const int64_t ne00 = src0->ne[0];
48+
const int64_t ne01 = src0->ne[1];
49+
const int64_t ne02 = src0->ne[2];
50+
const int64_t ne03 = src0->ne[3];
51+
52+
const int64_t ne0 = dst->ne[0];
53+
const int64_t ne1 = dst->ne[1];
54+
const int64_t ne2 = dst->ne[2];
55+
const int64_t ne3 = dst->ne[3];
56+
57+
GGML_ASSERT(ne00 == ne0);
58+
GGML_ASSERT(ne01 == 1);
59+
GGML_ASSERT(ne02 == ne2);
60+
GGML_ASSERT(ne03 == ne3);
61+
62+
const int64_t n_elems = ggml_nelements(dst);
63+
const int64_t num_blocks = (n_elems + CUDA_DIAG_BLOCK_SIZE - 1) / CUDA_DIAG_BLOCK_SIZE;
64+
65+
switch (dst->type) {
66+
case GGML_TYPE_F32:
67+
diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((float *) dst_d, (const float *) src0_d, ne0,
68+
ne1, ne2, ne3, n_elems);
69+
break;
70+
case GGML_TYPE_F16:
71+
diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((half *) dst_d, (const half *) src0_d, ne0,
72+
ne1, ne2, ne3, n_elems);
73+
break;
74+
default:
75+
GGML_ABORT("unsupported type");
76+
}
77+
}

ggml/src/ggml-cuda/diag.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_DIAG_BLOCK_SIZE 256
4+
5+
void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ extern "C" void ggml_dl_flash_attn_ext_dldnn_prepare_varlen_buffers(ggml_backend
3939
#include "ggml-cuda/cpy.cuh"
4040
#include "ggml-cuda/cross-entropy-loss.cuh"
4141
#include "ggml-cuda/diagmask.cuh"
42+
#include "ggml-cuda/diag.cuh"
4243
#include "ggml-cuda/fattn.cuh"
4344
#include "ggml-cuda/getrows.cuh"
4445
#include "ggml-cuda/im2col.cuh"
@@ -2803,6 +2804,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
28032804
case GGML_OP_PERMUTE:
28042805
case GGML_OP_TRANSPOSE:
28052806
break;
2807+
case GGML_OP_DIAG:
2808+
ggml_cuda_op_diag(ctx, dst);
2809+
break;
28062810
case GGML_OP_DIAG_MASK_INF:
28072811
ggml_cuda_op_diag_mask_inf(ctx, dst);
28082812
break;
@@ -4897,6 +4901,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
48974901
case GGML_OP_FILL:
48984902
case GGML_OP_CUMSUM:
48994903
case GGML_OP_TRI:
4904+
case GGML_OP_DIAG:
49004905
return true;
49014906
#ifdef GGML_USE_DLCU
49024907
case GGML_OP_MOE_SUM:

tests/test-backend-ops.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6754,6 +6754,31 @@ struct test_solve_tri : public test_case {
67546754
}
67556755
};
67566756

6757+
// GGML_OP_DIAG
6758+
struct test_diag : public test_case {
6759+
const ggml_type type;
6760+
const std::array<int64_t, 4> ne;
6761+
6762+
std::string vars() override { return VARS_TO_STR2(type, ne); }
6763+
6764+
test_diag(ggml_type type = GGML_TYPE_F32,
6765+
std::array<int64_t, 4> ne = { 10, 1, 4, 3 })
6766+
: type(type), ne(ne) {}
6767+
6768+
ggml_tensor * build_graph(ggml_context * ctx) override {
6769+
GGML_ASSERT(ne[1] == 1);
6770+
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
6771+
ggml_set_param(a);
6772+
ggml_set_name(a, "a");
6773+
6774+
ggml_tensor * out = ggml_diag(ctx, a);
6775+
ggml_set_name(out, "out");
6776+
6777+
return out;
6778+
}
6779+
};
6780+
6781+
67576782
enum llm_norm_type {
67586783
LLM_NORM,
67596784
LLM_NORM_RMS,
@@ -8425,6 +8450,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
84258450
test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 }));
84268451
test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 }));
84278452

8453+
test_cases.emplace_back(new test_diag());
8454+
test_cases.emplace_back(new test_diag(GGML_TYPE_F32, { 79, 1, 19, 13 }));
8455+
test_cases.emplace_back(new test_diag(GGML_TYPE_F32, { 256, 1, 8, 16 }));
8456+
84288457
test_cases.emplace_back(new test_solve_tri());
84298458
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 }));
84308459
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 17, 17, 2, 4 }, { 9, 17, 2, 4 }));

0 commit comments

Comments
 (0)