Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
b566811
Add buffer label and enable dawn-specific toggles to turn off some ch…
reeselevine Oct 15, 2025
2560412
Intermediate state
reeselevine Oct 17, 2025
833c4a8
Fast working f16/f32 vec4
reeselevine Oct 19, 2025
bd38091
Working float fast mul mat
reeselevine Oct 20, 2025
f808c48
Clean up naming of mul_mat to match logical model, start work on q mu…
reeselevine Oct 22, 2025
a3b2f67
Setup for subgroup matrix mat mul
reeselevine Oct 25, 2025
0bdd9f4
Basic working subgroup matrix
reeselevine Oct 25, 2025
a80e2bb
Working subgroup matrix tiling
reeselevine Oct 26, 2025
e4fd0b5
Handle weirder sg matrix sizes (but still % sg matrix size)
reeselevine Oct 26, 2025
b524249
Working start to gemv
reeselevine Oct 26, 2025
749a791
working f16 accumulation with shared memory staging
reeselevine Oct 26, 2025
abaf12e
Print out available subgroup matrix configurations
reeselevine Oct 26, 2025
54c31c1
Vectorize dst stores for sg matrix shader
reeselevine Oct 26, 2025
0f6e38d
Gemv working scalar
reeselevine Oct 27, 2025
f2e187c
Minor set_rows optimization (#4)
neha-ha Oct 27, 2025
2aa05c6
Merge remote-tracking branch 'upstream/master'
reeselevine Oct 27, 2025
51aae63
Comment on dawn toggles
reeselevine Oct 27, 2025
9edfcc9
Working subgroup matrix code for (semi)generic sizes
reeselevine Oct 27, 2025
f0cfae4
Remove some comments
reeselevine Oct 27, 2025
904dbe3
Merge remote-tracking branch 'origin/master' into mul_mat_opt
reeselevine Oct 27, 2025
cf0c536
Cleanup code
reeselevine Oct 28, 2025
71c7a4a
Update dawn version and move to portable subgroup size
reeselevine Oct 28, 2025
c73893e
Try to fix new dawn release
reeselevine Oct 28, 2025
f538ca3
Update subgroup size comment
reeselevine Oct 28, 2025
f5001d8
Only check for subgroup matrix configs if they are supported
reeselevine Oct 28, 2025
844ba40
Add toggles for subgroup matrix/f16 support on nvidia+vulkan
reeselevine Oct 29, 2025
d426436
Make row/col naming consistent
reeselevine Oct 29, 2025
a46d093
Refactor shared memory loading
reeselevine Oct 29, 2025
eb7150a
Move sg matrix stores to correct file
reeselevine Oct 29, 2025
4ec09e4
Working q4_0
reeselevine Oct 30, 2025
9726640
Formatting
reeselevine Oct 31, 2025
346cfa4
Merge remote-tracking branch 'origin/master' into mul_mat_opt
reeselevine Nov 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,16 @@ jobs:
- name: Dawn Dependency
id: dawn-depends
run: |
DAWN_VERSION="v1.0.0"
DAWN_VERSION="v2.0.0"
DAWN_OWNER="reeselevine"
DAWN_REPO="dawn"
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz"
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip"
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
curl -L -o artifact.tar.gz \
curl -L -o artifact.zip \
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
mkdir dawn
tar -xvf artifact.tar.gz -C dawn --strip-components=1
unzip artifact.zip
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1

- name: Build
id: cmake_build
Expand Down Expand Up @@ -521,15 +522,16 @@ jobs:
id: dawn-depends
run: |
sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
DAWN_VERSION="v1.0.0"
DAWN_VERSION="v2.0.0"
DAWN_OWNER="reeselevine"
DAWN_REPO="dawn"
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz"
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip"
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
curl -L -o artifact.tar.gz \
curl -L -o artifact.zip \
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
mkdir dawn
tar -xvf artifact.tar.gz -C dawn --strip-components=1
unzip artifact.zip
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1

- name: Build
id: cmake_build
Expand Down
326 changes: 313 additions & 13 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,12 @@ def generate_variants(fname, input_dir, output_dir, outfile):
except ValueError:
decls_map = {}

with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f:
common_decls = f.read()
decls_map.update(parse_decls(common_decls))
for fname in sorted(os.listdir(input_dir)):
if fname.endswith(".tmpl"):
tmpl_path = os.path.join(input_dir, fname)
with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
decls = f_tmpl.read()
decls_map.update(parse_decls(decls))

shader_template = extract_block(text, "SHADER")
for variant in variants:
Expand Down
10 changes: 5 additions & 5 deletions ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -864,8 +864,8 @@ struct MulMatParams {
broadcast3: u32
};

@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // N rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed)
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns

@group(0) @binding(3) var<uniform> params: MulMatParams;
Expand All @@ -891,8 +891,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {

let dst2_rem = dst3_rem % dst2_stride;

let row = dst2_rem / params.n; // output row
let col = dst2_rem % params.n; // output column
let row = dst2_rem / params.m; // output row
let col = dst2_rem % params.m; // output column

let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
Expand All @@ -901,7 +901,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
sum += multiply_add(src0_idx_base, src1_idx_base, i);
}
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
}

#end(SHADER)
97 changes: 97 additions & 0 deletions ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#decl(SHMEM_VEC)
fn store_shmem(val: vec4<f16>, idx: u32) {
shmem[idx] = val.x;
shmem[idx + 1] = val.y;
shmem[idx + 2] = val.z;
shmem[idx + 3] = val.w;
}
#enddecl(SHMEM_VEC)

#decl(SHMEM_SCALAR)
fn store_shmem(val: f16, idx: u32) {
shmem[idx] = val;
}
#enddecl(SHMEM_SCALAR)

#decl(INIT_SRC0_SHMEM_FLOAT)

fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let src0_val = select( // taking a slight performance hit to avoid oob
{{SRC0_TYPE}}(0.0),
src0[src0_idx/{{VEC_SIZE}}],
global_m < params.m && global_k < params.k);
store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
}
}

#enddecl(INIT_SRC0_SHMEM_FLOAT)

#decl(INIT_SRC1_SHMEM)

fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
let tile_n = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_n = offset_n + tile_n;
let global_k = k_outer + tile_k;
let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
let src1_val = select(
{{SRC1_TYPE}}(0.0),
src1[src1_idx/{{VEC_SIZE}}],
global_n < params.n && global_k < params.k);
store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
}
}

#enddecl(INIT_SRC1_SHMEM)

#decl(INIT_SRC0_SHMEM_Q4_0)

const BLOCK_SIZE = 32u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;

fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;

let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_k = k_outer / BLOCK_SIZE + block_k;

if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];

for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 1u + block_offset + j];
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];

let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
shmem[shmem_idx + j * 2 + k] = q_lo;
shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
}
}
}
}
}

#enddecl(INIT_SRC0_SHMEM_Q4_0)
Loading
Loading