Skip to content
This repository was archived by the owner on Feb 24, 2026. It is now read-only.

Commit da9695a

Browse files
authored
[Dev] Fix GEMV Dynamic Scheduling with Splitk (#52)
* improve e4m3 decoding. * append fp16xint1 * Update submodule commit reference * chore: Update shared memory scope for float32 output dtype * BUGFIX: UINT8/INT8 Decoding * feat: Add rasterization options for roller module * Refactor tensorcore_legalization method to optimize tensor core usage * feat: Add function to collect variables from expression, improve for splitk * chore: Update typing import in __init__.py * chore: Refactor CPU execution of operators * Refactor matmul implementation for splitk layout * Refactor matmul implementation for splitk layout * Refactor matmul implementation for splitk layout * chore: Update version to 0.0.1.dev8 * chore: Enable debug output in bitblas.set_debug_level() * Refactor Linear module matmul implementation for splitk layout * Refactor matmul implementation for splitk layout * Refactor CUDA kernel launch string for dynamic symbolic set * Bumpt version to v0.0.1.dev9 * Refactor CUDA kernel launch string for dynamic symbolic set * Bump version to v0.0.1.dev10 --------- Co-authored-by: LeiWang199 <leiwang199>
1 parent b78dcfe commit da9695a

File tree

4 files changed

+17
-3
lines changed

4 files changed

+17
-3
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.0.1.dev9
1+
0.0.1.dev10

python/bitblas/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ def _init_logger():
8181

8282
_init_logger()
8383

84-
__version__ = "0.0.1.dev9"
84+
__version__ = "0.0.1.dev10"

python/bitblas/gpu/gemv.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,8 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring
775775
return None
776776

777777
block_info = block_infos[0]
778-
if len(block_info.iters) not in [2, 3]:
778+
if len(block_info.iters) not in [2, 3, 4]:
779+
# either [SK, B, S, R] = [SK, B, S, R] * [SK, B, R]
779780
# either [B, S, R] = [B, S, R] * [B, R]
780781
# or [S, R] = [S, R] * [R]
781782
return None

python/bitblas/gpu/gemv_dequantize.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ def get_vectorize_factor(target_format):
110110
if len(sch.get_loops(block_b)) == 3:
111111
i = sch.get_loops(block_b)[0]
112112
sch.bind(i, "blockIdx.z")
113+
elif len(sch.get_loops(block_b)) == 4:
114+
# splitk case
115+
sk, i = sch.get_loops(block_b)[:2]
116+
sch.bind(sk, "blockIdx.y")
117+
sch.bind(i, "blockIdx.z")
113118

114119
# get target dequantize buffer's idx
115120
def get_idx(weight_decode_info: Dict):
@@ -274,6 +279,14 @@ def get_vectorize_factor(target_format):
274279
if len(sch.get_loops(block_b)) == 3:
275280
i = sch.get_loops(block_b)[0]
276281
sch.bind(i, "blockIdx.z")
282+
elif len(sch.get_loops(block_b)) == 4:
283+
# splitk case
284+
sk, i = sch.get_loops(block_b)[:2]
285+
sch.bind(sk, "blockIdx.y")
286+
sch.bind(i, "blockIdx.z")
287+
assert len(config.thread) == 2, "SplitK only support 2D thread config"
288+
num_warps = int(num_warps // config.thread[0])
289+
277290

278291
# get target dequantize buffer's idx
279292
def get_idx(weight_decode_info: Dict):

0 commit comments

Comments
 (0)