From 492ba85ffd3be1618a6d6966bcc8232d53db501d Mon Sep 17 00:00:00 2001 From: rentianyue-jk Date: Thu, 13 Jun 2024 17:49:11 +0800 Subject: [PATCH 1/3] Add 3696 bgmv-kernel to support qwen2-72b-instruct lora --- csrc/punica/bgmv/bgmv_config.h | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 4b376261d30d..7b700bc74eac 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -31,6 +31,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 3328) \ f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3584) \ + f(in_T, out_T, W_T, narrow, 3696) \ f(in_T, out_T, W_T, narrow, 4096) \ f(in_T, out_T, W_T, narrow, 4608) \ f(in_T, out_T, W_T, narrow, 5120) \ From ecafe7eb62ff056e31d098824486447f1851178f Mon Sep 17 00:00:00 2001 From: rentianyue-jk Date: Thu, 13 Jun 2024 18:04:07 +0800 Subject: [PATCH 2/3] add missing value --- csrc/punica/bgmv/bgmv_config.h | 1 + tests/lora/test_punica.py | 1 + 2 files changed, 2 insertions(+) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 7b700bc74eac..bf702fc7bbb3 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -103,6 +103,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 3328, narrow) \ f(in_T, out_T, W_T, 3456, narrow) \ f(in_T, out_T, W_T, 3584, narrow) \ + f(in_T, out_T, W_T, 3696, narrow) \ f(in_T, out_T, W_T, 4096, narrow) \ f(in_T, out_T, W_T, 4608, narrow) \ f(in_T, out_T, W_T, 5120, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index f021c003b132..0f80d3ffdace 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -61,6 +61,7 @@ def _lora_ref_impl( 3328, 3456, 3584, + 3696, 4096, 4608, 5120, From 09486dbdd2b2c3beb0bc1d30060aa61d3537fccd Mon Sep 17 00:00:00 2001 From: rentianyue-jk Date: Thu, 13 Jun 2024 18:23:02 +0800 Subject: [PATCH 3/3] add missing values --- csrc/punica/bgmv/bgmv_config.h | 6 ++++++ tests/lora/test_punica.py | 3 +++ 2 files changed, 9 insertions(+) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index bf702fc7bbb3..f428059487b1 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -42,6 +42,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 7168) \ + f(in_T, out_T, W_T, narrow, 7392) \ f(in_T, out_T, W_T, narrow, 8192) \ f(in_T, out_T, W_T, narrow, 9216) \ f(in_T, out_T, W_T, narrow, 10240) \ @@ -50,6 +51,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 13696) \ f(in_T, out_T, W_T, narrow, 13824) \ f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 14784) \ f(in_T, out_T, W_T, narrow, 15360) \ f(in_T, out_T, W_T, narrow, 16384) \ f(in_T, out_T, W_T, narrow, 20480) \ @@ -58,6 +60,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 27392) \ f(in_T, out_T, W_T, narrow, 27648) \ f(in_T, out_T, W_T, narrow, 28672) \ + f(in_T, out_T, W_T, narrow, 29568) \ f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32256) \ f(in_T, out_T, W_T, narrow, 32512) \ @@ -114,6 +117,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 6848, narrow) \ f(in_T, out_T, W_T, 6912, narrow) \ f(in_T, out_T, W_T, 7168, narrow) \ + f(in_T, out_T, W_T, 7392, narrow) \ f(in_T, out_T, W_T, 8192, narrow) \ f(in_T, out_T, W_T, 9216, narrow) \ f(in_T, out_T, W_T, 10240, narrow) \ @@ -122,6 +126,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 13696, narrow) \ f(in_T, out_T, W_T, 13824, narrow) \ f(in_T, out_T, W_T, 14336, narrow) \ + f(in_T, out_T, W_T, 14784, narrow) \ f(in_T, out_T, W_T, 15360, narrow) \ f(in_T, out_T, W_T, 16384, narrow) \ f(in_T, out_T, W_T, 20480, narrow) \ @@ -130,6 +135,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 27392, narrow) \ f(in_T, out_T, W_T, 27648, narrow) \ f(in_T, out_T, W_T, 28672, narrow) \ + f(in_T, out_T, W_T, 29568, narrow) \ f(in_T, out_T, W_T, 32000, narrow) \ f(in_T, out_T, W_T, 32256, narrow) \ f(in_T, out_T, W_T, 32512, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 0f80d3ffdace..fcc2b0cfd973 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -72,17 +72,20 @@ def _lora_ref_impl( 6848, 6912, 7168, + 7392, 8192, 9216, 10240, 11008, 13824, 14336, + 14784, 15360, 22016, 24576, 27392, 27648, + 29568, 32000, 32256, 32512,