From 6512937de1d7b4738938e0bb3004be86b6883729 Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Wed, 31 Jul 2024 21:55:21 +0800 Subject: [PATCH 0001/3246] Support W4A8 quantization for vllm (#5218) --- .../configs/Meta-Llama-3-8B-QQQ.yaml | 11 + .../lm-eval-harness/configs/models-small.txt | 1 + CMakeLists.txt | 1 + csrc/ops.h | 7 + csrc/quantization/marlin/dense/common/base.h | 32 + csrc/quantization/marlin/dense/common/mem.h | 89 ++ .../marlin/dense/marlin_cuda_kernel.cu | 90 +- .../marlin/qqq/marlin_qqq_gemm_kernel.cu | 1243 +++++++++++++++++ csrc/torch_bindings.cpp | 4 + tests/kernels/test_marlin_gemm.py | 66 + vllm/_custom_ops.py | 9 + .../layers/quantization/__init__.py | 2 + .../model_executor/layers/quantization/qqq.py | 285 ++++ .../utils/marlin_utils_test_qqq.py | 125 ++ .../layers/quantization/utils/quant_utils.py | 82 ++ 15 files changed, 1963 insertions(+), 84 deletions(-) create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml create mode 100644 csrc/quantization/marlin/dense/common/base.h create mode 100644 csrc/quantization/marlin/dense/common/mem.h create mode 100644 csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu create mode 100644 vllm/model_executor/layers/quantization/qqq.py create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml new file mode 100644 index 000000000000..c457468902c9 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1 +model_name: "HandH1998/QQQ-Llama-3-8b-g128" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.409 + - name: "exact_match,flexible-extract" + value: 0.406 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index e4df4b547aa5..bca89f00653e 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -7,3 +7,4 @@ Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml Minitron-4B-Base.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml Qwen2-1.5B-Instruct-FP8W8.yaml +Meta-Llama-3-8B-QQQ.yaml diff --git a/CMakeLists.txt b/CMakeLists.txt index bf00a36edc50..28b8879a7ba1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -170,6 +170,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" + "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu" diff --git a/csrc/ops.h b/csrc/ops.h index f075850248d1..f274a7e647b9 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -115,6 +115,13 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); +torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, + torch::Tensor const& b_q_weight, + torch::Tensor const& s_tok, + torch::Tensor const& s_ch, + torch::Tensor const& s_group, + torch::Tensor& workspace, int64_t size_m, + int64_t size_n, int64_t size_k); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/marlin/dense/common/base.h b/csrc/quantization/marlin/dense/common/base.h new file mode 100644 index 000000000000..68c83d5478cf --- /dev/null +++ b/csrc/quantization/marlin/dense/common/base.h @@ -0,0 +1,32 @@ +/* + * Modified by HandH1998 + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; diff --git a/csrc/quantization/marlin/dense/common/mem.h b/csrc/quantization/marlin/dense/common/mem.h new file mode 100644 index 000000000000..64f9c393d77c --- /dev/null +++ b/csrc/quantization/marlin/dense/common/mem.h @@ -0,0 +1,89 @@ +/* + * Modified by HandH1998 + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu index efbcc182a3ae..1ce734c9d90d 100644 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu @@ -25,6 +25,12 @@ #include +#include "common/base.h" + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + #include "common/mem.h" +#endif + template inline std::string str(T x) { return std::to_string(x); @@ -32,23 +38,9 @@ inline std::string str(T x) { namespace marlin_dense { -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - using I4 = Vec; - // Matrix fragments for tensor core instructions; their precise layout is // documented here: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type @@ -57,43 +49,6 @@ using FragB = Vec; using FragC = Vec; using FragS = Vec; // quantization scales -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. __device__ inline void mma(const FragA& a_frag, const FragB& frag_b, @@ -164,39 +119,6 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { frag_b[1] = __hmul2(frag_b[1], s); } -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - template + +#include +#include +#include +#include +#include + +#include + +#include "../dense/common/base.h" + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + #include "../dense/common/mem.h" +#endif + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +using I4 = Vec; +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS_GROUP = Vec; // weight per-group quantization scales +using FragS_CHANNEL = + Vec; // weight per-channel quantization scales or activaton + // per-token quantization scales + +// NOTE(HandH1998): cp.async.cg only support BYTES = 16, however, +// cp.async.ca can support BYTES = 4, 8, 16; +// as s_tok's shape is equal to prob_m, we need set s_tok to float type, +// and cp_size = 1 float, i.e., 4 BYTES +// Asynchronous global->shared copy for activation quantizaton scales s_tok +__device__ inline void cp_async1(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 4; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.ca.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// m16n8k16 tensor core mma instruction with int8 inputs and int32 +// output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + int* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), + "r"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in int8 tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); +} + +inline __device__ half2 float2_to_half2(float2 f) { + uint32_t res; + // NOTE(HandH1998): h0,h1 should be uint16_t, not half + uint16_t h0, h1; + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h0) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h1) : "f"(f.y)); + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(h0), "h"(h1)); + return reinterpret_cast(res); +} + +inline __device__ float int32_to_float(int h) { + float res; + asm volatile("cvt.rn.f32.s32 %0, %1;\n" : "=f"(res) : "r"(h)); + return res; +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values +// for weight per channel dequant. +__device__ inline FragB dequant_per_channel(int q) { + static constexpr int MASK = 0xf0f0f0f0; + FragB frag_b; + frag_b[0] = (q & MASK); + return frag_b; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values +// for weight per group dequant. +__device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) { + static constexpr uint32_t LO = 0x000f000f; + static constexpr uint32_t HI = 0x00f000f0; + static constexpr uint32_t EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + static constexpr uint32_t SUB = 0x64086408; + static constexpr uint32_t MUL = 0x2c002c00; + static constexpr uint32_t ADD = 0xd480d480; + *reinterpret_cast(&t0) = __hsub2( + *reinterpret_cast(&t0), *reinterpret_cast(&SUB)); + *reinterpret_cast(&t1) = __hfma2( + *reinterpret_cast(&t1), *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + + uint16_t s = reinterpret_cast(&frag_s)[i]; + uint32_t double_s; + // pack 2xfp16 to half2 + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(double_s) : "h"(s), "h"(s)); + // dequant and convert 4 half to 4 uint8 (be placed at the low 8 bits of 4 + // half, respectively) + static constexpr uint32_t MAGIC_NUM = 0x64806480; + *reinterpret_cast(&t0) = __hfma2( + *reinterpret_cast(&t0), *reinterpret_cast(&double_s), + *reinterpret_cast(&MAGIC_NUM)); + *reinterpret_cast(&t1) = __hfma2( + *reinterpret_cast(&t1), *reinterpret_cast(&double_s), + *reinterpret_cast(&MAGIC_NUM)); + // take out the 4 uint8 from 4 half, then convert them to 4 int8 and pack 4 + // int8 into 1 uint32 + FragB frag_b; + uint32_t uint8s; + static constexpr uint32_t MASK_0246 = 0x6420; + static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(uint8s) + : "r"(t0), "r"(t1), "n"(MASK_0246)); + frag_b[0] = (uint8s ^ UINT8s_TO_INT8s_MASK); + return frag_b; +} + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // int8 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // int32 global_reduce buffer of shape + // (max_par*16*4)xn, as int8 tensor core's output is + // int32 dtype + int4* __restrict__ D, // fp16 output buffer of shape mxn + const float* __restrict__ s_tok, // fp32 activation per-token quantization + // scales of shape mx1 + const int4* __restrict__ s_ch, // fp32 weight per-channel quantization + // scales of shape 1xn + const int4* __restrict__ s_group, // fp16 weight per-group quantization + // scales of shape (k/groupsize)xn, when + // group_blocks=-1, it should be nullptr + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts in + // the middle of group. + if constexpr (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 16; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 4; + D += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + s_tok += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 16; + C += 16 * thread_m_blocks * prob_n / 4; + D += 16 * thread_m_blocks * prob_n / 8; + s_tok += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + int a_gl_stride = prob_k / 16; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time + // constant + constexpr int a_sh_stride = + 16 * thread_k_blocks / 16; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = + 16 * thread_k_blocks / + 16; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = + a_gl_stride * + (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = + a_sh_stride * + (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = + 1 * ((threads / 32) / + (thread_n_blocks / 4)); // between shared memory tile reads + constexpr int a_sh_rd_delta_i = + a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = + a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = + ceildiv(a_sh_stage, + a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + constexpr int s_tok_sh_stride = 16 * thread_m_blocks; + + constexpr int s_ch_sh_stride = 16 * thread_n_blocks / 4; + + int s_group_gl_stride = prob_n / 8; + constexpr int s_group_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_group_sh_stage = s_group_sh_stride; + int s_group_gl_rd_delta = s_group_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + // NOTE(HandH1998): int8 input a only need 16 threads to load 16x16 matrix + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16); + a_sh_rd += 1 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int s_tok_gl_rd = threadIdx.x; + // NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10, + // 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for + // thread 0, 1, 2, 3. For more details, refer to mma operand A layout as + // s_tok's size is not fixed, we can not shuffle before inference we shuffle + // it when fetching s_tok from global memory to shared memory, that's why + // s_tok_sh_wr is like this + int s_tok_sh_wr = + (threadIdx.x / 16) * 16 + (threadIdx.x % 8) * 2 + (threadIdx.x % 16) / 8; + int s_tok_sh_rd = (threadIdx.x % 32) / 4; + bool s_tok_sh_wr_pred = threadIdx.x < prob_m; + + int s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; + int s_ch_sh_wr = threadIdx.x; + int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + 2 * ((threadIdx.x % 32) % 4); + bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride; + + int s_group_gl_rd, s_group_sh_wr, s_group_sh_rd; + bool s_group_sh_wr_pred; + if constexpr (group_blocks != -1) { + s_group_gl_rd = + s_group_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_group_sh_stride * slice_col + threadIdx.x; + s_group_sh_wr = threadIdx.x; + // NOTE(HandH1998): s_group_sh_rd is related to mma output C + s_group_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + s_group_sh_wr_pred = threadIdx.x < s_group_sh_stride; + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + // NOTE(HandH1998): stages need >= 4, otherwise, sh_s_tok = sh + max(stages * + // a_sh_stage + stages * b_sh_stage, 4 * stages * a_sh_stage) + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s_tok = sh_b + (stages * b_sh_stage); + int4* sh_s_ch = sh_s_tok + s_tok_sh_stride; + int4* sh_s_group = sh_s_ch + s_ch_sh_stride; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS_GROUP frag_s_group[2][4]; + FragS_CHANNEL frag_s_tok[thread_m_blocks]; + FragS_CHANNEL frag_s_ch[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if constexpr (group_blocks != -1) { + if (pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_group_stage = sh_s_group + s_group_sh_stage * pipe; + if (s_group_sh_wr_pred) + cp_async4(&sh_s_group_stage[s_group_sh_wr], + &s_group[s_group_gl_rd]); + s_group_gl_rd += s_group_gl_rd_delta; + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; + // however, this does not seem to be a significant bottleneck, while some + // theoretically better attempts have lead to bad instruction ordering by + // the compiler and correspondingly a noticeable drop in performance. + if constexpr (group_blocks != -1) { + int4* sh_s_group_stage = + sh_s_group + + s_group_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s_group[k % 2])[0] = + sh_s_group_stage[s_group_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + // int b_quant_shift = b_quant << 4; + FragB frag_b0, frag_b1; + // If there are no groups, we can just scale the final output once and can + // avoid doing so for each weight. + if constexpr (group_blocks != -1) { + int b_quant_shift = b_quant >> 8; + frag_b0 = dequant_per_group(b_quant, frag_s_group[k % 2][j], 0); + frag_b1 = dequant_per_group(b_quant_shift, frag_s_group[k % 2][j], 1); + } else { + int b_quant_shift = b_quant << 4; + frag_b0 = dequant_per_channel(b_quant); + frag_b1 = dequant_per_channel(b_quant_shift); + } + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + int* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + int* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + int* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + // global_reduce works on INT32 elements, which are the results of INT8 GEMM. + // This is why we need another INT32 maxtrix `C` to reduce instead of the + // original half matrix `D`. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 4; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 8 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2; + c_gl_wr += (4 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads * 2; + int c_sh_wr = 2 * threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i + 1], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2) + 1], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 d_red1 = sh[c_sh_wr + i * c_sh_wr_delta]; + int4 d_red2 = sh[c_sh_wr + i * c_sh_wr_delta + 1]; + #pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + reinterpret_cast(&d_red1)[j]; + } + #pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)] += + reinterpret_cast(&d_red2)[j]; + } + } + if (!last) { + int4 d1, d2; + #pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast(&d1)[j] = reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]; + } + #pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast(&d2)[j] = reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)]; + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + d1; + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) + + 1] = d2; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int d_gl_stride = prob_n / 8; + constexpr int d_sh_stride = 2 * thread_n_blocks + 1; + int d_gl_wr_delta = d_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int d_sh_rd_delta = + d_sh_stride * (threads / (2 * thread_n_blocks)); + + int d_gl_wr = d_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + d_gl_wr += (2 * thread_n_blocks) * slice_col; + int d_sh_wr = + (4 * d_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + d_sh_wr += 32 * (threadIdx.x / 32); + int d_sh_rd = d_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int d_gl_wr_end = d_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, int c0, int c1, float a_s, FragS_CHANNEL& w_s) { + float2 deq_res; + deq_res.x = int32_to_float(c0) * w_s[0] * a_s; + deq_res.y = int32_to_float(c1) * w_s[1] * a_s; + ((half2*)sh)[idx] = float2_to_half2(deq_res); + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = d_sh_wr + 8 * j; + write(wr + (4 * d_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s_tok[i][0], + frag_s_ch[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * d_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s_tok[i][1], + frag_s_ch[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * d_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s_tok[i][0], + frag_s_ch[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * d_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s_tok[i][1], + frag_s_ch[j / 2][2 * (j % 2) + 1]); + } + d_sh_wr += 16 * (4 * d_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (d_gl_wr < d_gl_wr_end) { + D[d_gl_wr] = sh[d_sh_rd]; + d_gl_wr += d_gl_wr_delta; + d_sh_rd += d_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + #pragma unroll + for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines have + // even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if (last) { + if (s_tok_sh_wr_pred) { + cp_async1(&sh_s_tok[s_tok_sh_wr], &s_tok[s_tok_gl_rd]); + } + if (s_ch_sh_wr_pred) { + cp_async4(&sh_s_ch[s_ch_sh_wr], &s_ch[s_ch_gl_rd]); + } + cp_async_fence(); + } + thread_block_reduce(); + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + frag_s_tok[i][0] = + *reinterpret_cast(&sh_s_tok[16 * i + 2 * s_tok_sh_rd]); + frag_s_tok[i][1] = *reinterpret_cast( + &sh_s_tok[16 * i + 2 * s_tok_sh_rd + 1]); + } + reinterpret_cast(&frag_s_ch)[0] = sh_s_ch[s_ch_sh_rd + 0]; + reinterpret_cast(&frag_s_ch)[1] = sh_s_ch[s_ch_sh_rd + 1]; + reinterpret_cast(&frag_s_ch)[2] = sh_s_ch[s_ch_sh_rd + 8]; + reinterpret_cast(&frag_s_ch)[3] = sh_s_ch[s_ch_sh_rd + 9]; + } + } + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + s_group_gl_rd = s_group_sh_stride * slice_col + threadIdx.x; + s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + +#else + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // int8 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // int32 global_reduce buffer of shape + // (max_par*16*4)xn, as int8 tensor core's output is + // int32 dtype + int4* __restrict__ D, // fp16 output buffer of shape mxn + const float* __restrict__ s_tok, // fp32 activation per-token quantization + // scales of shape mx1 + const int4* __restrict__ s_ch, // fp32 weight per-channel quantization + // scales of shape 1xn + const int4* __restrict__ s_group, // fp16 weight per-group quantization + // scales of shape (k/groupsize)xn, when + // group_blocks=-1, it should be nullptr + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +static constexpr int pack_factor_4bit = + 8; // We have 8 4-bit vals inside a 32 bit + +#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + GROUP_BLOCKS, NUM_THREADS) \ + else if (thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute(Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ + Marlin \ + <<>>( \ + A_ptr, B_ptr, C_ptr, D_ptr, s_tok_ptr, s_ch_ptr, s_group_ptr, \ + prob_m, prob_n, prob_k, locks); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X +}; + +bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, + int prob_k) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // thread_k can be only 128 or 64 (because it must be less than groupsize + // which is 128) + if (th_config.thread_k != 128 && th_config.thread_k != 64) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + return true; +} + +thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + } + + return thread_config_t{-1, -1, -1}; +} + +#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) + +void marlin_qqq_cuda(const void* A, const void* B, void* C, void* D, + void* s_tok, void* s_ch, void* s_group, int prob_m, + int prob_n, int prob_k, void* workspace, + int groupsize = -1, int dev = 0, cudaStream_t stream = 0, + int thread_k = -1, int thread_n = -1, int sms = -1, + int max_par = 16) { + int tot_m = prob_m; + int tot_m_blocks = ceildiv(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + thread_config_t th_config; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; + } else { + // Auto config + th_config = determine_thread_config(prob_m, prob_n, prob_k); + } + + if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { + throw std::runtime_error( + "Invalid thread config: thread_k = " + str(th_config.thread_k) + + ", thread_n = " + str(th_config.thread_n) + + ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + + str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); + } + + int num_threads = th_config.num_threads; + thread_k = th_config.thread_k; + thread_n = th_config.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + if (prob_m == 0 || prob_n == 0 || prob_k == 0) { + return; + } + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + if (group_blocks != -1) { + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + int4* D_ptr = (int4*)D; + const float* s_tok_ptr = (const float*)s_tok; + const int4* s_ch_ptr = (const int4*)s_ch; + const int4* s_group_ptr = (const int4*)s_group; + + int* locks = (int*)workspace; + + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); + thread_m_blocks = 4; + } + + // For compilation speed, we only define the kernel configurations that have + // seemed useful (in terms of performance) in our testing, however many more + // are, in principle, possible. + if (false) { + } + CALL_IF(8, 8, 256) + CALL_IF(16, 4, 256) + CALL_IF(8, 4, 128) + CALL_IF(4, 8, 128) + else { + throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + + ", " + str(prob_k) + ", " + str(prob_n) + "]" + + ", groupsize = " + str(groupsize) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_m_blocks * (prob_k / 16) * par; + D_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + s_tok_ptr += 16 * thread_m_blocks * par; + } +} +} // anonymous namespace + +torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, + torch::Tensor const& b_q_weight, + torch::Tensor const& s_tok, + torch::Tensor const& s_ch, + torch::Tensor const& s_group, + torch::Tensor& workspace, int64_t size_m, + int64_t size_n, int64_t size_k) { + // Verify M + TORCH_CHECK(size_m == a.size(0), + "Shape mismatch: a.size(0) = " + str(a.size(0)) + + ", size_m = " + str(size_m)); + TORCH_CHECK(size_m == s_tok.numel(), + "Shape mismatch: s_tok.numel() = " + str(s_tok.numel()) + + ", size_m = " + str(size_m)); + + // Verify K + TORCH_CHECK(size_k == a.size(1), + "Shape mismatch: a.size(1) = " + str(a.size(1)) + + ", size_k = " + str(size_k)); + TORCH_CHECK(size_k % tile_size == 0, + "size_k = " + str(size_k) + + " is not divisible by tile_size = " + str(tile_size)); + TORCH_CHECK( + (size_k / tile_size) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) + + ", size_k = " + str(size_k) + ", tile_size = " + str(tile_size)); + + int groupsize = (s_group.numel() == 0) ? -1 : size_k / s_group.size(0); + // Verify groupsize + TORCH_CHECK(groupsize == -1 || groupsize == 128, + "Unexpected groupsize = " + str(groupsize)); + + // Verify N + TORCH_CHECK(s_ch.numel() == size_n, + "Shape mismatch: s_ch.numel() = " + str(s_ch.numel()) + + ", size_n = " + str(size_n)); + TORCH_CHECK(b_q_weight.size(1) % tile_size == 0, + "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + + " is not divisible by tile_size = " + str(tile_size)); + if (groupsize != -1) { + TORCH_CHECK(s_group.size(1) == size_n, + "Shape mismatch: s_group.size(1) = " + str(s_group.size(1)) + + ", size_n = " + str(size_n)); + TORCH_CHECK( + size_k % s_group.size(0) == 0, + "size_k = " + str(size_k) + + ", is not divisible by s_group.size(0) = " + str(s_group.size(0))); + } + + int actual_size_n = (b_q_weight.size(1) / tile_size) * pack_factor_4bit; + TORCH_CHECK(size_n == actual_size_n, + "Shape mismatch: size_n = " + str(size_n) + + ", actual_size_n = " + str(actual_size_n)); + + // Verify A device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + // Verify B device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + // Verify s_tok device, strides and dtype + TORCH_CHECK(s_tok.device().is_cuda(), "s_tok is not on GPU"); + TORCH_CHECK(s_tok.is_contiguous(), "s_tok is not contiguous"); + TORCH_CHECK(s_tok.dtype() == torch::kFloat32, "s_tok's dtype is not float32"); + + // Verify s_ch device, strides and dtype + TORCH_CHECK(s_ch.device().is_cuda(), "s_ch is not on GPU"); + TORCH_CHECK(s_ch.is_contiguous(), "s_ch is not contiguous"); + TORCH_CHECK(s_ch.dtype() == torch::kFloat32, "s_ch's dtype is not float32"); + + // Verify s_group device, strides and dtype + TORCH_CHECK(s_group.device().is_cuda(), "s_group is not on GPU"); + TORCH_CHECK(s_group.is_contiguous(), "s_group is not contiguous"); + TORCH_CHECK(s_group.dtype() == torch::kFloat16, + "s_group's dtype is not float16"); + + // Verify workspace size + TORCH_CHECK(size_n % min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + str(min_thread_n)); + int min_workspace_size = (size_n / min_thread_n) * max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = " + str(workspace.numel()) + + " is below min_workspace_size = " + str(min_workspace_size)); + + // Alloc C matrix + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options_c = torch::TensorOptions().dtype(torch::kInt).device(a.device()); + torch::Tensor c = torch::empty({max_par * 64, size_n}, options_c); + + // Alloc D matrix + auto options_d = + torch::TensorOptions().dtype(torch::kFloat16).device(a.device()); + torch::Tensor d = torch::empty({size_m, size_n}, options_d); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + int dev = a.get_device(); + marlin_qqq_cuda( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), d.data_ptr(), + s_tok.data_ptr(), s_ch.data_ptr(), s_group.data_ptr(), size_m, size_n, + size_k, workspace.data_ptr(), groupsize, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par); + + return d; +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 3027b63ba2b3..bf8cefa8d471 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -149,6 +149,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("fp8_marlin_gemm", &fp8_marlin_gemm); ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm); + // marlin_qqq_gemm for QQQ. + ops.def("marlin_qqq_gemm", &marlin_qqq_gemm); + ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm); + // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization. ops.def( diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index bd35ef2eb255..a9e34ac8a7aa 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -10,6 +10,9 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.qqq import ( + MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N, + MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS, @@ -21,6 +24,8 @@ marlin_weights) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( marlin_24_quantize) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501 + marlin_qqq_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp, sort_weights) @@ -425,3 +430,64 @@ def test_awq_marlin_gemm( print("max_diff = {}".format(max_diff)) assert max_diff < 0.04 + + +@pytest.mark.skipif(not is_quant_method_supported("qqq"), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) +@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +def test_marlin_qqq_gemm( + k_chunk, + n_chunk, + num_bits, + group_size, + mnk_factors, +): + int8_traits = torch.iinfo(torch.int8) + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + print(f"groupsize = {group_size}") + + a_input = rand_data((size_m, size_k)) + b_weight = rand_data((size_k, size_n)) + + # Quantize activations + s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to( + torch.float) + q_a = (a_input / s_a).round().clamp(int8_traits.min, + int8_traits.max).to(torch.int8) + + # Quantize weights + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \ + marlin_qqq_quantize(b_weight, num_bits, group_size) + + workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N, + MARLIN_QQQ_MAX_PARALLEL) + + output = ops.marlin_qqq_gemm( + q_a, + marlin_qqq_q_w, + s_a, + marlin_qqq_s_channel, + marlin_qqq_s_group, + workspace.scratch, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + ) + output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref) + + torch.cuda.synchronize() + + max_diff = compute_max_diff(output, output_ref) + print("max_diff = {}".format(max_diff)) + + assert max_diff < 0.04 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 2c09ca2c1407..9e09b9a32eab 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -389,6 +389,15 @@ def scaled_int8_quant( return output, input_scales +# qqq ops +def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: int, size_n: int, size_k: int) -> torch.Tensor: + return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group, + workspace, size_m, size_n, size_k) + + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index bd574512e343..13da6376ec29 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config) from vllm.model_executor.layers.quantization.marlin import MarlinConfig +from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { @@ -37,6 +38,7 @@ "squeezellm": SqueezeLLMConfig, "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, + "qqq": QQQConfig, } diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py new file mode 100644 index 000000000000..be10cee2cf68 --- /dev/null +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -0,0 +1,285 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + + +class QQQConfig(QuantizationConfig): + """Config class for QQQ + + Reference: https://arxiv.org/pdf/2406.09904 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + is_sym: bool = True, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.is_sym = is_sym + + # Verify + if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS: + raise ValueError( + f"QQQ does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} " + "are supported.") + if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"QQQ does not support group_size = {self.group_size}. " + f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} " + "are supported.") + if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM: + raise ValueError( + f"QQQ does not support is_sym = {self.is_sym}. " + f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported.") + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // self.weight_bits + + # Tile size used by QQQ kernels. + self.tile_size = MARLIN_QQQ_TILE + + # Min out_features dim + self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N + + # Min in_features dim + self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = MARLIN_QQQ_MAX_PARALLEL + + # Permutation length used by the QQQ kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return "QQQConfig(weight_bits={}, group_size={})".format( + self.weight_bits, self.group_size) + + @classmethod + def get_name(cls) -> str: + return "qqq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + """List of filenames to search for in the model directory.""" + return [ + "quant_config.json", + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "QQQConfig": + weight_bits = cls.get_from_keys(config, ["wbits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QQQLinearMethod"]: + if isinstance(layer, LinearBase): + return QQQLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class QQQLinearMethod(LinearMethodBase): + """Linear method for QQQ. + + Args: + quant_config: The QQQ quantization config. + """ + + def __init__(self, quant_config: QQQConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}.") + if (self.quant_config.group_size != -1 and + input_size_per_partition % self.quant_config.group_size != 0): + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}.") + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.tile_size, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, + { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + "marlin_tile_size": self.quant_config.tile_size, + }, + ) + + s_channel = Parameter( + torch.empty( + 1, + output_size_per_partition, + device="cuda", + dtype=torch.float, + ), + requires_grad=False, + ) + set_weight_attrs( + s_channel, + { + "input_dim": None, + "output_dim": 1, + }, + ) + + if self.quant_config.group_size == -1: + s_group = Parameter( + torch.tensor( + [], + device="cuda", + dtype=torch.half, + ), + requires_grad=False, + ) + else: + s_group = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + device="cuda", + dtype=torch.half, + ), + requires_grad=False, + ) + + set_weight_attrs( + s_group, + { + "input_dim": None if self.quant_config.group_size == -1 else 0, + "output_dim": + None if self.quant_config.group_size == -1 else 1, + }, + ) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + workspace = Parameter(torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + requires_grad=False) + + layer.register_parameter("B", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("s_channel", s_channel) + set_weight_attrs(s_channel, extra_weight_attrs) + layer.register_parameter("s_group", s_group) + set_weight_attrs(s_group, extra_weight_attrs) + layer.register_parameter("workspace", workspace) + set_weight_attrs(workspace, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B + s_ch = layer.s_channel + s_group = layer.s_group + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = s_ch.shape[1] + + x_int8, s_tok = ops.scaled_int8_quant(x_2d) + + output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group, + workspace, size_m, size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 000000000000..cb58eb945836 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,125 @@ +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 2ba6a9a810ec..7ade8bf664cc 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -205,6 +205,88 @@ def reshape_w(w): ) +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / + s_channel).to(dtype=torch.half) + else: + max_q_val = 2**(num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= (2**(8 - num_bits)) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): orig_device = q_w.device From 2f4e108f75c817bf2f323e306db590e13d2863f6 Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 31 Jul 2024 22:39:19 +0800 Subject: [PATCH 0002/3246] [Bugfix] Clean up MiniCPM-V (#6939) Co-authored-by: hezhihui Co-authored-by: Cyrus Leung --- docs/source/models/supported_models.rst | 6 +- vllm/model_executor/models/llama.py | 4 +- vllm/model_executor/models/minicpm.py | 4 +- vllm/model_executor/models/minicpmv.py | 249 +++++--- vllm/model_executor/models/na_vit.py | 804 ++++++++++++++++++++++++ vllm/model_executor/models/qwen2.py | 2 +- 6 files changed, 975 insertions(+), 94 deletions(-) create mode 100644 vllm/model_executor/models/na_vit.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 4fe33e5ab5d8..a1ea366b82b0 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -222,9 +222,13 @@ Vision Language Models - * - :code:`MiniCPM-V` - MiniCPM-V - - :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. + - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. - +.. note:: + For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. + For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 + ---- If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 306d22e42ed1..2052c443a888 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -418,11 +418,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - input_embeds: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - input_embeds) + attn_metadata, intermediate_tensors) return model_output def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 7a8ac0bb1f94..b46e88f5fc58 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -370,6 +370,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if inputs_embeds is not None: @@ -463,11 +464,10 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - input_embeds: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, input_embeds) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 8563216d9c39..2a7fe7ba0eba 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -20,32 +20,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only MiniCPM-V-2 model compatible with HuggingFace weights.""" +"""Inference-only MiniCPM-V model compatible with HuggingFace weights.""" import math import re from functools import partial -from typing import Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F +import torch.types from PIL import Image from torch import nn from torch.nn.init import trunc_normal_ from transformers.configuration_utils import PretrainedConfig -from transformers.models.idefics2.modeling_idefics2 import ( - Idefics2VisionTransformer) from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsVision -from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.models.minicpm import MiniCPMForCausalLM +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.models.minicpm import MiniCPMModel +from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import (cached_get_image_processor, @@ -53,12 +55,12 @@ from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData _KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", + "llm.lm_head": "lm_head", + "llm.model": "llm", } -def get_abs_pos(abs_pos, tgt_size): +def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): # abs_pos: L, C # tgt_size: (H, W) # return: M, C @@ -75,10 +77,10 @@ def get_abs_pos(abs_pos, tgt_size): # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 -def get_2d_sincos_pos_embed(embed_dim, - grid_size, - cls_token=False, - version=2.0): +def get_2d_sincos_pos_embed(embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: Tuple[int, int] = (2, 0)): """ grid_size: int of the grid height and width return: @@ -95,7 +97,7 @@ def get_2d_sincos_pos_embed(embed_dim, grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) - if version == 2.0: + if version == (2, 0): grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) if cls_token: @@ -106,7 +108,9 @@ def get_2d_sincos_pos_embed(embed_dim, return pos_embed -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0): +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, + grid: Union[int, Tuple[int, int]], + version: Tuple[int, int] = (2, 0)): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h @@ -115,14 +119,16 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0): emb_w = get_1d_sincos_pos_embed_from_grid( embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) - if version == 2.0: + if version == (2, 0): emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) else: emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) return emb -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, version=2.0): +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, + pos: int, + version: Tuple[int, int] = (2, 0)): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) / (H, W) @@ -133,7 +139,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, version=2.0): omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) - if version == 2.0: + if version == (2, 0): pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) @@ -158,19 +164,19 @@ class Resampler(nn.Module): default_norm_layer = partial(nn.LayerNorm, eps=1e-6) def __init__(self, - num_queries, - grid_size, - embed_dim, - num_heads, - kv_dim=None, - norm_layer=default_norm_layer, - adaptive=False, - max_size=(70, 70), - version=2.0): + num_queries: int, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: nn.Module = default_norm_layer, + adaptive: bool = False, + max_size: Tuple[int, int] = (70, 70), + version: Tuple[int, int] = (2, 0)): super().__init__() self.version = version - if self.version == 2.0: + if self.version == (2, 0): self.num_queries = grid_size**2 else: self.num_queries = num_queries @@ -195,7 +201,7 @@ def __init__(self, self.proj = nn.Parameter( (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim)) - if self.version == 2.0: + if self.version == (2, 0): self.pos_embed = nn.Parameter( torch.from_numpy( get_2d_sincos_pos_embed( @@ -206,14 +212,17 @@ def __init__(self, self.apply(self._init_weights) - def _set_2d_pos_cache(self, max_size, device='cpu'): + def _set_2d_pos_cache(self, + max_size: Tuple[int, int], + device: torch.types.Device = 'cpu'): pos_embed = torch.from_numpy( get_2d_sincos_pos_embed(self.embed_dim, max_size, version=self.version)).float().to(device) self.register_buffer("pos_embed", pos_embed, persistent=False) - def _adjust_pos_cache(self, tgt_sizes, device): + def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, + device: torch.types.Device): max_h = torch.max(tgt_sizes[:, 0]) max_w = torch.max(tgt_sizes[:, 1]) if max_h > self.max_size[0] or max_w > self.max_size[1]: @@ -223,7 +232,7 @@ def _adjust_pos_cache(self, tgt_sizes, device): ] self._set_2d_pos_cache(self.max_size, device) - def _init_weights(self, m): + def _init_weights(self, m: nn.Module): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: @@ -232,7 +241,9 @@ def _init_weights(self, m): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - def forward_2_5(self, x, tgt_sizes=None): + def forward_2_5(self, + x: torch.Tensor, + tgt_sizes: Optional[torch.Tensor] = None): assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] @@ -278,7 +289,10 @@ def forward_2_5(self, x, tgt_sizes=None): x = x @ self.proj return x - def forward_2(self, x, tgt_sizes=None, attn_mask=None): + def forward_2(self, + x: torch.Tensor, + tgt_sizes: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None): if self.adaptive: pos_embed = torch.Tensor( get_2d_sincos_pos_embed(self.embed_dim, @@ -302,8 +316,11 @@ def forward_2(self, x, tgt_sizes=None, attn_mask=None): x = x @ self.proj return x - def forward(self, x, tgt_sizes=None, attn_mask=None): - if self.version == 2.0: + def forward(self, + x: torch.Tensor, + tgt_sizes: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None): + if self.version == (2, 0): return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask) else: return self.forward_2_5(x, tgt_sizes=tgt_sizes) @@ -322,7 +339,7 @@ def dummy_seq_data_for_minicpmv(seq_len: int): return SequenceData(token_ids) -def dummy_image_for_minicpmv(hf_config): +def dummy_image_for_minicpmv(hf_config: PretrainedConfig): width = height = hf_config.image_size image = Image.new("RGB", (width, height), color=0) return {"image": image} @@ -381,7 +398,7 @@ class MiniCPMV(nn.Module, SupportsVision): def __init__( self, - config, + config: PretrainedConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -390,30 +407,48 @@ def __init__( self.config = config self.multimodal_config = multimodal_config - self.version = float(self.config.version) + if not hasattr(self.config, "version"): + if self.config.hidden_size == 2304 and self.config.query_num == 64: + self.version = (2, 0) + else: + self.version = (2, 5) + else: + self.version = str(self.config.version).split(".") + self.version = tuple([int(x) for x in self.version]) self.llm = self.init_llm(config, cache_config, quant_config) self.vpm = self.init_vision_module() param_dtype = torch.get_default_dtype() self.vpm.to(dtype=param_dtype) - self.vision_dim = self.vpm.embed_dim if self.version == 2.0 \ + self.vision_dim = self.vpm.embed_dim if self.version == (2, 0) \ else self.vpm.embeddings.embed_dim - self.embed_dim = self.llm.config.hidden_size + self.embed_dim = self.config.hidden_size self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) self.resampler.to(device="cuda", dtype=param_dtype) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() - def init_llm(self, config, cache_config, quant_config): - if self.version == 2.0: - return MiniCPMForCausalLM(config, - cache_config=cache_config, - quant_config=quant_config) + def init_llm(self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + if self.version == (2, 0): + return MiniCPMModel(config, + cache_config=cache_config, + quant_config=quant_config) + elif self.version == (2, 5): + return LlamaModel(config, + cache_config=cache_config, + quant_config=quant_config) else: - return LlamaForCausalLM(config, - cache_config=cache_config, - quant_config=quant_config) + return Qwen2Model(config, + cache_config=cache_config, + quant_config=quant_config) def init_vision_module(self): - if self.version == 2.0: + if self.version == (2, 0): try: import timm except ImportError: @@ -433,16 +468,30 @@ def init_vision_module(self): if self.config.drop_vision_last_layer: model.blocks = model.blocks[:-1] - else: + elif self.version == (2, 5): + from transformers.models.idefics2.modeling_idefics2 import ( + Idefics2VisionTransformer) model = Idefics2VisionTransformer(self.config.vision_config) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] + else: + from vllm.model_executor.models.na_vit import ( + SiglipVisionTransformer) + if self.config._attn_implementation == 'flash_attention_2': + self.config.vision_config._attn_implementation \ + = 'flash_attention_2' + else: + # not support sdpa + self.config.vision_config._attn_implementation = 'eager' + model = SiglipVisionTransformer(self.config.vision_config) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler(self, embed_dim, vision_dim): + def init_resampler(self, embed_dim: int, vision_dim: int): default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float16) - if self.version == 2.0: + if self.version == (2, 0): resampler = Resampler(grid_size=int( math.sqrt(self.config.query_num)), num_queries=None, @@ -463,11 +512,11 @@ def init_resampler(self, embed_dim, vision_dim): return resampler def get_vision_embedding(self, - pixel_values, - patch_attn_mask=None, - tgt_sizes=None, - version=2.0): - if version == 2.0: + pixel_values: List[List[torch.Tensor]], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + version: Tuple[int, int] = (2, 0)): + if version == (2, 0): res = [] dtype = self.vpm.pos_embed.data.dtype for pixel_value in pixel_values: @@ -484,21 +533,32 @@ def get_vision_embedding(self, num_prefix_tokens:] res.append(self.resampler(vision_embedding, tgt_size)) return torch.vstack(res) - else: + elif version == (2, 5): vision_embedding = self.vpm( pixel_values.type(dtype), patch_attention_mask=patch_attn_mask).last_hidden_state vision_embedding = self.resampler(vision_embedding, tgt_sizes) + else: + vision_embedding = self.vpm(pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes).last_hidden_state - def get_image_bounds(self, input_ids): + def get_image_bounds(self, input_ids: torch.Tensor): tokenizer = cached_get_tokenizer(self.config._name_or_path, trust_remote_code=True) - im_start_token_id = tokenizer.im_start_id - im_end_token_id = tokenizer.im_end_id - image_start_tokens = torch.where(input_ids == im_start_token_id)[0] + if not hasattr(tokenizer, "slice_start_id"): + start_cond = input_ids == tokenizer.im_start_id + end_cond = input_ids == tokenizer.im_end_id + else: + start_cond = (input_ids == tokenizer.im_start_id) | ( + input_ids == tokenizer.slice_start_id) + end_cond = (input_ids == tokenizer.im_end_id) | ( + input_ids == tokenizer.slice_end_id) + + image_start_tokens = torch.where(start_cond)[0] image_start_tokens += 1 - image_end_tokens = torch.where(input_ids == im_end_token_id)[0] - valid_image_nums = min(len(image_start_tokens), len(image_end_tokens)) + image_end_tokens = torch.where(end_cond)[0] + valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) if valid_image_nums == 0: return [] image_bound = torch.hstack([ @@ -508,12 +568,14 @@ def get_image_bounds(self, input_ids): return image_bound - def get_vision_hidden_states(self, data): + def get_vision_hidden_states(self, data: Dict[str, + Union[List[torch.Tensor], + torch.Tensor]]): if "vision_hidden_states" not in data: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] vision_hidden_states = [] - if self.version == 2.0: + if self.version == (2, 0): if pixel_values is not None and len(pixel_values) > 0: vision_hidden_states = self.get_vision_embedding( pixel_values) @@ -534,17 +596,26 @@ def get_vision_hidden_states(self, data): B, L, _ = all_pixel_values.shape all_pixel_values = all_pixel_values.permute( 0, 2, 1).reshape(B, 3, -1, L) - patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) - for i in range(B): - patch_attn_mask[i, :tgt_sizes[i][0] * - tgt_sizes[i][1]] = True + if self.version == (2, 5): + for i in range(B): + patch_attn_mask[i, :tgt_sizes[i][0] * + tgt_sizes[i][1]] = True + vision_embedding = self.vpm( + all_pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask + ).last_hidden_state + else: + for i in range(B): + patch_attn_mask[i, 0, :tgt_sizes[i][0] * + tgt_sizes[i][1]] = True + vision_embedding = self.vpm( + all_pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes).last_hidden_state - vision_embedding = self.vpm( - all_pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask).last_hidden_state vision_hidden_states = self.resampler( vision_embedding, tgt_sizes) @@ -556,7 +627,8 @@ def get_vision_hidden_states(self, data): return vision_hidden_states - def get_embedding(self, data): + def get_embedding(self, data: Dict[str, Union[List[torch.Tensor], + torch.Tensor]]): input_ids = data["input_ids"] vision_hidden_states = self.get_vision_hidden_states(data) @@ -565,11 +637,11 @@ def get_embedding(self, data): else: image_bounds = [] - if hasattr(self.llm.config, 'scale_emb'): - vlm_embedding = self.llm.model.embed_tokens( - input_ids) * self.llm.config.scale_emb + if hasattr(self.config, 'scale_emb'): + vlm_embedding = self.llm.embed_tokens( + input_ids) * self.config.scale_emb else: - vlm_embedding = self.llm.model.embed_tokens(input_ids) + vlm_embedding = self.llm.embed_tokens(input_ids) vision_hidden_states = [ i.type(vlm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states @@ -587,7 +659,9 @@ def get_embedding(self, data): vision_hidden_states.view(-1, vision_hidden_states.shape[-1])) return vlm_embedding, vision_hidden_states - def process_multimodal_inputs(self, inputs): + def process_multimodal_inputs(self, inputs: Dict[str, + Union[List[torch.Tensor], + torch.Tensor]]): pixel_values = [] tgt_sizes = [] for b in range(len(inputs["pixel_values"])): @@ -613,7 +687,6 @@ def forward( "input_ids": input_ids, "tgt_sizes": kwargs.pop("tgt_sizes", None), } - inputs = self.process_multimodal_inputs(inputs) vlm_embeddings, vision_hidden_states = self.get_embedding(inputs) @@ -623,19 +696,21 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, - input_embeds=vlm_embeddings) + inputs_embeds=vlm_embeddings) return output def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - return self.llm.compute_logits(hidden_states, sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.llm.sample(logits, sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -649,9 +724,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - # for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - # if key_to_modify in name: - # name = name.replace(key_to_modify, new_key) + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in name: + name = name.replace(key_to_modify, new_key) if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/na_vit.py b/vllm/model_executor/models/na_vit.py new file mode 100644 index 000000000000..871e4128b66e --- /dev/null +++ b/vllm/model_executor/models/na_vit.py @@ -0,0 +1,804 @@ +import logging +import math +import os +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask +from transformers.modeling_outputs import (BaseModelOutput, + BaseModelOutputWithPooling) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (ModelOutput, is_flash_attn_2_available, + replace_return_docstrings) + +logger = logging.getLogger("vllm") + + +# For Siglip: copied from +# HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes +# Remove hints as there's little possibility to change these code. +class SiglipVisionConfig(PretrainedConfig): + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, + os.PathLike], + **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr( + cls, + "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + "You are using a model of type %s to " + "instantiate a model of type %s. " + "This is not supported for all configurations" + "of models and can yield errors.", config_dict['model_type'], + cls.model_type) + + return cls.from_dict(config_dict, **kwargs) + + +_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" + +SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/siglip-base-patch16-224", + # See all SigLIP models at https://huggingface.co/models?filter=siglip +] + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input # noqa + from flash_attn.bert_padding import index_first_axis, unpad_input + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l_ = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l_ - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + if tensor.dtype in [torch.float16, torch.bfloat16]: + # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu + og_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + tensor.erfinv_() + tensor = tensor.to(og_dtype) + else: + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + if tensor.dtype == torch.float16: + # The `clamp_` op is not (yet?) defined in float16+cpu + tensor = tensor.to(torch.float32) + tensor.clamp_(min=a, max=b) + tensor = tensor.to(torch.float16) + else: + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_(tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0) -> torch.Tensor: + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +class SiglipVisionModelOutput(ModelOutput): + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class SiglipVisionEmbeddings(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + + def forward(self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor: + batch_size = pixel_values.size(0) + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = (max_im_h // self.patch_size, + max_im_w // self.patch_size) + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, + 1 / self.num_patches_per_side) + position_ids = torch.full( + size=( + batch_size, + max_nb_patches_h * max_nb_patches_w, + ), + fill_value=0, + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + if tgt_sizes is not None: + nb_patches_h = tgt_sizes[batch_idx][0] + nb_patches_w = tgt_sizes[batch_idx][1] + else: + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, + boundaries, + right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, + boundaries, + right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads (got `embed_dim`: " + f"{self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose( + 2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, + k_v_seq_len): + raise ValueError( + "Attention weights should be of size " + f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + "Attention mask should be of size " + f"{(batch_size, 1, q_len, k_v_seq_len)}", + f"but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to( + query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, + p=self.dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, + self.head_dim): + raise ValueError( + "`attn_output` should be of size " + f"{(batch_size, self.num_heads, q_len, self.head_dim)}, " + "but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SiglipFlashAttention2(SiglipAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False # Hack to make sure we don't use a causal mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length( + kv_seq_len, self.layer_idx) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.dropout if self.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning( + "The input hidden states seems to be " + "silently casted in float32, " + "this might be related to the fact " + "you have upcasted embedding or layer norm layers in float32. " + "We will cast back the input in" + " %s.", target_dtype) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward(query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate) + + attn_output = attn_output.reshape(bsz, q_len, + self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + (query_states, key_states, value_states, indices_q, cu_seq_lens, + max_seq_lens) = self._upad_input(query_states, key_states, + value_states, attention_mask, + query_length) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, + query_length) + else: + attn_output = flash_attn_func(query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, + query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, + head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + (query_layer, indices_q, cu_seqlens_q, + max_seqlen_in_batch_q) = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer +# with CLIP->Siglip +class SiglipEncoderLayer(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self._use_flash_attention_2 = ( + config._attn_implementation == "flash_attention_2") + self.self_attn = (SiglipAttention(config) + if not self._use_flash_attention_2 else + SiglipFlashAttention2(config)) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (attn_weights, ) + + return outputs + + +class SiglipPreTrainedModel(PreTrainedModel): + config_class = SiglipVisionConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + + if isinstance(module, SiglipVisionEmbeddings): + width = self.config.hidden_size + nn.init.normal_(module.position_embedding.weight, + std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder +# with CLIP->Siglip +class SiglipEncoder(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + + if not return_dict: + return tuple( + v for v in [hidden_states, encoder_states, all_attentions] + if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions) + + +class SiglipVisionTransformer(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + _supports_flash_attn_2 = True + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + self._use_flash_attention_2 = ( + config._attn_implementation == "flash_attention_2") + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.patch_embedding + + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, + config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + tgt_sizes: Optional[torch.IntTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_attention_mask = torch.ones( + size=( + batch_size, + pixel_values.size(2) // self.config.patch_size, + pixel_values.size(3) // self.config.patch_size, + ), + dtype=torch.bool, + device=pixel_values.device, + ) + + hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + tgt_sizes=tgt_sizes) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s + # (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, + # which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + attention_mask = None + else: + attention_mask = (_prepare_4d_attention_mask( + patch_attention_mask, hidden_states.dtype) + if not self._use_flash_attention_2 else + patch_attention_mask) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + if not return_dict: + return (last_hidden_state, None) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=None, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 3deb3d8840cc..35fd6f37589a 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -342,7 +342,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, From daed30c4a917c870f8fbddf45e3b027710c0842b Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 31 Jul 2024 23:46:17 +0800 Subject: [PATCH 0003/3246] [Bugfix] Fix feature size calculation for LLaVA-NeXT (#6982) --- tests/models/test_llava_next.py | 88 +++++++++++++++++++----- vllm/model_executor/models/fuyu.py | 2 +- vllm/model_executor/models/internvl.py | 6 +- vllm/model_executor/models/llava_next.py | 48 ++++++------- vllm/model_executor/models/phi3v.py | 4 +- 5 files changed, 98 insertions(+), 50 deletions(-) diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index 9c64f39eb6d0..b6d72dee5c5b 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -1,7 +1,7 @@ -from typing import List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type, overload import pytest -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -50,6 +50,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, return hf_output_ids, hf_output_str, out_logprobs +@overload def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], @@ -62,13 +63,55 @@ def run_test( num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, +): + ... + + +@overload +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + sizes: List[Tuple[int, int]], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + ... + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: Optional[List[float]] = None, + sizes: Optional[List[Tuple[int, int]]] = None, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, ): images = [asset.pil_image for asset in image_assets] - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + if size_factors is not None: + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + elif sizes is not None: + inputs_per_image = [( + [prompt for _ in sizes], + [image.resize(size) for size in sizes], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + else: + raise ValueError("You must provide either `size_factors` or `sizes`") # max_model_len should be greater than image_feature_size with vllm_runner(model, @@ -150,15 +193,24 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, ) -@pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144), - (183, 488, 776)]) -def test_image_feature_size(height_and_width_and_result): - # Avoid initializing CUDA too early in distributed tests - from vllm.model_executor.models.llava_next import ( - get_llava_next_image_feature_size) - - height, width, result = height_and_width_and_result - config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") - assert get_llava_next_image_feature_size(config, - input_height=height, - input_width=width) == result +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "sizes", + [[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]], +) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes, + dtype, max_tokens, num_logprobs) -> None: + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + sizes=sizes, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index fdea8ee30ce6..c4738263c305 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -169,7 +169,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): raise TypeError(f"Invalid image type: {type(image_data)}") # process prompts - prompt = llm_inputs["prompt"] + prompt = llm_inputs.get("prompt") prompt_token_ids = llm_inputs["prompt_token_ids"] tokenizer = cached_get_tokenizer(model_config.model) # dim0 is batch_size, dim1 is subseq_size which will always be 1 diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index f64c78c15f8e..eabc283b1efd 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -20,7 +20,7 @@ from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.image import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput @@ -43,7 +43,7 @@ class InternVLImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: BatchedTensors + data: Union[torch.Tensor, List[torch.Tensor]] """ Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` @@ -193,7 +193,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) - prompt = llm_inputs["prompt"] + prompt = llm_inputs.get("prompt") prompt_token_ids = llm_inputs["prompt_token_ids"] if prompt is None: prompt = tokenizer.decode(prompt_token_ids) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 5abb55c2cc41..4a67b9a583ea 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -21,7 +21,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, @@ -43,7 +43,7 @@ class LlavaNextImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: BatchedTensors + data: Union[torch.Tensor, List[torch.Tensor]] """ Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` @@ -62,31 +62,26 @@ class LlavaNextImagePixelInputs(TypedDict): LlavaNextImageInputs = LlavaNextImagePixelInputs -# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91 -# NOTE: new_height and new_width are further incremented to properly invert the -# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133 +# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79 def _get_llava_next_num_unpadded_features( - height: int, - width: int, + original_height: int, + original_width: int, npatches: int, num_patch_height: int, num_patch_width: int, ) -> Tuple[int, int]: current_height = npatches * num_patch_height current_width = npatches * num_patch_width - current_height = torch.tensor(current_height).to("cuda") - current_width = torch.tensor(current_width).to("cuda") - aspect_ratio: float = width / height - current_aspect_ratio: float = current_width / current_height + aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + if aspect_ratio > current_aspect_ratio: - scale_factor = current_width / width - new_height = int(height * scale_factor) + new_height = (original_height * current_width) // original_width padding = (current_height - new_height) // 2 current_height -= padding * 2 else: - scale_factor = current_height / height - new_width = int(width * scale_factor) + new_width = (original_width * current_height) // original_height padding = (current_width - new_width) // 2 current_width -= padding * 2 @@ -95,7 +90,7 @@ def _get_llava_next_num_unpadded_features( return (unpadded_features, newline_features) -# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111 +# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106 def get_llava_next_image_feature_size( hf_config: LlavaNextConfig, *, @@ -111,9 +106,7 @@ def get_llava_next_image_feature_size( ) base_feature_size = num_patches * num_patches - # Note: We follow the "wrong" width/height order - # [ref: PR huggingface/transformers#31588] - num_patch_width, num_patch_height = get_anyres_image_grid_shape( + num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_size=(input_height, input_width), grid_pinpoints=hf_config.image_grid_pinpoints, patch_size=vision_config.image_size, @@ -349,11 +342,12 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor, if patch_embeddings.shape[0] > 1: other_patch_embeds = patch_embeddings[1:] + # Move to CPU to avoid floating-point errors + orig_height, orig_width = image_size.tolist() + # image_aspect_ratio == "anyres" - # Note: We follow the "wrong" width/height order - # [ref: PR huggingface/transformers#31588] - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_size, + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + (orig_height, orig_width), self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) @@ -365,7 +359,7 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor, .permute(4, 0, 2, 1, 3).contiguous() \ .flatten(1, 2).flatten(2, 3) other_patch_embeds = unpad_image(other_patch_embeds, - image_size) + (orig_height, orig_width)) other_patch_embeds = torch.cat(( other_patch_embeds, self.image_newline[:, None, None] \ @@ -398,7 +392,7 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor, def _process_image_pixels( self, inputs: LlavaNextImagePixelInputs, - ) -> BatchedTensors: + ) -> Union[torch.Tensor, List[torch.Tensor]]: assert self.vision_tower is not None pixel_values = inputs["data"] @@ -425,7 +419,9 @@ def _process_image_pixels( ] def _process_image_input( - self, image_input: LlavaNextImageInputs) -> BatchedTensors: + self, + image_input: LlavaNextImageInputs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: patch_embeddings = self._process_image_pixels(image_input) image_sizes = image_input.get("image_sizes") diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 75e2f5fc95cb..823c34b10187 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -36,7 +36,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput @@ -261,7 +261,7 @@ def add_image_newline(self, image_features_hd): class Phi3VImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: BatchedTensors + data: Union[torch.Tensor, List[torch.Tensor]] """ Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` From 2ee8d3ba55f1175162dbc8e70b76674197b127c6 Mon Sep 17 00:00:00 2001 From: Avshalom Manevich <12231371+avshalomman@users.noreply.github.com> Date: Wed, 31 Jul 2024 22:00:24 +0300 Subject: [PATCH 0004/3246] [Model] use FusedMoE layer in Jamba (#6935) --- vllm/model_executor/models/jamba.py | 157 +++++++++------------------- 1 file changed, 49 insertions(+), 108 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 344457822725..cf407c86acd7 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,5 +1,5 @@ # coding=utf-8 -"""Inference-only Jurassic model.""" +"""Inference-only Jamba model.""" from dataclasses import dataclass from typing import Dict, Iterable, List, Optional, Tuple @@ -15,10 +15,9 @@ from vllm.attention.layer import Attention from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -282,108 +281,50 @@ def forward(self, x): class JambaMoE(nn.Module): - """A tensor-parallel MoE implementation for Mixtral that shards each expert - across all ranks. - Each expert's weights are sharded across all ranks and a fused MoE - kernel is used for the forward pass, and finally we reduce the outputs - across ranks. - """ - - def __init__( - self, - config: JambaConfig, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + config: JambaConfig, + num_experts: Optional[int] = None, + top_k: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.tp_size = tp_size or get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_experts - self.top_k = config.num_experts_per_tok + self.num_total_experts = num_experts or config.num_experts + self.top_k = top_k or config.num_experts_per_tok self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size // self.tp_size - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - - self.router = ReplicatedLinear(self.hidden_size, - self.num_total_experts, - bias=False, - params_dtype=self.params_dtype) - - self.ws = nn.Parameter( - torch.empty( - self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - device="cuda", - dtype=self.params_dtype, - )) - self.w2s = nn.Parameter( - torch.empty( - self.num_total_experts, - self.hidden_size, - self.intermediate_size, - device="cuda", - dtype=self.params_dtype, - )) + self.intermediate_size = config.intermediate_size - set_weight_attrs( - self.ws, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2s, - { - "weight_loader": self.weight_loader, - }, - ) - - def weight_loader( - self, - param: nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - expert_id: int, - ): - tp_rank = get_tensor_model_parallel_rank() - param_data = param.data - shard_size = self.intermediate_size - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - if weight_name.endswith("gate_proj.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("up_proj.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("down_proj.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] + if self.num_total_experts > 1: + self.router = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + quant_config=None, + params_dtype=params_dtype) + + self.experts = FusedMoE(self.num_total_experts, + self.top_k, + self.hidden_size, + self.intermediate_size, + tp_size=tp_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=False, + use_grouped_topk=False, + quant_config=quant_config) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape + orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (batch * sequence_length, n_experts) - router_logits, _ = self.router(hidden_states) - - final_hidden_states = fused_moe( - hidden_states, - self.ws, - self.w2s, - router_logits, - self.top_k, - renormalize= - False, # Mixtral normalize the expert probs to 1. We don't! - inplace=True, - ) - - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - - return final_hidden_states.view(num_tokens, hidden_size) + if self.num_total_experts > 1: + router_logits, _ = self.router(hidden_states) + else: + router_logits = torch.ones((hidden_states.shape[0], 1), + device=hidden_states.device, + dtype=hidden_states.dtype) + hidden_states = self.experts(hidden_states, router_logits) + return hidden_states.view(orig_shape) class JambaMambaDecoderLayer(nn.Module): @@ -917,15 +858,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] - expert_params_mapping = [ - # (param_name, weight_name, expert_id) - ( - "ws" if weight_name in ["gate_proj", "up_proj"] else "w2s", - f"experts.{expert_id}.{weight_name}.weight", - expert_id, - ) for expert_id in range(self.config.num_experts) - for weight_name in ["down_proj", "up_proj", "gate_proj"] - ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: @@ -952,7 +891,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: - for param_name, weight_name, expert_id in expert_params_mapping: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -961,6 +901,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, weight_name, + shard_id=shard_id, expert_id=expert_id) break else: From bd700134072d9513902b42f3ef20a7cd8a1c6377 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 31 Jul 2024 12:02:17 -0700 Subject: [PATCH 0005/3246] [MISC] Introduce pipeline parallelism partition strategies (#6920) Co-authored-by: youkaichao --- tests/distributed/test_pipeline_partition.py | 34 ++++++++++++++++++++ vllm/distributed/utils.py | 32 +++++++++++++++--- vllm/envs.py | 5 +++ 3 files changed, 66 insertions(+), 5 deletions(-) create mode 100644 tests/distributed/test_pipeline_partition.py diff --git a/tests/distributed/test_pipeline_partition.py b/tests/distributed/test_pipeline_partition.py new file mode 100644 index 000000000000..2d4d07dd2752 --- /dev/null +++ b/tests/distributed/test_pipeline_partition.py @@ -0,0 +1,34 @@ +import os + +import pytest + +from vllm.distributed.utils import get_pp_indices + + +def test_custom_layer_partition(): + + def _verify(partition_str, num_layers, pp_size, goldens): + bak = os.environ.get("VLLM_PP_LAYER_PARTITION", None) + os.environ["VLLM_PP_LAYER_PARTITION"] = partition_str + for pp_rank, golden in enumerate(goldens): + assert get_pp_indices(num_layers, pp_rank, pp_size) == golden + if bak is not None: + os.environ["VLLM_PP_LAYER_PARTITION"] = bak + + # Even partition + _verify("5,5,5,5", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)]) + # Balanced partition + _verify("4,6,6,4", 20, 4, [(0, 4), (4, 10), (10, 16), (16, 20)]) + # Put reminder somewhere + _verify("5,6,5,6", 22, 4, [(0, 5), (5, 11), (11, 16), (16, 22)]) + # Invalid partition strings + with pytest.raises(ValueError): + _verify("5,5,5,5,", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)]) + with pytest.raises(ValueError): + _verify("5,5,5,a", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)]) + # Wrong number of partitions + with pytest.raises(ValueError): + _verify("5,5,5", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)]) + # Wrong number of layers + with pytest.raises(ValueError): + _verify("5,5,5,5", 21, 4, [(0, 5), (5, 10), (10, 15), (15, 20)]) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index b5cf6c45f478..8c94ef8cb10c 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -6,6 +6,11 @@ import torch +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" @@ -54,11 +59,28 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int, If the number of layers is not divisible by the number of partitions, the last partition will have the remaining layers. """ - layers_per_partition = num_hidden_layers // pp_size - start_layer = pp_rank * layers_per_partition - end_layer = start_layer + layers_per_partition + partition_list_str = envs.VLLM_PP_LAYER_PARTITION + if partition_list_str is not None: + try: + partitions = [ + int(layer) for layer in partition_list_str.split(",") + ] + except ValueError as err: + raise ValueError("Invalid partition string: {}".format( + partition_list_str)) from err + if len(partitions) != pp_size: + raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") + if sum(partitions) != num_hidden_layers: + raise ValueError( + f"{sum(partitions)=} does not match {num_hidden_layers=}.") + start_layer = sum(partitions[:pp_rank]) + end_layer = start_layer + partitions[pp_rank] + else: + layers_per_partition = num_hidden_layers // pp_size + start_layer = pp_rank * layers_per_partition + end_layer = start_layer + layers_per_partition - if pp_rank == pp_size - 1: - end_layer = num_hidden_layers + if pp_rank == pp_size - 1: + end_layer = num_hidden_layers return (start_layer, end_layer) diff --git a/vllm/envs.py b/vllm/envs.py index f06b6d66ea6f..aef7ac385ec6 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -28,6 +28,7 @@ VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None + VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_OPENVINO_KVCACHE_SPACE: int = 0 @@ -242,6 +243,10 @@ def get_default_config_root(): "VLLM_ATTENTION_BACKEND": lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), + # Pipeline stage partition strategy + "VLLM_PP_LAYER_PARTITION": + lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), + # (CPU backend only) CPU key-value cache space. # default is 4GB "VLLM_CPU_KVCACHE_SPACE": From 460c1884e3cb781730f85cb5591a85d5864bdac8 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 31 Jul 2024 15:47:46 -0400 Subject: [PATCH 0006/3246] [Bugfix] Support cpu offloading with fp8 quantization (#6960) --- tests/basic_correctness/test_cpu_offload.py | 43 +++++++++++++--- vllm/model_executor/model_loader/loader.py | 56 +++++++++++++++++++-- vllm/model_executor/models/utils.py | 50 +++++++++--------- 3 files changed, 116 insertions(+), 33 deletions(-) diff --git a/tests/basic_correctness/test_cpu_offload.py b/tests/basic_correctness/test_cpu_offload.py index 9ebcc48a9b93..180b926637ec 100644 --- a/tests/basic_correctness/test_cpu_offload.py +++ b/tests/basic_correctness/test_cpu_offload.py @@ -1,4 +1,6 @@ -from vllm.utils import is_hip +import pytest + +from tests.quantization.utils import is_quant_method_supported from ..utils import compare_two_settings @@ -6,8 +8,37 @@ def test_cpu_offload(): compare_two_settings("meta-llama/Llama-2-7b-hf", [], ["--cpu-offload-gb", "4"]) - if not is_hip(): - # compressed-tensors quantization is currently not supported in ROCm. - compare_two_settings( - "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [], - ["--cpu-offload-gb", "1"]) + + +@pytest.mark.skipif(not is_quant_method_supported("fp8"), + reason="fp8 is not supported on this GPU type.") +def test_cpu_offload_fp8(): + # Test quantization of an unquantized checkpoint + compare_two_settings("meta-llama/Meta-Llama-3-8B-Instruct", + ["--quantization", "fp8"], + ["--quantization", "fp8", "--cpu-offload-gb", "2"]) + # Test loading a quantized checkpoint + compare_two_settings("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", [], + ["--cpu-offload-gb", "2"]) + + +@pytest.mark.skipif(not is_quant_method_supported("awq"), + reason="awq is not supported on this GPU type.") +def test_cpu_offload_awq(): + compare_two_settings("casperhansen/llama-3-8b-instruct-awq", [], + ["--cpu-offload-gb", "2"]) + + +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), + reason="gptq_marlin is not supported on this GPU type.") +def test_cpu_offload_compressed_tensors(): + # Test wNa16 + compare_two_settings("nm-testing/tinyllama-oneshot-w4a16-channel-v2", [], + ["--cpu-offload-gb", "1"]) + # Test w4a16_marlin24 + compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", + [], ["--cpu-offload-gb", "1"]) + # Test w8a8 + compare_two_settings( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", [], + ["--cpu-offload-gb", "1"]) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bbe49655020d..f72515e01482 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -7,6 +7,7 @@ import math import os from abc import ABC, abstractmethod +from contextlib import contextmanager from typing import Any, Dict, Generator, List, Optional, Tuple, Type import huggingface_hub @@ -37,7 +38,49 @@ supports_vision) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import is_tpu +from vllm.utils import is_pin_memory_available, is_tpu + + +@contextmanager +def device_loading_context(module: torch.nn.Module, + target_device: torch.device): + if target_device.type == "cpu": + # If target is CPU, no need to move anything + yield module + return + + original_device_states: Dict[str, torch.device] = {} + + # Store original device states and move parameters to GPU if they're on CPU + for name, p in module.named_parameters(): + if p.device.type == "cpu": + original_device_states[name] = p.device + p.data = p.data.to(target_device) + # Parameters already on target device are not touched + + try: + yield module + + finally: + # Restore parameters to their original devices, ignoring new parameters + pin_memory = is_pin_memory_available() + for name, p in module.named_parameters(): + if name in original_device_states: + original_device: torch.device = original_device_states[name] + if original_device.type == "cpu": + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided(size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory) + cpu_data.copy_(p.data) + p.data = cpu_data + else: + p.data = p.data.to(original_device) + # New parameters or parameters already on target device are untouched + logger = init_logger(__name__) @@ -275,8 +318,9 @@ def load_model(self, *, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: + target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with target_device: model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) @@ -291,7 +335,13 @@ def load_model(self, *, model_config: ModelConfig, for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: - quant_method.process_weights_after_loading(module) + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) return model.eval() diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 197d3839a766..91b4a27814bf 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -87,6 +87,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: # offload parameters to CPU # use pin_memory if possible, which helps cudagraph capture speed + offloaded_parameters = False for p in module.parameters(): if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: # we use per-parameter offloading @@ -94,35 +95,36 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: break # `torch.empty_like` does not support `pin_memory` argument - cpu_data = torch.empty(size=p.data.size(), - dtype=p.data.dtype, - layout=p.data.layout, - device='cpu', - pin_memory=pin_memory) + cpu_data = torch.empty_strided(size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device='cpu', + pin_memory=pin_memory) cpu_data.copy_(p.data) p.data = cpu_data _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() + offloaded_parameters = True + + if offloaded_parameters: + original_forward = module.forward + + def forward(*args, **kwargs): + module.forward = original_forward + device_state = { + # here we blindly call `to(device)` + # if the parameter is already on the device, it will be a no-op + k: v.to(device, non_blocking=True) + for k, v in module.state_dict().items() + } + output = functional_call(module, + device_state, + args=args, + kwargs=kwargs) + module.forward = forward + return output - state_dict: Dict[str, torch.Tensor] = module.state_dict() - - original_forward = module.forward - - def forward(*args, **kwargs): - module.forward = original_forward - device_state = { - # here we blindly call `to(device)` - # if the parameter is already on the device, it will be a no-op - k: v.to(device, non_blocking=True) - for k, v in state_dict.items() - } - output = functional_call(module, - device_state, - args=args, - kwargs=kwargs) module.forward = forward - return output - - module.forward = forward return module From 93548eb37e952a0af035dc524a3826cdcd78d6cf Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 31 Jul 2024 17:40:22 -0400 Subject: [PATCH 0007/3246] [Kernel] Enable FP8 Cutlass for Ada Lovelace (#6950) Co-authored-by: Varun Sundar Rabindranath --- csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 605166930ccc..8d4d94ca0845 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -38,13 +38,7 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { if (cuda_device_capability >= 90) { return CUDA_VERSION >= 12000; } else if (cuda_device_capability >= 89) { - // CUTLASS Kernels have not been tuned for Ada Lovelace systems - // and are slower than torch.mm. Return false unconditionally in this case. - return false; - - // Once the CUTLASS kernels have been optimized for Lovelace systems, - // use the following check: - // return CUDA_VERSION >= 12040; + return CUDA_VERSION >= 12040; } #endif From 35e9c12bfaf8f273281af897b7208dfba53f103c Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 31 Jul 2024 17:40:32 -0400 Subject: [PATCH 0008/3246] [Kernel] Tuned int8 Cutlass Kernels for SM75 (T4) (#6996) Co-authored-by: Varun Sundar Rabindranath --- .../cutlass_benchmarks/w8a8_benchmarks.py | 9 +- .../cutlass_w8a8/scaled_mm_c2x.cu | 15 +-- .../scaled_mm_c2x_sm75_dispatch.cuh | 123 ++++++++++++++++++ 3 files changed, 135 insertions(+), 12 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 70247e94e63c..64011b2db239 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -112,13 +112,20 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) timers = [] - # pytorch impl + # pytorch impl - bfloat16 timers.append( bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, torch.bfloat16, label, sub_label, pytorch_mm_impl, "pytorch_bf16_bf16_bf16_matmul-no-scales")) + # pytorch impl - float16 + timers.append( + bench_fn(a.to(dtype=torch.float16, device="cuda"), + b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b, + torch.float16, label, sub_label, pytorch_mm_impl, + "pytorch_fp16_fp16_fp16_matmul-no-scales")) + # cutlass impl timers.append( bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index aac4900f933a..8d0dfee7bf23 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -3,6 +3,7 @@ #include "cutlass/cutlass.h" #include "scaled_mm_c2x.cuh" +#include "scaled_mm_c2x_sm75_dispatch.cuh" #include "scaled_mm_c2x_sm80_dispatch.cuh" #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" #include "scaled_mm_c2x_sm89_int8_dispatch.cuh" @@ -20,21 +21,13 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); - using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; - using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; - if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_caller< - vllm::cutlass_2x_gemm>( + return vllm::cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_caller>( + return vllm::cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh new file mode 100644 index 000000000000..a562fd896e54 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh @@ -0,0 +1,123 @@ +#pragma once + +#include "scaled_mm_c2x.cuh" + +/** + * This file defines Gemm kernel configurations for SM75 based on the Gemm + * shape. + */ + +namespace vllm { + +template typename Epilogue> +struct sm75_config_default { + // This config is used in 2 cases, + // - M in (256, inf] + // - M in (64, 128] + // Shared memory required by this Gemm 32768 + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm75_config_M256 { + // M in (128, 256] + // Shared memory required by this Gemm 65536 + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm75_config_M64 { + // M in (32, 64] + // Shared memory required by this Gemm 49152 + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm75_config_M32 { + // M in [1, 32] + // Shared memory required by this Gemm 49152 + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +template typename Epilogue, + typename... EpilogueArgs> +inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + + using Cutlass2xGemmDefault = + typename sm75_config_default::Cutlass2xGemm; + using Cutlass2xGemmM256 = + typename sm75_config_M256::Cutlass2xGemm; + using Cutlass2xGemmM128 = Cutlass2xGemmDefault; + using Cutlass2xGemmM64 = + typename sm75_config_M64::Cutlass2xGemm; + using Cutlass2xGemmM32 = + typename sm75_config_M32::Cutlass2xGemm; + + // Due to shared memory requirements, some Gemms may fail to run on some + // GPUs. As the name indicates, the Fallback Gemm is used as an alternative + // in such cases. + // sm75_config_default has the least shared-memory requirements. + using FallbackGemm = Cutlass2xGemmDefault; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(32), next_pow_2(m)); // next power of 2 + if (mp2 <= 32) { + // M in [1, 32] + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 64) { + // M in (32, 64] + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 128) { + // M in (64, 128] + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 256) { + // M in (128, 256] + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else { + // M in (256, inf) + return fallback_cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } +} + +} // namespace vllm From a0dce9383ab7de0015060fb9fedadeb7d8ffdfb9 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 31 Jul 2024 17:40:44 -0400 Subject: [PATCH 0009/3246] [Misc] Add compressed-tensors to optimized quant list (#7006) --- vllm/config.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index fd48cc3a6b37..de5d0402a1bc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -197,13 +197,17 @@ def _verify_embedding_mode(self) -> None: def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) if quant_cfg is None: - # compress-tensors uses a "compression_config" key + # compressed-tensors uses a "compression_config" key quant_cfg = getattr(self.hf_config, "compression_config", None) return quant_cfg def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = ["gptq", "squeezellm"] + optimized_quantization_methods = [ + "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", + "fbgemm_fp8", "compressed_tensors", "compressed-tensors" + ] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -242,9 +246,7 @@ def _verify_quantization(self) -> None: raise ValueError( f"{self.quantization} quantization is currently not " f"supported in ROCm.") - if (self.quantization - not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin", - "awq_marlin", "fbgemm_fp8", "compressed_tensors")): + if self.quantization not in optimized_quantization_methods: logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " From 7eb0cb4a14ff3de84bf18fad8054d12ea8000c22 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Wed, 31 Jul 2024 16:34:26 -0700 Subject: [PATCH 0010/3246] Revert "[Frontend] Factor out code for running uvicorn" (#7012) Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> --- pyproject.toml | 1 - vllm/entrypoints/api_server.py | 74 +++++++++------------------ vllm/entrypoints/openai/api_server.py | 72 ++++++++++++++++++-------- vllm/server/__init__.py | 3 -- vllm/server/launch.py | 42 --------------- 5 files changed, 75 insertions(+), 117 deletions(-) delete mode 100644 vllm/server/__init__.py delete mode 100644 vllm/server/launch.py diff --git a/pyproject.toml b/pyproject.toml index cd5d196a1620..b0d115a091c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,6 @@ files = [ "vllm/logging", "vllm/multimodal", "vllm/platforms", - "vllm/server", "vllm/transformers_utils", "vllm/triton_utils", "vllm/usage", diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 347635765852..66941442c8c9 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -5,12 +5,12 @@ We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead. """ -import asyncio + import json import ssl -from argparse import Namespace -from typing import Any, AsyncGenerator, Optional +from typing import AsyncGenerator +import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -18,10 +18,8 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.server import serve_http from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid -from vllm.version import __version__ as VLLM_VERSION logger = init_logger("vllm.entrypoints.api_server") @@ -83,50 +81,6 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return JSONResponse(ret) -def build_app(args: Namespace) -> FastAPI: - global app - - app.root_path = args.root_path - return app - - -async def init_app( - args: Namespace, - llm_engine: Optional[AsyncLLMEngine] = None, -) -> FastAPI: - app = build_app(args) - - global engine - - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = (llm_engine - if llm_engine is not None else AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.API_SERVER)) - - return app - - -async def run_server(args: Namespace, - llm_engine: Optional[AsyncLLMEngine] = None, - **uvicorn_kwargs: Any) -> None: - logger.info("vLLM API server version %s", VLLM_VERSION) - logger.info("args: %s", args) - - app = await init_app(args, llm_engine) - await serve_http( - app, - host=args.host, - port=args.port, - log_level=args.log_level, - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs, - **uvicorn_kwargs, - ) - - if __name__ == "__main__": parser = FlexibleArgumentParser() parser.add_argument("--host", type=str, default=None) @@ -151,5 +105,25 @@ async def run_server(args: Namespace, parser.add_argument("--log-level", type=str, default="debug") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER) + + app.root_path = args.root_path - asyncio.run(run_server(args)) + logger.info("Available routes are:") + for route in app.routes: + if not hasattr(route, 'methods'): + continue + methods = ', '.join(route.methods) + logger.info("Route: %s, Methods: %s", route.path, methods) + + uvicorn.run(app, + host=args.host, + port=args.port, + log_level=args.log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c1640a10a407..0fe4dd245b5e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -2,12 +2,14 @@ import importlib import inspect import re -from argparse import Namespace +import signal from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Any, Optional, Set +from typing import Optional, Set -from fastapi import APIRouter, FastAPI, Request +import fastapi +import uvicorn +from fastapi import APIRouter, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -36,7 +38,6 @@ from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.logger import init_logger -from vllm.server import serve_http from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser from vllm.version import __version__ as VLLM_VERSION @@ -56,7 +57,7 @@ @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: fastapi.FastAPI): async def _force_log(): while True: @@ -74,7 +75,7 @@ async def _force_log(): router = APIRouter() -def mount_metrics(app: FastAPI): +def mount_metrics(app: fastapi.FastAPI): # Add prometheus asgi middleware to route /metrics requests metrics_route = Mount("/metrics", make_asgi_app()) # Workaround for 307 Redirect for /metrics @@ -164,8 +165,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) -def build_app(args: Namespace) -> FastAPI: - app = FastAPI(lifespan=lifespan) +def build_app(args): + app = fastapi.FastAPI(lifespan=lifespan) app.include_router(router) app.root_path = args.root_path @@ -213,8 +214,11 @@ async def authentication(request: Request, call_next): return app -async def init_app(args: Namespace, - llm_engine: Optional[AsyncLLMEngine] = None) -> FastAPI: +async def build_server( + args, + llm_engine: Optional[AsyncLLMEngine] = None, + **uvicorn_kwargs, +) -> uvicorn.Server: app = build_app(args) if args.served_model_name is not None: @@ -277,17 +281,14 @@ async def init_app(args: Namespace, ) app.root_path = args.root_path - return app - - -async def run_server(args: Namespace, - llm_engine: Optional[AsyncLLMEngine] = None, - **uvicorn_kwargs: Any) -> None: - logger.info("vLLM API server version %s", VLLM_VERSION) - logger.info("args: %s", args) + logger.info("Available routes are:") + for route in app.routes: + if not hasattr(route, 'methods'): + continue + methods = ', '.join(route.methods) + logger.info("Route: %s, Methods: %s", route.path, methods) - app = await init_app(args, llm_engine) - await serve_http( + config = uvicorn.Config( app, host=args.host, port=args.port, @@ -300,6 +301,36 @@ async def run_server(args: Namespace, **uvicorn_kwargs, ) + return uvicorn.Server(config) + + +async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + server = await build_server( + args, + llm_engine, + **uvicorn_kwargs, + ) + + loop = asyncio.get_running_loop() + + server_task = loop.create_task(server.serve()) + + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + except asyncio.CancelledError: + print("Gracefully stopping http server") + await server.shutdown() + if __name__ == "__main__": # NOTE(simon): @@ -308,5 +339,4 @@ async def run_server(args: Namespace, description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) args = parser.parse_args() - asyncio.run(run_server(args)) diff --git a/vllm/server/__init__.py b/vllm/server/__init__.py deleted file mode 100644 index 17c98b4dad6c..000000000000 --- a/vllm/server/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .launch import serve_http - -__all__ = ["serve_http"] diff --git a/vllm/server/launch.py b/vllm/server/launch.py deleted file mode 100644 index 1a8aeb7f1022..000000000000 --- a/vllm/server/launch.py +++ /dev/null @@ -1,42 +0,0 @@ -import asyncio -import signal -from typing import Any - -import uvicorn -from fastapi import FastAPI - -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -async def serve_http(app: FastAPI, **uvicorn_kwargs: Any) -> None: - logger.info("Available routes are:") - for route in app.routes: - methods = getattr(route, "methods", None) - path = getattr(route, "path", None) - - if methods is None or path is None: - continue - - logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) - - config = uvicorn.Config(app, **uvicorn_kwargs) - server = uvicorn.Server(config) - - loop = asyncio.get_running_loop() - - server_task = loop.create_task(server.serve()) - - def signal_handler() -> None: - # prevents the uvicorn signal handler to exit early - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - except asyncio.CancelledError: - logger.info("Gracefully stopping http server") - await server.shutdown() From 7ecee3432110bae563c8756a66b54e5f08dc777d Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 1 Aug 2024 08:12:24 +0800 Subject: [PATCH 0011/3246] [Kernel][RFC] Refactor the punica kernel based on Triton (#5036) --- .github/workflows/scripts/build.sh | 2 - CMakeLists.txt | 62 - Dockerfile | 2 - Dockerfile.rocm | 3 +- csrc/punica/LICENSE | 217 --- csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu | 5 - csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu | 5 - csrc/punica/bgmv/bgmv_config.h | 218 --- csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu | 5 - csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu | 5 - csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu | 5 - csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu | 5 - csrc/punica/bgmv/bgmv_impl.cuh | 451 ------ csrc/punica/bgmv/generator.py | 48 - csrc/punica/bgmv/vec_dtypes.cuh | 1325 ------------------ csrc/punica/punica_ops.cu | 569 -------- csrc/punica/punica_ops.h | 11 - csrc/punica/torch_bindings.cpp | 18 - csrc/punica/type_convert.h | 82 -- docs/source/getting_started/installation.rst | 1 - setup.py | 10 - tests/kernels/test_sampler.py | 48 +- tests/lora/test_gemma.py | 2 +- tests/lora/test_layers.py | 140 +- tests/lora/test_lora.py | 224 --- tests/lora/test_punica.py | 258 ---- tests/lora/test_punica_sizes.py | 408 ++++++ tests/lora/test_punica_variation.py | 342 +++++ tests/lora/test_quant_model.py | 48 +- tests/lora/utils.py | 148 ++ vllm/_custom_ops.py | 42 +- vllm/envs.py | 5 - vllm/lora/fully_sharded_layers.py | 137 +- vllm/lora/layers.py | 437 ++---- vllm/lora/models.py | 171 +-- vllm/lora/ops/__init__.py | 0 vllm/lora/ops/bgmv_expand.py | 169 +++ vllm/lora/ops/bgmv_expand_slice.py | 182 +++ vllm/lora/ops/bgmv_shrink.py | 150 ++ vllm/lora/ops/sgmv_expand.py | 192 +++ vllm/lora/ops/sgmv_expand_slice.py | 205 +++ vllm/lora/ops/sgmv_shrink.py | 189 +++ vllm/lora/ops/utils.py | 46 + vllm/lora/punica.py | 765 +++++++--- vllm/triton_utils/__init__.py | 3 +- vllm/triton_utils/libentry.py | 167 +++ vllm/worker/model_runner.py | 12 +- 47 files changed, 3175 insertions(+), 4364 deletions(-) delete mode 100644 csrc/punica/LICENSE delete mode 100644 csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu delete mode 100644 csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu delete mode 100644 csrc/punica/bgmv/bgmv_config.h delete mode 100644 csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu delete mode 100644 csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu delete mode 100644 csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu delete mode 100644 csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu delete mode 100644 csrc/punica/bgmv/bgmv_impl.cuh delete mode 100644 csrc/punica/bgmv/generator.py delete mode 100644 csrc/punica/bgmv/vec_dtypes.cuh delete mode 100644 csrc/punica/punica_ops.cu delete mode 100644 csrc/punica/punica_ops.h delete mode 100644 csrc/punica/torch_bindings.cpp delete mode 100644 csrc/punica/type_convert.h delete mode 100644 tests/lora/test_lora.py delete mode 100644 tests/lora/test_punica.py create mode 100644 tests/lora/test_punica_sizes.py create mode 100644 tests/lora/test_punica_variation.py create mode 100644 vllm/lora/ops/__init__.py create mode 100644 vllm/lora/ops/bgmv_expand.py create mode 100644 vllm/lora/ops/bgmv_expand_slice.py create mode 100644 vllm/lora/ops/bgmv_shrink.py create mode 100644 vllm/lora/ops/sgmv_expand.py create mode 100644 vllm/lora/ops/sgmv_expand_slice.py create mode 100644 vllm/lora/ops/sgmv_shrink.py create mode 100644 vllm/lora/ops/utils.py create mode 100644 vllm/triton_utils/libentry.py diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh index 60a3978f9abd..0a759d303238 100644 --- a/.github/workflows/scripts/build.sh +++ b/.github/workflows/scripts/build.sh @@ -13,8 +13,6 @@ $python_executable -m pip install -r requirements-cuda.txt # Limit the number of parallel jobs to avoid OOM export MAX_JOBS=1 -# Make sure punica is built for the release (for LoRA) -export VLLM_INSTALL_PUNICA_KERNELS=1 # Make sure release wheels are built for the following architectures export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" # Build diff --git a/CMakeLists.txt b/CMakeLists.txt index 28b8879a7ba1..0d599c547070 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -223,61 +223,7 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) -# -# _punica_C extension -# - -set(VLLM_PUNICA_EXT_SRC - "csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu" - "csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" - "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" - "csrc/punica/punica_ops.cu" - "csrc/punica/torch_bindings.cpp") - -# -# Copy GPU compilation flags+update for punica -# -set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS}) -list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS - "-D__CUDA_NO_HALF_OPERATORS__" - "-D__CUDA_NO_HALF_CONVERSIONS__" - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" - "-D__CUDA_NO_HALF2_OPERATORS__") - -# -# Filter out CUDA architectures < 8.0 for punica. -# -if (${VLLM_GPU_LANG} STREQUAL "CUDA") - set(VLLM_PUNICA_GPU_ARCHES) - foreach(ARCH ${VLLM_GPU_ARCHES}) - string_to_ver(CODE_VER ${ARCH}) - if (CODE_VER GREATER_EQUAL 8.0) - list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH}) - endif() - endforeach() - message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") -elseif(${VLLM_GPU_LANG} STREQUAL "HIP") - set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES}) - message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") -endif() -if (VLLM_PUNICA_GPU_ARCHES) - define_gpu_extension_target( - _punica_C - DESTINATION vllm - LANGUAGE ${VLLM_GPU_LANG} - SOURCES ${VLLM_PUNICA_EXT_SRC} - COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} - ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES} - USE_SABI 3 - WITH_SOABI) -else() - message(WARNING "Unable to create _punica_C target because none of the " - "requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0") -endif() # # Add the `default` target which detects which extensions should be @@ -301,12 +247,4 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling moe extension.") add_dependencies(default _moe_C) - # Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or - # VLLM_INSTALL_PUNICA_KERNELS is set in the environment and - # there are supported target arches. - if (VLLM_PUNICA_GPU_ARCHES AND - (ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS)) - message(STATUS "Enabling punica extension.") - add_dependencies(default _punica_C) - endif() endif() diff --git a/Dockerfile b/Dockerfile index b9a56e67e8d7..db4453ab0efc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -88,8 +88,6 @@ ENV MAX_JOBS=${max_jobs} # number of threads used by nvcc ARG nvcc_threads=8 ENV NVCC_THREADS=$nvcc_threads -# make sure punica kernels are built (for LoRA) -ENV VLLM_INSTALL_PUNICA_KERNELS=1 ARG buildkite_commit ENV BUILDKITE_COMMIT=${buildkite_commit} diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 64bc0f3c12c7..33423fde4ff9 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -131,8 +131,7 @@ COPY . . RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install --upgrade numba scipy huggingface-hub[cli] -# Make sure punica kernels are built (for LoRA) -ENV VLLM_INSTALL_PUNICA_KERNELS=1 + # Workaround for ray >= 2.10.0 ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 # Silences the HF Tokenizers warning diff --git a/csrc/punica/LICENSE b/csrc/punica/LICENSE deleted file mode 100644 index a46e2cdcadf7..000000000000 --- a/csrc/punica/LICENSE +++ /dev/null @@ -1,217 +0,0 @@ -Contains code from https://github.com/punica-ai/punica - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ------------------------------------------------------------------------------------- - -This product bundles various third-party components under other open source licenses. -This section summarizes those components and their licenses. See licenses/ -for text of these licenses. - - -Apache-2.0 -* third_party/nvbench (with LLVM exception) -* third_party/flashinfer - -BSD-3-Clause: -* third_party/cutlass \ No newline at end of file diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu deleted file mode 100644 index 86846c274c90..000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu deleted file mode 100644 index de39c3121f5d..000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h deleted file mode 100644 index 2c8d007d8719..000000000000 --- a/csrc/punica/bgmv/bgmv_config.h +++ /dev/null @@ -1,218 +0,0 @@ -#pragma once - -template -void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t batch_size, int64_t num_layers, - int64_t layer_idx, float scale); - -// clang-format off - -#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \ - f(in_T, out_T, W_T, narrow, 128) \ - f(in_T, out_T, W_T, narrow, 256) \ - f(in_T, out_T, W_T, narrow, 512) \ - f(in_T, out_T, W_T, narrow, 640) \ - f(in_T, out_T, W_T, narrow, 768) \ - f(in_T, out_T, W_T, narrow, 896) \ - f(in_T, out_T, W_T, narrow, 1024) \ - f(in_T, out_T, W_T, narrow, 1152) \ - f(in_T, out_T, W_T, narrow, 1216) \ - f(in_T, out_T, W_T, narrow, 1280) \ - f(in_T, out_T, W_T, narrow, 1536) \ - f(in_T, out_T, W_T, narrow, 1664) \ - f(in_T, out_T, W_T, narrow, 1728) \ - f(in_T, out_T, W_T, narrow, 1792) \ - f(in_T, out_T, W_T, narrow, 2048) \ - f(in_T, out_T, W_T, narrow, 2240) \ - f(in_T, out_T, W_T, narrow, 2304) \ - f(in_T, out_T, W_T, narrow, 2368) \ - f(in_T, out_T, W_T, narrow, 2432) \ - f(in_T, out_T, W_T, narrow, 2560) \ - f(in_T, out_T, W_T, narrow, 2752) \ - f(in_T, out_T, W_T, narrow, 2816) \ - f(in_T, out_T, W_T, narrow, 3072) \ - f(in_T, out_T, W_T, narrow, 3328) \ - f(in_T, out_T, W_T, narrow, 3456) \ - f(in_T, out_T, W_T, narrow, 3584) \ - f(in_T, out_T, W_T, narrow, 3712) \ - f(in_T, out_T, W_T, narrow, 4096) \ - f(in_T, out_T, W_T, narrow, 4480) \ - f(in_T, out_T, W_T, narrow, 4608) \ - f(in_T, out_T, W_T, narrow, 4736) \ - f(in_T, out_T, W_T, narrow, 4864) \ - f(in_T, out_T, W_T, narrow, 5120) \ - f(in_T, out_T, W_T, narrow, 5504) \ - f(in_T, out_T, W_T, narrow, 5632) \ - f(in_T, out_T, W_T, narrow, 5888) \ - f(in_T, out_T, W_T, narrow, 6144) \ - f(in_T, out_T, W_T, narrow, 6400) \ - f(in_T, out_T, W_T, narrow, 6848) \ - f(in_T, out_T, W_T, narrow, 6912) \ - f(in_T, out_T, W_T, narrow, 7168) \ - f(in_T, out_T, W_T, narrow, 7424) \ - f(in_T, out_T, W_T, narrow, 8192) \ - f(in_T, out_T, W_T, narrow, 8960) \ - f(in_T, out_T, W_T, narrow, 9216) \ - f(in_T, out_T, W_T, narrow, 9472) \ - f(in_T, out_T, W_T, narrow, 10240) \ - f(in_T, out_T, W_T, narrow, 11008) \ - f(in_T, out_T, W_T, narrow, 11264) \ - f(in_T, out_T, W_T, narrow, 12288) \ - f(in_T, out_T, W_T, narrow, 13696) \ - f(in_T, out_T, W_T, narrow, 13824) \ - f(in_T, out_T, W_T, narrow, 14336) \ - f(in_T, out_T, W_T, narrow, 14784) \ - f(in_T, out_T, W_T, narrow, 14848) \ - f(in_T, out_T, W_T, narrow, 15360) \ - f(in_T, out_T, W_T, narrow, 16384) \ - f(in_T, out_T, W_T, narrow, 18944) \ - f(in_T, out_T, W_T, narrow, 20480) \ - f(in_T, out_T, W_T, narrow, 22016) \ - f(in_T, out_T, W_T, narrow, 22528) \ - f(in_T, out_T, W_T, narrow, 24576) \ - f(in_T, out_T, W_T, narrow, 27392) \ - f(in_T, out_T, W_T, narrow, 27648) \ - f(in_T, out_T, W_T, narrow, 28672) \ - f(in_T, out_T, W_T, narrow, 29568) \ - f(in_T, out_T, W_T, narrow, 29696) \ - f(in_T, out_T, W_T, narrow, 32000) \ - f(in_T, out_T, W_T, narrow, 32256) \ - f(in_T, out_T, W_T, narrow, 32512) \ - f(in_T, out_T, W_T, narrow, 32768) \ - f(in_T, out_T, W_T, narrow, 33024) \ - f(in_T, out_T, W_T, narrow, 36864) \ - f(in_T, out_T, W_T, narrow, 43264) \ - f(in_T, out_T, W_T, narrow, 49152) \ - f(in_T, out_T, W_T, narrow, 49408) \ - f(in_T, out_T, W_T, narrow, 60544) \ - f(in_T, out_T, W_T, narrow, 60672) \ - f(in_T, out_T, W_T, narrow, 64000) \ - f(in_T, out_T, W_T, narrow, 64256) \ - f(in_T, out_T, W_T, narrow, 64512) \ - f(in_T, out_T, W_T, narrow, 102400) \ - f(in_T, out_T, W_T, narrow, 102656) \ - f(in_T, out_T, W_T, narrow, 102912) \ - f(in_T, out_T, W_T, narrow, 128000) \ - f(in_T, out_T, W_T, narrow, 128256) \ - f(in_T, out_T, W_T, narrow, 128512) \ - - -// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA -// and vllm/tests/lora/test_punica.py - -// Used for defining kernels going from the variety of -// dim in to the narrow dim out - // Using it for the fully sharded column - // parallel LoRA A which splits the rank dim -#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \ - f(in_T, out_T, W_T, 128, narrow) \ - f(in_T, out_T, W_T, 256, narrow) \ - f(in_T, out_T, W_T, 512, narrow) \ - f(in_T, out_T, W_T, 640, narrow) \ - f(in_T, out_T, W_T, 768, narrow) \ - f(in_T, out_T, W_T, 896, narrow) \ - f(in_T, out_T, W_T, 1024, narrow) \ - f(in_T, out_T, W_T, 1152, narrow) \ - f(in_T, out_T, W_T, 1216, narrow) \ - f(in_T, out_T, W_T, 1280, narrow) \ - f(in_T, out_T, W_T, 1536, narrow) \ - f(in_T, out_T, W_T, 1664, narrow) \ - f(in_T, out_T, W_T, 1728, narrow) \ - f(in_T, out_T, W_T, 1792, narrow) \ - f(in_T, out_T, W_T, 2048, narrow) \ - f(in_T, out_T, W_T, 2240, narrow) \ - f(in_T, out_T, W_T, 2304, narrow) \ - f(in_T, out_T, W_T, 2368, narrow) \ - f(in_T, out_T, W_T, 2432, narrow) \ - f(in_T, out_T, W_T, 2560, narrow) \ - f(in_T, out_T, W_T, 2752, narrow) \ - f(in_T, out_T, W_T, 2816, narrow) \ - f(in_T, out_T, W_T, 3072, narrow) \ - f(in_T, out_T, W_T, 3328, narrow) \ - f(in_T, out_T, W_T, 3456, narrow) \ - f(in_T, out_T, W_T, 3584, narrow) \ - f(in_T, out_T, W_T, 3712, narrow) \ - f(in_T, out_T, W_T, 4096, narrow) \ - f(in_T, out_T, W_T, 4480, narrow) \ - f(in_T, out_T, W_T, 4608, narrow) \ - f(in_T, out_T, W_T, 4736, narrow) \ - f(in_T, out_T, W_T, 4864, narrow) \ - f(in_T, out_T, W_T, 5120, narrow) \ - f(in_T, out_T, W_T, 5504, narrow) \ - f(in_T, out_T, W_T, 5632, narrow) \ - f(in_T, out_T, W_T, 5888, narrow) \ - f(in_T, out_T, W_T, 6144, narrow) \ - f(in_T, out_T, W_T, 6400, narrow) \ - f(in_T, out_T, W_T, 6848, narrow) \ - f(in_T, out_T, W_T, 6912, narrow) \ - f(in_T, out_T, W_T, 7168, narrow) \ - f(in_T, out_T, W_T, 7424, narrow) \ - f(in_T, out_T, W_T, 8192, narrow) \ - f(in_T, out_T, W_T, 8960, narrow) \ - f(in_T, out_T, W_T, 9216, narrow) \ - f(in_T, out_T, W_T, 9472, narrow) \ - f(in_T, out_T, W_T, 10240, narrow) \ - f(in_T, out_T, W_T, 11008, narrow) \ - f(in_T, out_T, W_T, 11264, narrow) \ - f(in_T, out_T, W_T, 12288, narrow) \ - f(in_T, out_T, W_T, 13696, narrow) \ - f(in_T, out_T, W_T, 13824, narrow) \ - f(in_T, out_T, W_T, 14336, narrow) \ - f(in_T, out_T, W_T, 14784, narrow) \ - f(in_T, out_T, W_T, 14848, narrow) \ - f(in_T, out_T, W_T, 15360, narrow) \ - f(in_T, out_T, W_T, 16384, narrow) \ - f(in_T, out_T, W_T, 18944, narrow) \ - f(in_T, out_T, W_T, 20480, narrow) \ - f(in_T, out_T, W_T, 22016, narrow) \ - f(in_T, out_T, W_T, 22528, narrow) \ - f(in_T, out_T, W_T, 24576, narrow) \ - f(in_T, out_T, W_T, 27392, narrow) \ - f(in_T, out_T, W_T, 27648, narrow) \ - f(in_T, out_T, W_T, 28672, narrow) \ - f(in_T, out_T, W_T, 29568, narrow) \ - f(in_T, out_T, W_T, 29696, narrow) \ - f(in_T, out_T, W_T, 32000, narrow) \ - f(in_T, out_T, W_T, 32256, narrow) \ - f(in_T, out_T, W_T, 32512, narrow) \ - f(in_T, out_T, W_T, 32768, narrow) \ - f(in_T, out_T, W_T, 33024, narrow) \ - f(in_T, out_T, W_T, 36864, narrow) \ - f(in_T, out_T, W_T, 43264, narrow) \ - f(in_T, out_T, W_T, 49152, narrow) \ - f(in_T, out_T, W_T, 49408, narrow) \ - f(in_T, out_T, W_T, 60544, narrow) \ - f(in_T, out_T, W_T, 60672, narrow) \ - f(in_T, out_T, W_T, 64000, narrow) \ - f(in_T, out_T, W_T, 64256, narrow) \ - f(in_T, out_T, W_T, 64512, narrow) \ - f(in_T, out_T, W_T, 102400, narrow) \ - f(in_T, out_T, W_T, 102656, narrow) \ - f(in_T, out_T, W_T, 102912, narrow) \ - f(in_T, out_T, W_T, 128000, narrow) \ - f(in_T, out_T, W_T, 128256, narrow) \ - f(in_T, out_T, W_T, 128512, narrow) \ -// Keep above in sync with vllm/lora/layers::SamplerWithLoRA - - -// Keep this in sync with vllm/config::LoRAConfig -#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) - - -#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ - FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \ - FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \ - FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \ - f(in_T, out_T, W_T, 8, 64) \ - f(in_T, out_T, W_T, 16, 64) \ - f(in_T, out_T, W_T, 32, 64) \ - f(in_T, out_T, W_T, 64, 64) - -// clang-format on diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu deleted file mode 100644 index d225a1eaa82b..000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu deleted file mode 100644 index b37d288a7556..000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu deleted file mode 100644 index a1ab2deecbab..000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu deleted file mode 100644 index 0b35bf569989..000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu +++ /dev/null @@ -1,5 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh deleted file mode 100644 index 8a3b8403b4a6..000000000000 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ /dev/null @@ -1,451 +0,0 @@ -#pragma once - -#include -#ifndef USE_ROCM -#include -#else -#include -#endif -#ifndef USE_ROCM -#include -#endif -#include -#include -#include - -#include "vec_dtypes.cuh" - -namespace cg = cooperative_groups; - -#ifdef USE_ROCM -template -__host__ __device__ -inline void* memcpy_blocking(void *dst, const void *src) { - // Does not handle the case of long datatypes - char *d = reinterpret_cast(dst); - const char *s = reinterpret_cast(src); - size_t i = 0; -#pragma unroll - for (i = 0; i < len; ++i) { - d[i] = s[i]; - } - return dst; -} -#endif - -#ifndef USE_ROCM - -// nthrs = (32, 4) -template -__global__ void -bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t num_layers, int64_t layer_idx, - float scale) { - size_t batch_idx = blockIdx.y; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; - if (idx < 0) { - return; - } - - auto block = cg::this_thread_block(); - size_t j = blockIdx.x; - constexpr size_t num_pipeline_stages = 2; - constexpr size_t tile_size = tx * ty * vec_size; - __shared__ W_T W_shared[num_pipeline_stages * tile_size]; - __shared__ in_T X_shared[num_pipeline_stages * tile_size]; - __shared__ float y_warpwise[ty]; - - size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; - size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; - auto pipe = cuda::make_pipeline(); - - // pipeline load W/X and compute WX; - pipe.producer_acquire(); - cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, - W + (idx * feat_out + j) * feat_in + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(W_copy_size), pipe); - cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, - X + (batch_idx * feat_in) + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(X_copy_size), pipe); - pipe.producer_commit(); - size_t copy_idx, compute_idx; - float y = 0.f; - vec_t x_vec; - vec_t w_vec; - size_t tile_idx; - -#pragma unroll - for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size; - ++tile_idx) { - copy_idx = tile_idx % num_pipeline_stages; - // pipeline stage: async copy W fragment - pipe.producer_acquire(); - if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { - cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size, - W + (idx * feat_out + j) * feat_in + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(W_copy_size), pipe); - cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size, - X + (batch_idx * feat_in) + tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size, - cuda::aligned_size_t(X_copy_size), pipe); - } - pipe.producer_commit(); - - compute_idx = (tile_idx - 1) % num_pipeline_stages; - // pipeline stage: compute WX - pipe.consumer_wait(); - block.sync(); - x_vec.load(X_shared + X_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - w_vec.load(W_shared + W_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - sum += float(w_vec[i]) * float(x_vec[i]) * scale; - } -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - y_warpwise[threadIdx.y] = sum; - block.sync(); -#pragma unroll - for (size_t i = 0; i < ty; ++i) { - y += y_warpwise[i]; - } - - block.sync(); - pipe.consumer_release(); - } - - compute_idx = (tile_idx - 1) % num_pipeline_stages; - // final pipeline stage - pipe.consumer_wait(); - block.sync(); - x_vec.load(X_shared + X_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - w_vec.load(W_shared + W_shared_offset[compute_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size); - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - sum += float(w_vec[i]) * float(x_vec[i]) * scale; - } -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); - } - y_warpwise[threadIdx.y] = - ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) - ? sum - : 0.f; - block.sync(); -#pragma unroll - for (size_t i = 0; i < ty; ++i) { - y += y_warpwise[i]; - } - - block.sync(); - pipe.consumer_release(); - - // write Y; - if (block.thread_rank() == 0) { - Y[batch_idx * full_y_size + y_offset + j] += static_cast(y); - } -} - -#else - -template -__global__ void -bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t num_layers, int64_t layer_idx, - float scale) { - size_t batch_idx = blockIdx.y; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; - if (idx < 0) { - return; - } - - size_t j = blockIdx.x; - constexpr size_t tile_size = tx * ty * vec_size; - constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size; - __shared__ float y_warpwise[ty]; - - float y = 0; - vec_t x_vec; - vec_t w_vec; - size_t tile_idx; - -#pragma unroll - for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { - if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { - x_vec.load(X + (batch_idx * feat_in) + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size); - w_vec.load(W + (idx * feat_out + j) * feat_in + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size); - } - - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; - } -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += VLLM_SHFL_DOWN_SYNC(sum, offset); - } - - __syncthreads(); - - if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { - y += sum; - } - } - - if (threadIdx.x == 0) { - y_warpwise[threadIdx.y] = y; - } - __syncthreads(); - - float y_write = 0.f; -#pragma unroll - for (size_t i = 0; i < ty; ++i) { - y_write += y_warpwise[i]; - } - - // write Y; - if (threadIdx.x == 0 && threadIdx.y == 0) { - size_t y_idx = batch_idx * full_y_size + y_offset + j; - Y[y_idx] = vllm_add(Y[y_idx], convert_type(y_write)); - } -} - -#endif - -// nthrs = (2, 16, 4) -template -__global__ void -bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t num_layers, int64_t layer_idx, - float scale) { - size_t batch_idx = blockIdx.y; - int64_t idx = indicies[batch_idx] * num_layers + layer_idx; - - if (idx < 0) { - return; - } - - auto block = cg::this_thread_block(); - size_t tile_idx = blockIdx.x; - - // load X; - vec_t x_vec; - x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size); - - // load W; - vec_t w_vec; - w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in + - block.thread_rank() * vec_size); - - float sum = 0.f; -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { -#ifndef USE_ROCM - sum += float(w_vec[i]) * float(x_vec[i]) * scale; -#else - sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; -#endif - } - - cg::thread_block_tile g = cg::tiled_partition(block); -#pragma unroll - for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += g.shfl_down(sum, offset); - } - sum = g.shfl(sum, 0); - - if (threadIdx.x == 0) { -#ifndef USE_ROCM - Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + - threadIdx.z * ty + threadIdx.y] += static_cast(sum); -#else - size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + - threadIdx.z * ty + threadIdx.y; - Y[y_idx] = vllm_add(Y[y_idx], convert_type(sum)); -#endif - } -} - -template -void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, - const W_T *__restrict__ W, - const int64_t *__restrict__ indicies, int64_t y_offset, - int64_t full_y_size, int64_t batch_size, int64_t num_layers, - int64_t layer_idx, float scale) { - constexpr size_t vec_size = 8; - constexpr int tz = 4; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if constexpr (feat_in <= feat_out) { - static_assert(feat_in % vec_size == 0); - constexpr int tx = feat_in / vec_size; - - static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) || - (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) || - (8 % tx == 0 && feat_out % (8 / tx * tz) == 0)); - - if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) { - constexpr int ty = 32 / tx; - dim3 nblks(feat_out / (ty * tz), batch_size); - dim3 nthrs(tx, ty, tz); - - bgmv_expand_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) { - constexpr int ty = 16 / tx; - dim3 nblks(feat_out / (ty * tz), batch_size); - dim3 nthrs(tx, ty, tz); - - bgmv_expand_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else { - constexpr int ty = 8 / tx; - dim3 nblks(feat_out / (ty * tz), batch_size); - dim3 nthrs(tx, ty, tz); - - bgmv_expand_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } - } else { -#ifndef USE_ROCM - static_assert(feat_in % (vec_size * 32) == 0 || - feat_in % (vec_size * 16) == 0 || - feat_in % (vec_size * 8) == 0); - - if constexpr (feat_in % (vec_size * 32) == 0) { - constexpr int tx = 32; - constexpr int ty = 4; - - dim3 nblks(feat_out, batch_size); - dim3 nthrs(tx, ty); - - bgmv_shrink_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) { - constexpr int tx = 32; - constexpr int ty = 4; - - dim3 nblks(feat_out, batch_size); - dim3 nthrs(tx, ty); - - bgmv_shrink_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) { - constexpr int tx = 16; - constexpr int ty = 4; - - dim3 nblks(feat_out, batch_size); - dim3 nthrs(tx, ty); - - bgmv_shrink_kernel - <<>>(Y, X, W, indicies, y_offset, - full_y_size, num_layers, layer_idx, - scale); - } -#else - constexpr size_t rocm_warp_size = warpSize; - -#define CHECK_INPUT_TILEABLE_BY(vec_size_) \ - feat_in % (rocm_warp_size * vec_size_) == 0 - -#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \ - if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \ - constexpr size_t vec_size_shrink = vec_size_; \ - constexpr int tx = tx_; \ - constexpr int ty = ty_; \ - dim3 nblks(feat_out, batch_size); \ - dim3 nthrs(tx, ty); \ - bgmv_shrink_kernel \ - <<>>(Y, X, W, indicies, y_offset, \ - full_y_size, num_layers, layer_idx, \ - scale); \ - } - - static_assert(CHECK_INPUT_TILEABLE_BY(32) || - CHECK_INPUT_TILEABLE_BY(16) || - CHECK_INPUT_TILEABLE_BY( 8) || - CHECK_INPUT_TILEABLE_BY( 4) || - CHECK_INPUT_TILEABLE_BY( 2) || - CHECK_INPUT_TILEABLE_BY( 1)); - - LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2) - else - LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1) - -#undef CHECK_INPUT_TILEABLE_BY -#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM -#endif - } -} - -#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \ - template void bgmv_kernel( \ - out_T * __restrict__ Y, const in_T *__restrict__ X, \ - const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \ - int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ - int64_t num_layers, int64_t layer_idx, float scale); - -#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \ - INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) - -#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ - INST_BGMV(narrow, wide, in_T, out_T, W_T) \ - INST_BGMV(wide, narrow, in_T, out_T, W_T) diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py deleted file mode 100644 index 972df5a7208c..000000000000 --- a/csrc/punica/bgmv/generator.py +++ /dev/null @@ -1,48 +0,0 @@ -DTYPES = ["fp16", "bf16", "fp32"] -DTYPE_MAP = { - "fp16": "nv_half", - "bf16": "nv_bfloat16", - "fp32": "float", -} - -TEMPLATE = """ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) -FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype}) -""".lstrip() # noqa: E501 - -for input_dtype in DTYPES: - for output_dtype in DTYPES: - for weight_dtype in DTYPES: - if weight_dtype == "fp32": - # FP32 weights are not supported. - continue - if output_dtype == "fp32": - # LoRA A matrix. - if input_dtype != weight_dtype: - # NOTE(woosuk): While Punica supports the case where the - # input and weight dtypes are different, we only generate - # the kernels the same dtypes to reduce the binary size. - continue - elif input_dtype == "fp32": - # LoRA B matrix. - if output_dtype != weight_dtype: - # NOTE(woosuk): While Punica supports the case where the - # output and weight dtypes are different, we only generate - # the kernels the same dtypes to reduce the binary size. - continue - elif not (input_dtype == output_dtype == weight_dtype): - # NOTE(woosuk): While Punica supports mixed data types for - # input, output, and weight, we only generate the kernels with - # the same data types to reduce the binary size. - continue - - kernel_definition = TEMPLATE.format( - input_dtype=DTYPE_MAP[input_dtype], - output_dtype=DTYPE_MAP[output_dtype], - weight_dtype=DTYPE_MAP[weight_dtype]) - filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu" - with open(filename, "w") as f: - f.write(kernel_definition) diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh deleted file mode 100644 index 2738892e6dc4..000000000000 --- a/csrc/punica/bgmv/vec_dtypes.cuh +++ /dev/null @@ -1,1325 +0,0 @@ -#ifndef VEC_DTYPES_CUH_ -#define VEC_DTYPES_CUH_ - -#ifdef FLASHINFER_USE_FP8 -#include -#endif -#include - -#include - -#include "../type_convert.h" -#include "../../cuda_compat.h" - -#define FLASHINFER_INLINE \ - inline __attribute__((always_inline)) __device__ __host__ - -template -struct vec_t { - FLASHINFER_INLINE float_t &operator[](size_t i); - FLASHINFER_INLINE const float_t &operator[](size_t i) const; - FLASHINFER_INLINE void fill(float_t val); - FLASHINFER_INLINE void load(const float_t *ptr); - FLASHINFER_INLINE void store(float_t *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src); - template - FLASHINFER_INLINE void cast_load(const T *ptr); - template - FLASHINFER_INLINE void cast_store(T *ptr) const; - FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src); -}; - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = tgt_float_t(src[i]); - } -} - -template -FLASHINFER_INLINE void cast_load_impl(const src_float_t *src_ptr, - vec_t &dst) { - if constexpr (std::is_same::value) { - dst.load(src_ptr); - } else { - vec_t tmp; - tmp.load(src_ptr); - dst.cast_from(tmp); - } -} - -template -FLASHINFER_INLINE void cast_store_impl(const vec_t &src, - tgt_float_t *dst_ptr) { - if constexpr (std::is_same::value) { - src.store(dst_ptr); - } else { - vec_t tmp; - tmp.cast_from(src); - tmp.store(dst_ptr); - } -} - -#ifdef FLASHINFER_USE_FP8 -/******************* vec_t<__nv_fp8_e4m3> *******************/ - -// __nv_fp8_e4m3 x 1 -template <> -struct vec_t<__nv_fp8_e4m3, 1> { - __nv_fp8_e4m3 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { - data = val; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3 *ptr) { - data = *ptr; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store( - __nv_fp8_e4m3 *ptr) const { - *ptr = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *dst = *src; -} - -// __nv_fp8_e4m3 x 2 -template <> -struct vec_t<__nv_fp8_e4m3, 2> { - __nv_fp8x2_e4m3 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) { - data.__x = - (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3 *ptr) { - data = *((__nv_fp8x2_e4m3 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store( - __nv_fp8_e4m3 *ptr) const { - *((__nv_fp8x2_e4m3 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *((__nv_fp8x2_e4m3 *)dst) = *((__nv_fp8x2_e4m3 *)src); -} - -// __nv_fp8_e4m3 x 4 - -template <> -struct vec_t<__nv_fp8_e4m3, 4> { - __nv_fp8x4_e4m3 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) { - data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3 *ptr) { - data = *((__nv_fp8x4_e4m3 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store( - __nv_fp8_e4m3 *ptr) const { - *((__nv_fp8x4_e4m3 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *((__nv_fp8x4_e4m3 *)dst) = *((__nv_fp8x4_e4m3 *)src); -} - -// __nv_fp8_e4m3 x 8 - -template <> -struct vec_t<__nv_fp8_e4m3, 8> { - uint2 data; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) { - ((__nv_fp8x4_e4m3 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3 *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store( - __nv_fp8_e4m3 *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy( - __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { - *((__nv_fp8_e4m3 *)dst) = *((__nv_fp8_e4m3 *)src); -} - -// __nv_fp8_e4m3 x 16 or more -template -struct vec_t<__nv_fp8_e4m3, vec_size> { - uint4 data[vec_size / 16]; - - FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { - return ((__nv_fp8_e4m3 *)data)[i]; - } - FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { - return ((const __nv_fp8_e4m3 *)data)[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__nv_fp8x4_e4m3 *)(&(data[i].x)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&(data[i].y)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&(data[i].z)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e4m3 *)(&(data[i].w)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - } - } - FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, - const __nv_fp8_e4m3 *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; - -/******************* vec_t<__nv_fp8_e5m2> *******************/ - -// __nv_fp8_e5m2 x 1 -template <> -struct vec_t<__nv_fp8_e5m2, 1> { - __nv_fp8_e5m2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { - data = val; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2 *ptr) { - data = *ptr; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store( - __nv_fp8_e5m2 *ptr) const { - *ptr = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *dst = *src; -} - -// __nv_fp8_e5m2 x 2 -template <> -struct vec_t<__nv_fp8_e5m2, 2> { - __nv_fp8x2_e5m2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) { - data.__x = - (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2 *ptr) { - data = *((__nv_fp8x2_e5m2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store( - __nv_fp8_e5m2 *ptr) const { - *((__nv_fp8x2_e5m2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *((__nv_fp8x2_e5m2 *)dst) = *((__nv_fp8x2_e5m2 *)src); -} - -// __nv_fp8_e5m2 x 4 - -template <> -struct vec_t<__nv_fp8_e5m2, 4> { - __nv_fp8x4_e5m2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) { - data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2 *ptr) { - data = *((__nv_fp8x4_e5m2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store( - __nv_fp8_e5m2 *ptr) const { - *((__nv_fp8x4_e5m2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *((__nv_fp8x4_e5m2 *)dst) = *((__nv_fp8x4_e5m2 *)src); -} - -// __nv_fp8_e5m2 x 8 - -template <> -struct vec_t<__nv_fp8_e5m2, 8> { - uint2 data; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src); -}; - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) { - ((__nv_fp8x4_e5m2 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | - __nv_fp8x4_storage_t(val.__x); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2 *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store( - __nv_fp8_e5m2 *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy( - __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { - *((__nv_fp8_e5m2 *)dst) = *((__nv_fp8_e5m2 *)src); -} - -// __nv_fp8_e5m2 x 16 or more - -template -struct vec_t<__nv_fp8_e5m2, vec_size> { - uint4 data[vec_size / 16]; - - FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { - return ((__nv_fp8_e5m2 *)data)[i]; - } - FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { - return ((const __nv_fp8_e5m2 *)data)[i]; - } - FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__nv_fp8x4_e5m2 *)(&(data[i].x)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&(data[i].y)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&(data[i].z)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - ((__nv_fp8x4_e5m2 *)(&(data[i].w)))->__x = - (__nv_fp8x4_storage_t(val.__x) << 24) | - (__nv_fp8x4_storage_t(val.__x) << 16) | - (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); - } - } - FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, - const __nv_fp8_e5m2 *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; -#endif - -/******************* vec_t *******************/ - -// half x 1 -template <> -struct vec_t { - half data; - - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); -}; - -FLASHINFER_INLINE void vec_t::fill(half val) { data = val; } - -FLASHINFER_INLINE void vec_t::load(const half *ptr) { data = *ptr; } - -FLASHINFER_INLINE void vec_t::store(half *ptr) const { *ptr = data; } - -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { - *dst = *src; -} - -// half x 2 -template <> -struct vec_t { - half2 data; - - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); -}; - -FLASHINFER_INLINE void vec_t::fill(half val) { - data = make_half2(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const half *ptr) { - data = *((half2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(half *ptr) const { - *((half2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { - *((half2 *)dst) = *((half2 *)src); -} - -// half x 4 - -template <> -struct vec_t { - uint2 data; - - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); -}; - -FLASHINFER_INLINE void vec_t::fill(half val) { - *(half2 *)(&data.x) = make_half2(val, val); - *(half2 *)(&data.y) = make_half2(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const half *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(half *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { - *((uint2 *)dst) = *((uint2 *)src); -} - -// half x 8 or more - -template -struct vec_t { - uint4 data[vec_size / 8]; - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const { - return ((const half *)data)[i]; - } - FLASHINFER_INLINE void fill(half val) { -#pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - *(half2 *)(&(data[i].x)) = make_half2(val, val); - *(half2 *)(&(data[i].y)) = make_half2(val, val); - *(half2 *)(&(data[i].z)) = make_half2(val, val); - *(half2 *)(&(data[i].w)) = make_half2(val, val); - } - } - FLASHINFER_INLINE void load(const half *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(half *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; - -/******************* vec_t *******************/ - -// nv_bfloat16 x 1 -template <> -struct vec_t { - nv_bfloat16 data; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val); - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src); -}; - -FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { - data = val; -} - -FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { - data = *ptr; -} - -FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { - *ptr = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { - *dst = *src; -} - -// nv_bfloat16 x 2 -template <> -struct vec_t { - nv_bfloat162 data; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val); - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src); -}; - -FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { - data = make_bfloat162(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { - data = *((nv_bfloat162 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { - *((nv_bfloat162 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { - *((nv_bfloat162 *)dst) = *((nv_bfloat162 *)src); -} - -// nv_bfloat16 x 4 - -template <> -struct vec_t { - uint2 data; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val); - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src); -}; - -FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { - *(nv_bfloat162 *)(&data.x) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&data.y) = make_bfloat162(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { - data = *((uint2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { - *((uint2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { - *((uint2 *)dst) = *((uint2 *)src); -} - -// nv_bfloat16 x 8 or more - -template -struct vec_t { - uint4 data[vec_size / 8]; - - FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { - return ((nv_bfloat16 *)data)[i]; - } - FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { - return ((const nv_bfloat16 *)data)[i]; - } - FLASHINFER_INLINE void fill(nv_bfloat16 val) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - *(nv_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val); - *(nv_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val); - } - } - FLASHINFER_INLINE void load(const nv_bfloat16 *ptr) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, - const nv_bfloat16 *src) { -#pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } - } -}; - -/******************* vec_t *******************/ - -// float x 1 - -template <> -struct vec_t { - float data; - - FLASHINFER_INLINE float &operator[](size_t i) { - return ((float *)(&data))[i]; - } - FLASHINFER_INLINE const float &operator[](size_t i) const { - return ((const float *)(&data))[i]; - } - FLASHINFER_INLINE void fill(float val); - FLASHINFER_INLINE void load(const float *ptr); - FLASHINFER_INLINE void store(float *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - - FLASHINFER_INLINE static void memcpy(float *dst, const float *src); -}; - -FLASHINFER_INLINE void vec_t::fill(float val) { data = val; } - -FLASHINFER_INLINE void vec_t::load(const float *ptr) { data = *ptr; } - -FLASHINFER_INLINE void vec_t::store(float *ptr) const { *ptr = data; } - -FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { - *dst = *src; -} - -// float x 2 - -template <> -struct vec_t { - float2 data; - - FLASHINFER_INLINE float &operator[](size_t i) { - return ((float *)(&data))[i]; - } - FLASHINFER_INLINE const float &operator[](size_t i) const { - return ((const float *)(&data))[i]; - } - FLASHINFER_INLINE void fill(float val); - FLASHINFER_INLINE void load(const float *ptr); - FLASHINFER_INLINE void store(float *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - FLASHINFER_INLINE static void memcpy(float *dst, const float *src); -}; - -FLASHINFER_INLINE void vec_t::fill(float val) { - data = make_float2(val, val); -} - -FLASHINFER_INLINE void vec_t::load(const float *ptr) { - data = *((float2 *)ptr); -} - -FLASHINFER_INLINE void vec_t::store(float *ptr) const { - *((float2 *)ptr) = data; -} - -FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { - *((float2 *)dst) = *((float2 *)src); -} - -// float x 4 or more -template -struct vec_t { - float4 data[vec_size / 4]; - - FLASHINFER_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; } - FLASHINFER_INLINE const float &operator[](size_t i) const { - return ((const float *)(data))[i]; - } - FLASHINFER_INLINE void fill(float val) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = make_float4(val, val, val, val); - } - } - FLASHINFER_INLINE void load(const float *ptr) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = ((float4 *)ptr)[i]; - } - } - FLASHINFER_INLINE void store(float *ptr) const { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) { - cast_from_impl(src, *this); - } - template - FLASHINFER_INLINE void cast_load(const T *ptr) { - cast_load_impl(ptr, *this); - } - template - FLASHINFER_INLINE void cast_store(T *ptr) const { - cast_store_impl(*this, ptr); - } - FLASHINFER_INLINE static void memcpy(float *dst, const float *src) { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)dst)[i] = ((float4 *)src)[i]; - } - } -}; - -/******************* vec_t type cast *******************/ - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2 *)(&dst.data))[i] = __half22float2(((half2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = half(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2 *)(&dst.data))[i] = __float22half2_rn(((float2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2 *)(&dst.data))[i] = - __bfloat1622float2(((nv_bfloat162 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = nv_bfloat16(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((nv_bfloat162 *)(&dst.data))[i] = - __float22bfloat162_rn(((float2 *)(&src.data))[i]); - } - } -} - -#ifdef FLASHINFER_USE_FP8 - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else if constexpr (vec_size == 2) { - *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e4m3 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e4m3, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e4m3(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(float2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = - __nv_fp8x4_e4m3(((float4 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e4m3, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e4m3(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(half2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - // NOTE(Zihao): need to double check if we properly handle flo and fhi - ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = __nv_fp8x4_e4m3( - ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else if constexpr (vec_size == 2) { - *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e5m2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, - vec_t &dst) { - if constexpr (vec_size == 1) { - dst.data = float(src.data); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e5m2, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e5m2(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(float2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = - __nv_fp8x4_e5m2(((float4 *)(&src.data))[i]); - } - } -} - -template -FLASHINFER_INLINE void cast_from_impl(const vec_t &src, - vec_t<__nv_fp8_e5m2, vec_size> &dst) { - if constexpr (vec_size == 1) { - dst.data = __nv_fp8_e4m3(src.data); - } else if constexpr (vec_size == 2) { - *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(half2 *)(&src.data)); - } else { -#pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - // NOTE(Zihao): need to double check if we properly handle flo and fhi - ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = __nv_fp8x4_e5m2( - ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); - } - } -} - -#endif // FLASHINFER_USE_FP8 - -#endif // VEC_DTYPES_CUH_ diff --git a/csrc/punica/punica_ops.cu b/csrc/punica/punica_ops.cu deleted file mode 100644 index dd29820144b3..000000000000 --- a/csrc/punica/punica_ops.cu +++ /dev/null @@ -1,569 +0,0 @@ -#include -#include -#include - -#include "type_convert.h" -#include "../cuda_compat.h" -#include "bgmv/bgmv_config.h" - - -//====== utils ====== - -inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, - const char *a_name, const char *b_name) { - TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", - a.dim(), " vs ", b.dim()); - for (int i = 0; i < a.dim(); ++i) { - TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, - ".size(", i, ")"); - } -} - -inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) { - return (uint64_t(a) << 32) | uint64_t(b); -} - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") - -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -#define CHECK_DIM(d, x) \ - TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") - -#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) - -#define CHECK_EQ(a, b) \ - TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) - -//====== bgmv ====== - -template -inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, - const int64_t *lora_indices, - uint32_t in_features, uint32_t out_features, - int64_t y_offset, int64_t full_y_size, - int64_t batch_size, int64_t num_layers, - int64_t layer_idx, float scale) { - // NOTE(woosuk): While Punica supports various combinations of input/output - // data types, we limit the supported data types to reduce the binary size. - constexpr bool is_input_float = std::is_same::value; - constexpr bool is_output_float = std::is_same::value; - if (is_input_float) { - if (!std::is_same::value) { - return false; - } - } else if (is_output_float) { - if (!std::is_same::value) { - return false; - } - } else if (!(std::is_same::value && - std::is_same::value)) { - return false; - } - - switch (pack_u32(in_features, out_features)) { -#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ - case pack_u32(feat_in, feat_out): \ - bgmv_kernel(Y, X, W, lora_indices, y_offset, \ - full_y_size, batch_size, num_layers, \ - layer_idx, scale); \ - break; -#define CASE(_in_T, _out_T, _W_T, narrow, wide) \ - CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \ - CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) - - FOR_BGMV_WIDE_NARROW(CASE, _, _, _) - FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _) -#undef CASE -#undef CASE_ONESIDE - default: - return false; - } - return true; -} - -void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, double scale) { - CHECK_INPUT(y); - CHECK_INPUT(x); - CHECK_INPUT(w); - CHECK_INPUT(indicies); - - CHECK_DIM(2, y); - CHECK_DIM(2, x); - CHECK_DIM(4, w); - CHECK_DIM(1, indicies); - - int64_t B = x.size(0); - int64_t h_in = x.size(1); - int64_t h_out = y.size(1); - int64_t num_layers = w.size(1); - CHECK_EQ(w.size(3), h_in); - CHECK_EQ(w.size(2), h_out); - CHECK_EQ(indicies.size(0), x.size(0)); - CHECK_EQ(y.size(0), x.size(0)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - bool ok = false; - if (h_in <= 128512 && h_out <= 128512) { - // TODO: See if we can get rid of this massive nested switch - switch (x.scalar_type()) { - case at::ScalarType::Half: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, 0, - h_out, B, num_layers, layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - default: - break; - } - } - TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, - " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); -} - -void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, - double scale, int64_t h_in, int64_t h_out, - int64_t y_offset) { - CHECK_INPUT(y); - CHECK_INPUT(x); - CHECK_INPUT(w); - CHECK_INPUT(indicies); - - CHECK_DIM(2, y); - CHECK_DIM(2, x); - CHECK_DIM(4, w); - CHECK_DIM(1, indicies); - - int64_t B = x.size(0); - int64_t num_layers = w.size(1); - int64_t full_y_size = y.size(1); - CHECK_EQ(w.size(3), h_in); - CHECK_EQ(w.size(2), h_out); - CHECK_EQ(indicies.size(0), x.size(0)); - CHECK_EQ(y.size(0), x.size(0)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - bool ok = false; - if (h_in <= 128512 && h_out <= 128512) { - // TODO: See if we can get rid of this massive nested switch - switch (x.scalar_type()) { - case at::ScalarType::Half: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (y.scalar_type()) { - case at::ScalarType::Half: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::BFloat16: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - case at::ScalarType::Float: - switch (w.scalar_type()) { - case at::ScalarType::Half: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - case at::ScalarType::BFloat16: - ok = launch_bgmv_kernel(static_cast(y.data_ptr()), - static_cast(x.data_ptr()), - static_cast(w.data_ptr()), - indicies.data_ptr(), h_in, h_out, - y_offset, full_y_size, B, num_layers, - layer_idx, scale); - break; - default: - break; - } - break; - default: - break; - } - break; - default: - break; - } - } - TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, - " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); -} diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h deleted file mode 100644 index 5d625d0564f7..000000000000 --- a/csrc/punica/punica_ops.h +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#include - -void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, double scale); - -void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, - double scale, int64_t h_in, int64_t h_out, - int64_t y_offset); diff --git a/csrc/punica/torch_bindings.cpp b/csrc/punica/torch_bindings.cpp deleted file mode 100644 index 894e229b6d9d..000000000000 --- a/csrc/punica/torch_bindings.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "registration.h" -#include "punica_ops.h" - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { - m.def( - "dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int " - "layer_idx, float scale) -> ()"); - m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv); - - m.def( - "dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w," - "Tensor indicies, int layer_idx," - "float scale, int h_in, int h_out," - "int y_offset) -> ()"); - m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); -} - -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/punica/type_convert.h b/csrc/punica/type_convert.h deleted file mode 100644 index dff7ce49283d..000000000000 --- a/csrc/punica/type_convert.h +++ /dev/null @@ -1,82 +0,0 @@ -#ifndef CSRC__PUNICA__TYPE_CONVERT_H__ -#define CSRC__PUNICA__TYPE_CONVERT_H__ - -#ifndef USE_ROCM - -#include -#include - -#else - -#include -#include - -#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ - -typedef __half nv_half; -typedef __hip_bfloat16 nv_bfloat16; -typedef __hip_bfloat162 nv_bfloat162; - -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { - return __hip_bfloat162{val, val}; -} - -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) { - return __hip_bfloat162{vall, valr}; -} - -template -__TYPE_CONVERT__HOST_DEVICE__ -inline T_dst convert_type(T_src val) { - return static_cast(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline float convert_type<__half, float>(__half val) { - return __half2float(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __half convert_type(float val) { - return __float2half(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) { - return __bfloat162float(val); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat16 convert_type(float val) { - return __float2bfloat16(val); -} - -template -__TYPE_CONVERT__HOST_DEVICE__ -inline T vllm_add(T a, T b) { - return a + b; -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __half vllm_add<__half>(__half a, __half b) { - return __hadd(a, b); -} - -template <> -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) { - return __hadd(a, b); -} - -#undef __TYPE_CONVERT__HOST_DEVICE__ - -#endif // USE_ROCM - -#endif // CSRC__PUNICA__TYPE_CONVERT_H__ diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index fe041e03a1b6..0253717da3cd 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -66,7 +66,6 @@ You can also build and install vLLM from source: $ git clone https://github.com/vllm-project/vllm.git $ cd vllm - $ # export VLLM_INSTALL_PUNICA_KERNELS=1 # optionally build for multi-LoRA capability $ pip install -e . # This may take 5-10 minutes. .. tip:: diff --git a/setup.py b/setup.py index 72ef26f15e40..63c1f466d291 100644 --- a/setup.py +++ b/setup.py @@ -181,9 +181,6 @@ def configure(self, ext: CMakeExtension) -> None: # match. cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)] - if _install_punica(): - cmake_args += ['-DVLLM_INSTALL_PUNICA_KERNELS=ON'] - # # Setup parallelism and build tool # @@ -274,10 +271,6 @@ def _build_custom_ops() -> bool: return _is_cuda() or _is_hip() or _is_cpu() -def _install_punica() -> bool: - return envs.VLLM_INSTALL_PUNICA_KERNELS - - def get_hipcc_rocm_version(): # Run the hipcc --version command result = subprocess.run(['hipcc', '--version'], @@ -446,9 +439,6 @@ def _read_requirements(filename: str) -> List[str]: if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) - if _install_punica(): - ext_modules.append(CMakeExtension(name="vllm._punica_C")) - package_data = { "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] } diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py index 3c53f7decc6e..713e868986a5 100644 --- a/tests/kernels/test_sampler.py +++ b/tests/kernels/test_sampler.py @@ -1,14 +1,17 @@ import gc +from unittest.mock import patch import pytest import torch import triton import triton.language as tl -from vllm.model_executor.layers.ops.sample import (_uniform_to_exponential, +from vllm.model_executor.layers.ops.sample import (_sample_triton, + _uniform_to_exponential, sample) from vllm.model_executor.sampling_metadata import SamplingTensors from vllm.model_executor.utils import set_random_seed +from vllm.triton_utils.libentry import LibEntry from vllm.triton_utils.sample import (MAX_TRITON_N_COLS, get_num_triton_sampler_splits) @@ -76,15 +79,20 @@ def test_sample_decoding_only(random_sampling, max_best_of, seeds = torch.randint(1, torch.iinfo(torch.long).max, (n_splits, bs), device="cuda").mul_(random_sampling_mask) - sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( - probs=probs, - logprobs=logprobs, - sample_indices=sample_indices, - seeds=seeds, - max_best_of=max_best_of, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - _save_modified_probs=True) + #The current _sample_triton does not utilize the + # libentry decoration. The purpose of adding this patch is to test + # the correctness of libentry. + with patch("vllm.model_executor.layers.ops.sample._sample_triton", + LibEntry(_sample_triton)): + sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( + probs=probs, + logprobs=logprobs, + sample_indices=sample_indices, + seeds=seeds, + max_best_of=max_best_of, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + _save_modified_probs=True) assert sampled_tokens.shape == (bs, max_best_of) for i in range(bs): assert torch.all(sampled_tokens[i] == i * (vocab_size // bs)) @@ -130,6 +138,7 @@ def test_sample_decoding_only(random_sampling, max_best_of, [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) def test_sample_prompt_logprobs(random_sampling, max_best_of, modify_greedy_probs, seed, vocab_size): + set_random_seed(seed) prompt_sizes = [16, 32, 64, 128] * 2 samples = 8 @@ -157,14 +166,17 @@ def test_sample_prompt_logprobs(random_sampling, max_best_of, seeds = torch.randint(1, torch.iinfo(torch.long).max, (n_splits, samples), device="cuda").mul_(random_sampling_mask) - sampled_tokens, sampled_logprobs, _ = sample( - probs=probs, - logprobs=logprobs, - sample_indices=sample_indices, - seeds=seeds, - max_best_of=max_best_of, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=True) + #ditto + with patch("vllm.model_executor.layers.ops.sample._sample_triton", + LibEntry(_sample_triton)): + sampled_tokens, sampled_logprobs, _ = sample( + probs=probs, + logprobs=logprobs, + sample_indices=sample_indices, + seeds=seeds, + max_best_of=max_best_of, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=True) assert sampled_tokens.shape == (samples, max_best_of) assert sampled_logprobs.shape == (samples, max_best_of) for i, t in enumerate(sample_indices): diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index 709246179bfe..478bb86b7861 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -37,7 +37,7 @@ def test_gemma_lora(gemma_lora_files): expected_lora_output = [ "more important than knowledge.\nAuthor: Albert Einstein\n", "everyone else is already taken.\nAuthor: Oscar Wilde\n", - "so little time\nAuthor: Frank Zappa\n", + "so little time.\nAuthor: Frank Zappa\n", ] output1 = do_sample(llm, gemma_lora_files, lora_id=1) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 7207af6b1a4b..6f33f56616fc 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -26,7 +26,8 @@ VocabParallelEmbeddingWithLoRA) # yapf: enable from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, - PackedLoRALayerWeights, convert_mapping) + PackedLoRALayerWeights) +from vllm.lora.punica import PunicaWrapper from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, @@ -47,6 +48,9 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +# We will launch different triton kernels between the prefill and decode +# stages, so we need to verify this. prefill stage(True) or decode stage(False) +STAGES = [True, False] def get_random_id_to_index(num_loras: int, @@ -182,10 +186,12 @@ def create_random_inputs( @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) -def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: +@pytest.mark.parametrize("stage", STAGES) +def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: torch.set_default_device(device) max_loras = 8 + punica_wrapper = PunicaWrapper(8192, 256, device) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -204,7 +210,7 @@ def create_random_embedding_layer(): id_to_index = get_random_id_to_index(num_loras, max_loras) embedding, lora_embedding = create_random_embedding_layer() - + lora_embedding.set_mapping(punica_wrapper) lora_dict, _ = populate_loras( id_to_index, layer=lora_embedding, @@ -217,12 +223,12 @@ def create_random_embedding_layer(): input_size=(200, ), input_range=(1, vocab_size), ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size) - lora_embedding.set_mapping(*mapping_info) lora_result = lora_embedding(torch.cat(inputs)) @@ -255,12 +261,12 @@ def create_random_embedding_layer(): input_size=(200, ), input_range=(1, vocab_size), ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size) - lora_embedding.set_mapping(*mapping_info, ) lora_result = lora_embedding(torch.cat(inputs)) expected_result = embedding(torch.cat(inputs)) @@ -278,11 +284,13 @@ def create_random_embedding_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +@pytest.mark.parametrize("stage", STAGES) def test_embeddings_with_new_embeddings(dist_init, num_loras, device, - vocab_size) -> None: + vocab_size, stage) -> None: torch.set_default_device(device) max_loras = 8 + punica_wrapper = PunicaWrapper(8192, 256, device) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -318,6 +326,7 @@ def create_random_embedding_layer(): generate_embeddings_tensor=256, ) + lora_embedding.set_mapping(punica_wrapper) # All embeddings tensors have the same shape. embeddings_tensors = [ lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) @@ -334,8 +343,12 @@ def create_random_embedding_layer(): input_size=(200, ), input_range=(1, vocab_size), ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + vocab_size, + lora_config.lora_extra_vocab_size) original_inputs = deepcopy(inputs) # Force some of the inputs to be in the extended embeddings range @@ -349,11 +362,6 @@ def create_random_embedding_layer(): (embedding_id + 1) * embeddings_tensor_len - 1) original_input_[-2] = vocab_size + embeddings_tensor_len - 1 - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) - lora_embedding.set_mapping(*mapping_info, ) - expanded_embedding.weight[vocab_size:vocab_size + (embeddings_tensor_len * max_loras)] = torch.cat(embeddings_tensors) @@ -390,15 +398,13 @@ def create_random_embedding_layer(): input_size=(200, ), input_range=(1, vocab_size), ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - original_inputs = deepcopy(inputs) - - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size) - lora_embedding.set_mapping(*mapping_info, ) - lora_result = lora_embedding(torch.cat(original_inputs)) expected_result = expanded_embedding(torch.cat(inputs)) @@ -413,11 +419,13 @@ def create_random_embedding_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) -def test_lm_head_logits_processor(dist_init, num_loras, device, - vocab_size) -> None: +@pytest.mark.parametrize("stage", STAGES) +def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, + stage) -> None: torch.set_default_device(device) max_loras = 8 + punica_wrapper = PunicaWrapper(8192, 256, device) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16) @@ -443,7 +451,7 @@ def _pretest(): id_to_index = get_random_id_to_index(num_loras, max_loras) linear, logits_processor, lora_logits_processor = _pretest() - + lora_logits_processor.set_mapping(punica_wrapper) # NOTE: all the generated loras share the same embeddings tensor. lora_dict, _ = populate_loras( id_to_index, @@ -461,17 +469,17 @@ def _pretest(): input_range=(0, 1), input_type=torch.float16, ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - - input_ = torch.rand(20, 1024) - mapping_info = convert_mapping( + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size, ) - lora_logits_processor.set_mapping(*mapping_info, ) + input_ = torch.rand(20, 1024) lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), @@ -510,12 +518,16 @@ def _pretest(): input_range=(0, 1), input_type=torch.float16, ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) - lora_logits_processor.set_mapping(*mapping_info, ) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), @@ -538,10 +550,12 @@ def _pretest(): @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("stage", STAGES) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, - device) -> None: + device, stage) -> None: torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -575,7 +589,7 @@ def create_random_linear_parallel_layer(): id_to_index = get_random_id_to_index(num_loras, max_loras) linear, lora_linear = create_random_linear_parallel_layer() - + lora_linear.set_mapping(punica_wrapper) lora_dict, _ = populate_loras( id_to_index, layer=lora_linear, @@ -589,16 +603,16 @@ def create_random_linear_parallel_layer(): input_range=(0, 1), input_type=torch.float16, ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - - mapping_info = convert_mapping( + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size, ) - lora_linear.set_mapping(*mapping_info, ) lora_result = lora_linear(torch.cat(inputs))[0] @@ -628,11 +642,12 @@ def create_random_linear_parallel_layer(): input_range=(0, 1), input_type=torch.float16, ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) - mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size) - lora_linear.set_mapping(*mapping_info, ) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] @@ -649,10 +664,12 @@ def create_random_linear_parallel_layer(): @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("stage", STAGES) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device) -> None: + device, stage) -> None: torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -707,7 +724,7 @@ class FakeConfig: id_to_index = get_random_id_to_index(num_loras, max_loras) linear, lora_linear = create_column_parallel_packed_layer() - + lora_linear.set_mapping(punica_wrapper) lora_dict, sublora_dict = populate_loras( id_to_index, layer=lora_linear, @@ -722,16 +739,17 @@ class FakeConfig: input_range=(0, 1), input_type=torch.float16, ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) - mapping_info = convert_mapping( + punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size, ) - lora_linear.set_mapping(*mapping_info) lora_result = lora_linear(torch.cat(inputs))[0] @@ -762,16 +780,18 @@ class FakeConfig: input_range=(0, 1), input_type=torch.float16, ) - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) - mapping_info = convert_mapping( + punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size, ) - lora_linear.set_mapping(*mapping_info) + # lora_linear.set_mapping(*mapping_info) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] @@ -803,7 +823,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.set_default_device(device) - + punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -825,6 +845,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, is_neox_style, ) lora_rope = LinearScalingRotaryEmbeddingWithLora(rope) + lora_rope.set_mapping(punica_wrapper) lora_rope.create_lora_weights(max_loras, lora_config) linear_rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { @@ -840,6 +861,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, input_range=(0, lora_config.lora_extra_vocab_size), input_type=torch.float16, ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) long_lora_context = LongContextLoRAContext(list(scaling_factors), rotary_dim) @@ -854,7 +876,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, for i in range(len(scaling_factors)): long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get( scaling_factors[i], 0) - mapping_info = convert_mapping( + punica_wrapper.update_metadata( lora_mapping, id_to_index, max_loras, @@ -862,7 +884,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, lora_config.lora_extra_vocab_size, long_lora_context=long_lora_context, ) - lora_rope.set_mapping(*mapping_info) + # lora_rope.set_mapping(*mapping_info) positions = torch.randint(0, max_position, (batch_size, seq_len)) query = torch.randn(batch_size, diff --git a/tests/lora/test_lora.py b/tests/lora/test_lora.py deleted file mode 100644 index 3415d36b7e34..000000000000 --- a/tests/lora/test_lora.py +++ /dev/null @@ -1,224 +0,0 @@ -import pytest -import torch - -from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice - -from .utils import DummyLoRAManager - -TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4] -QKV_TENSOR_SIZES = [ - (8192, 1024, 1024), - (8192 // 8, 1024 // 8, 1024 // 8), - (4096, 4096, 4096), - (4096 // 2, 4096 // 2, 4096 // 2), -] -BATCH_SIZES = [8, 32, 256] -RANKS = [8] -DTYPES = [torch.float16] -TOLERANCES = { - torch.float16: (5e-3, 5e-3), - torch.bfloat16: (3e-2, 2e-2), -} - - -@pytest.mark.parametrize("m", TENSOR_SIZES) -@pytest.mark.parametrize("n", TENSOR_SIZES) -@pytest.mark.parametrize("k", BATCH_SIZES) -@pytest.mark.parametrize("rank", RANKS) -@pytest.mark.parametrize("dtype", DTYPES) -def test_apply_lora(m, n, k, rank, dtype) -> None: - manager = DummyLoRAManager() - - module_name = "module" - weight = torch.rand([m, n], device="cuda", dtype=dtype) - - manager.init_random_lora(module_name, weight, rank=rank) - lora = manager.get_module_lora(module_name) - - input = torch.rand(k, n, device="cuda", dtype=dtype) - expected = input @ lora.lora_a @ lora.lora_b * lora.scaling - - lora_a_stack = torch.zeros(8, - 1, - lora.lora_a.shape[1], - lora.lora_a.shape[0], - device="cuda", - dtype=dtype) - lora_b_stack = torch.zeros(8, - 1, - lora.lora_b.shape[1], - lora.lora_b.shape[0], - device="cuda", - dtype=dtype) - for i in range(lora_a_stack.shape[0]): - lora_a_stack[i][0] = lora.lora_a.T - lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T - - output = torch.zeros(k, m, device="cuda", dtype=dtype) - _apply_lora( - input, lora_a_stack, lora_b_stack, - torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"), - output) - - rtol, atol = TOLERANCES[dtype] - assert torch.allclose(expected, output, rtol=rtol, atol=atol) - - output[:] = 0 - _apply_lora(input, lora_a_stack, lora_b_stack, - torch.full((len(input), ), -1, device="cuda"), output) - assert torch.allclose(torch.zeros_like(output), output) - - manager.reset_lora() - - -@pytest.mark.parametrize("m", TENSOR_SIZES) -@pytest.mark.parametrize("n", TENSOR_SIZES) -@pytest.mark.parametrize("k", BATCH_SIZES) -@pytest.mark.parametrize("rank", RANKS) -@pytest.mark.parametrize("dtype", DTYPES) -def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: - if m % 2 != 0: - pytest.skip("m must be divisible by 2") - if m // 2 not in TENSOR_SIZES: - pytest.skip("m//2 must be in TENSOR_SIZES") - - manager = DummyLoRAManager() - - module_name = "module" - weight = torch.rand([m // 2, n], device="cuda", dtype=dtype) - - manager.init_random_lora(module_name + "1", weight, rank=rank) - lora_1 = manager.get_module_lora(module_name + "1") - manager.init_random_lora(module_name + "2", weight, rank=rank) - lora_2 = manager.get_module_lora(module_name + "2") - - input = torch.rand(k, n, device="cuda", dtype=dtype) - expected = torch.cat([ - input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling, - input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling - ], - dim=1) - - lora_a_stacks = [ - torch.zeros(8, - 1, - lora_1.lora_a.shape[1], - lora_1.lora_a.shape[0], - device="cuda", - dtype=dtype) for i in range(2) - ] - lora_b_stacks = [ - torch.zeros(8, - 1, - lora_1.lora_b.shape[1], - lora_1.lora_b.shape[0], - device="cuda", - dtype=dtype) for i in range(2) - ] - for i in range(lora_a_stacks[0].shape[0]): - lora_a_stacks[0][i][0] = lora_1.lora_a.T - lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T - lora_a_stacks[1][i][0] = lora_2.lora_a.T - lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T - - output = torch.zeros(k, m, device="cuda", dtype=dtype) - _apply_lora_packed_nslice( - input, lora_a_stacks, lora_b_stacks, - torch.randint(0, - lora_a_stacks[0].shape[0], (len(input), ), - device="cuda"), output, (m // 2, m // 2)) - - rtol, atol = TOLERANCES[dtype] - assert torch.allclose(expected, output, rtol=rtol, atol=atol) - - output[:] = 0 - _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, - torch.full((len(input), ), -1, device="cuda"), - output, (m // 2, m // 2)) - assert torch.allclose(torch.zeros_like(output), output) - - manager.reset_lora() - - -@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES) -@pytest.mark.parametrize("n", TENSOR_SIZES) -@pytest.mark.parametrize("k", BATCH_SIZES) -@pytest.mark.parametrize("rank", RANKS) -@pytest.mark.parametrize("dtype", DTYPES) -def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: - manager = DummyLoRAManager() - - module_name = "module" - weight_q = torch.empty(qkv[0], n, device="cuda", dtype=dtype) - weight_kv = torch.empty(qkv[1], n, device="cuda", dtype=dtype) - - manager.init_random_lora(module_name + "q", weight_q, rank=rank) - lora_q = manager.get_module_lora(module_name + "q") - manager.init_random_lora(module_name + "k", weight_kv, rank=rank) - lora_k = manager.get_module_lora(module_name + "k") - manager.init_random_lora(module_name + "v", weight_kv, rank=rank) - lora_v = manager.get_module_lora(module_name + "v") - - input = torch.rand(k, n, device="cuda", dtype=dtype) - expected = torch.cat([ - input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling, - input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling, - input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling - ], - dim=1) - - lora_a_stacks = [ - torch.zeros(8, - 1, - lora_q.lora_a.shape[1], - lora_q.lora_a.shape[0], - device="cuda", - dtype=dtype) - ] + [ - torch.zeros(8, - 1, - lora_k.lora_a.shape[1], - lora_k.lora_a.shape[0], - device="cuda", - dtype=dtype) for i in range(2) - ] - lora_b_stacks = [ - torch.zeros(8, - 1, - lora_q.lora_b.shape[1], - lora_q.lora_b.shape[0], - device="cuda", - dtype=dtype) - ] + [ - torch.zeros(8, - 1, - lora_k.lora_b.shape[1], - lora_k.lora_b.shape[0], - device="cuda", - dtype=dtype) for i in range(2) - ] - for i in range(lora_a_stacks[0].shape[0]): - lora_a_stacks[0][i][0] = lora_q.lora_a.T - lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T - lora_a_stacks[1][i][0] = lora_k.lora_a.T - lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T - lora_a_stacks[2][i][0] = lora_v.lora_a.T - lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T - - output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype) - _apply_lora_packed_nslice( - input, lora_a_stacks, lora_b_stacks, - torch.randint(0, - lora_a_stacks[0].shape[0], (len(input), ), - device="cuda"), output, (qkv[0], qkv[1], qkv[2])) - - rtol, atol = TOLERANCES[dtype] - assert torch.allclose(expected, output, rtol=rtol, atol=atol) - - output[:] = 0 - _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, - torch.full((len(input), ), -1, device="cuda"), - output, (qkv[0], qkv[1], qkv[2])) - assert torch.allclose(torch.zeros_like(output), output) - - manager.reset_lora() diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py deleted file mode 100644 index dbeb16cb21ad..000000000000 --- a/tests/lora/test_punica.py +++ /dev/null @@ -1,258 +0,0 @@ -# Based on code from https://github.com/punica-ai/punica - -import pytest -import torch - -import vllm.lora.punica as punica - - -def assert_close(a, b): - rtol, atol = { - torch.float16: (5e-3, 5e-3), - torch.bfloat16: (3e-2, 2e-2), - torch.float32: (None, None), - }[a.dtype] - torch.testing.assert_close(a, b, rtol=rtol, atol=atol) - - -def _lora_ref_impl( - y_final: torch.Tensor, - x: torch.Tensor, - wa_T_all: torch.Tensor, - wb_T_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, - scale: float, -): - y_stage_1 = torch.empty( - (x.size(0), wa_T_all.size(-2)), - dtype=torch.float32, - device=x.device, - ) - bs = x.shape[0] - s = torch.tensor(scale, dtype=torch.float32, device=x.device) - for i, lora_idx in zip(range(bs), indicies.cpu().tolist()): - xi = x[i].unsqueeze(0).to(torch.float32) - wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) - if wb_T_all is not None: - wb = wb_T_all[lora_idx, layer_idx].transpose(-1, - -2).to(torch.float32) - - tmp = xi @ wa - y_stage_1[i] = tmp.squeeze(0) - y_final[i] += ((tmp @ wb).squeeze(0) * - s if wb_T_all is not None else y_stage_1[i]) - return y_final, y_stage_1 - - -H1 = H2 = [ - 128, - 256, - 512, - 896, - 1024, - 1152, - 1216, - 1280, - 1536, - 1664, - 2048, - 2240, - 2304, - 2368, - 2432, - 2560, - 2752, - 3072, - 3328, - 3456, - 3584, - 3712, - 4096, - 4480, - 4608, - 4736, - 4864, - 5120, - 5504, - 5632, - 5888, - 6144, - 6400, - 6848, - 6912, - 7168, - 7424, - 8192, - 8960, - 9216, - 9472, - 10240, - 11008, - 11264, - 13824, - 14336, - 14784, - 14848, - 15360, - 18944, - 22016, - 22528, - 24576, - 27392, - 27648, - 29568, - 29696, - 32000, - 32256, - 32512, - 32768, - 33024, - 36864, - 43264, - 49152, - 49408, - 60544, - 60672, - 64000, - 64256, - 102400, - 102656, - 128000, - 128256, -] -H2 = [64] + H2 -R = [1, 2, 4] -SEED = [0xabcdabcd987] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - - -@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) -@pytest.mark.parametrize("h1", H1) -@pytest.mark.parametrize("r", R) -@pytest.mark.parametrize("seed", SEED) -@torch.inference_mode() -def test_lora_a_extra_shapes(dtype_str, h1, r, seed): - torch.manual_seed(seed) - num_loras = 4 - num_layers = 1 - bs = 32 - dtype = getattr(torch, dtype_str) - device = torch.device("cuda") - - wa_T_all = torch.randn(num_loras, - num_layers, - r, - h1, - dtype=dtype, - device=device) - indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) - - for layer_idx in range(num_layers): - x = torch.randn(bs, h1, dtype=dtype, device=device) - y = torch.randn(bs, r, dtype=dtype, device=device) - - y_ref = y.clone() - _lora_ref_impl( - y_ref, - x, - wa_T_all, - None, - indices, - layer_idx, - 1.0, - ) - - y_our = y.clone() - punica.bgmv(y_our, x, wa_T_all, indices, layer_idx, 1.0) - - assert_close(y_ref, y_our) - - -@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) -@pytest.mark.parametrize("h1", H1) -@pytest.mark.parametrize("h2", H2) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_lora_correctness(dtype_str, h1, h2, seed, device): - torch.manual_seed(seed) - num_loras = 4 - num_layers = 1 - r = 8 - bs = 32 - scale = 0.123 - dtype = getattr(torch, dtype_str) - torch.set_default_device(device) - - wa_T_all = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) - wb_T_all = torch.randn(num_loras, num_layers, h2, r, dtype=dtype) - indices = torch.randint(num_loras, (bs, ), dtype=torch.long) - - for layer_idx in range(num_layers): - x = torch.randn(bs, h1, dtype=dtype) - y = torch.randn(bs, h2, dtype=dtype) - - y_ref = y.clone() - _lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale) - - y_our = y.clone() - punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx, - scale) - - assert_close(y_ref, y_our) - - -@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) -@pytest.mark.parametrize("h1", H1) -@pytest.mark.parametrize("h2", H2) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_lora_correctness_slice(dtype_str, h1, h2, seed, device): - if h2 % 3 != 0 or h2 // 3 not in H1: - pytest.skip("h2 must be divisible by 3 and in supported shapes") - torch.manual_seed(seed) - num_loras = 4 - num_layers = 1 - r = 8 - bs = 32 - scale = 0.123 - dtype = getattr(torch, dtype_str) - torch.set_default_device(device) - - wa_T_all_0 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) - wa_T_all_1 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) - wa_T_all_2 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype) - wb_T_all_0 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype) - wb_T_all_1 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype) - wb_T_all_2 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype) - - indices = torch.randint(num_loras, (bs, ), dtype=torch.long) - - for layer_idx in range(num_layers): - x = torch.randn(bs, h1, dtype=dtype) - y = torch.randn(bs, h2, dtype=dtype) - s = h2 // 3 - - y_ref = y.clone() - _lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices, - layer_idx, scale) - _lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices, - layer_idx, scale) - _lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices, - layer_idx, scale) - - y_our = y.clone() - punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices, - layer_idx, scale, 0, s) - punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices, - layer_idx, scale, s, s) - punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices, - layer_idx, scale, s * 2, s) - - assert_close(y_ref[:, :s], y_our[:, :s]) - assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2]) - assert_close(y_ref[:, s * 2:], y_our[:, s * 2:]) diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py new file mode 100644 index 000000000000..c052568dc2e3 --- /dev/null +++ b/tests/lora/test_punica_sizes.py @@ -0,0 +1,408 @@ +""" +This script is mainly used to tests various hidden_sizes. We have collected the +hidden_sizes included in the LoRA models currently supported by vLLM. It tests +whether the corresponding Triton kernel can run normally when tensor parallelism +is set to [1, 2, 4, 8, 16, 32, 64]. +""" +import random +from unittest.mock import patch + +import pytest +import torch + +from vllm.lora.ops.bgmv_expand import bgmv_expand +from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice +from vllm.lora.ops.bgmv_shrink import bgmv_shrink +from vllm.lora.ops.sgmv_expand import sgmv_expand +from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice +from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.triton_utils.libentry import LibEntry + +from .utils import (generate_data, generate_data_for_expand_nslices, + ref_torch_groupgemm) + +HIDDEN_SIZES = [ + 128, + 256, + 512, + 896, + 1024, + 1152, + 1216, + 1280, + 1536, + 1664, + 2048, + 2240, + 2304, + 2368, + 2432, + 2560, + 2752, + 3072, + 3328, + 3456, + 3584, + 3712, + 4096, + 4480, + 4608, + 4736, + 4864, + 5120, + 5504, + 5632, + 5888, + 6144, + 6400, + 6848, + 6912, + 7168, + 7424, + 8192, + 8960, + 9216, + 9472, + 10240, + 11008, + 11264, + 13824, + 14336, + 14784, + 14848, + 15360, + 18944, + 22016, + 22528, + 24576, + 27392, + 27648, + 29568, + 29696, + 32000, + 32256, + 32512, + 32768, + 33024, + 36864, + 43264, + 49152, + 49408, + 60544, + 60672, + 64000, + 64256, + 102400, + 102656, + 128000, + 128256, +] +#The size of TP +divisibility = [1, 2, 4, 8, 16, 32, 64] + +all_hidden_size = [] +for div in divisibility: + for hidden_size in HIDDEN_SIZES: + all_hidden_size.append(hidden_size // div) + +HIDDEN_SIZES = list(set(all_hidden_size)) + +BATCHES = [4] +NUM_LORA = [4] +DTYPES = [torch.float16, torch.bfloat16] +MAX_RANKS = [32] +SCALES = [0.5] +SEED = [0] +CUDA_DEVICES = [f"cuda:{0}"] + + +def assert_close(a, b): + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[a.dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_punica_sgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + scaling: float, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + random.seed(seed) + torch.set_default_device(device) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + seq_length = 128 + ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + op_type, + device, + ) + max_seq_length = seq_len_tensor.max() + if isinstance(max_seq_length, tuple): + max_seq_length = max_seq_length[0].item() + else: + max_seq_length = max_seq_length.item() + if op_type == "shrink": + sgmv_shrink( + inputs_tensor, + lora_weights, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + scaling, + ) + else: + sgmv_expand( + inputs_tensor, + lora_weights, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + add_inputs=True, + ) + ref_torch_groupgemm( + ref_out_tensor, + inputs_tensor, + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + scaling if op_type == "shrink" else 1.0, + op_type, + ) + if op_type == "shrink": + ref_out_tensor = ref_out_tensor.to(torch.float32) + assert_close(our_out_tensor, ref_out_tensor) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_punica_bgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + scaling: float, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel + from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel + + random.seed(seed) + torch.set_default_device(device) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + seq_length = 1 + ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + op_type, + device, + ) + if op_type == "shrink": + # The current _bgmv_shrink_kernel does not require the libentry + # decoration. The purpose of adding this patch is to test the + # correctness of libentry. + with patch( + "vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel", + LibEntry(_bgmv_shrink_kernel), + ): + bgmv_shrink( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + scaling, + ) + else: + # ditto + with patch( + "vllm.lora.ops.bgmv_expand._bgmv_expand_kernel", + LibEntry(_bgmv_expand_kernel), + ): + bgmv_expand( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + add_inputs=True, + ) + ref_torch_groupgemm( + ref_out_tensor, + inputs_tensor, + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + scaling if op_type == "shrink" else 1.0, + op_type, + ) + if op_type == "shrink": + ref_out_tensor = ref_out_tensor.to(torch.float32) + assert_close(our_out_tensor, ref_out_tensor) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("nslices", [2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_punica_expand_nslices( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel + + random.seed(seed) + torch.set_default_device(device) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + seq_length = 128 if op_type == "sgmv" else 1 + ( + inputs_tensor, + lora_weights_lst, + our_outputs, + ref_outputs, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data_for_expand_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + nslices, + device, + ) + max_seq_length = seq_len_tensor.max() + if isinstance(max_seq_length, tuple): + max_seq_length = max_seq_length[0].item() + else: + max_seq_length = max_seq_length.item() + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + if op_type == "sgmv": + sgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + slice_offset, + hidden_size, + add_inputs=True, + ) + else: + # The current _bgmv_expand_slice_kernel does not require the + # libentry decoration. The purpose of adding this patch is to test + # the correctness of libentry. + with patch( + "vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel", + LibEntry(_bgmv_expand_slice_kernel), + ): + bgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) + ref_torch_groupgemm( + ref_outputs[:, slice_offset:slice_offset + hidden_size], + inputs_tensor, + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + 1.0, + op_type="expand", + ) + + slice_offset += hidden_size + assert_close(our_outputs, ref_outputs) diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py new file mode 100644 index 000000000000..7e73ea67ee5f --- /dev/null +++ b/tests/lora/test_punica_variation.py @@ -0,0 +1,342 @@ +""" +This script is mainly used to test whether trtion kernels can run normally +under different conditions, including various batches, numbers of LoRA , and +maximum ranks. +""" +import random +from unittest.mock import patch + +import pytest +import torch + +from vllm.lora.ops.bgmv_expand import bgmv_expand +from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice +from vllm.lora.ops.bgmv_shrink import bgmv_shrink +from vllm.lora.ops.sgmv_expand import sgmv_expand +from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice +from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.triton_utils.libentry import LibEntry + +from .utils import (generate_data, generate_data_for_expand_nslices, + ref_torch_groupgemm) + +HIDDEN_SIZES = [3424, 4096, 4097] + +BATCHES = [1, 4, 16, 32] +NUM_LORA = [1, 4, 8, 16, 32, 64, 128] +DTYPES = [torch.float16, torch.bfloat16] +MAX_RANKS = [1, 4, 8, 16, 32, 64, 128] +SCALES = [0.5] +SEED = [0] +CUDA_DEVICES = [f"cuda:{0}"] + + +def assert_close(a, b): + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[a.dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_punica_sgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + scaling: float, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + random.seed(seed) + torch.set_default_device(device) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + seq_length = 128 + ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + op_type, + device, + ) + max_seq_length = seq_len_tensor.max() + if isinstance(max_seq_length, tuple): + max_seq_length = max_seq_length[0].item() + else: + max_seq_length = max_seq_length.item() + if op_type == "shrink": + sgmv_shrink( + inputs_tensor, + lora_weights, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + scaling, + ) + else: + sgmv_expand( + inputs_tensor, + lora_weights, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + add_inputs=True, + ) + ref_torch_groupgemm( + ref_out_tensor, + inputs_tensor, + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + scaling if op_type == "shrink" else 1.0, + op_type, + ) + if op_type == "shrink": + ref_out_tensor = ref_out_tensor.to(torch.float32) + assert_close(our_out_tensor, ref_out_tensor) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_punica_bgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + scaling: float, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel + from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel + + random.seed(seed) + torch.set_default_device(device) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + seq_length = 1 + ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + op_type, + device, + ) + if op_type == "shrink": + # The current _bgmv_shrink_kernel does not require the libentry + # decoration. The purpose of adding this patch is to test the + # correctness of libentry. + with patch( + "vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel", + LibEntry(_bgmv_shrink_kernel), + ): + bgmv_shrink( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + scaling, + ) + else: + # ditto + with patch( + "vllm.lora.ops.bgmv_expand._bgmv_expand_kernel", + LibEntry(_bgmv_expand_kernel), + ): + bgmv_expand( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + add_inputs=True, + ) + ref_torch_groupgemm( + ref_out_tensor, + inputs_tensor, + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + scaling if op_type == "shrink" else 1.0, + op_type, + ) + if op_type == "shrink": + ref_out_tensor = ref_out_tensor.to(torch.float32) + assert_close(our_out_tensor, ref_out_tensor) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("nslices", [2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_punica_expand_nslices( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel + + random.seed(seed) + torch.set_default_device(device) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + seq_length = 128 if op_type == "sgmv" else 1 + ( + inputs_tensor, + lora_weights_lst, + our_outputs, + ref_outputs, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data_for_expand_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + nslices, + device, + ) + max_seq_length = seq_len_tensor.max() + if isinstance(max_seq_length, tuple): + max_seq_length = max_seq_length[0].item() + else: + max_seq_length = max_seq_length.item() + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + if op_type == "sgmv": + sgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + slice_offset, + hidden_size, + add_inputs=True, + ) + else: + # The current _bgmv_expand_slice_kernel does not require the + # libentry decoration. The purpose of adding this patch is to test + # the correctness of libentry. + with patch( + "vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel", + LibEntry(_bgmv_expand_slice_kernel), + ): + bgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) + ref_torch_groupgemm( + ref_outputs[:, slice_offset:slice_offset + hidden_size], + inputs_tensor, + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + 1.0, + op_type="expand", + ) + + slice_offset += hidden_size + assert_close(our_outputs, ref_outputs) + + +if __name__ == "__main__": + from itertools import product + + lst = list( + product( + BATCHES, + NUM_LORA, + MAX_RANKS, + [1.0], + [torch.float16], + ["expand"], + SEED, + CUDA_DEVICES, + )) + for ele in lst: + test_punica_bgmv(*ele) + print(f"{ele},pass") diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 8fd968c69e58..2370c693e953 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -64,14 +64,16 @@ def test_quant_model_lora(tinyllama_lora_files, model, tp_size): # if torch.cuda.device_count() < tp_size: # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") - llm = vllm.LLM(model=model.model_path, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_model_len=400, - tensor_parallel_size=tp_size, - quantization=model.quantization, - trust_remote_code=True) + llm = vllm.LLM( + model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_model_len=400, + tensor_parallel_size=tp_size, + gpu_memory_utilization=0.2, #avoid OOM + quantization=model.quantization, + trust_remote_code=True) if model.quantization is None: expected_no_lora_output = [ @@ -156,24 +158,28 @@ def test_quant_model_tp_equality(tinyllama_lora_files, model): # if torch.cuda.device_count() < 2: # pytest.skip(f"Not enough GPUs for tensor parallelism {2}") - llm_tp1 = vllm.LLM(model=model.model_path, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=1, - quantization=model.quantization, - trust_remote_code=True) + llm_tp1 = vllm.LLM( + model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=1, + gpu_memory_utilization=0.2, #avoid OOM + quantization=model.quantization, + trust_remote_code=True) output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) del llm_tp1 cleanup() - llm_tp2 = vllm.LLM(model=model.model_path, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=2, - quantization=model.quantization) + llm_tp2 = vllm.LLM( + model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=2, + gpu_memory_utilization=0.2, #avoid OOM + quantization=model.quantization) output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) del llm_tp2 diff --git a/tests/lora/utils.py b/tests/lora/utils.py index b73cf5bf5532..00f8e26d1041 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -86,3 +86,151 @@ def init_packed_lora( packed_lora = PackedLoRALayerWeights.pack(base_loras) self.set_module_lora(module_name, packed_lora) return packed_lora + + +def assert_close(a, b): + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[a.dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +def ref_torch_groupgemm( + out_tensor, + inputs, + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + scaling, + op_type, +) -> torch.Tensor: + out_list = [] + current_offset = 0 + for lora_index, b_length in zip(range(batches), seq_len_tensor): + input_weight = inputs[current_offset:b_length + current_offset, :] + current_offset += b_length + lora_weight = lora_weights[lora_indices_tensor[lora_index]] + result = torch.nn.functional.linear(input_weight, lora_weight) + result *= scaling + out_list.append(result) + cat_result = torch.cat(out_list, dim=0) + if op_type == "expand": + out_tensor += cat_result + else: + out_tensor.copy_(cat_result) + return + + +def generate_data(batches, hidden_size, lora_nums, max_rank, seq_length, dtype, + op_type, device): + seq_len_tensor = torch.randint(seq_length, seq_length + 1, + (batches, )).to(device) + b_seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), + dim=0, + ).to(device) + total_tokens = seq_len_tensor.sum() + if op_type == "shrink": + inputs_tensor = torch.rand((total_tokens, hidden_size), + dtype=dtype).to(device) + lora_weights = torch.rand( + (lora_nums, max_rank, hidden_size), # col-major + dtype=dtype, + ).to(device) + # shrink op need atomic_add, so output is initinized by 0 + ref_out_tensor = torch.zeros((total_tokens, max_rank), + dtype=dtype, + device=inputs_tensor.device) + # NOTE shrink kernel using torch.float32 as output type + our_out_tensor = torch.zeros((total_tokens, max_rank), + dtype=torch.float32).to(device) + else: + inputs_tensor = torch.rand( + (total_tokens, max_rank), + dtype=dtype, + ).to(device) + lora_weights = torch.rand( + (lora_nums, hidden_size, max_rank), # col-major + dtype=dtype, + ).to(device) + # expand op needs to complete y+=a@lora_b, so output is + # initinized randomly + ref_out_tensor = torch.rand( + (total_tokens, hidden_size), + dtype=dtype, + ).to(device) + # Ensure the same input. + our_out_tensor = ref_out_tensor.clone() + lora_indices_tensor = torch.randint(0, + lora_nums - 1 if lora_nums > 1 else 1, + (batches, )).to(device) + indices = torch.zeros((total_tokens), dtype=torch.long).to(device) + current_offset = 0 + for b_id in range(batches): + lora_index = lora_indices_tensor[b_id] + indices[current_offset:current_offset + + seq_len_tensor[b_id]].copy_(lora_index) + current_offset += seq_len_tensor[b_id].item() + return ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) + + +def generate_data_for_expand_nslices(batches, hidden_size, lora_nums, max_rank, + seq_length, dtype, nslices, device): + seq_len_tensor = torch.randint(seq_length, seq_length + 1, + (batches, )).to(device) + b_seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), + dim=0, + ).to(device) + total_tokens = seq_len_tensor.sum() + inputs_tensor = torch.rand( + (total_tokens, max_rank), + dtype=dtype, + ).to(device) + lora_weights_lst = [] + for _ in range(nslices): + lora_weights_lst.append( + torch.rand( + (lora_nums, hidden_size, max_rank), # col-major + dtype=dtype, + ).to(device)) + # expand op needs to complete y+=a@lora_b, so output is + # initinized randomly + ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices), + dtype=dtype).to(device) + # Ensure the same input. + our_out_tensor = ref_out_tensor.clone() + lora_indices_tensor = torch.randint(0, + lora_nums - 1 if lora_nums > 1 else 1, + (batches, )) + indices = torch.zeros((total_tokens), dtype=torch.long).to(device) + current_offset = 0 + for b_id in range(batches): + lora_index = lora_indices_tensor[b_id] + indices[current_offset:current_offset + + seq_len_tensor[b_id]] = lora_index.item() + current_offset += seq_len_tensor[b_id].item() + + lora_indices_tensor = lora_indices_tensor.to(device) + return ( + inputs_tensor, + lora_weights_lst, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9e09b9a32eab..6cd77f75cae8 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -13,12 +13,9 @@ except ImportError as e: logger.warning("Failed to import from vllm._C with %r", e) -with contextlib.suppress(ImportError): - import vllm._moe_C - with contextlib.suppress(ImportError): # ruff: noqa: F401 - import vllm._punica_C + import vllm._moe_C def is_custom_op_supported(op_name: str) -> bool: @@ -519,43 +516,6 @@ def register_graph_buffers(fa: int, handles: List[str], torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) -# punica -def dispatch_bgmv( - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - indicies: torch.Tensor, - layer_idx: int, - scale: float, -) -> None: - torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, - scale) - - -def dispatch_bgmv_low_level( - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - indicies: torch.Tensor, - layer_idx: int, - scale: float, - h_in: int, - h_out: int, - y_offset: int, -) -> None: - torch.ops._punica_C.dispatch_bgmv_low_level( - y, - x, - w_t_all, - indicies, - layer_idx, - scale, - h_in, - h_out, - y_offset, - ) - - # temporary fix for https://github.com/vllm-project/vllm/issues/5456 # TODO: remove this in v0.6.0 names_and_values = globals() diff --git a/vllm/envs.py b/vllm/envs.py index aef7ac385ec6..9bcb26f8e5a6 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -45,7 +45,6 @@ MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None VLLM_USE_PRECOMPILED: bool = False - VLLM_INSTALL_PUNICA_KERNELS: bool = False VLLM_NO_DEPRECATION_WARNING: bool = False CMAKE_BUILD_TYPE: Optional[str] = None VERBOSE: bool = False @@ -94,10 +93,6 @@ def get_default_config_root(): "VLLM_USE_PRECOMPILED": lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")), - # If set, vllm will install Punica kernels - "VLLM_INSTALL_PUNICA_KERNELS": - lambda: bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))), - # CMake build type # If not set, defaults to "Debug" or "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo" diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index d27171f72083..a7887a048746 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -14,7 +14,6 @@ MergedQKVParallelLinearWithLora, QKVParallelLinearWithLora, RowParallelLinearWithLoRA) -from vllm.lora.punica import bgmv, dispatch_bgmv_low_level if TYPE_CHECKING: pass @@ -28,7 +27,7 @@ def _fully_sharded_can_replace(can_replace): def dec(*args, **kwargs): return (can_replace(*args, **kwargs) - and kwargs['lora_config'].fully_sharded_loras) + and kwargs["lora_config"].fully_sharded_loras) return dec @@ -59,25 +58,30 @@ def apply(self, x: torch.Tensor, x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape - buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), - dtype=torch.float32, - device=x.device) - - bgmv(buffer, x, self.lora_a_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + buffer = torch.zeros( + (x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device, + ) + self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) buffer = tensor_model_parallel_all_gather(buffer) - bgmv(output, buffer, self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + self.punica_wrapper.add_expand(output, + buffer, + self.lora_b_stacked, + add_input=True) # now have column partitioned output - output = output.view(*out_orig_shape) return output @classmethod @_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( source_layer=source_layer, @@ -88,14 +92,14 @@ def can_replace_layer(cls, source_layer: nn.Module, ) -def _mcp_apply(x, bias, layer): +def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora): """ - MergedColumnParallelLinearWithShardedLoRA and - MergedQKVParallelLinearWithShardedLora share the same + MergedColumnParallelLinearWithShardedLoRA and + MergedQKVParallelLinearWithShardedLora share the same LoRa weight application method. The main difference is the step by shard_size for lora_b which can - vary for MergedQKVParallelLinearWithShardedLora but is constant for + vary for MergedQKVParallelLinearWithShardedLora but is constant for MergedColumnParallelLinearWithShardedLoRA. """ # expecting 2 for column parallel and 3 for qkv @@ -104,21 +108,27 @@ def _mcp_apply(x, bias, layer): x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape - buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]), - dtype=torch.float32, - device=x.device) + buffers = torch.zeros( + (n, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) for idx in range(n): - bgmv(buffers[idx], x, layer.lora_a_stacked[idx], - layer.indices[:layer.indices_len[0]], 0, 1.0) + layer.punica_wrapper.add_shrink(buffers[idx], x, + layer.lora_a_stacked[idx], 1.0) buffers = tensor_model_parallel_all_gather(buffers) left_offset = 0 for idx in range(n): shard_size = layer.lora_b_stacked[idx].shape[2] - dispatch_bgmv_low_level(output, buffers[idx], - layer.lora_b_stacked[idx], - layer.indices[:layer.indices_len[0]], 0, 1.0, - left_offset, shard_size) + layer.punica_wrapper.add_expand_slice( + output, + buffers[idx], + layer.lora_b_stacked[idx], + left_offset, + shard_size, + add_input=True, + ) left_offset += shard_size output = output.view(*out_orig_shape) @@ -129,7 +139,7 @@ def _mcp_apply(x, bias, layer): class MergedColumnParallelLinearWithShardedLoRA( MergedColumnParallelLinearWithLoRA): """ - Differs from MergedColumnParallelLinearWithLoRA by slicing the + Differs from MergedColumnParallelLinearWithLoRA by slicing the LoRA A's also. Based on S-LoRA, slicing happens along the rank dim. @@ -145,7 +155,8 @@ def slice_lora_a( lora_a = [ lora_a[0][:, output_start_idx:output_start_idx + output_shard_size], - lora_a[1][:, output_start_idx:output_start_idx + output_shard_size] + lora_a[1][:, + output_start_idx:output_start_idx + output_shard_size], ] return lora_a @@ -155,9 +166,13 @@ def apply(self, x: torch.Tensor, @classmethod @_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( source_layer=source_layer, @@ -170,7 +185,7 @@ def can_replace_layer(cls, source_layer: nn.Module, class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora): """ - Differs from QKVParallelLinearWithLora by slicing the + Differs from QKVParallelLinearWithLora by slicing the LoRA A's also. Based on S-LoRA, slicing happens along the rank dim. @@ -193,14 +208,13 @@ def apply(self, x: torch.Tensor, buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), dtype=torch.float32, device=x.device) - - bgmv(buffer, x, self.lora_a_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) buffer = tensor_model_parallel_all_gather(buffer) - bgmv(output, buffer, self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + self.punica_wrapper.add_expand(output, + buffer, + self.lora_b_stacked, + add_input=True) # now have column partitioned output - output = output.view(*out_orig_shape) return output @@ -237,7 +251,7 @@ def slice_lora_a( lora_a = [ lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]], lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]], - lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]] + lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]], ] return lora_a @@ -247,9 +261,13 @@ def apply(self, x: torch.Tensor, @classmethod @_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( source_layer=source_layer, @@ -262,11 +280,11 @@ def can_replace_layer(cls, source_layer: nn.Module, class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): """ - Differs from RowParallelLinearWithLoRA by slicing the + Differs from RowParallelLinearWithLoRA by slicing the LoRA B's also. Based on S-LoRA, slicing happens along the output dim. - This yields a combined partial sum from the row parallel base + This yields a combined partial sum from the row parallel base layer and column partitioned output from the LoRA. """ @@ -283,11 +301,13 @@ def apply(self, x: torch.Tensor) -> torch.Tensor: x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape - buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), - dtype=torch.float32, - device=x.device) - bgmv(buffer, x, self.lora_a_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + buffer = torch.zeros( + (x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device, + ) + + self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) buffer = tensor_model_parallel_all_reduce(buffer) # following S-LoRA, allows the fusing of all_gather and all_reduce @@ -298,18 +318,21 @@ def apply(self, x: torch.Tensor) -> torch.Tensor: # reduced before being used shard_size = self.lora_b_stacked.shape[2] start_idx = self.tp_rank * shard_size - dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0, - start_idx, shard_size) - + self.punica_wrapper.add_expand_slice(output, buffer, + self.lora_b_stacked, start_idx, + shard_size) output = output.view(*out_orig_shape) return output @classmethod @_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( source_layer=source_layer, diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 87de285a373a..3176badabbc7 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -17,7 +17,7 @@ tensor_model_parallel_all_reduce, tensor_model_parallel_gather) from vllm.distributed.utils import divide -from vllm.lora.punica import add_lora, add_lora_slice, bgmv +from vllm.lora.punica import PunicaWrapper from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, @@ -55,88 +55,17 @@ def _not_fully_sharded_can_replace(can_replace): """ def dec(*args, **kwargs): - decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True - condition = (not kwargs['lora_config'].fully_sharded_loras + decorate = kwargs.pop("decorate") if "decorate" in kwargs else True + condition = (not kwargs["lora_config"].fully_sharded_loras if decorate else True) return can_replace(*args, **kwargs) and condition return dec -def _apply_lora( - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - indices: torch.Tensor, - output: torch.Tensor, -): - """Applies lora to each input. - - This method applies all loras to each input. It uses the - indices vector to determine which lora yields the - correct output. An index of -1 means no lora should be - applied. This method adds the final lora results to the - output. - - Input shapes: - x: (batch_size, hidden_dim) - lora_a_stacked: (num_loras, lora_rank, hidden_dim) - lora_b_stacked: (num_loras, output_dim, lora_rank) - indices: (batch_size) - output: (batch_size, output_dim) - """ - org_output = output - x = x.view(-1, x.shape[-1]) - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) - return output.view_as(org_output) - - -def _apply_lora_packed_nslice( - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - indices: torch.Tensor, - output: torch.Tensor, - output_slices: Tuple[int, ...], -): - """Applies lora to each input. - - This method applies all loras to each input. It uses the - indices vector to determine which lora yields the - correct output. An index of -1 means no lora should be - applied. This method adds the final lora results to the - output. - - This method is used for layers that are composed of multiple sublayers - (slices) packed together. - - Input shapes: - x: (batch_size, hidden_dim) - lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim) - lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) - indices: (batch_size) - output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices - """ - org_output = output - x = x.view(-1, x.shape[-1]) - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - offset_left = 0 - for slice_idx in range(len(output_slices)): - add_lora_slice(output, x, lora_a_stacked[slice_idx], - lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, - output_slices[slice_idx]) - offset_left += output_slices[slice_idx] - return output.view_as(org_output) - - @dataclass class LoRAMapping(AdapterMapping): - pass + is_prefill: bool = False class BaseLayerWithLoRA(nn.Module): @@ -154,10 +83,11 @@ def slice_lora_b( ... def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: """Initializes lora matrices.""" ... @@ -177,20 +107,18 @@ def set_lora( def set_mapping( self, - base_indices: torch.Tensor, - sampler_indices: torch.Tensor, - sampler_indices_padded: torch.Tensor, - embeddings_indices: torch.Tensor, - long_lora_indices: torch.Tensor, - indices_len: List[int], + punica_wrapper: PunicaWrapper, ): - """Sets the mapping indices.""" - ... + self.punica_wrapper: PunicaWrapper = punica_wrapper @classmethod - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: """Returns True if the layer can be replaced by this LoRA layer.""" raise NotImplementedError @@ -259,10 +187,6 @@ def create_lora_weights( self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], self.lora_a_stacked.shape[2], ) - # Lazily initialized. - self.indices: torch.Tensor - self.indices_len: List[int] - self.embeddings_indices: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 @@ -285,40 +209,27 @@ def set_lora( if embeddings_tensor is not None: self.embeddings_tensors[ index, :embeddings_tensor.shape[0], :embeddings_tensor. - shape[1]].copy_(embeddings_tensor, non_blocking=True) + shape[1], ].copy_(embeddings_tensor, non_blocking=True) if self.embeddings_slice is not None: # TODO(yard1): Optimize this copy, we don't need to copy # everything, just the modified part embeddings = self.embeddings_tensors.view( self.embeddings_tensors.shape[0] * self.embeddings_tensors.shape[1], - self.embeddings_tensors.shape[2] + self.embeddings_tensors.shape[2], )[self.embeddings_slice[0]:self.embeddings_slice[1]] assert self.embeddings_weights is not None self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) - def set_mapping( - self, - base_indices: torch.Tensor, - sampler_indices: torch.Tensor, - sampler_indices_padded: torch.Tensor, - embeddings_indices: torch.Tensor, - long_lora_indices: torch.Tensor, - indices_len: List[int], - ): - self.indices = base_indices - self.embeddings_indices = embeddings_indices - self.indices_len = indices_len - def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = x > self.base_layer.org_vocab_size - 1 - embedding_len = self.indices_len[3] - indices = self.embeddings_indices[1][:embedding_len].view_as(x) + embeddings_indices = self.punica_wrapper.embeddings_indices + indices = embeddings_indices[1].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, self.lora_a_stacked_2d, ) - indices = self.embeddings_indices[0][:embedding_len].view_as(x) + indices = embeddings_indices[0].view_as(x) full_output = self.base_layer.forward( x.add_(indices * added_tokens_mask)) @@ -329,22 +240,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if full_lora_a_embeddings.ndim == 3: full_lora_a_embeddings = full_lora_a_embeddings.view( full_lora_a_embeddings.shape[0] * - full_lora_a_embeddings.shape[1], -1) - bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, - self.indices[:self.indices_len[0]], 0, 1.0) + full_lora_a_embeddings.shape[1], + -1, + ) + + # Embedding layer only need expand op + self.punica_wrapper.add_expand(full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) return full_output.view_as(full_output_org) @classmethod - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: return type(source_layer) is VocabParallelEmbedding class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): """ LoRA on top of ColumnParallelLinear layer. - + LoRA B is sliced for tensor parallelism. """ @@ -357,10 +278,11 @@ def __init__(self, base_layer: ColumnParallelLinear) -> None: self.device = _get_lora_device(self.base_layer) def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: self.lora_config = lora_config self.tp_size = get_tensor_model_parallel_world_size() lora_a_output_size_per_partition = ( @@ -384,10 +306,6 @@ def create_lora_weights( ) self.output_dim = self.lora_b_stacked.shape[2] - # lazily initialized. - self.indices: torch.Tensor - self.indices_len: List[int] - def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 @@ -423,28 +341,11 @@ def set_lora( 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) - def set_mapping( - self, - base_indices: torch.Tensor, - sampler_indices: torch.Tensor, - sampler_indices_padded: torch.Tensor, - embeddings_indices: torch.Tensor, - long_lora_indices: torch.Tensor, - indices_len: List[int], - ): - self.indices = base_indices - self.indices_len = indices_len - def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - _apply_lora( - x, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[0]], - output, - ) + self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, + self.lora_b_stacked, 1.0) return output def forward(self, input_): @@ -473,9 +374,13 @@ def forward(self, input_): @classmethod @_not_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: return type(source_layer) is ColumnParallelLinear or ( type(source_layer) is MergedColumnParallelLinear and len(packed_modules_list) == 1) @@ -494,10 +399,11 @@ def __init__(self, base_layer: MergedColumnParallelLinear) -> None: super().__init__(base_layer) def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: self.lora_config = lora_config n_slices = 2 if not (len(self.base_layer.output_sizes) == n_slices @@ -533,8 +439,6 @@ def create_lora_weights( ) for _ in range(n_slices)) self.output_dim = self.lora_b_stacked[0].shape[2] - # Lazily initialized. - self.indices: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[0][index] = 0 @@ -556,7 +460,8 @@ def slice_lora_b( start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size lora_b = [ - lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx] + lora_b[0][:, start_idx:end_idx], + lora_b[1][:, start_idx:end_idx], ] return lora_b @@ -591,34 +496,33 @@ def set_lora( def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - _apply_lora_packed_nslice( - x, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[0]], - output, - (self.output_dim, self.output_dim), - ) + self.punica_wrapper.add_lora_packed_nslice( + output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, + (self.output_dim, self.output_dim)) return output @classmethod @_not_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: - return type(source_layer) is MergedColumnParallelLinear and len( - packed_modules_list) == 2 + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 2) class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): """ - ColumnParallelLinear layer that is specifically designed for - qkv_proj. Certain models, such as chtglm3 and baichuan-7b, - only contains a single LoRA within their qkv_proj layer. + ColumnParallelLinear layer that is specifically designed for + qkv_proj. Certain models, such as chtglm3 and baichuan-7b, + only contains a single LoRA within their qkv_proj layer. - During inference with Tensor Parallel, the weights of lora_b + During inference with Tensor Parallel, the weights of lora_b must be accurately partitioned according to the respective ranks. - + Q slice may have different shape than K and V slices (which both have the same shape). """ @@ -696,10 +600,11 @@ def __init__(self, base_layer: QKVParallelLinear) -> None: super().__init__(base_layer) def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: self.lora_config = lora_config self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() @@ -767,11 +672,15 @@ def create_lora_weights( ), ) - self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size, - self.kv_proj_shard_size) + self.output_slices = ( + self.q_proj_shard_size, + self.kv_proj_shard_size, + self.kv_proj_shard_size, + ) self.packed_indices: Optional[torch.Tensor] = None self.standard_indices: Optional[torch.Tensor] = None # lazily initialized. + self.indices: torch.Tensor self.indices_len: List[int] def reset_lora(self, index: int): @@ -794,15 +703,15 @@ def slice_lora_b( if lora_b[0] is not None: lora_b_q = lora_b[0][:, self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] + (self.q_shard_id + 1), ] if lora_b[1] is not None: lora_b_k = lora_b[1][:, self.kv_proj_shard_size * self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] + (self.kv_shard_id + 1), ] if lora_b[2] is not None: lora_b_v = lora_b[2][:, self.kv_proj_shard_size * self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] + (self.kv_shard_id + 1), ] lora_b = [lora_b_q, lora_b_k, lora_b_v] return lora_b @@ -851,23 +760,23 @@ def set_lora( def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - _apply_lora_packed_nslice( - x, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[0]], - output, - self.output_slices, - ) + self.punica_wrapper.add_lora_packed_nslice(output, x, + self.lora_a_stacked, + self.lora_b_stacked, 1.0, + self.output_slices) return output @classmethod @_not_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: - return type(source_layer) is QKVParallelLinear and len( - packed_modules_list) == 3 + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is QKVParallelLinear + and len(packed_modules_list) == 3) class RowParallelLinearWithLoRA(BaseLayerWithLoRA): @@ -880,10 +789,11 @@ def __init__(self, base_layer: RowParallelLinear) -> None: self.device = _get_lora_device(self.base_layer) def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: self.lora_config = lora_config self.tp_rank = get_tensor_model_parallel_rank() self.lora_a_stacked = torch.zeros( @@ -911,9 +821,6 @@ def create_lora_weights( dtype=lora_config.lora_dtype, device=self.device, ) - # Lazily initialized - self.indices: torch.Tensor - self.indices_len: List[int] def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 @@ -950,27 +857,10 @@ def set_lora( 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) - def set_mapping( - self, - base_indices: torch.Tensor, - sampler_indices: torch.Tensor, - sampler_indices_padded: torch.Tensor, - embeddings_indices: torch.Tensor, - long_lora_indices: torch.Tensor, - indices_len: List[int], - ): - self.indices = base_indices - self.indices_len = indices_len - def apply(self, x: torch.Tensor) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x) - _apply_lora( - x, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[0]], - output, - ) + self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, + self.lora_b_stacked, 1.0) return output def forward(self, input_): @@ -1013,14 +903,18 @@ def forward(self, input_): @property def weight(self): - return self.base_layer.weight if hasattr( - self.base_layer, "weight") else self.base_layer.qweight + return (self.base_layer.weight if hasattr(self.base_layer, "weight") + else self.base_layer.qweight) @classmethod @_not_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: return type(source_layer) is RowParallelLinear @@ -1125,10 +1019,6 @@ def create_lora_weights( dtype=torch.long) else: self.sharded_to_full_mapping_gpu = None - # Lazily initialized. - self.indices: torch.Tensor - self.indices_len: List[int] - self.indices_padded: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 @@ -1154,19 +1044,6 @@ def set_lora( index, :embeddings_tensor.shape[0], :embeddings_tensor. shape[1], ] = embeddings_tensor - def set_mapping( - self, - base_indices: torch.Tensor, - sampler_indices: torch.Tensor, - sampler_indices_padded: torch.Tensor, - embeddings_indices: torch.Tensor, - long_lora_indices: torch.Tensor, - indices_len: List[int], - ): - self.indices = sampler_indices - self.indices_padded = sampler_indices_padded - self.indices_len = indices_len - def _get_logits( self, hidden_states: torch.Tensor, @@ -1212,38 +1089,37 @@ def _get_logits( out=lora_logits[:-1]) lora_logits[-1] = float("-inf") lora_logits = lora_logits.mT + indices_padded = self.punica_wrapper.sampler_indices_padded lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], - ).index_select(0, - self.indices_padded[:self.indices_len[2]]).nan_to_num_( - nan=float("-inf"), - posinf=float("inf"), - neginf=float("-inf"))) + ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), + posinf=float("inf"), + neginf=float("-inf"))) logits[:, self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + - lora_logits.shape[1]] = lora_logits - - _apply_lora( - hidden_states, - self.lora_a_stacked, - self.lora_b_stacked, - self.indices[:self.indices_len[1]], - logits, - ) + lora_logits.shape[1], ] = lora_logits + + # LogitsProcessorWithLoRA always using bgmv + self.punica_wrapper.add_lora_logits(logits, hidden_states, + self.lora_a_stacked, + self.lora_b_stacked, 1.0) # Remove paddings in vocab (if any). logits = logits[:, :self.base_layer.vocab_size] - return logits def forward(self, *args, **kwargs): return type(self.base_layer).forward(self, *args, **kwargs) @classmethod - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: # Special handling for the LogitsProcessor. return False @@ -1259,9 +1135,6 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): def __init__(self, base_layer: RotaryEmbedding) -> None: super().__init__() self.base_layer = base_layer - # Lazily initialized - self.long_lora_indices: torch.Tensor - self.indices_len: List[int] @property def scaling_factors(self): @@ -1277,9 +1150,8 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: - scaling_factors = list( - lora_config.long_lora_scaling_factors - ) if lora_config.long_lora_scaling_factors else [] + scaling_factors = (list(lora_config.long_lora_scaling_factors) + if lora_config.long_lora_scaling_factors else []) base_scaling_factor = (self.base_layer.scaling_factor if isinstance( self.base_layer, LinearScalingRotaryEmbedding) else 1.0) scaling_factors = sorted( @@ -1306,18 +1178,6 @@ def set_lora( ): ... - def set_mapping( - self, - base_indices: torch.Tensor, - sampler_indices: torch.Tensor, - sampler_indices_padded: torch.Tensor, - embeddings_indices: torch.Tensor, - long_lora_indices: torch.Tensor, - indices_len: List[int], - ): - self.long_lora_indices = long_lora_indices - self.indices_len = indices_len - def forward( self, positions: torch.Tensor, @@ -1328,19 +1188,24 @@ def forward( positions, query, key, - offsets=self.long_lora_indices[:self.indices_len[4]]) + offsets=self.punica_wrapper.long_lora_indices, + ) @property def scaling_factor_to_offset(self) -> Dict[float, int]: return self.base_layer.scaling_factor_to_offset @classmethod - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: """Returns True if the layer can be replaced by this LoRA layer.""" - return type(source_layer) is LinearScalingRotaryEmbedding or type( - source_layer) is RotaryEmbedding + return (type(source_layer) is LinearScalingRotaryEmbedding + or type(source_layer) is RotaryEmbedding) def extra_repr(self) -> str: return self.base_layer.extra_repr() diff --git a/vllm/lora/models.py b/vllm/lora/models.py index e1ede7d4d710..017a1002bb9a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,7 +4,7 @@ import os import re from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type import safetensors.torch import torch @@ -21,6 +21,7 @@ LinearScalingRotaryEmbeddingWithLora, LoRAMapping) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.punica import PunicaWrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models.interfaces import SupportsLoRA @@ -43,115 +44,6 @@ class LongContextLoRAContext: offsets_by_lora_id: Dict[int, int] = field(default_factory=dict) -def convert_mapping( - mapping: LoRAMapping, - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional[LongContextLoRAContext] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], List[int]]: - """Converts LoRAMapping to index tensors. - - Args: - mapping: LoRAMapping mapping rows in a batch to LoRA ids. - lora_index_to_id: List mapping LoRA ids to LoRA indices. - max_loras: Maximum number of LoRAs. - vocab_size: Model vocab size. - extra_vocab_size: Extra vocab size each LoRA can have. - long_lora_context: Passed if there are long context lora in a batch. - - Returns: - A tuple of tensors: - base_indices: Tensor of shape [batch_size] mapping batch rows to - LoRA indices. - sampler_indices: Tensor of shape [batch_size] mapping requests to - LoRA indices for sampler. For generation, this will be the - same as base_indicies. For prefill, this will map requests - to LoRA indices. - sampler_indices_padded: Tensor of shape [batch_size] mapping - requests to LoRA indices for sampler with padding. - Same as sampler_indicies, but -1 is replaced with - max_loras. - embeddings_indices: Tensor of shape [2, batch_size] mapping - requests to embedding indices. First row is for embeddings - added by the LoRAs, second row is for the LoRA.lora_a - embeddings. - long_lora_indices: Tensor of shape [batch_size] mapping - requests to RoPE offsets and rot dims for long LoRAs. - None if long context lora doesn't exist. - indices_len: List of lengths of the above tensors. - Used to index into each tensor. It contains length for - (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_indices). If long_lora doesn't - exist, it only contains first 4 entries. - """ - index_mapping_indices: List[int] = list(mapping.index_mapping).copy() - embedding_indices = index_mapping_indices.copy() - lora_indices = index_mapping_indices.copy() - long_lora_offsets: Optional[torch.Tensor] = None - if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device="cuda", - dtype=torch.long) - prompt_mapping: List[int] = [ - lora_index_to_id.index(x) if x > 0 else -1 - for x in mapping.prompt_mapping - ] - lora_idx = None - for i in range(len(index_mapping_indices)): - # TODO index can be slow. optimize - lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) - if index_mapping_indices[i] > 0 else -1) - embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 - lora_indices[i] = lora_idx - if long_lora_context: - assert long_lora_offsets is not None - lora_offset: int = long_lora_context.offsets_by_lora_id.get( - index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset - - indices_list: List[Union[List[int], torch.Tensor]] = [ - index_mapping_indices, lora_indices, embedding_indices - ] - if long_lora_context: - assert long_lora_offsets is not None - indices_list.append(long_lora_offsets) - indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") - prompt_mapping_tensor = torch.tensor(prompt_mapping, - device="cuda", - dtype=torch.long) - embeddings_indices = torch.stack([ - indices[2] * extra_vocab_size, - indices[2] * (vocab_size + extra_vocab_size) - ]) - embeddings_indices[embeddings_indices == -1] = max_loras - 1 - base_indices = indices[1] - sampler_indices = prompt_mapping_tensor - sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 - sampler_indices_padded = ( - torch.arange( - 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + - (sampler_indices_padded * len(sampler_indices_padded))) - long_lora_indices = None - long_lora_indices_len: Optional[int] = None - if long_lora_context: - long_lora_indices = indices[3] - long_lora_indices_len = long_lora_indices.shape[-1] - # Contain length of indices tensors. Used to index into each tensor. - indices_len = [ - base_indices.shape[-1], sampler_indices.shape[-1], - sampler_indices_padded.shape[-1], embeddings_indices.shape[-1] - ] - if long_lora_indices_len is not None: - indices_len.append(long_lora_indices_len) - - return (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_indices, indices_len) - - def get_lora_id(): global _GLOBAL_LORA_ID _GLOBAL_LORA_ID += 1 @@ -422,29 +314,12 @@ def __init__( self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size self.long_lora_context: Optional[LongContextLoRAContext] = None - self.base_indices = torch.empty(self.max_num_batched_tokens, - dtype=torch.long, - device="cuda") - self.sampler_indices = torch.empty(self.max_num_batched_tokens, - dtype=torch.long, - device="cuda") - self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens, - dtype=torch.long, - device="cuda") - self.embeddings_indices = torch.empty(2, - self.max_num_batched_tokens, - dtype=torch.long, - device="cuda") - self.long_lora_indices = torch.empty(self.max_num_batched_tokens, - dtype=torch.long, - device="cuda") + self.punica_wrapper = PunicaWrapper(max_num_batched_tokens, + max_batches=self.max_num_seqs, + device="cuda") # Scaling factor -> offset to the sin_cos_cache to it. # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} - # 4 is the number of indicies tensors defined above - # base_indices, sampler_indices, sampler_indices_padded, - # embeddings_indices - self.indices_len: List[Optional[int]] = [None] * 4 super().__init__(model) if hasattr(self.model, "supported_lora_modules"): self.supported_lora_modules = copy.deepcopy( @@ -536,28 +411,16 @@ def pin_adapter(self, lora_id: int) -> bool: "Pinning is not supported in LoRAModelManager." "Use LRUCacheLoRAModelManager for pinning") # type: ignore - # TODO see if this can be vectorized def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: - (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_offsets_tensor, - indices_len) = convert_mapping(mapping, self.lora_index_to_id, - self.lora_slots + 1, self.vocab_size, - self.lora_config.lora_extra_vocab_size, - self.long_lora_context) - self.base_indices[:base_indices.shape[0]].copy_(base_indices) - self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) - self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( - sampler_indices_padded) - self.embeddings_indices[:embeddings_indices. - shape[0], :embeddings_indices.shape[1]].copy_( - embeddings_indices) - if long_lora_offsets_tensor is not None: - self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( - long_lora_offsets_tensor) - else: - self.long_lora_indices.zero_() - # Maintain the reference - self.indices_len[:] = indices_len + # update lora states + self.punica_wrapper.update_metadata( + mapping, + self.lora_index_to_id, + self.lora_slots + 1, + self.vocab_size, + self.lora_config.lora_extra_vocab_size, + self.long_lora_context, + ) def remove_all_adapters(self): """Remove all LoRAModels from the manager.""" @@ -595,10 +458,8 @@ def _create_lora_modules(self): self.model.config)) self.register_module(module_name, new_module) self._register_packed_modules(module_name) - new_module.set_mapping(self.base_indices, self.sampler_indices, - self.sampler_indices_padded, - self.embeddings_indices, - self.long_lora_indices, self.indices_len) + # All lora layers share the same punica_wrapper based on reference. + new_module.set_mapping(self.punica_wrapper) def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): assert isinstance(module, BaseLayerWithLoRA) diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py new file mode 100644 index 000000000000..dcaf2e3d462c --- /dev/null +++ b/vllm/lora/ops/bgmv_expand.py @@ -0,0 +1,169 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Dict, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_N: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's + performance + """ + pid_sn = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + offset_k = tl.arange(0, BLOCK_K) + offset_n = tl.arange(0, BLOCK_N) + if EVEN_K: + tiled_a = tl.load(input_ptr + cur_batch * xm_stride + + offset_k * xk_stride, ) # [BLOCK_K] + else: + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + mask=offset_k < K, + other=0, + ) # [BLOCK_K] + # N must be divisible by SPLIT_N + split_n_length = tl.cdiv(N, SPLIT_N) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + # sliding to next row-block + b_ptr = (lora_ptr + l0_stride * lora_index + + pid_sn * split_n_length * lora_k_stride) + c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + for n in range(0, split_n_length, BLOCK_N): + current_n = n + offset_n + current_n_c = tl.max_contiguous(current_n, BLOCK_N) + b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] + < K) + c_mask = current_n < split_n_length + tiled_b = tl.load( + b_ptr + current_n_c[:, None] * lora_k_stride + + offset_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + if ADD_INPUTS: + tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out + else: + accumulator = tl.sum(tiled_a * tiled_b, 1) + + tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) + + +@torch.inference_mode() +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, + override_config: Optional[Dict[str, int]] = None, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch, An index of -1 means no lora should be + applied. + batches (int): batch size + add_inputs (bool, optional): Defaults to False. adds the final lora + results to the output. + override_config (Optional[Dict[str, int]], optional): Defaults to None. + Triton grid config + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_K = triton.next_power_of_2(K) + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + batches = lora_indices_tensor.size(0) + if override_config: + config = override_config + else: + config = get_lora_op_configs("expand", batches, N) + grid = lambda META: ( + META["SPLIT_N"], + batches, + ) + _bgmv_expand_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_K=BLOCK_K, + EVEN_K=EVEN_K, + ADD_INPUTS=ADD_INPUTS, + CAST_TYPE=CAST_TYPE, + **config, + ) + return diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py new file mode 100644 index 000000000000..fa6571074f3a --- /dev/null +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -0,0 +1,182 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Dict, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_expand_slice_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + slice_offset, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_N: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's + performance + """ + pid_sn = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + offset_k = tl.arange(0, BLOCK_K) + offset_n = tl.arange(0, BLOCK_N) + if EVEN_K: + tiled_a = tl.load(input_ptr + cur_batch * xm_stride + + offset_k * xk_stride, ) # [BLOCK_K] + else: + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + mask=offset_k < K, + other=0, + ) # [BLOCK_K] + # N must be divisible by SPLIT_N + split_n_length = tl.cdiv(N, SPLIT_N) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + # sliding to next row-block + b_ptr = (lora_ptr + l0_stride * lora_index + + pid_sn * split_n_length * lora_k_stride) + c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + + slice_offset * cn_stride) + + for n in range(0, split_n_length, BLOCK_N): + current_n = n + offset_n + b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] + < K) + c_mask = current_n < split_n_length + tiled_b = tl.load( + b_ptr + current_n[:, None] * lora_k_stride + + offset_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + + if ADD_INPUTS: + tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out + else: + accumulator = tl.sum(tiled_a * tiled_b, 1) + + tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) + + +@torch.inference_mode() +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, + override_config: Optional[Dict[str, int]] = None, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'b weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch, An index of -1 means no lora should be + applied. + slice_offst (int): output_tensor's offst + slice_size (int): current output_tensor's size + batches (int): batch size + add_inputs (bool, optional): Defaults to False. + override_config (Optional[Dict[str, int]], optional): Defaults to None. + Triton grid config + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + + assert slice_size == lora_b_weights.size(-2) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_K = triton.next_power_of_2(K) + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + batches = lora_indices_tensor.size(0) + + if override_config: + config = override_config + else: + config = get_lora_op_configs("expand", batches, N) + + grid = lambda META: ( + META["SPLIT_N"], + batches, + ) + _bgmv_expand_slice_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + slice_offset, + BLOCK_K=BLOCK_K, + EVEN_K=EVEN_K, + ADD_INPUTS=ADD_INPUTS, + CAST_TYPE=CAST_TYPE, + **config, + ) + return diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/bgmv_shrink.py new file mode 100644 index 000000000000..e69d33078f5a --- /dev/null +++ b/vllm/lora/ops/bgmv_shrink.py @@ -0,0 +1,150 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Dict, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + scaling, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's + performance + """ + pid_sk = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + + offset_n = tl.arange(0, BLOCK_N) + offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K + a_ptr = input_ptr + cur_batch * xm_stride + b_ptr = lora_ptr + l0_stride * lora_index + accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32) + for k in range(0, K, BLOCK_K * SPLIT_K): + current_k = k + offset_k + current_k_c = tl.max_contiguous(current_k, BLOCK_K) + tiled_a = tl.load( + a_ptr + current_k_c, + mask=current_k < K, + other=0.0, + ) # [BLOCK_K] + b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K) + + tiled_b = tl.load( + b_ptr + offset_n[:, None] * lora_k_stride + + current_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + + accumulator += tl.sum(tiled_a * tiled_b, 1) + accumulator *= scaling + offset_cn = tl.arange(0, BLOCK_N) + c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride + c_mask = offset_cn < N + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) + + +@torch.inference_mode() +def bgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, + override_config: Optional[Dict[str, int]] = None, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + scaling (float): Scaling factor. + override_config (Optional[Dict[str, int]], optional): Defaults to None. + Triton grid config + """ + assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + assert lora_a_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_a_weights.size(-1) + assert inputs.is_contiguous() + + if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) + assert lora_a_weights.size(1) == 1 + lora_a_weights = lora_a_weights.squeeze(dim=1) + else: + assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) + assert lora_a_weights.is_contiguous() + assert output_tensor.is_contiguous() + # TODO tuning this config + batches = lora_indices_tensor.size(0) + N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank + BLOCK_N = triton.next_power_of_2(N) + if override_config: + config = override_config + else: + # First try to load optimal config from the file + config = get_lora_op_configs("bgmv_shrink", batches, K) + + grid = lambda META: ( + META["SPLIT_K"], + batches, + ) + _bgmv_shrink_kernel[grid]( + inputs, + lora_a_weights, + output_tensor, + N, + K, + lora_indices_tensor, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_a_weights.stride(0), + lora_a_weights.stride(1), + lora_a_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_N=BLOCK_N, + **config, + ) + return diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py new file mode 100644 index 000000000000..459049546909 --- /dev/null +++ b/vllm/lora/ops/sgmv_expand.py @@ -0,0 +1,192 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from vllm.triton_utils import libentry + + +@libentry() +@triton.jit +def _sgmv_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + The sgmv's expand triton kernel is based on GroupGEMM. + """ + pid = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = tl.arange(0, BLOCK_K) + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + + offset_k[None, :] * xk_stride, ) + b_ptr = (lora_ptr + l0_stride * lora_index + + offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_n_stride + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + M = tl.load(seq_lens + cur_batch) + c_mask = (offset_cm[:, None] < + (cur_seq_start + M)) & (offset_cn[None, :] < N) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def sgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + add_inputs: bool = False, +): + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g.,if the sequence length is [4, 6], it is + [0, 4, 10]. + seq_len_tensor (torch.Tensor): (batch_size,). record the sequence + length of the sequences in the batch + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences + in the batch + add_inputs (bool, optional): Defaults to False. adds the final lora + results to the output. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_M = 32 + BLOCK_N = 32 + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + batches, + ) + _sgmv_expand_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + ) + return diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py new file mode 100644 index 000000000000..ff3bcda071b8 --- /dev/null +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -0,0 +1,205 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from vllm.triton_utils import libentry + + +@libentry() +@triton.jit +def _sgmv_expand_slice_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + slice_offset, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + + Similar to the 'sgmv_expand' operator, but with an added parameter + 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator + might be that in the future, we could implement a fusion operator to + achieve the current functionality instead of having to call it multiple + times. + """ + pid = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = tl.arange(0, BLOCK_K) + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + + offset_k[None, :] * xk_stride, ) + b_ptr = (lora_ptr + l0_stride * lora_index + + offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_n_stride + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + M = tl.load(seq_lens + cur_batch) + c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < + (slice_offset + N)) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def sgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +): + """_summary_ + + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g.,if the sequence length is [4, 6], it is + [0, 4, 10]. + seq_len_tensor (torch.Tensor): (batch_size,). record the sequence + length of the sequences in the batch + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences + in the batch + slice_offst (int): output_tensor's offst + slice_size (int): current output_tensor's size + add_inputs (bool, optional): Defaults to False. adds the final lora + results to the output.. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert slice_size == lora_b_weights.size(-2) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + + BLOCK_M = 32 + BLOCK_N = 32 + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + batches, + ) + _sgmv_expand_slice_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + slice_offset, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + ) + return diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py new file mode 100644 index 000000000000..8ab049989abe --- /dev/null +++ b/vllm/lora/ops/sgmv_shrink.py @@ -0,0 +1,189 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from vllm.triton_utils import libentry + + +@libentry() +@triton.jit +def _sgmv_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + scaling, + xm_stride, # hidden_size + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + """ + The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. + The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, + introducing SPLIT-K can improve performance + """ + pid = tl.program_id(axis=0) + pid_sk = tl.program_id(axis=1) + cur_batch = tl.program_id(axis=2) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) + + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + + offset_k[None, :] * xk_stride) + b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride + + offset_k[:, None] * lora_n_stride) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < k_remaining, + other=0.0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < k_remaining, + other=0.0) + accumulator += tl.dot(tiled_a, tiled_b) + + a_ptr += BLOCK_K * SPLIT_K * xk_stride + b_ptr += BLOCK_K * SPLIT_K * lora_n_stride + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + c_mask = (offset_cm[:, None] < + (cur_seq_start + M)) & (offset_cn[None, :] < N) + accumulator *= scaling + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) + + +@torch.inference_mode() +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + scaling: float, +): + """ + + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g.,if the sequence length is [4, 6], it is + [0, 4]. + seq_len_tensor (torch.Tensor): (batch_size,). record the sequence + length of the sequences in the batch + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences + in the batch + scaling (float): Scaling factor. + """ + assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + assert lora_a_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_a_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert inputs.is_contiguous() + + if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) + assert lora_a_weights.size(1) == 1 + lora_a_weights = lora_a_weights.squeeze(dim=1) + else: + assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) + assert lora_a_weights.is_contiguous() + assert output_tensor.is_contiguous() + # TODO tuning this config + N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank + BLOCK_M = 32 + BLOCK_N = 16 + BLOCK_K = 32 + SPLIT_K = 8 + EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + SPLIT_K, + batches, + ) + + _sgmv_shrink_kernel[grid]( + inputs, + lora_a_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_a_weights.stride(0), + lora_a_weights.stride(1), + lora_a_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + ) + return diff --git a/vllm/lora/ops/utils.py b/vllm/lora/ops/utils.py new file mode 100644 index 000000000000..7c3e27313ad9 --- /dev/null +++ b/vllm/lora/ops/utils.py @@ -0,0 +1,46 @@ +import functools +from typing import Dict + + +@functools.lru_cache +def _get_op_configs(op_type: str, batch: int, hidden_size: int): + # TODO: add optimal configurations + return None + + +def _check_divisibility(hidden_size: int): + # The bgmv_expand kernel requires that the hidden_size be divisible by + # the number below. + divisibility = [2, 4, 8, 16, 32, 64] + divisibility.sort(reverse=True) + for div in divisibility: + if hidden_size % div == 0: + return div + # hidden_size is an odd number + return 1 + + +def _get_default_config(op_type: str, batch: int, hidden_size: int): + if op_type == "expand": + return { + "BLOCK_N": 256, + "SPLIT_N": _check_divisibility(hidden_size), + "num_warps": 8 + } + else: + return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8} + + +def get_lora_op_configs(op_type: str, batch: int, + hidden_size: int) -> Dict[str, int]: + """Inspired by `fused_moe_kernel` + The return value will be a dictionary mapping an irregular grid of batch + sizes and hidden_size to configurations of the bgmv-related kernel. + NOTE: It currently only supports the default configuration. We plan to + generate optimal configurations for different hardware in the future using + scripts similar to `benchmark_moe.py`. + """ + config = _get_op_configs(op_type, batch, hidden_size) + if not config: + config = _get_default_config(op_type, batch, hidden_size) + return config diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 64f87a4b2c69..6d5c83429996 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -1,207 +1,604 @@ -# Based on code from https://github.com/punica-ai/punica +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" -from typing import Optional +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union import torch -from vllm import _custom_ops as ops -from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON +if HAS_TRITON: + from vllm.lora.ops.bgmv_expand import bgmv_expand + from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice + from vllm.lora.ops.bgmv_shrink import bgmv_shrink + from vllm.lora.ops.sgmv_expand import sgmv_expand + from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice + from vllm.lora.ops.sgmv_shrink import sgmv_shrink -def _check_punica_support(): - if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"): - return +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext - if current_platform.get_device_capability() < (8, 0): - raise ImportError( - "punica LoRA kernels require compute capability >= 8.0") - else: - raise ImportError( - "punica LoRA kernels could not be imported. If you built vLLM " - "from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var " - "was set.") - - -def bgmv( - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, - scale: float, -): - """ - Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) - Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight - matrices. - indicies: Shape: `[B]`. Indices of the weight matrices. - layer_idx: Layer index of the weight matrices. - scale: Scaling factor. +def compute_meta( + token_lora_tensor: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]: + """ + Get the information required for the sgmv kernel. With the features: + 1. If consecutive requests in the batch use the same LoRA, this function + will combine them into a single request, improving sgmv kernel inference + performance. + 2. At the beginning of each prefill stage inference, recalculations are + needed based on the input, but only once. """ - _check_punica_support() - ops.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( + token_lora_tensor, return_counts=True) + cum_result = torch.cumsum(seq_length_tensor, dim=0) + b_seq_start_tensor = torch.zeros_like(seq_length_tensor) + b_seq_start_tensor[1:].copy_(cum_result[:-1]) + max_length = seq_length_tensor.max().item() + batch_size = lora_indices_tensor.size(0) + no_lora = False + # -1 means no lora should be applied. Use `no_lora` to determine whether + # the current step requires LoRA. If LoRA is not needed, the prefill stage + # does not need to launch the triton kernel, which can improve performance + if batch_size == 1 and lora_indices_tensor == -1: + no_lora = True + return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, no_lora) -def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, indicies: torch.LongTensor, - layer_idx: int, scale: float, y_offset: int, - y_slice_size: int): - """ - Same as `bgmv` but you can operate on slices of y. - Pass whole y, define y_offset and y_slice_size. - Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) +# TODO see if this can be vectorized +def convert_mapping( + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], List[int]]: + """Converts LoRAMapping to index tensors. Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of - all of the transposed LoRA matrices. - indicies: Shape: `[B]`. Indices of the LoRA weights. - layer_idx: Layer index of LoRA weights. - scale: Scaling factor. - y_offset: Offset to apply to the starting column of y. - y_slice_size: Size of the y column slice. + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_index_to_id: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. + indices_len: List of lengths of the above tensors. It contains + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). """ - _check_punica_support() - - ops.dispatch_bgmv_low_level( - y, - x, - w_t_all, - indicies, - layer_idx, - scale, - x.size(1), - y_slice_size, - y_offset, - ) + index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device="cuda", + dtype=torch.long) + prompt_mapping: List[int] = [ + lora_index_to_id.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(index_mapping_indices)): + # TODO index can be slow. optimize + lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 + lora_indices[i] = lora_idx + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, + lora_indices, + embedding_indices, + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") + prompt_mapping_tensor = torch.tensor(prompt_mapping, + device="cuda", + dtype=torch.long) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping_tensor + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.arange( + 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + ( + sampler_indices_padded * len(sampler_indices_padded)) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. + indices_len = [ + base_indices.shape[-1], + sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1], + ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) + else: + # If long_lora doesn't exist,append None + indices_len.append(None) -def add_lora(y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, - scale: float, - *, - buffer: Optional[torch.Tensor] = None): - """ - Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) + return ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_indices, + indices_len, + ) - Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed - LoRA A matrices. - wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed - LoRA B matrices. - indicies: Shape: `[B]`. Indices of the LoRA weights. - layer_idx: Layer index of LoRA weights. - scale: Scaling factor. - buffer: Optional. Shape: `[B, R]`. Temporary buffer. + +class PunicaWrapper: """ - _check_punica_support() - - r = wb_t_all.size(-1) - if buffer is None: - # We set the buffer to be float32 by default to avoid - # numerical inaccuracies that would otherwise happen - # due to downcasting. - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) - ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) - ops.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale) - - -def add_lora_slice(y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - indicies: torch.LongTensor, - layer_idx: int, - scale: float, - y_offset: int, - y_slice_size: int, - *, - buffer: Optional[torch.Tensor] = None): + PunicaWrapper is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica kernel. """ - Same as `add_lora` but you can operate on slices of y. - Pass whole y, define y_offset and y_slice_size. - Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: str): + self._token_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices_padded = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._embeddings_indices = torch.empty(2, + max_num_batched_tokens, + dtype=torch.long, + device=device) + self._long_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) - Args: - y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. - x: Shape: `[B, H1]`. Input vectors. - wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed - LoRA A matrices. - wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed - LoRA B matrices. - indicies: Shape: `[B]`. Indices of the LoRA weights. - layer_idx: Layer index of LoRA weights. - scale: Scaling factor. - y_offset: Offset to apply to the starting column of y. - y_slice_size: Size of the y column slice. - """ - _check_punica_support() - - r = wb_t_all.size(-1) - if buffer is None: - # We set the buffer to be float32 by default to avoid - # numerical inaccuracies that would otherwise happen - # due to downcasting. - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) - ops.dispatch_bgmv_low_level( - buffer, - x, - wa_t_all, - indicies, - layer_idx, - 1.0, - x.size(1), - buffer.size(1), - 0, - ) - ops.dispatch_bgmv_low_level( - y, - buffer, - wb_t_all, - indicies, - layer_idx, - scale, - buffer.size(1), - y_slice_size, - y_offset, - ) + # 5 is the number of indicies tensors. + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices,long_lora_indices + self.indices_len: List[Optional[int]] = [None] * 5 + # these attributes are the information required for sgmv kernel + self._seq_start_locs = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._seq_lengths = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._lora_indices_per_batch = torch.empty(max_batches, + dtype=torch.long, + device=device) + self.max_length: int = 0 + self.batch_size: int = -1 + self.is_prefill = False + self.no_lora = False + + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + if mapping.is_prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metada(self.token_lora_indices) + self.is_prefill = True + else: + self.is_prefill = False + + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + long_lora_context, + ) + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + + self.indices_len[:] = indices_len + + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + + (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, no_lora) = compute_meta(token_lora_tensor) + + self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( + b_seq_start_tensor) + self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( + lora_indices_tensor) + self.batch_size = batch_size + self.max_length = max_length + self.no_lora = no_lora + + @property + def prefill_metadata( + self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: + """ + This property provides a convenient way to access the necessary + metadata for prefill-related kernel computations. + 1. seq_start_locs: Tensor of sequence start positions + 2. seq_lengths: Tensor of sequence lengths + 3. lora_indices_per_batch: Tensor of lora indices, and an index of + -1 means no lora should be applied. + 4. batch_size: batch size after clustering identical lora indices + 5. max_length: The maximum sequence length in the batch + """ + return (self._seq_start_locs[:self.batch_size], + self._seq_lengths[:self.batch_size], + self._lora_indices_per_batch[:self.batch_size], + self.batch_size, self.max_length) + + @property + def token_lora_indices(self) -> torch.Tensor: + """ + This property provides the lora indices corresponding to each token + in the batch. An index of -1 means no lora should be applied. + """ + token_lora_len = self.indices_len[0] + return self._token_lora_indices[:token_lora_len] + + @property + def sampler_indices(self) -> torch.Tensor: + """ + This property is used to access the lora indices specifically for + LogitsProcessorWithLoRA + """ + sampler_indices_len = self.indices_len[1] + return self._sampler_indices[:sampler_indices_len] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices + """ + indices_padded_len = self.indices_len[2] + return self._sampler_indices_padded[:indices_padded_len] + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA + """ + embeddings_indices_len = self.indices_len[3] + return self._embeddings_indices[:, :embeddings_indices_len] + + @property + def long_lora_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for long context + lora, specifically for LinearScalingRotaryEmbeddingWithLora + """ + long_lora_len = self.indices_len[4] + return self._long_lora_indices[:long_lora_len] + + def shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_input, + ) + + def expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) + + def expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_input, + ) + + def expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_input) + + def add_shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the shrink_decode function + should be called. + """ + shrink_fun: Callable = (self.shrink_prefill + if self.is_prefill else self.shrink_decode) + shrink_fun(y, x, w_t_all, scale) + + def add_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool = True, + ): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'b. + When `is_prefill` is true, it indicates that it is currently the + prefill stage, and the `expand_prefill` function should be called. + Otherwise, it is the decode stage, and the expand_decode function + should be called. + """ + + expand_fun: Callable = (self.expand_prefill + if self.is_prefill else self.expand_decode) + expand_fun(y, x, w_t_all, add_input) + + def add_expand_slice(self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool = True): + """ + Similar to `add_expand` + """ + + expand_slice_fun: Callable = (self.expand_slice_prefill + if self.is_prefill else + self.expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + + def add_lora(self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale: float, + y_offset: Optional[int] = None, + y_slice_size: Optional[int] = None, + *, + buffer: Optional[torch.Tensor] = None) -> None: + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + wa_t_all (torch.Tensor): lora_a's weight + wb_t_all (torch.Tensor): lora_b's weight + scale (float): Scaling factor. + y_offset (Optional[int], optional): Offset to apply to the starting + column of y. + y_slice_size (Optional[int], optional): Size of the y column slice.. + buffer (Optional[torch.Tensor], optional): Defaults to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + self.add_shrink(buffer, x, wa_t_all, scale) + if y_offset is None and y_slice_size is None: + self.add_expand(y, buffer, wb_t_all, add_input=True) + else: + self.add_expand_slice(y, + buffer, + wb_t_all, + y_offset, + y_slice_size, + add_input=True) + y = y.view_as(y_org) + + def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, + torch.Tensor, + torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, + torch.Tensor, + torch.Tensor], + scale: float, + output_slices: Tuple[int, ...]) -> None: + """ + Applies lora to each input. Similar to add_lora, This method is + used for layers that are composed of multiple sublayers + (slices) packed together. + """ + y_org = y + x = x.view(-1, x.shape[-1]) + y = y.view(-1, y.shape[-1]) + offset_left = 0 + # TODO fuse these kernels + for slice_idx in range(len(output_slices)): + self.add_lora(y, x, lora_a_stacked[slice_idx], + lora_b_stacked[slice_idx], scale, offset_left, + output_slices[slice_idx]) + offset_left += output_slices[slice_idx] + + y = y.view_as(y_org) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None) -> None: + """ + LogitsProcessorWithLoRA always using bgmv + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True) + y = y.view_as(y_org) diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 568185383aa5..3f57c22e1f2e 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -6,5 +6,6 @@ from vllm.triton_utils.custom_cache_manager import ( maybe_set_triton_cache_manager) + from vllm.triton_utils.libentry import libentry - __all__ += ["maybe_set_triton_cache_manager"] + __all__ += ["maybe_set_triton_cache_manager", "libentry"] diff --git a/vllm/triton_utils/libentry.py b/vllm/triton_utils/libentry.py new file mode 100644 index 000000000000..ae00af44a048 --- /dev/null +++ b/vllm/triton_utils/libentry.py @@ -0,0 +1,167 @@ +# Copied From https://github.com/FlagOpen/FlagGems + +import inspect + +import triton + + +class LibEntry(triton.KernelInterface): + + def __init__( + self, + fn, + ): + self.fn = fn + self.arg_names = fn.arg_names + self.divisibility = 16 + self.kernel_cache = dict() + fn = self.fn + while not isinstance(fn, triton.runtime.JITFunction): + fn = fn.fn + self.jit_function: triton.runtime.JITFunction = fn + self.specialize_indices = [ + p.num for p in self.jit_function.params + if not p.is_constexpr and not p.do_not_specialize + ] + self.do_not_specialize_indices = [ + p.num for p in self.jit_function.params + if not p.is_constexpr and p.do_not_specialize + ] + + def key(self, spec_args, dns_args, const_args): + spec_key = [(arg.dtype, arg.data_ptr() % + self.divisibility == 0) if hasattr(arg, "data_ptr") else + (type(arg), arg) for arg in spec_args] + dns_key = [ + arg.dtype if hasattr( + arg, "data_ptr") else type(arg) if not isinstance(arg, int) + else "i32" if -(2**31) <= arg and arg <= 2**31 - + 1 else "u64" if 2**63 <= arg and arg <= 2**64 - 1 else "i64" + for arg in dns_args + ] + # const args passed by position + return tuple(spec_key + dns_key + const_args) + + def run(self, *args, **kwargs): + grid = kwargs["grid"] + # collect all the arguments + spec_args = [] # specialize arguments + dns_args = [] # do not specialize arguments + const_args = [] # constexpr arguments + k_args = [] # kernel arguments + for i, arg in enumerate(args): + if i in self.specialize_indices: + k_args.append(arg) + spec_args.append(arg) + elif i in self.do_not_specialize_indices: + k_args.append(arg) + dns_args.append(arg) + else: + const_args.append(arg) + for p in self.jit_function.params[len(args):]: + if p.name in kwargs: + val = kwargs[p.name] + elif p.default is inspect._empty: + continue + else: + val = p.default + + if p.is_constexpr: + const_args.append(val) + elif p.do_not_specialize: + dns_args.append(val) + k_args.append(val) + else: + spec_args.append(val) + k_args.append(val) + + entry_key = self.key(spec_args, dns_args, const_args) + + if entry_key not in self.kernel_cache: + # compile the kernel also completes the related computations + kernel = self.fn.run(*args, **kwargs) + fn = self.fn + # collect constexpr arguments for grid computation + constexprs = {} + while not isinstance(fn, triton.runtime.JITFunction): + if isinstance(fn, triton.runtime.Autotuner): + config = fn.best_config + constexprs["num_warps"] = config.num_warps + constexprs["num_stages"] = config.num_stages + constexprs["num_ctas"] = config.num_ctas + constexprs = {**constexprs, **config.kwargs} + elif isinstance(fn, triton.runtime.Heuristics): + for v, heur in fn.values.items(): + constexprs[v] = heur({ + **dict(zip(fn.arg_names, args)), + **kwargs, + **constexprs, + }) + else: + raise RuntimeError("Invalid Runtime Function") + fn = fn.fn + # In vLLM, certain kernels like fused_moe_kernel get the + # best_config(as kwargs) from a configuration json file, rather + # than using Autotuner & Heuristics. Therefore, all their constexprs + # (tl.constexpr) are assigned values through the following loop. + for p in self.jit_function.params: + if p.is_constexpr and p.name not in constexprs: + constexprs[p.name] = p.default #default=inspect._empty + self.kernel_cache[entry_key] = (kernel, constexprs) + else: + # load kernel from cache directly + kernel, constexprs = self.kernel_cache[entry_key] + + if callable(grid): + # collect all arguments to the grid fn,ie: + # 1. args, + # 2. kwargs, + # 3. all all other captured arguments in CompiledKernel from + # Autotunner & Heuristics when kwargs & captured args conflict, + # captured args have higher priority + # 4. We must filter out captured args with default value firstly + constexprs = { + k: v + for k, v in constexprs.items() if v is not inspect._empty + } + meta = { + **dict(zip(self.arg_names, args)), + **kwargs, + **constexprs, + } + grid = grid(meta) + if isinstance(grid, tuple): + grid = grid + (1, 1) + elif isinstance(grid, list): + grid = grid + [1, 1] + kernel[grid[0:3]](*k_args) + # maintaining the same return type as the JITFunction.run + return kernel + + +def libentry(): + """ + Decorator for triton library entries. + Motivation: + The runtime overhead of Triton kernels is the reason for the lower + performance of small kernels, particularly evident with smaller models. + Using this decorator can reduce Triton runtime overhead. + How: + The `run` function of JITFunction needs to accomplish: + - Parameter binding using inspect + - KernelArg type wrapping + - Cache key calculation + When dealing with small size, these steps can become bottlenecks in + Triton runtime. Libentry simplifies these steps to reduce runtime + overhead, thereby improving the runtime expenses of small kernels. + NOTE: + When Triton is upgraded to version 3.0.0, libentry can be removed, + see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245 + + + """ + + def decorator(fn): + return LibEntry(fn) + + return decorator diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index de999b11d91b..777344289958 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -578,9 +578,9 @@ def build(self) -> ModelInputForGPU: for inter_data in self.inter_data_list ]) lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) + **dict(index_mapping=lora_index_mapping, + prompt_mapping=lora_prompt_mapping, + is_prefill=not self.decode_only)) # Prompt adapter data. prompt_adapter_requests: Set[PromptAdapterRequest] = set() @@ -1152,9 +1152,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: if self.lora_config: lora_mapping = LoRAMapping( - [0] * batch_size, - [0] * batch_size, - ) + **dict(index_mapping=[0] * batch_size, + prompt_mapping=[0] * batch_size, + is_prefill=False)) self.set_active_loras(set(), lora_mapping) if self.prompt_adapter_config: From 1d2e7fb73f1205ae03e4ee3bcd3de566733bf582 Mon Sep 17 00:00:00 2001 From: xuyi Date: Thu, 1 Aug 2024 09:49:51 +0800 Subject: [PATCH 0012/3246] [Model] Pipeline parallel support for Qwen2 (#6924) --- vllm/config.py | 2 + vllm/model_executor/models/qwen2.py | 57 +++++++++++++++----- vllm/model_executor/models/qwen2_moe.py | 69 +++++++++++++++++++------ 3 files changed, 101 insertions(+), 27 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index de5d0402a1bc..e06574459237 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -40,6 +40,8 @@ "GPT2LMHeadModel", "MixtralForCausalLM", "NemotronForCausalLM", + "Qwen2ForCausalLM", + "Qwen2MoeForCausalLM", ] diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 35fd6f37589a..99fdd993943b 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -30,7 +30,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -49,6 +49,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA +from .utils import is_pp_missing_parameter, make_layers class Qwen2MLP(nn.Module): @@ -227,6 +228,7 @@ def __init__( config: Qwen2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -237,10 +239,14 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - Qwen2DecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen2DecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -255,20 +261,30 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + residual = None else: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -351,6 +367,20 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def sample( self, logits: torch.Tensor, @@ -381,6 +411,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -393,7 +425,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue - + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 2cc2f1440d14..b895788206d1 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -31,7 +31,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_world_size, +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -52,6 +53,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once +from .utils import is_pp_missing_parameter, make_layers + class Qwen2MoeMLP(nn.Module): @@ -315,6 +318,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -324,13 +328,15 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - Qwen2MoeDecoderLayer(config, - layer_idx, - cache_config, - quant_config=quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen2MoeDecoderLayer(config=config, + layer_idx=int( + prefix.split(".")[-1]), + cache_config=cache_config, + quant_config=quant_config), + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -339,14 +345,25 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -380,7 +397,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -389,6 +406,20 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def sample( self, logits: Optional[torch.Tensor], @@ -435,6 +466,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue if name not in params_dict: continue @@ -448,6 +482,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -460,6 +497,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( @@ -474,7 +514,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue else: name = remapped_kv_scale_name - param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From 23993a7997ff927decaca60281871d5fdab11334 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 31 Jul 2024 18:50:28 -0700 Subject: [PATCH 0013/3246] [Bugfix][TPU] Do not use torch.Generator for TPUs (#6981) --- vllm/model_executor/model_loader/weight_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 942215da01af..5e142e8cb8b8 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization import (QuantizationConfig, get_quantization_config) from vllm.model_executor.layers.quantization.schema import QuantParamSchema +from vllm.platforms import current_platform from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -490,6 +491,11 @@ def initialize_dummy_weights( """ for param in model.state_dict().values(): if torch.is_floating_point(param): + if current_platform.is_tpu(): + # XLA device does not support torch.Generator() + param.uniform_(low, high) + continue + generator = torch.Generator(device=param.data.device) generator.manual_seed(seed) if torch.finfo(param.data.dtype).bits < 16: From 630dd9e0aea166085a4c897e21a98ec752954265 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 31 Jul 2024 20:49:11 -0600 Subject: [PATCH 0014/3246] [Bugfix][Model] Skip loading lm_head weights if using tie_word_embeddings (#6758) Signed-off-by: Travis Johnson --- vllm/model_executor/models/chameleon.py | 7 +++++++ vllm/model_executor/models/llama.py | 5 +++++ vllm/model_executor/models/minicpm.py | 6 +++++- vllm/model_executor/models/olmo.py | 5 +++++ 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 7659f598bab9..10a82207d90e 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -998,6 +998,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + use_default_weight_loading = False if "vqmodel" in name: if self.model.vqmodel is not None: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2052c443a888..048c292c79c8 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -469,6 +469,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue if scale_name := get_compressed_tensors_cache_scale(name): # Loading kv cache scales for compressed-tensors quantization param = params_dict[scale_name] diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index b46e88f5fc58..7f8f38fe8439 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -514,7 +514,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 408c0c883a9d..1a0a3774dc8f 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -343,6 +343,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue From 0437492ea97f0650a8b2ca39121be8864625fd70 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Wed, 31 Jul 2024 20:15:42 -0700 Subject: [PATCH 0015/3246] PP comm optimization: replace send with partial send + allgather (#6695) Co-authored-by: Aurick Qiao --- vllm/distributed/parallel_state.py | 38 ++++++++++++++++++++++++++++-- vllm/worker/worker_base.py | 8 ++++--- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index bf7a7de0724a..d7ca8fd82e1a 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -569,7 +569,8 @@ def broadcast_tensor_dict( def send_tensor_dict( self, tensor_dict: Dict[str, Union[torch.Tensor, Any]], - dst: Optional[int] = None + dst: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. @@ -578,6 +579,11 @@ def send_tensor_dict( if not torch.distributed.is_initialized() or self.world_size == 1: return tensor_dict + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + group = self.device_group metadata_group = self.cpu_group @@ -598,6 +604,12 @@ def send_tensor_dict( if tensor.numel() == 0: # Skip sending empty tensors. continue + + # send-allgather: send only a slice, then do allgather. + if (all_gather_group is not None + and tensor.numel() % all_gather_size == 0): + tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] + if tensor.is_cpu: # use metadata_group for CPU tensors torch.distributed.send(tensor, @@ -612,7 +624,8 @@ def send_tensor_dict( def recv_tensor_dict( self, - src: Optional[int] = None + src: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. @@ -621,6 +634,11 @@ def recv_tensor_dict( if not torch.distributed.is_initialized() or self.world_size == 1: return None + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + group = self.device_group metadata_group = self.cpu_group @@ -639,6 +657,16 @@ def recv_tensor_dict( # Skip broadcasting empty tensors. tensor_dict[key] = tensor continue + + # send-allgather: send only a slice, then do allgather. + use_all_gather = (all_gather_group is not None + and tensor.numel() % all_gather_size == 0) + + if use_all_gather: + orig_shape = tensor.shape + tensor = tensor.reshape(all_gather_size, + -1)[all_gather_rank] + if tensor.is_cpu: # use metadata_group for CPU tensors torch.distributed.recv(tensor, @@ -649,6 +677,12 @@ def recv_tensor_dict( torch.distributed.recv(tensor, src=self.ranks[src], group=group) + if use_all_gather: + # do the allgather + tensor = all_gather_group.all_gather( # type: ignore + tensor, dim=0) + tensor = tensor.reshape(orig_shape) + tensor_dict[key] = tensor else: tensor_dict[key] = value diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 03e3857e23c4..8a4d1958c65a 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -6,7 +6,7 @@ import torch -from vllm.distributed import broadcast_tensor_dict, get_pp_group +from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.platforms import current_platform @@ -267,7 +267,8 @@ def execute_model( intermediate_tensors = None if not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( - get_pp_group().recv_tensor_dict()) + get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group())) output = self.model_runner.execute_model( model_input, self.kv_cache[worker_input.virtual_engine] @@ -276,7 +277,8 @@ def execute_model( if not get_pp_group().is_last_rank: # output is IntermediateTensors - get_pp_group().send_tensor_dict(output.tensors) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) return [None] # output is List[SamplerOutput] From 3c10591ef2e78fbb6aa341195c4b24c36ae8b84d Mon Sep 17 00:00:00 2001 From: zifeitong Date: Wed, 31 Jul 2024 21:13:34 -0700 Subject: [PATCH 0016/3246] [Bugfix] Set SamplingParams.max_tokens for OpenAI requests if not provided by user (#6954) --- tests/entrypoints/openai/test_serving_chat.py | 39 +++++++++++++++++++ vllm/entrypoints/openai/protocol.py | 30 ++++++++++---- vllm/entrypoints/openai/serving_chat.py | 23 ++++------- vllm/entrypoints/openai/serving_completion.py | 27 +++++-------- vllm/entrypoints/openai/serving_engine.py | 17 ++++++-- 5 files changed, 92 insertions(+), 44 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 464465494b71..168ba7ba888e 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,7 +1,12 @@ import asyncio +from contextlib import suppress from dataclasses import dataclass +from unittest.mock import MagicMock +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.transformers_utils.tokenizer import get_tokenizer MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" @@ -42,3 +47,37 @@ async def _async_serving_chat_init(): def test_async_serving_chat_init(): serving_completion = asyncio.run(_async_serving_chat_init()) assert serving_completion.chat_template == CHAT_TEMPLATE + + +def test_serving_chat_should_set_correct_max_tokens(): + mock_engine = MagicMock(spec=AsyncLLMEngine) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + + serving_chat = OpenAIServingChat(mock_engine, + MockModelConfig(), + served_model_names=[MODEL_NAME], + response_role="assistant", + chat_template=CHAT_TEMPLATE, + lora_modules=None, + prompt_adapters=None, + request_logger=None) + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + guided_decoding_backend="outlines", + ) + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + # AsyncLLMEngine.generate(inputs, sampling_params, ...) + assert mock_engine.generate.call_args.args[1].max_tokens == 93 + + req.max_tokens = 10 + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].max_tokens == 10 diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 205860aa8e72..3b35ae1ebd70 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -11,7 +11,7 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.utils import random_uuid @@ -215,15 +215,22 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params - def to_sampling_params(self, - tokenizer: PreTrainedTokenizer) -> SamplingParams: - # We now allow logprobs being true without top_logrobs. + def to_sampling_params( + self, tokenizer: PreTrainedTokenizer, + guided_decode_logits_processor: Optional[LogitsProcessor], + default_max_tokens: int) -> SamplingParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + # We now allow logprobs being true without top_logrobs. logits_processors = get_logits_processors( logit_bias=self.logit_bias, allowed_token_ids=None, tokenizer=tokenizer, ) + if guided_decode_logits_processor: + logits_processors.append(guided_decode_logits_processor) return SamplingParams( n=self.n, @@ -241,7 +248,7 @@ def to_sampling_params(self, logprobs=self.top_logprobs if self.logprobs else None, prompt_logprobs=self.top_logprobs if self.echo else None, ignore_eos=self.ignore_eos, - max_tokens=self.max_tokens, + max_tokens=max_tokens, min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, @@ -395,7 +402,14 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params - def to_sampling_params(self, tokenizer: PreTrainedTokenizer): + def to_sampling_params( + self, tokenizer: PreTrainedTokenizer, + guided_decode_logits_processor: Optional[LogitsProcessor], + default_max_tokens: int) -> SamplingParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + echo_without_generation = self.echo and self.max_tokens == 0 logits_processors = get_logits_processors( @@ -403,6 +417,8 @@ def to_sampling_params(self, tokenizer: PreTrainedTokenizer): allowed_token_ids=self.allowed_token_ids, tokenizer=tokenizer, ) + if guided_decode_logits_processor: + logits_processors.append(guided_decode_logits_processor) return SamplingParams( n=self.n, @@ -419,7 +435,7 @@ def to_sampling_params(self, tokenizer: PreTrainedTokenizer): stop_token_ids=self.stop_token_ids, logprobs=self.logprobs, ignore_eos=self.ignore_eos, - max_tokens=self.max_tokens if not echo_without_generation else 1, + max_tokens=max_tokens if not echo_without_generation else 1, min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 01843930bf11..c832cf2a24b5 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -25,8 +25,6 @@ PromptAdapterPath) from vllm.inputs import PromptInputs from vllm.logger import init_logger -from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) from vllm.multimodal import MultiModalDataDict from vllm.outputs import RequestOutput from vllm.sequence import Logprob @@ -134,28 +132,23 @@ async def create_chat_completion( request_id = f"chat-{random_uuid()}" try: - sampling_params = request.to_sampling_params(tokenizer) - decoding_config = await self.engine.get_decoding_config() - guided_decoding_backend = request.guided_decoding_backend \ - or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( - await - get_guided_decoding_logits_processor(guided_decoding_backend, - request, tokenizer)) - if guided_decode_logits_processor: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = [] - sampling_params.logits_processors.append( - guided_decode_logits_processor) + await self._guided_decode_logits_processor(request, tokenizer)) prompt_inputs = self._tokenize_prompt_input( request, tokenizer, prompt, - truncate_prompt_tokens=sampling_params.truncate_prompt_tokens, + truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) + sampling_params = request.to_sampling_params( + tokenizer, + guided_decode_logits_processor, + default_max_tokens=self.max_model_len - + len(prompt_inputs["prompt_token_ids"])) + self._log_inputs(request_id, prompt_inputs, params=sampling_params, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 854835279168..7765c5903f34 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -24,8 +24,6 @@ OpenAIServing, PromptAdapterPath) from vllm.logger import init_logger -from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, @@ -95,31 +93,24 @@ async def create_completion(self, request: CompletionRequest, tokenizer = await self.engine.get_tokenizer(lora_request) - sampling_params = request.to_sampling_params(tokenizer) - decoding_config = await self.engine.get_decoding_config() - guided_decoding_backend = request.guided_decoding_backend \ - or decoding_config.guided_decoding_backend - guided_decode_logit_processor = ( - await - get_guided_decoding_logits_processor(guided_decoding_backend, - request, tokenizer)) - if guided_decode_logit_processor is not None: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = [] - sampling_params.logits_processors.append( - guided_decode_logit_processor) - + guided_decode_logits_processor = ( + await self._guided_decode_logits_processor(request, tokenizer)) prompts = list( self._tokenize_prompt_input_or_inputs( request, tokenizer, request.prompt, - truncate_prompt_tokens=sampling_params. - truncate_prompt_tokens, + truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, )) for i, prompt_inputs in enumerate(prompts): + sampling_params = request.to_sampling_params( + tokenizer, + guided_decode_logits_processor, + default_max_tokens=self.max_model_len - + len(prompt_inputs["prompt_token_ids"])) + request_id_item = f"{request_id}-{i}" self._log_inputs(request_id_item, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index b374a7946b11..8c7929a12e9a 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -25,9 +25,11 @@ from vllm.inputs import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer_group import AnyTokenizer @@ -150,6 +152,15 @@ def create_streaming_error_response( }) return json_str + async def _guided_decode_logits_processor( + self, request: Union[ChatCompletionRequest, CompletionRequest], + tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: + decoding_config = await self.engine.get_decoding_config() + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend + return await get_guided_decoding_logits_processor( + guided_decoding_backend, request, tokenizer) + async def _check_model( self, request: AnyRequest, @@ -254,9 +265,7 @@ def _validate_input( f"{self.max_model_len} tokens. However, you requested " f"{token_num} tokens in the messages, " f"Please reduce the length of the messages.") - request.max_tokens = self.max_model_len - token_num - - if token_num + request.max_tokens > self.max_model_len: + elif token_num + request.max_tokens > self.max_model_len: raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " From c8a7e93273ff4338d6f89f8a63ff16426ac240b8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 31 Jul 2024 23:51:09 -0700 Subject: [PATCH 0017/3246] [core][scheduler] simplify and improve scheduler (#6867) --- tests/core/block/e2e/test_correctness.py | 2 +- tests/core/test_scheduler.py | 163 ++++++++++------------- vllm/core/policy.py | 45 ------- vllm/core/scheduler.py | 116 ++++++---------- 4 files changed, 112 insertions(+), 214 deletions(-) delete mode 100644 vllm/core/policy.py diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 8502eab0f8da..e0dee43f500a 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator, # Allow only 2 sequences of ~128 tokens in worst case. # Note 16 = 128/block_size - "num_gpu_blocks_override": 2 * (16 + 1), + "num_gpu_blocks_override": 2 * (16 + 2), } ]) @pytest.mark.parametrize("baseline_llm_kwargs", [{ diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 4ca2260b5e01..447e8f8a586f 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -1,13 +1,12 @@ import time from collections import deque -from typing import Deque, List, Set, Tuple +from typing import List, Set, Tuple from unittest.mock import MagicMock import pytest # noqa from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus -from vllm.core.policy import PolicyFactory from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.lora.request import LoRARequest from vllm.sequence import Logprob, SequenceGroup, SequenceStatus @@ -348,10 +347,10 @@ def test_prefill_schedule_max_prompt_len(): """ scheduler = initialize_scheduler(max_model_len=30) _, seq_group = create_dummy_prompt("0", prompt_length=60) - waiting = deque([seq_group]) + scheduler.add_seq_group(seq_group) budget = create_token_budget() - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 1 assert len(output.seq_groups) == 0 assert budget.num_batched_tokens == 0 @@ -364,15 +363,14 @@ def test_prefill_schedule_token_budget(): Test token budget respected. """ scheduler = initialize_scheduler() - waiting: Deque[SequenceGroup] = deque() budget = create_token_budget(token_budget=0) for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - waiting.append(seq_group) + scheduler.add_seq_group(seq_group) # 0 token budget == nothing is scheduled. - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 0 assert len(output.seq_groups) == 0 assert budget.num_batched_tokens == 0 @@ -381,8 +379,8 @@ def test_prefill_schedule_token_budget(): # 60 token budget == 1 request scheduled. budget = create_token_budget(token_budget=60) - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 0 assert len(output.seq_groups) == 1 assert budget.num_batched_tokens == 60 @@ -391,14 +389,13 @@ def test_prefill_schedule_token_budget(): # Test when current_batched_tokens respected. scheduler = initialize_scheduler() - waiting = deque() budget = create_token_budget(token_budget=60) add_token_budget(budget, 30, 0) _, seq_group = create_dummy_prompt(str(i), prompt_length=60) # Cannot schedule a prompt that doesn't fit the budget. - waiting.append(seq_group) - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + scheduler.add_seq_group(seq_group) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 0 assert len(output.seq_groups) == 0 assert budget.num_batched_tokens == 30 @@ -406,8 +403,8 @@ def test_prefill_schedule_token_budget(): assert len(remaining_waiting) == 1 budget = create_token_budget(token_budget=90) add_token_budget(budget, 30, 0) - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.seq_groups) == 1 assert budget.num_batched_tokens == 90 assert budget.num_curr_seqs == 1 @@ -419,13 +416,12 @@ def test_prefill_schedule_max_seqs(): Test max seq respected. """ scheduler = initialize_scheduler() - waiting: Deque[SequenceGroup] = deque() budget = create_token_budget(max_num_seqs=2) for i in range(3): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - waiting.append(seq_group) - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + scheduler.add_seq_group(seq_group) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 0 assert len(output.seq_groups) == 2 assert budget.num_batched_tokens == 120 @@ -433,13 +429,13 @@ def test_prefill_schedule_max_seqs(): assert len(remaining_waiting) == 1 # Verify curr_num_seqs respected. - waiting = deque() + scheduler.waiting = deque() budget = create_token_budget(max_num_seqs=2) add_token_budget(budget, 0, 2) _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - waiting.append(seq_group) - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + scheduler.add_seq_group(seq_group) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 0 assert len(output.seq_groups) == 0 assert budget.num_batched_tokens == 0 @@ -453,7 +449,6 @@ def test_prefill_schedule_max_lora(): """ lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) scheduler = initialize_scheduler(lora_config=lora_config) - waiting: Deque[SequenceGroup] = deque() budget = create_token_budget(token_budget=120) curr_loras: Set[int] = set() for i in range(2): @@ -463,7 +458,7 @@ def test_prefill_schedule_max_lora(): lora_name=str(i), lora_int_id=i + 1, lora_path="abc")) - waiting.append(seq_group) + scheduler.add_seq_group(seq_group) # Add two more requests to verify lora is prioritized. # 0: Lora, 1: Lora, 2: regular, 3: regular # In the first iteration, index 0, 2 is scheduled. @@ -471,10 +466,10 @@ def test_prefill_schedule_max_lora(): # prioritized. Verify that. for i in range(2, 4): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - waiting.append(seq_group) + scheduler.add_seq_group(seq_group) # Schedule 2 requests (0 and 2) - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, curr_loras) + output = scheduler._schedule_prefills(budget, curr_loras) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 0 assert len(output.seq_groups) == 2 assert budget.num_batched_tokens == 120 @@ -485,8 +480,8 @@ def test_prefill_schedule_max_lora(): # Reset curr_loras so that it can be scheduled. curr_loras = set() budget = create_token_budget(token_budget=60) - remaining_waiting, output = scheduler._schedule_prefills( - remaining_waiting, budget, curr_loras) + output = scheduler._schedule_prefills(budget, curr_loras) + remaining_waiting = scheduler.waiting assert len(output.seq_groups) == 1 assert output.seq_groups[0].seq_group.request_id == "1" assert len(remaining_waiting) == 1 @@ -499,31 +494,29 @@ def test_prefill_schedule_no_block_manager_capacity(): Test sequence cannot be scheduled due to block manager has no capacity. """ scheduler = initialize_scheduler() - waiting: Deque[SequenceGroup] = deque() budget = create_token_budget() for i in range(3): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - waiting.append(seq_group) + scheduler.add_seq_group(seq_group) scheduler.block_manager.can_allocate = MagicMock() scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER - remainig_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 0 assert len(output.seq_groups) == 0 assert budget.num_batched_tokens == 0 assert budget.num_curr_seqs == 0 - assert len(remainig_waiting) == 3 + assert len(remaining_waiting) == 3 scheduler = initialize_scheduler() - waiting = deque() budget = create_token_budget() for i in range(3): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - waiting.append(seq_group) + scheduler.add_seq_group(seq_group) scheduler.block_manager.can_allocate = MagicMock() scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 3 assert len(output.seq_groups) == 0 assert budget.num_batched_tokens == 0 @@ -536,14 +529,12 @@ def test_decode_schedule_preempted(): Test decodes cannot be scheduled and preempted. """ scheduler = initialize_scheduler() - running: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None for i in range(3): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) - running.append(seq_group) + scheduler._add_seq_group_to_running(seq_group) scheduler.block_manager.can_append_slots = MagicMock() def cannot_append_second_group(seq_group, num_lookahead_slots): @@ -555,8 +546,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): # 1 cannot be scheduled, and the lowest priority (request 2) # should be preempted. 1 will also be preempted. budget = create_token_budget() - remainig_running, output = scheduler._schedule_running( - running, budget, curr_loras, policy) + output = scheduler._schedule_running(budget, curr_loras) + remainig_running = scheduler.running assert len(remainig_running) == 0 assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 @@ -577,14 +568,12 @@ def test_decode_swap_beam_search(): Test best_of > 1 swap out blocks """ scheduler = initialize_scheduler() - running: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None budget = create_token_budget() for i in range(3): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) - running.append(seq_group) + scheduler._add_seq_group_to_running(seq_group) append_new_token_seq_group(60, seq_group, 1) budget.add_num_seqs(seq_group.request_id, seq_group.get_max_num_running_seqs()) @@ -603,8 +592,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): expected_swap_mapping = [("5", "7")] scheduler.block_manager.swap_out.return_value = expected_swap_mapping - remainig_running, output = scheduler._schedule_running( - running, budget, curr_loras, policy) + output = scheduler._schedule_running(budget, curr_loras) + remainig_running = scheduler.running assert len(remainig_running) == 0 assert len(output.decode_seq_groups) == 2 assert len(output.prefill_seq_groups) == 0 @@ -628,20 +617,18 @@ def test_schedule_decode_blocks_to_copy_update(): """ scheduler = initialize_scheduler() _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) - running: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) - running.append(seq_group) + scheduler._add_seq_group_to_running(seq_group) # The last request should be swapped out. scheduler.block_manager.append_slots = MagicMock() scheduler.block_manager.append_slots.return_value = [(2, 3)] budget = create_token_budget() - remaining_running, output = scheduler._schedule_running( - running, budget, curr_loras, policy) + output = scheduler._schedule_running(budget, curr_loras) + remaining_running = scheduler.running assert len(remaining_running) == 0 assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 @@ -656,19 +643,17 @@ def test_schedule_decode_blocks_to_copy_update(): def test_schedule_swapped_simple(): scheduler = initialize_scheduler() - swapped: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) - swapped.append(seq_group) + scheduler._add_seq_group_to_swapped(seq_group) budget = create_token_budget() - remaining_swapped, output = scheduler._schedule_swapped( - swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 0 assert budget.num_batched_tokens == 1 assert budget.num_curr_seqs == 2 @@ -683,8 +668,6 @@ def test_schedule_swapped_simple(): def test_schedule_swapped_max_token_budget(): scheduler = initialize_scheduler() - swapped: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] for _ in range(2): @@ -692,11 +675,11 @@ def test_schedule_swapped_max_token_budget(): scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) - swapped.append(seq_group) + scheduler._add_seq_group_to_swapped(seq_group) budget = create_token_budget(token_budget=1) - remaining_swapped, output = scheduler._schedule_swapped( - swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 1 assert budget.num_batched_tokens == 1 assert budget.num_curr_seqs == 2 @@ -706,8 +689,8 @@ def test_schedule_swapped_max_token_budget(): # Verify num_batched_tokens are respected. budget = create_token_budget(token_budget=1) add_token_budget(budget, 1, 0) - remaining_swapped, output = scheduler._schedule_swapped( - remaining_swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 1 assert budget.num_batched_tokens == 1 assert budget.num_curr_seqs == 0 @@ -717,8 +700,6 @@ def test_schedule_swapped_max_token_budget(): def test_schedule_swapped_max_seqs(): scheduler = initialize_scheduler() - swapped: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] for i in range(4): @@ -726,11 +707,11 @@ def test_schedule_swapped_max_seqs(): scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) - swapped.append(seq_group) + scheduler._add_seq_group_to_swapped(seq_group) budget = create_token_budget(max_num_seqs=2) - remaining_swapped, output = scheduler._schedule_swapped( - swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 2 assert budget.num_batched_tokens == 2 assert budget.num_curr_seqs == 2 @@ -738,8 +719,8 @@ def test_schedule_swapped_max_seqs(): assert len(output.prefill_seq_groups) == 0 # Verify num_curr_seqs are respected. - remaining_swapped, output = scheduler._schedule_swapped( - remaining_swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 2 assert budget.num_batched_tokens == 2 assert budget.num_curr_seqs == 2 @@ -750,8 +731,6 @@ def test_schedule_swapped_max_seqs(): def test_schedule_swapped_max_loras(): lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) scheduler = initialize_scheduler(lora_config=lora_config) - swapped: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras: Set[int] = set() blocks_to_swap_out: List[Tuple[int, int]] = [] for i in range(2): @@ -764,11 +743,11 @@ def test_schedule_swapped_max_loras(): scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) - swapped.append(seq_group) + scheduler._add_seq_group_to_swapped(seq_group) budget = create_token_budget() - remaining_swapped, output = scheduler._schedule_swapped( - swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 1 assert budget.num_batched_tokens == 1 assert budget.num_curr_seqs == 1 @@ -779,8 +758,6 @@ def test_schedule_swapped_max_loras(): def test_schedule_swapped_cannot_swap_in(): scheduler = initialize_scheduler() - swapped: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] for _ in range(2): @@ -788,15 +765,15 @@ def test_schedule_swapped_cannot_swap_in(): scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) - swapped.append(seq_group) + scheduler._add_seq_group_to_swapped(seq_group) # The last request should be swapped out. scheduler.block_manager.can_swap_in = MagicMock() scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER # Since we cannot swap in, none of the requests are swapped in. budget = create_token_budget() - remaining_swapped, output = scheduler._schedule_swapped( - swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 2 assert budget.num_batched_tokens == 0 assert budget.num_curr_seqs == 0 @@ -806,8 +783,6 @@ def test_schedule_swapped_cannot_swap_in(): def test_infeasible_swap(): scheduler = initialize_scheduler() - swapped: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] for _ in range(2): @@ -815,15 +790,15 @@ def test_infeasible_swap(): scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) - swapped.append(seq_group) + scheduler._add_seq_group_to_swapped(seq_group) # The last request should be swapped out. scheduler.block_manager.can_swap_in = MagicMock() scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER # Since we cannot swap in, none of the requests are swapped in. budget = create_token_budget() - remaining_swapped, output = scheduler._schedule_swapped( - swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 0 assert len(output.infeasible_seq_groups) == 2 assert budget.num_batched_tokens == 0 @@ -834,23 +809,21 @@ def test_infeasible_swap(): def test_schedule_swapped_blocks_to_copy(): scheduler = initialize_scheduler() - swapped: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) blocks_to_swap_out: List[Tuple[int, int]] = [] scheduler._swap_out(seq_group, blocks_to_swap_out) - swapped.append(seq_group) + scheduler._add_seq_group_to_swapped(seq_group) # The last request should be swapped out. scheduler.block_manager.append_slots = MagicMock() scheduler.block_manager.append_slots.return_value = [(2, 3)] budget = create_token_budget() - remaining_swapped, output = scheduler._schedule_swapped( - swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 0 assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 diff --git a/vllm/core/policy.py b/vllm/core/policy.py deleted file mode 100644 index a4463ac0f340..000000000000 --- a/vllm/core/policy.py +++ /dev/null @@ -1,45 +0,0 @@ -from collections import deque -from typing import Deque - -from vllm.sequence import SequenceGroup - - -class Policy: - - def get_priority( - self, - now: float, - seq_group: SequenceGroup, - ) -> float: - raise NotImplementedError - - def sort_by_priority( - self, - now: float, - seq_groups: Deque[SequenceGroup], - ) -> Deque[SequenceGroup]: - return deque( - sorted( - seq_groups, - key=lambda seq_group: self.get_priority(now, seq_group), - reverse=True, - )) - - -class FCFS(Policy): - - def get_priority( - self, - now: float, - seq_group: SequenceGroup, - ) -> float: - return now - seq_group.metrics.arrival_time - - -class PolicyFactory: - - _POLICY_REGISTRY = {'fcfs': FCFS} - - @classmethod - def get_policy(cls, policy_name: str, **kwargs) -> Policy: - return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5cdf1d15c31e..11d020be0c94 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -8,7 +8,6 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.core.policy import Policy, PolicyFactory from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest @@ -345,6 +344,16 @@ def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. self.waiting.append(seq_group) + def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the running queue. + # Only for testing purposes. + self.running.append(seq_group) + + def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the swapped queue. + # Only for testing purposes. + self.swapped.append(seq_group) + def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: """Aborts a sequence group with the given ID. @@ -398,32 +407,26 @@ def get_and_reset_finished_requests_ids(self) -> List[str]: def _schedule_running( self, - running_queue: deque, budget: SchedulingBudget, curr_loras: Optional[Set[int]], - policy: Policy, enable_chunking: bool = False, - ) -> Tuple[deque, SchedulerRunningOutputs]: + ) -> SchedulerRunningOutputs: """Schedule sequence groups that are running. Running queue should include decode and chunked prefill requests. Args: - running_queue: The queue that contains running requests (i.e., - decodes). The given arguments are NOT in-place modified. budget: The scheduling budget. The argument is in-place updated when any decodes are preempted. curr_loras: Currently batched lora request ids. The argument is in-place updated when any decodes are preempted. - policy: The sorting policy to sort running_queue. enable_chunking: If True, seq group can be chunked and only a chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. Returns: - A tuple of remaining running queue (should be always 0) after - scheduling and SchedulerRunningOutputs. + SchedulerRunningOutputs. """ # Blocks that need to be swapped or copied before model execution. blocks_to_swap_out: List[Tuple[int, int]] = [] @@ -436,10 +439,9 @@ def _schedule_running( # NOTE(woosuk): Preemption happens only when there is no available slot # to keep all the sequence groups in the RUNNING state. - # In this case, the policy is responsible for deciding which sequence - # groups to preempt. - now = time.time() - running_queue = policy.sort_by_priority(now, running_queue) + + running_queue = self.running + while running_queue: seq_group = running_queue[0] num_running_tokens = self._get_num_new_tokens( @@ -503,7 +505,7 @@ def _schedule_running( if curr_loras is not None and seq_group.lora_int_id > 0: curr_loras.add(seq_group.lora_int_id) - return running_queue, SchedulerRunningOutputs( + return SchedulerRunningOutputs( decode_seq_groups=decode_seq_groups, prefill_seq_groups=prefill_seq_groups, preempted=preempted, @@ -515,12 +517,10 @@ def _schedule_running( def _schedule_swapped( self, - swapped_queue: deque, budget: SchedulingBudget, curr_loras: Optional[Set[int]], - policy: Policy, enable_chunking: bool = False, - ) -> Tuple[deque, SchedulerSwappedInOutputs]: + ) -> SchedulerSwappedInOutputs: """Schedule sequence groups that are swapped out. It schedules swapped requests as long as it fits `budget` and @@ -528,20 +528,16 @@ def _schedule_swapped( `budget` and `curr_loras` are updated based on scheduled seq_groups. Args: - swapped_queue: The queue that contains swapped out requests. - The given arguments are NOT in-place modified. budget: The scheduling budget. The argument is in-place updated when any requests are swapped in. curr_loras: Currently batched lora request ids. The argument is in-place updated when any requests are swapped in. - policy: The sorting policy to sort swapped_queue. enable_chunking: If True, seq group can be chunked and only a chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. Returns: - A tuple of remaining swapped_queue after scheduling and SchedulerSwappedInOutputs. """ # Blocks that need to be swapped or copied before model execution. @@ -549,10 +545,10 @@ def _schedule_swapped( blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] - now = time.time() - swapped_queue = policy.sort_by_priority(now, swapped_queue) infeasible_seq_groups: List[SequenceGroup] = [] + swapped_queue = self.swapped + leftover_swapped: Deque[SequenceGroup] = deque() while swapped_queue: seq_group = swapped_queue[0] @@ -617,7 +613,7 @@ def _schedule_swapped( swapped_queue.extendleft(leftover_swapped) - return swapped_queue, SchedulerSwappedInOutputs( + return SchedulerSwappedInOutputs( decode_seq_groups=decode_seq_groups, prefill_seq_groups=prefill_seq_groups, blocks_to_swap_in=blocks_to_swap_in, @@ -644,11 +640,10 @@ def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: def _schedule_prefills( self, - waiting_queue: deque, budget: SchedulingBudget, curr_loras: Optional[Set[int]], enable_chunking: bool = False, - ) -> Tuple[deque, SchedulerPrefillOutputs]: + ) -> SchedulerPrefillOutputs: """Schedule sequence groups that are in prefill stage. Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE @@ -660,8 +655,6 @@ def _schedule_prefills( `budget` and `curr_loras` are updated based on scheduled seq_groups. Args: - waiting_queue: The queue that contains prefill requests. - The given arguments are NOT in-place modified. budget: The scheduling budget. The argument is in-place updated when any requests are scheduled. curr_loras: Currently batched lora request ids. The argument is @@ -672,14 +665,12 @@ def _schedule_prefills( all tokens. Returns: - A tuple of remaining waiting_queue after scheduling and SchedulerSwappedInOutputs. """ ignored_seq_groups: List[SequenceGroup] = [] seq_groups: List[SequenceGroup] = [] - # We don't sort waiting queue because we assume it is sorted. - # Copy the queue so that the input queue is not modified. - waiting_queue = deque([s for s in waiting_queue]) + + waiting_queue = self.waiting leftover_waiting_sequences: Deque[SequenceGroup] = deque() while self._passed_delay(time.time()) and waiting_queue: @@ -758,7 +749,7 @@ def _schedule_prefills( if len(seq_groups) > 0: self.prev_prompt = True - return waiting_queue, SchedulerPrefillOutputs( + return SchedulerPrefillOutputs( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True)) @@ -785,53 +776,43 @@ def _schedule_default(self) -> SchedulerOutputs: seq_group.lora_int_id for seq_group in self.running if seq_group.lora_int_id > 0) if self.lora_enabled else None - remaining_waiting, prefills = (self.waiting, - SchedulerPrefillOutputs.create_empty()) - remaining_running, running_scheduled = ( - self.running, SchedulerRunningOutputs.create_empty()) - remaining_swapped, swapped_in = ( - self.swapped, SchedulerSwappedInOutputs.create_empty()) + prefills = SchedulerPrefillOutputs.create_empty() + running_scheduled = SchedulerRunningOutputs.create_empty() + swapped_in = SchedulerSwappedInOutputs.create_empty() # If any requests are swapped, prioritized swapped requests. if not self.swapped: - remaining_waiting, prefills = self._schedule_prefills( - self.waiting, budget, curr_loras, enable_chunking=False) + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=False) - fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs") # Don't schedule decodes if prefills are scheduled. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # only contains decode requests, not chunked prefills. if len(prefills.seq_groups) == 0: - remaining_running, running_scheduled = self._schedule_running( - self.running, - budget, - curr_loras, - fcfs_policy, - enable_chunking=False) + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=False) # If any sequence group is preempted, do not swap in any sequence # group. because it means there's no slot for new running requests. if len(running_scheduled.preempted) + len( running_scheduled.swapped_out) == 0: - remaining_swapped, swapped_in = self._schedule_swapped( - self.swapped, budget, curr_loras, fcfs_policy) + swapped_in = self._schedule_swapped(budget, curr_loras) assert (budget.num_batched_tokens <= self.scheduler_config.max_num_batched_tokens) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. - self.waiting = remaining_waiting self.waiting.extendleft(running_scheduled.preempted) # Update new running requests. - self.running = remaining_running self.running.extend([s.seq_group for s in prefills.seq_groups]) self.running.extend( [s.seq_group for s in running_scheduled.decode_seq_groups]) self.running.extend( [s.seq_group for s in swapped_in.decode_seq_groups]) # Update swapped requests. - self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) preempted = (len(running_scheduled.preempted) + len(running_scheduled.swapped_out)) @@ -877,42 +858,32 @@ def _schedule_chunked_prefill(self): ) curr_loras: Set[int] = set() - remaining_waiting, prefills = (self.waiting, - SchedulerPrefillOutputs.create_empty()) - remaining_running, running_scheduled = ( - self.running, SchedulerRunningOutputs.create_empty()) - remaining_swapped, swapped_in = ( - self.swapped, SchedulerSwappedInOutputs.create_empty()) + prefills = SchedulerPrefillOutputs.create_empty() + swapped_in = SchedulerSwappedInOutputs.create_empty() # Decoding should be always scheduled first by fcfs. - fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs") - remaining_running, running_scheduled = self._schedule_running( - self.running, - budget, - curr_loras, - fcfs_policy, - enable_chunking=True) + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=True) # Schedule swapped out requests. # If preemption happens, it means we don't have space for swap-in. if len(running_scheduled.preempted) + len( running_scheduled.swapped_out) == 0: - remaining_swapped, swapped_in = self._schedule_swapped( - self.swapped, budget, curr_loras, fcfs_policy) + swapped_in = self._schedule_swapped(budget, curr_loras) # Schedule new prefills. - remaining_waiting, prefills = self._schedule_prefills( - self.waiting, budget, curr_loras, enable_chunking=True) + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=True) assert (budget.num_batched_tokens <= self.scheduler_config.max_num_batched_tokens) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. - self.waiting = remaining_waiting self.waiting.extendleft(running_scheduled.preempted) # Update new running requests. - self.running = remaining_running self.running.extend([s.seq_group for s in prefills.seq_groups]) self.running.extend( [s.seq_group for s in running_scheduled.decode_seq_groups]) @@ -923,7 +894,6 @@ def _schedule_chunked_prefill(self): self.running.extend( [s.seq_group for s in swapped_in.prefill_seq_groups]) # Update swapped requests. - self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) return SchedulerOutputs( scheduled_seq_groups=(prefills.seq_groups + From a72a424b3eac43a26d2214c0f2a7f91cc59f2f84 Mon Sep 17 00:00:00 2001 From: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Date: Thu, 1 Aug 2024 13:07:37 -0500 Subject: [PATCH 0018/3246] [Build/CI] Fixing Docker Hub quota issue. (#7043) --- .buildkite/run-amd-test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 77e451354caf..85b2b6b50353 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -56,7 +56,7 @@ done echo "--- Pulling container" docker login registry-1.docker.io -u alexeivivanovamd -p ${DH_TOKEN} -image_name="rocmshared/vllm-ci:${BUILDKITE_COMMIT}" +image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}" container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" docker pull ${image_name} From 7e0861bd0bb25ea5ceaa3a513da4133fb828b5fe Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 1 Aug 2024 11:11:24 -0700 Subject: [PATCH 0019/3246] [CI/Build] Update PyTorch to 2.4.0 (#6951) Co-authored-by: Michael Goin --- .buildkite/test-pipeline.yaml | 6 +++--- .github/workflows/publish.yml | 2 +- CMakeLists.txt | 2 +- Dockerfile | 2 +- pyproject.toml | 2 +- requirements-build.txt | 2 +- requirements-cuda.txt | 8 ++++---- vllm/model_executor/layers/ops/sample.py | 2 +- 8 files changed, 13 insertions(+), 13 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9ec9ec12bfcf..573c3740f0bb 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -44,7 +44,7 @@ steps: fast_check: true commands: # This flashinfer installation will fail on AMD ROCm, so it is set as optional. - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl || true + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl || true - pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_cpu_offload.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py @@ -164,7 +164,7 @@ steps: - label: Models Test #mirror_hardwares: [amd] commands: - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl - pytest -v -s models -m \"not vlm\" - label: Vision Language Models Test @@ -281,7 +281,7 @@ steps: - pytest -v -s distributed/test_custom_all_reduce.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s -x lora/test_mixtral.py diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 15c2ec05b25d..607fda754bf2 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -49,7 +49,7 @@ jobs: matrix: os: ['ubuntu-20.04'] python-version: ['3.8', '3.9', '3.10', '3.11'] - pytorch-version: ['2.3.1'] # Must be the most recent version that meets requirements-cuda.txt. + pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt. cuda-version: ['11.8', '12.1'] steps: diff --git a/CMakeLists.txt b/CMakeLists.txt index 0d599c547070..9c2cb360fca3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,7 +32,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1") +set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0") set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0") # diff --git a/Dockerfile b/Dockerfile index db4453ab0efc..7294707046ab 100644 --- a/Dockerfile +++ b/Dockerfile @@ -192,7 +192,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.9/flashinfer-0.0.9+cu121torch2.3-cp310-cp310-linux_x86_64.whl + python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl #################### vLLM installation IMAGE #################### diff --git a/pyproject.toml b/pyproject.toml index b0d115a091c4..26d963aa5109 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ "ninja", "packaging", "setuptools >= 49.4.0", - "torch == 2.3.1", + "torch == 2.4.0", "wheel", ] build-backend = "setuptools.build_meta" diff --git a/requirements-build.txt b/requirements-build.txt index b05f38a0ed91..d0f677fd344e 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -3,5 +3,5 @@ cmake>=3.21 ninja packaging setuptools>=49.4.0 -torch==2.3.1 +torch==2.4.0 wheel diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 3eb91212e976..1f60d54083b4 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -4,8 +4,8 @@ # Dependencies for NVIDIA GPUs ray >= 2.9 nvidia-ml-py # for pynvml package -torch == 2.3.1 +torch == 2.4.0 # These must be updated alongside torch -torchvision == 0.18.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version -xformers == 0.0.27 # Requires PyTorch 2.3.1 -vllm-flash-attn == 2.5.9.post1 # Requires PyTorch 2.3.1 +torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version +xformers == 0.0.27.post2 # Requires PyTorch 2.4.0 +vllm-flash-attn == 2.6.0 # Requires PyTorch 2.4.0 diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py index bdb577da3172..fb88a05daf48 100644 --- a/vllm/model_executor/layers/ops/sample.py +++ b/vllm/model_executor/layers/ops/sample.py @@ -7,7 +7,7 @@ from vllm.model_executor.layers.ops.rand import seeded_uniform from vllm.triton_utils.sample import get_num_triton_sampler_splits -_EPS = 1e-6 +_EPS: tl.constexpr = 1e-6 def _multi_split_sample( From 2dd34371a6054966d30971dae89b0c431d7f0f08 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 2 Aug 2024 03:00:28 +0800 Subject: [PATCH 0020/3246] [Bugfix] Fix RMSNorm forward in InternViT attention qk_layernorm (#6992) --- vllm/model_executor/models/intern_vit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 86d0930d8012..c6c692deca2e 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -113,10 +113,10 @@ def forward(self, x): if self.qk_normalization: B_, H_, N_, D_ = q.shape - q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view( - B_, N_, H_, D_).transpose(1, 2) - k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view( - B_, N_, H_, D_).transpose(1, 2) + q = self.q_norm.forward_native(q.transpose(1, 2).flatten( + -2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + k = self.k_norm.forward_native(k.transpose(1, 2).flatten( + -2, -1)).view(B_, N_, H_, D_).transpose(1, 2) x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) x = x.transpose(1, 2).reshape(B, N, C) From fb3db616881d7225c4bbe64bb709ea6bcd6157f7 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 1 Aug 2024 15:00:51 -0400 Subject: [PATCH 0021/3246] [CI/Build] Remove sparseml requirement from testing (#7037) --- requirements-test.txt | 1 - tests/conftest.py | 4 -- tests/models/test_compressed_tensors.py | 52 ------------------- tests/quantization/test_compressed_tensors.py | 2 +- 4 files changed, 1 insertion(+), 58 deletions(-) delete mode 100644 tests/models/test_compressed_tensors.py diff --git a/requirements-test.txt b/requirements-test.txt index 9b88fcce3e84..df247496be16 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -14,7 +14,6 @@ peft requests ray sentence-transformers # required for embedding -sparseml==1.8.0 # required for compressed-tensors compressed-tensors==0.4.0 # required for compressed-tensors timm # required for internvl test diff --git a/tests/conftest.py b/tests/conftest.py index 59510075b006..999ca60d07a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -152,7 +152,6 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, is_embedding_model: bool = False, is_vision_model: bool = False, - is_sparseml_model: bool = False, ) -> None: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] @@ -169,9 +168,6 @@ def __init__( else: if is_vision_model: auto_cls = AutoModelForVision2Seq - elif is_sparseml_model: - from sparseml.transformers import SparseAutoModelForCausalLM - auto_cls = SparseAutoModelForCausalLM else: auto_cls = AutoModelForCausalLM diff --git a/tests/models/test_compressed_tensors.py b/tests/models/test_compressed_tensors.py deleted file mode 100644 index da47d5f3f3d2..000000000000 --- a/tests/models/test_compressed_tensors.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Compares vllm vs sparseml for compressed-tensors - -Note: vllm and sparseml do not have bitwise correctness, -so in this test, we just confirm that the top selected -tokens of the are in the top 5 selections of each other. -""" - -import pytest - -from tests.quantization.utils import is_quant_method_supported - -from .utils import check_logprobs_close - -MODELS = [ - # No bias - "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test", - # Bias - "neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8" -] - -MAX_TOKENS = 32 -NUM_LOGPROBS = 5 - - -@pytest.mark.skipif( - not is_quant_method_supported("compressed-tensors"), - reason="compressed-tensors is not supported on this machine type.") -@pytest.mark.parametrize("model_name", MODELS) -def test_models( - vllm_runner, - hf_runner, - example_prompts, - model_name, -) -> None: - # Run sparseml. - with hf_runner(model_name=model_name, - is_sparseml_model=True) as sparseml_model: - - sparseml_outputs = sparseml_model.generate_greedy_logprobs_limit( - example_prompts, MAX_TOKENS, NUM_LOGPROBS) - - # Run vllm. - with vllm_runner(model_name=model_name) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, MAX_TOKENS, NUM_LOGPROBS) - - check_logprobs_close( - outputs_0_lst=sparseml_outputs, - outputs_1_lst=vllm_outputs, - name_0="sparseml", - name_1="vllm", - ) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index c5a01b73f4a8..bd79da84a776 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -1,4 +1,4 @@ -"""Test model set-up and weight loading for sparseml-quantized models. +"""Test model set-up and weight loading for llmcompressor-quantized models. Run `pytest tests/quantization/test_compressed_tensors.py`. """ From f4fd390f5de585fd94877158bea4e1b2d1920df3 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 1 Aug 2024 15:01:07 -0400 Subject: [PATCH 0022/3246] [Bugfix] Lower gemma's unloaded_params exception to warning (#7002) --- vllm/model_executor/models/gemma.py | 6 +++--- vllm/model_executor/models/gemma2.py | 9 ++++++--- vllm/model_executor/models/paligemma.py | 6 +++--- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 7e0888b5f5ab..64aef1024a1a 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -404,6 +404,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params if unloaded_params: - raise RuntimeError( - "Some weights are not initialized from checkpoints: " - f"{unloaded_params}") + logger.warning( + "Some weights are not initialized from checkpoints: %s", + unloaded_params) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 8386084c2b3f..b77c901f6cd3 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -23,6 +23,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -41,6 +42,8 @@ from .interfaces import SupportsLoRA +logger = init_logger(__name__) + class Gemma2MLP(nn.Module): @@ -390,6 +393,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): unloaded_params = params_dict.keys() - loaded_params if unloaded_params: - raise RuntimeError( - "Some weights are not initialized from checkpoints: " - f"{unloaded_params}") + logger.warning( + "Some weights are not initialized from checkpoints: %s", + unloaded_params) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 2af48b6bc190..fe91611cd30f 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -342,6 +342,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): unloaded_params = params_dict.keys() - loaded_params if unloaded_params: - raise RuntimeError( - "Some weights are not initialized from checkpoints: " - f"{unloaded_params}") + logger.warning( + "Some weights are not initialized from checkpoints: %s", + unloaded_params) From fc912e0886f5eaa584c1a65fad81c6c269f609a0 Mon Sep 17 00:00:00 2001 From: Murali Andoorveedu <37849411+andoorve@users.noreply.github.com> Date: Thu, 1 Aug 2024 12:40:43 -0700 Subject: [PATCH 0023/3246] [Models] Support Qwen model with PP (#6974) Signed-off-by: Muralidhar Andoorveedu --- docs/source/serving/distributed_serving.rst | 2 +- vllm/config.py | 1 + vllm/model_executor/models/qwen.py | 54 +++++++++++++++++---- 3 files changed, 46 insertions(+), 11 deletions(-) diff --git a/docs/source/serving/distributed_serving.rst b/docs/source/serving/distributed_serving.rst index 5f14fd2b0ee0..fcb2646df50d 100644 --- a/docs/source/serving/distributed_serving.rst +++ b/docs/source/serving/distributed_serving.rst @@ -50,7 +50,7 @@ You can also additionally specify :code:`--pipeline-parallel-size` to enable pip $ --pipeline-parallel-size 2 .. note:: - Pipeline parallel is a beta feature. It is only supported for online serving as well as LLaMa, GPT2, and Mixtral style models. + Pipeline parallel is a beta feature. It is only supported for online serving as well as LLaMa, GPT2, Mixtral, Qwen, Qwen2, and Nemotron style models. Multi-Node Inference and Serving -------------------------------- diff --git a/vllm/config.py b/vllm/config.py index e06574459237..ef56e2b6395b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -42,6 +42,7 @@ "NemotronForCausalLM", "Qwen2ForCausalLM", "Qwen2MoeForCausalLM", + "QWenLMHeadModel", ] diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 47c85c783db7..eb61adf34e9a 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -12,7 +12,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -30,6 +30,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once +from .utils import is_pp_missing_parameter, make_layers + class QWenMLP(nn.Module): @@ -186,6 +188,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -195,10 +198,10 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.h = nn.ModuleList([ - QWenBlock(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: QWenBlock(config, cache_config, quant_config), + prefix=f"{prefix}.h") self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( @@ -207,18 +210,29 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], ) -> torch.Tensor: - hidden_states = self.wte(input_ids) - residual = None - for i in range(len(self.h)): + if get_pp_group().is_first_rank: + hidden_states = self.wte(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.h[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states @@ -250,9 +264,23 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, @@ -284,6 +312,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -301,6 +332,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): "Only text inputs are allowed. Images won't be handled " "until Qwen-VL models are fully supported.") continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From 562e580abc63cd6c1d39bd04d7a007ddefba7575 Mon Sep 17 00:00:00 2001 From: omkar kakarparthi <75638701+okakarpa@users.noreply.github.com> Date: Thu, 1 Aug 2024 15:12:37 -0500 Subject: [PATCH 0024/3246] Update run-amd-test.sh (#7044) --- .buildkite/run-amd-test.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 85b2b6b50353..ccc2f090565e 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -55,7 +55,6 @@ while true; do done echo "--- Pulling container" -docker login registry-1.docker.io -u alexeivivanovamd -p ${DH_TOKEN} image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}" container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" docker pull ${image_name} From 805a8a75f2f17ee56c0882efcc34d35e1801cbee Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 1 Aug 2024 13:14:37 -0700 Subject: [PATCH 0025/3246] [Misc] Support attention logits soft-capping with flash-attn (#7022) --- requirements-cuda.txt | 2 +- tests/kernels/test_flash_attn.py | 19 +++++++++++++------ vllm/attention/backends/abstract.py | 1 + vllm/attention/backends/blocksparse_attn.py | 3 +++ vllm/attention/backends/flash_attn.py | 21 ++++++++++----------- vllm/attention/backends/flashinfer.py | 14 +++++--------- vllm/attention/backends/ipex_attn.py | 8 ++++++-- vllm/attention/backends/pallas.py | 4 ++++ vllm/attention/backends/rocm_flash_attn.py | 10 ++++++++-- vllm/attention/backends/torch_sdpa.py | 8 ++++++-- vllm/attention/backends/utils.py | 9 --------- vllm/attention/backends/xformers.py | 9 +++++++-- vllm/attention/layer.py | 3 ++- vllm/model_executor/models/gemma2.py | 7 +++++-- 14 files changed, 71 insertions(+), 47 deletions(-) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 1f60d54083b4..1d00f0c17dee 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -8,4 +8,4 @@ torch == 2.4.0 # These must be updated alongside torch torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version xformers == 0.0.27.post2 # Requires PyTorch 2.4.0 -vllm-flash-attn == 2.6.0 # Requires PyTorch 2.4.0 +vllm-flash-attn == 2.6.1 # Requires PyTorch 2.4.0 diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index cd06c27175ce..6c5eff00de44 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -20,6 +20,7 @@ def ref_paged_attn( block_tables: torch.Tensor, scale: float, sliding_window: Optional[int] = None, + soft_cap: Optional[float] = None, ) -> torch.Tensor: num_seqs = len(query_lens) block_tables = block_tables.cpu().numpy() @@ -53,6 +54,8 @@ def ref_paged_attn( (query_len + sliding_window) + 1).bool().logical_not() mask |= sliding_window_mask + if soft_cap is not None: + attn = soft_cap * torch.tanh(attn / soft_cap) attn.masked_fill_(mask, float("-inf")) attn = torch.softmax(attn, dim=-1).to(v.dtype) out = torch.einsum("hqk,khd->qhd", attn, v) @@ -68,13 +71,15 @@ def ref_paged_attn( @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode +@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@torch.inference_mode() def test_flash_attn_with_paged_kv( kv_lens: List[int], num_heads: Tuple[int, int], head_size: int, dtype: torch.dtype, block_size: int, + soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) @@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv( causal=True, block_table=block_tables, cache_seqlens=kv_lens_tensor, + softcap=soft_cap if soft_cap is not None else 0, ).squeeze(1) ref_output = ref_paged_attn( @@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv( kv_lens=kv_lens, block_tables=block_tables, scale=scale, + soft_cap=soft_cap, ) assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" @@ -129,7 +136,8 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("sliding_window", [None]) @pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode +@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@torch.inference_mode() def test_varlen_with_paged_kv( seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int], @@ -137,6 +145,7 @@ def test_varlen_with_paged_kv( sliding_window: Optional[int], dtype: torch.dtype, block_size: int, + soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) @@ -163,10 +172,6 @@ def test_varlen_with_paged_kv( head_size, dtype=dtype) value_cache = torch.randn_like(key_cache) - # Normalize the scale of the key and value caches to mitigate - # numerical instability. - key_cache /= head_size**0.5 - value_cache /= head_size**0.5 cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) @@ -192,6 +197,7 @@ def test_varlen_with_paged_kv( causal=True, window_size=window_size, block_table=block_tables, + softcap=soft_cap if soft_cap is not None else 0, ) ref_output = ref_paged_attn( @@ -203,6 +209,7 @@ def test_varlen_with_paged_kv( block_tables=block_tables, scale=scale, sliding_window=sliding_window, + soft_cap=soft_cap, ) assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 106b00cc1014..97b13917ccfa 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -150,6 +150,7 @@ def __init__( sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 71954f864a9b..907b45393eeb 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -283,12 +283,15 @@ def __init__( sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: assert blocksparse_params is not None assert alibi_slopes is None, ValueError( "Alibi not support for blocksparse flash attention.") assert sliding_window is None, ValueError( "sliding_window is invalid for blocksparse attention.") + assert logits_soft_cap is None, ValueError( + "logits_soft_cap is invalid for blocksparse attention.") if "num_heads" not in blocksparse_params: blocksparse_params["num_heads"] = num_heads diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 7d7aff9dc3cd..00654dca2adf 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -288,15 +288,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 - logits_soft_cap = getattr(self.runner.model_config.hf_config, - "attn_logit_softcapping", None) - if logits_soft_cap is not None: - raise ValueError( - "Please use Flashinfer backend for models with logits_soft_cap" - " (i.e., Gemma-2). Otherwise, the output might be wrong." - " Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") - max_query_len = max(query_lens) max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) @@ -405,9 +396,11 @@ def __init__( sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: - assert blocksparse_params is None, ValueError( - "FlashAttention does not support block-sparse attention.") + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -418,6 +411,10 @@ def __init__( self.sliding_window = ((sliding_window, sliding_window) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -525,6 +522,7 @@ def forward( causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, ) assert output[:num_prefill_tokens].shape == out.shape output[:num_prefill_tokens] = out @@ -544,6 +542,7 @@ def forward( causal=True, alibi_slopes=self.alibi_slopes, block_table=prefill_meta.block_tables, + softcap=self.logits_soft_cap, ) if decode_meta := attn_metadata.decode_metadata: diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 83a420d76834..ccf8ab03a621 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -116,8 +116,6 @@ class FlashInferMetadata(AttentionMetadata): # The data type of the paged kv cache data_type: torch.dtype = None device: torch.device = torch.device("cuda") - # Only used by gemma2 model - logits_soft_cap: Optional[float] = None def __post_init__(self): # Refer to @@ -391,9 +389,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.long, device=device) - logits_soft_cap = getattr(self.runner.model_config.hf_config, - "attn_logit_softcapping", None) - if len(self.paged_kv_indptr) > 0: paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, device="cpu", @@ -430,8 +425,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], query_start_loc=query_start_loc, device=device, data_type=kv_cache_dtype, - use_cuda_graph=use_captured_graph, - logits_soft_cap=logits_soft_cap) + use_cuda_graph=use_captured_graph) class FlashInferImpl(AttentionImpl): @@ -446,6 +440,7 @@ def __init__( sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -458,6 +453,7 @@ def __init__( raise ValueError("Sliding window is not supported in FlashInfer.") self.sliding_window = (-1, -1) self.kv_cache_dtype = kv_cache_dtype + self.logits_soft_cap = logits_soft_cap assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -532,7 +528,7 @@ def forward( output = prefill_meta.prefill_wrapper.forward( query, kv_cache, - logits_soft_cap=attn_metadata.logits_soft_cap, + logits_soft_cap=self.logits_soft_cap, causal=True) else: assert attn_metadata.decode_metadata is not None @@ -541,5 +537,5 @@ def forward( query, kv_cache, sm_scale=self.scale, - logits_soft_cap=attn_metadata.logits_soft_cap) + logits_soft_cap=self.logits_soft_cap) return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 4559dd15f600..bac30aec2482 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -105,9 +105,13 @@ def __init__( sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: - assert blocksparse_params is None, ValueError( - "Torch SPDA does not support block-sparse attention.") + if blocksparse_params is not None: + raise ValueError( + "IPEX backend does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError("IPEX backend does not support logits_soft_cap.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 2269ac2606e8..4ecf698c8d51 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -91,6 +91,7 @@ def __init__( sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -109,6 +110,9 @@ def __init__( raise NotImplementedError("FP8 KV cache dtype is not supported.") if blocksparse_params is not None: raise NotImplementedError("Blocksparse is not supported.") + if logits_soft_cap is not None: + raise NotImplementedError( + "Attention logits soft-capping is not supported.") if torch_xla.tpu.version() < 4: raise NotImplementedError("TPU version must be 4 or higher.") diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 058c8df0eaf8..26e9b8a93fb9 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -244,9 +244,15 @@ def __init__( sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: - assert blocksparse_params is None, ValueError( - "ROCFlashAttention does not support blocksparse attention.") + if blocksparse_params is not None: + raise ValueError( + "ROCmFlashAttention does not support blocksparse attention.") + if logits_soft_cap is not None: + raise ValueError( + "ROCmFlashAttention does not support attention logits soft " + "capping.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index fe6a56123ce7..b83c673f0165 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -109,9 +109,13 @@ def __init__( sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: - assert blocksparse_params is None, ValueError( - "Torch SPDA does not support block-sparse attention.") + if blocksparse_params is not None: + raise ValueError( + "Torch SPDA does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError("Torch SPDA does not support logits soft cap.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index dcd10ed410a7..bca1370343b7 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -165,15 +165,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 - logits_soft_cap = getattr(self.runner.model_config.hf_config, - "attn_logit_softcapping", None) - if logits_soft_cap is not None: - raise ValueError( - "Please use Flashinfer backend for models with logits_soft_cap " - "(i.e., Gemma-2). Otherwise, the output might be wrong. " - "Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") - max_query_len = max(query_lens) max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 1573cd7da94c..24ba5fc72540 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -408,9 +408,14 @@ def __init__( sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: - assert blocksparse_params is None, ValueError( - "XFormer does not support block-sparse attention.") + if blocksparse_params is not None: + raise ValueError( + "XFormers does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError( + "XFormers does not support attention logits soft capping.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 5fa552f2f4ec..2c21502dcf40 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -34,6 +34,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, prefix: str = "", ) -> None: super().__init__() @@ -82,7 +83,7 @@ def __init__( impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params) + blocksparse_params, logits_soft_cap) def forward( self, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index b77c901f6cd3..7bad2626fec6 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -90,7 +90,8 @@ def __init__(self, max_position_embeddings: int, rope_theta: float, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None) -> None: super().__init__() self.layer_idx = layer_idx self.config = config @@ -150,7 +151,8 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap) def forward( self, @@ -189,6 +191,7 @@ def __init__( rope_theta=config.rope_theta, cache_config=cache_config, quant_config=quant_config, + attn_logits_soft_cap=config.attn_logit_softcapping, ) self.hidden_size = config.hidden_size self.mlp = Gemma2MLP( From 6a11fdfbb8d6701c7ad38648aead23d8cbe6aac5 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 1 Aug 2024 16:51:15 -0400 Subject: [PATCH 0026/3246] [CI/Build][Bugfix] Fix CUTLASS header-only line (#7034) --- CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c2cb360fca3..77a8af549b02 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -156,12 +156,15 @@ set(VLLM_EXT_SRC if(VLLM_GPU_LANG STREQUAL "CUDA") include(FetchContent) - SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) + SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git # CUTLASS 3.5.0 GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc + # Shallow clone with depth 1 + GIT_SHALLOW TRUE + GIT_PROGRESS TRUE ) FetchContent_MakeAvailable(cutlass) From 6ce01f30667bbae33f112152e07a3b66b841078f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 1 Aug 2024 18:29:52 -0700 Subject: [PATCH 0027/3246] [Performance] Optimize `get_seqs` (#7051) --- vllm/core/block_manager_v1.py | 2 +- vllm/sequence.py | 40 +++++++++++++------------- vllm/transformers_utils/detokenizer.py | 2 +- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index e29eba375f4d..d81648caa585 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -700,5 +700,5 @@ def get_common_computed_block_ids( def mark_blocks_as_computed(self, seq_group: SequenceGroup): if self.enable_caching: - for seq in seq_group.seqs_dict.values(): + for seq in seq_group.get_seqs(): self.compute_full_blocks_in_seq(seq) diff --git a/vllm/sequence.py b/vllm/sequence.py index ab50cfdfd29a..7ef9387c611f 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -444,6 +444,7 @@ def __init__( prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.request_id = request_id + self.seqs = seqs self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.metrics = RequestMetrics(arrival_time=arrival_time, @@ -458,25 +459,24 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq self.trace_headers = trace_headers - self._first_seq = next(iter(self.seqs_dict.values())) @property def prompt(self) -> Optional[str]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. - return self._first_seq.prompt + return self.seqs[0].prompt @property def prompt_token_ids(self) -> List[int]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. - return self._first_seq.prompt_token_ids + return self.seqs[0].prompt_token_ids @property def multi_modal_data(self) -> "MultiModalDataDict": # All sequences in the group should have the same multi-modal data. # We use the multi-modal data of an arbitrary sequence. - return self._first_seq.multi_modal_data + return self.seqs[0].multi_modal_data @property def lora_int_id(self) -> int: @@ -512,7 +512,7 @@ def maybe_set_first_token_time(self, time: float) -> None: # in TPOT, rather than recalculating TTFT (since from the ) # POV of the user, there is simply a long generation delay. if (self.metrics.first_token_time is None - and self.get_seqs()[0].get_output_len() == 1): + and self.seqs[0].get_output_len() == 1): self.metrics.first_token_time = time def maybe_set_first_scheduled_time(self, time: float) -> None: @@ -548,9 +548,9 @@ def get_seqs( self, status: Optional[SequenceStatus] = None, ) -> List[Sequence]: - return list(self.seqs_dict.values()) if status is None else [ - seq for seq in self.seqs_dict.values() if seq.status == status - ] + if status is None: + return self.seqs + return [seq for seq in self.seqs if seq.status == status] def is_encoder_decoder(self) -> bool: return self.encoder_seq is not None @@ -559,22 +559,20 @@ def get_encoder_seq(self) -> Optional[Sequence]: return self.encoder_seq def get_unfinished_seqs(self) -> List[Sequence]: - return [ - seq for seq in self.seqs_dict.values() if not seq.is_finished() - ] + return [seq for seq in self.seqs if not seq.is_finished()] def get_finished_seqs(self) -> List[Sequence]: - return [seq for seq in self.seqs_dict.values() if seq.is_finished()] + return [seq for seq in self.seqs if seq.is_finished()] def update_num_computed_tokens(self, num_new_computed_tokens: int): """Update number of tokens computed so far.""" - for seq in self.seqs_dict.values(): + for seq in self.seqs: if not seq.is_finished(): seq.data.update_num_computed_tokens(num_new_computed_tokens) def get_num_uncomputed_tokens(self) -> int: num_uncomputed_tokens = 0 - for seq in self.get_seqs(): + for seq in self.seqs: if not seq.is_finished(): num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() return num_uncomputed_tokens @@ -583,7 +581,7 @@ def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: # Optimization. We don't need to call get_seqs if we don't need to # filter by states. if status is None: - return len(self.seqs_dict) + return len(self.seqs) return len(self.get_seqs(status)) @@ -602,23 +600,25 @@ def add(self, seq: Sequence) -> None: if seq.seq_id in self.seqs_dict: raise ValueError(f"Sequence {seq.seq_id} already exists.") self.seqs_dict[seq.seq_id] = seq + self.seqs.append(seq) def remove(self, seq_id: int) -> None: - if seq_id not in self.seqs_dict: + seq = self.seqs_dict.pop(seq_id, None) + if seq is None: raise ValueError(f"Sequence {seq_id} not found.") - del self.seqs_dict[seq_id] + self.seqs.remove(seq) def is_finished(self) -> bool: - return all(seq.is_finished() for seq in self.get_seqs()) + return all(seq.is_finished() for seq in self.seqs) def is_prefill(self) -> bool: # Every sequence should be in the same stage. - return self.get_seqs()[0].is_prefill() + return self.seqs[0].is_prefill() def __repr__(self) -> str: return (f"SequenceGroup(request_id={self.request_id}, " f"sampling_params={self.sampling_params}, " - f"num_seqs={len(self.seqs_dict)})") + f"num_seqs={len(self.seqs)})") class SequenceGroupMetadata: diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 76f418674532..001af67f3bb9 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -40,7 +40,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, assert prms is not None # We can pick any sequence for the prompt. - seq = next(iter(seq_group.seqs_dict.values())) + seq = seq_group.get_seqs()[0] # Only prompt, without the generated token. all_token_ids = seq.get_token_ids() prompt_token_ids = all_token_ids[:-1] From 954f7305a106058815bd7e47f5b9d585d8764c05 Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Thu, 1 Aug 2024 18:44:16 -0700 Subject: [PATCH 0028/3246] [Kernel] Fix input for flashinfer prefill wrapper. (#7008) --- vllm/attention/backends/flashinfer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ccf8ab03a621..91abaab78dcb 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -133,13 +133,20 @@ def begin_forward(self): return assert self.prefill_wrapper is not None + assert self.query_start_loc is not None assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None assert self.paged_kv_last_page_len is not None - self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + batch_size = self.query_start_loc.shape[0] - 1 + assert batch_size >= 0 + # The prefill stage does not read kv cache. + # Both paged_kv_indices and paged_kv_last_page_len are empty. + # paged_kv_indptr is a zero tensor with size batch_size + 1. + self.paged_kv_indptr = torch.zeros(batch_size + 1, + device=self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) + self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.prefill_wrapper.end_forward() self.prefill_wrapper.begin_forward( self.query_start_loc, self.paged_kv_indptr, From 3bb4b1e4cd3d07c80a208d875b016631d91844f8 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 2 Aug 2024 10:49:43 +0800 Subject: [PATCH 0029/3246] [mypy] Speed up mypy checking (#7056) --- .github/workflows/mypy.yaml | 2 +- format.sh | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 721c9c026cf1..68e3a3fefdc5 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -32,6 +32,7 @@ jobs: pip install types-setuptools - name: Mypy run: | + mypy mypy tests --follow-imports skip mypy vllm/attention --follow-imports skip mypy vllm/core --follow-imports skip @@ -44,5 +45,4 @@ jobs: mypy vllm/prompt_adapter --follow-imports skip mypy vllm/spec_decode --follow-imports skip mypy vllm/worker --follow-imports skip - mypy diff --git a/format.sh b/format.sh index 71697cffacfb..abc688c702aa 100755 --- a/format.sh +++ b/format.sh @@ -96,6 +96,7 @@ echo 'vLLM yapf: Done' # Run mypy echo 'vLLM mypy:' +mypy --follow-imports skip # Note that this is less strict than CI mypy tests --follow-imports skip mypy vllm/attention --follow-imports skip mypy vllm/core --follow-imports skip @@ -108,7 +109,7 @@ mypy vllm/model_executor --follow-imports skip mypy vllm/prompt_adapter --follow-imports skip mypy vllm/spec_decode --follow-imports skip mypy vllm/worker --follow-imports skip -mypy +echo 'vLLM mypy: Done' # If git diff returns a file that is in the skip list, the file may be checked anyway: @@ -127,7 +128,7 @@ spell_check_all(){ codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" } -# Spelling check of files that differ from main branch. +# Spelling check of files that differ from main branch. spell_check_changed() { # The `if` guard ensures that the list of filenames is not empty, which # could cause ruff to receive 0 positional arguments, making it hang From 252357793dd1fe9d30c34e68e4b8b2143a4c5138 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 1 Aug 2024 22:03:12 -0700 Subject: [PATCH 0030/3246] [ci][distributed] try to fix pp test (#7054) --- tests/distributed/test_pipeline_parallel.py | 4 +- tests/utils.py | 39 +++++++++++++++++++ .../tokenizer_group/ray_tokenizer_group.py | 2 +- vllm/utils.py | 3 +- 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 5ff39ddfbf99..f632caba9017 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -9,7 +9,7 @@ import pytest -from ..utils import compare_two_settings +from ..utils import compare_two_settings, fork_new_process_for_each_test VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" @@ -28,6 +28,7 @@ (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), ]) +@fork_new_process_for_each_test def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, DIST_BACKEND): if VLLM_MULTI_NODE and DIST_BACKEND == "mp": @@ -77,6 +78,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, "FLASH_ATTN", "FLASHINFER", ]) +@fork_new_process_for_each_test def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND): cudagraph_args = [ # use half precision for speed and memory savings in CI environment diff --git a/tests/utils.py b/tests/utils.py index 1086591464d4..f3ee801ee774 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,6 @@ +import functools import os +import signal import subprocess import sys import time @@ -336,3 +338,40 @@ def wait_for_gpu_memory_to_clear(devices: List[int], f'{dur_s=:.02f} ({threshold_bytes/2**30=})') time.sleep(5) + + +def fork_new_process_for_each_test(f): + + @functools.wraps(f) + def wrapper(*args, **kwargs): + # Make the process the leader of its own process group + # to avoid sending SIGTERM to the parent process + os.setpgrp() + from _pytest.outcomes import Skipped + pid = os.fork() + if pid == 0: + try: + f(*args, **kwargs) + except Skipped as e: + # convert Skipped to exit code 0 + print(str(e)) + os._exit(0) + except Exception: + import traceback + traceback.print_exc() + os._exit(1) + else: + os._exit(0) + else: + pgid = os.getpgid(pid) + _pid, _exitcode = os.waitpid(pid, 0) + # ignore SIGTERM signal itself + old_singla_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) + # kill all child processes + os.killpg(pgid, signal.SIGTERM) + # restore the signal handler + signal.signal(signal.SIGTERM, old_singla_handler) + assert _exitcode == 0, (f"function {f} failed when called with" + f" args {args} and kwargs {kwargs}") + + return wrapper diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index eebdf7bf644d..79081c04ddc1 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -3,7 +3,7 @@ from typing import List, Optional try: - from ray.exceptions import ActorDiedError + from ray.exceptions import ActorDiedError # type: ignore except ImportError: # For older versions of Ray from ray.exceptions import RayActorError as ActorDiedError # type: ignore diff --git a/vllm/utils.py b/vllm/utils.py index 38e1782a51ab..358788c95f30 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -928,7 +928,8 @@ def error_on_invalid_device_count_status(): with contextlib.suppress(Exception): # future pytorch will fix the issue, device_count will not be cached # at that time, `.cache_info().currsize` will error out - cache_entries = torch.cuda.device_count.cache_info().currsize + cache_entries = torch.cuda.device_count.cache_info( # type: ignore + ).currsize if cache_entries != 0: # the function is already called, and the result is cached remembered = torch.cuda.device_count() From cf2a1a4d9d8168d2e8e7bef244c1dfec80780405 Mon Sep 17 00:00:00 2001 From: Bongwon Jang <152451401+bong-furiosa@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:28:00 +0900 Subject: [PATCH 0031/3246] Fix tracing.py (#7065) --- vllm/tracing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/tracing.py b/vllm/tracing.py index ba6732cab68f..dc8377f2396f 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -15,7 +15,7 @@ OTEL_EXPORTER_OTLP_TRACES_PROTOCOL) from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor - from opentelemetry.semconv.ai import SpanAttributes as BaseSpanAttributes + from opentelemetry.semconv_ai import SpanAttributes as BaseSpanAttributes from opentelemetry.trace import SpanKind, Tracer, set_tracer_provider from opentelemetry.trace.propagation.tracecontext import ( TraceContextTextMapPropagator) From 660dea1235bfe8987e4e9136ce70269084384b2f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 2 Aug 2024 00:14:21 -0700 Subject: [PATCH 0032/3246] [cuda][misc] remove error_on_invalid_device_count_status (#7069) --- vllm/executor/multiproc_gpu_executor.py | 3 --- vllm/executor/ray_gpu_executor.py | 9 +++------ vllm/utils.py | 23 ----------------------- 3 files changed, 3 insertions(+), 32 deletions(-) diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index e1e92958e667..08a35a074b37 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -17,7 +17,6 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.triton_utils import maybe_set_triton_cache_manager from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, - error_on_invalid_device_count_status, get_distributed_init_method, get_open_port, get_vllm_instance_id, make_async, update_environment_variables) @@ -79,8 +78,6 @@ def _init_executor(self) -> None: f"please ensure that world_size ({world_size}) " f"is less than than max local gpu count ({cuda_device_count})") - error_on_invalid_device_count_status() - # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 564fa79acfd4..14007e5518d4 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -10,10 +10,9 @@ from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.utils import (_run_task_with_lock, - error_on_invalid_device_count_status, - get_distributed_init_method, get_ip, get_open_port, - get_vllm_instance_id, make_async) +from vllm.utils import (_run_task_with_lock, get_distributed_init_method, + get_ip, get_open_port, get_vllm_instance_id, + make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -216,8 +215,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) - error_on_invalid_device_count_status() - # Initialize the actual workers inside worker wrapper. init_worker_all_kwargs = [ self._get_worker_kwargs( diff --git a/vllm/utils.py b/vllm/utils.py index 358788c95f30..c4c17bfbefc6 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,6 +1,5 @@ import argparse import asyncio -import contextlib import datetime import enum import gc @@ -923,28 +922,6 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) -def error_on_invalid_device_count_status(): - cache_entries = 0 - with contextlib.suppress(Exception): - # future pytorch will fix the issue, device_count will not be cached - # at that time, `.cache_info().currsize` will error out - cache_entries = torch.cuda.device_count.cache_info( # type: ignore - ).currsize - if cache_entries != 0: - # the function is already called, and the result is cached - remembered = torch.cuda.device_count() - current = cuda_device_count_stateless() - if remembered > current: - raise RuntimeError( - "The number of CUDA devices has changed since the first " - "call to torch.cuda.device_count(). This is not allowed " - "and may result in undefined behavior. Please check out " - "https://github.com/vllm-project/vllm/issues/6056 to " - "find the first call to torch.cuda.device_count() " - "and defer it until the engine is up. Or you can set " - "CUDA_VISIBLE_DEVICES to the GPUs you want to use.") - - # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # all the related functions work on real physical device ids. From db35186391a2abfc6c91d703527dac20d2488107 Mon Sep 17 00:00:00 2001 From: Peng Guanwen Date: Fri, 2 Aug 2024 15:58:26 +0800 Subject: [PATCH 0033/3246] [Core] Comment out unused code in sampler (#7023) --- vllm/model_executor/sampling_metadata.py | 58 +++++++++++++----------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 59cfec9ec893..015e85b4ca81 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -13,6 +13,8 @@ _SAMPLING_EPS = 1e-5 _SEED_0_REPLACEMENT = 3403598558 +# Some triton sampler related code is guarded before it is ready. +_USE_TRITON_SAMPLER = False @dataclass @@ -347,14 +349,16 @@ def from_sampling_metadata( repetition_penalties: List[float] = [] sampling_seeds: List[int] = [] sample_indices: List[int] = [] - prompt_best_of: List[int] = [] do_penalties = False do_top_p_top_k = False do_min_p = False - # We need one base seed per Triton slice. - seeds_to_generate = (extra_seeds_to_generate + - get_num_triton_sampler_splits(vocab_size)) + if _USE_TRITON_SAMPLER: + prompt_best_of: List[int] = [] + + # We need one base seed per Triton slice. + seeds_to_generate = (extra_seeds_to_generate + + get_num_triton_sampler_splits(vocab_size)) assert sampling_metadata.seq_groups is not None for seq_group in sampling_metadata.seq_groups: @@ -366,9 +370,6 @@ def from_sampling_metadata( r = sampling_params.repetition_penalty top_p = sampling_params.top_p min_p = sampling_params.min_p - seed = sampling_params.seed - - is_greedy = sampling_params.sampling_type == SamplingType.GREEDY # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) @@ -389,8 +390,7 @@ def from_sampling_metadata( do_penalties = True is_prompt = seq_group.is_prompt - if (seq_group.is_prompt - and sampling_params.prompt_logprobs is not None): + if (is_prompt and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get # their logprobs query_len = seq_group.query_len @@ -415,23 +415,27 @@ def from_sampling_metadata( frequency_penalties += [f] * len(seq_ids) repetition_penalties += [r] * len(seq_ids) - if is_prompt: - prompt_best_of.append(sampling_params.best_of) - query_len = seq_group.query_len - assert query_len is not None - - for seq_id in seq_ids: - seq_data = seq_group.seq_data[seq_id] - extra_entropy = extra_entropy or () - seq_seeds = cls._get_sequence_seeds( - seed, - seq_data.get_len(), - *extra_entropy, - seq_id, - seeds_to_generate=seeds_to_generate, - is_greedy=is_greedy) - sampling_seeds.append(seq_seeds) - sample_indices.extend(seq_group.sample_indices) + if _USE_TRITON_SAMPLER: + if is_prompt: + prompt_best_of.append(sampling_params.best_of) + query_len = seq_group.query_len + assert query_len is not None + + seed = sampling_params.seed + is_greedy = sampling_params.sampling_type == SamplingType.GREEDY + + for seq_id in seq_ids: + seq_data = seq_group.seq_data[seq_id] + extra_entropy = extra_entropy or () + seq_seeds = cls._get_sequence_seeds( + seed, + seq_data.get_len(), + *extra_entropy, + seq_id, + seeds_to_generate=seeds_to_generate, + is_greedy=is_greedy) + sampling_seeds.append(seq_seeds) + sample_indices.extend(seq_group.sample_indices) if do_penalties: for seq_group in sampling_metadata.seq_groups: @@ -549,7 +553,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], device="cpu", dtype=torch.long, pin_memory=pin_memory, - ).T.contiguous() + ).t().contiguous() # Because the memory is pinned, we can do non-blocking # transfer to device. From c16eaac5001d9e2bfb51c9812ec0c2b9e32b8d25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jie=20Fu=20=28=E5=82=85=E6=9D=B0=29?= Date: Fri, 2 Aug 2024 23:55:58 +0800 Subject: [PATCH 0034/3246] [Hardware][Intel CPU] Update torch 2.4.0 for CPU backend (#6931) --- Dockerfile.cpu | 2 +- requirements-cpu.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile.cpu b/Dockerfile.cpu index c473ba431e68..78730f39721c 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -26,7 +26,7 @@ COPY ./ /workspace/vllm WORKDIR /workspace/vllm -RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/test/cpu +RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu # Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ... ARG VLLM_CPU_DISABLE_AVX512 diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 2dcd86274a2a..27ca8ca5dbc5 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -2,5 +2,5 @@ -r requirements-common.txt # Dependencies for x86_64 CPUs -torch == 2.4.0; platform_machine != "ppc64le" +torch == 2.4.0+cpu; platform_machine != "ppc64le" torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch From 806949514ab07a2d7218645022c22962696adf46 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 2 Aug 2024 10:03:24 -0700 Subject: [PATCH 0035/3246] [ci] set timeout for test_oot_registration.py (#7082) --- tests/entrypoints/openai/test_oot_registration.py | 4 ++++ vllm/worker/worker.py | 4 +++- vllm/worker/xpu_worker.py | 4 +++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py index dbbda6de1fa0..5272ac4065f1 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/test_oot_registration.py @@ -36,10 +36,12 @@ def test_oot_registration_for_api_server(): ctx = torch.multiprocessing.get_context() server = ctx.Process(target=server_function, args=(port, )) server.start() + MAX_SERVER_START_WAIT_S = 60 client = OpenAI( base_url=f"http://localhost:{port}/v1", api_key="token-abc123", ) + now = time.time() while True: try: completion = client.chat.completions.create( @@ -57,6 +59,8 @@ def test_oot_registration_for_api_server(): except OpenAIError as e: if "Connection error" in str(e): time.sleep(3) + if time.time() - now > MAX_SERVER_START_WAIT_S: + raise RuntimeError("Server did not start in time") from e else: raise e server.kill() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f3c379d1aa34..9e2cfff435cf 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -186,7 +186,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # GPU did not change their memory usage during the profiling. peak_memory = self.init_gpu_memory - free_gpu_memory assert peak_memory > 0, ( - "Error in memory profiling. This happens when the GPU memory was " + "Error in memory profiling. " + f"Initial free memory {self.init_gpu_memory}, current free memory" + f" {free_gpu_memory}. This happens when the GPU memory was " "not properly cleaned up before initializing the vLLM instance.") cache_block_size = self.get_cache_block_size_bytes() diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 6a822c2ba3e7..0f22d67c4f25 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -138,7 +138,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # GPU did not change their memory usage during the profiling. peak_memory = self.init_gpu_memory - free_gpu_memory assert peak_memory > 0, ( - "Error in memory profiling. This happens when the GPU memory was " + "Error in memory profiling. " + f"Initial free memory {self.init_gpu_memory}, current free memory" + f" {free_gpu_memory}. This happens when the GPU memory was " "not properly cleaned up before initializing the vLLM instance.") cache_block_size = self.get_cache_block_size_bytes() From b482b9a5b13ba7d126adabbedb3ba66f48d4d83b Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 2 Aug 2024 16:51:22 -0400 Subject: [PATCH 0036/3246] [CI/Build] Add support for Python 3.12 (#7035) --- .github/workflows/mypy.yaml | 2 +- .github/workflows/publish.yml | 2 +- .github/workflows/ruff.yml | 2 +- .github/workflows/yapf.yml | 2 +- CMakeLists.txt | 2 +- docs/source/getting_started/installation.rst | 2 +- setup.py | 1 + 7 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 68e3a3fefdc5..8d423657630c 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 607fda754bf2..aeeaf6efab04 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -48,7 +48,7 @@ jobs: fail-fast: false matrix: os: ['ubuntu-20.04'] - python-version: ['3.8', '3.9', '3.10', '3.11'] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt. cuda-version: ['11.8', '12.1'] diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 773def58fd96..1a794af572fe 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml index 04f307bcf8b0..c89f82dfaaaf 100644 --- a/.github/workflows/yapf.yml +++ b/.github/workflows/yapf.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/CMakeLists.txt b/CMakeLists.txt index 77a8af549b02..dbe688186f17 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,7 +14,7 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) # Supported python versions. These versions will be searched in order, the # first match will be selected. These should be kept in sync with setup.py. # -set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") +set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12") # Supported NVIDIA architectures. set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 0253717da3cd..57ad8bacedfc 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -9,7 +9,7 @@ Requirements ------------ * OS: Linux -* Python: 3.8 -- 3.11 +* Python: 3.8 -- 3.12 * GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, H100, etc.) Install with pip diff --git a/setup.py b/setup.py index 63c1f466d291..91307e8a9406 100644 --- a/setup.py +++ b/setup.py @@ -465,6 +465,7 @@ def _read_requirements(filename: str) -> List[str]: "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: Apache Software License", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], From a8d604ca2a2912b3a5352821c53c080383580df1 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 2 Aug 2024 16:51:58 -0400 Subject: [PATCH 0037/3246] [Misc] Disambiguate quantized types via a new ScalarType (#6396) --- CMakeLists.txt | 52 ++- Dockerfile.openvino | 3 + benchmarks/kernels/benchmark_marlin.py | 50 +-- cmake/cpu_extension.cmake | 1 - csrc/{ => core}/registration.h | 0 csrc/core/scalar_type.hpp | 382 ++++++++++++++++++ csrc/core/torch_bindings.cpp | 16 + csrc/cpu/torch_bindings.cpp | 2 +- csrc/moe/torch_bindings.cpp | 2 +- csrc/ops.h | 8 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 66 ++- .../marlin/sparse/marlin_24_cuda_kernel.cu | 17 +- csrc/torch_bindings.cpp | 2 +- setup.py | 9 +- tests/kernels/test_int8_quant.py | 2 - tests/kernels/test_marlin_gemm.py | 75 ++-- tests/test_scalartype.py | 36 ++ vllm/_core_ext.py | 177 ++++++++ vllm/_custom_ops.py | 29 +- .../layers/quantization/awq_marlin.py | 49 ++- .../schemes/compressed_tensors_w4a16_24.py | 18 +- .../schemes/compressed_tensors_wNa16.py | 29 +- .../layers/quantization/gptq_marlin.py | 43 +- .../layers/quantization/gptq_marlin_24.py | 29 +- .../layers/quantization/utils/marlin_utils.py | 120 +++--- .../quantization/utils/marlin_utils_test.py | 29 +- .../utils/marlin_utils_test_24.py | 30 +- .../layers/quantization/utils/quant_utils.py | 148 +++---- vllm/scalar_type.py | 35 ++ 29 files changed, 1107 insertions(+), 352 deletions(-) rename csrc/{ => core}/registration.h (100%) create mode 100644 csrc/core/scalar_type.hpp create mode 100644 csrc/core/torch_bindings.cpp create mode 100644 tests/test_scalartype.py create mode 100644 vllm/_core_ext.py create mode 100644 vllm/scalar_type.py diff --git a/CMakeLists.txt b/CMakeLists.txt index dbe688186f17..922613ec5dda 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,6 +66,39 @@ endif() # find_package(Torch REQUIRED) +# +# Add the `default` target which detects which extensions should be +# built based on platform/architecture. This is the same logic that +# setup.py uses to select which extensions should be built and should +# be kept in sync. +# +# The `default` target makes direct use of cmake easier since knowledge +# of which extensions are supported has been factored in, e.g. +# +# mkdir build && cd build +# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm .. +# cmake --build . --target default +# +add_custom_target(default) +message(STATUS "Enabling core extension.") + +# Define _core_C extension +# built for (almost) every target platform, (excludes TPU and Neuron) + +set(VLLM_EXT_SRC + "csrc/core/torch_bindings.cpp") + +define_gpu_extension_target( + _core_C + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS} + USE_SABI 3 + WITH_SOABI) + +add_dependencies(default _core_C) + # # Forward the non-CUDA device extensions to external CMake scripts. # @@ -74,7 +107,7 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND if (VLLM_TARGET_DEVICE STREQUAL "cpu") include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) else() - message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}") + return() endif() return() endif() @@ -132,7 +165,7 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") endif() # -# Define extension targets +# Define other extension targets # # @@ -228,21 +261,6 @@ define_gpu_extension_target( -# -# Add the `default` target which detects which extensions should be -# built based on platform/architecture. This is the same logic that -# setup.py uses to select which extensions should be built and should -# be kept in sync. -# -# The `default` target makes direct use of cmake easier since knowledge -# of which extensions are supported has been factored in, e.g. -# -# mkdir build && cd build -# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm .. -# cmake --build . --target default -# -add_custom_target(default) - if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") add_dependencies(default _C) diff --git a/Dockerfile.openvino b/Dockerfile.openvino index 7c62dd845aa9..c84dea419e58 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -13,6 +13,9 @@ COPY requirements-common.txt /workspace/vllm/ COPY requirements-openvino.txt /workspace/vllm/ COPY vllm/ /workspace/vllm/vllm +COPY csrc/core /workspace/vllm/csrc/core +COPY cmake/utils.cmake /workspace/vllm/cmake/ +COPY CMakeLists.txt /workspace/vllm/ COPY setup.py /workspace/vllm/ # install build requirements diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 684985b81f69..536c133bb334 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -7,16 +7,17 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS) + MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace, marlin_quantize) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( marlin_24_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, quantize_weights, sort_weights) + gptq_pack, gptq_quantize_weights, sort_weights) +from vllm.scalar_type import ScalarType from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] @@ -27,13 +28,14 @@ def bench_run(results: List[benchmark.Measurement], model: str, - act_order: bool, is_k_full: bool, num_bits: int, group_size: int, - size_m: int, size_k: int, size_n: int): + act_order: bool, is_k_full: bool, quant_type: ScalarType, + group_size: int, size_m: int, size_k: int, size_n: int): label = "Quant Matmul" - sub_label = ("{}, act={} k_full={}, b={}, g={}, " - "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, - group_size, size_m, size_k, size_n)) + sub_label = ("{}, act={} k_full={}, q={}, g={}, " + "MKN=({}x{}x{})".format(model, act_order, is_k_full, + str(quant_type), group_size, size_m, + size_k, size_n)) print(f"Testing: {sub_label}") @@ -50,18 +52,18 @@ def bench_run(results: List[benchmark.Measurement], model: str, marlin_g_idx, marlin_sort_indices, marlin_rand_perm, - ) = marlin_quantize(b, num_bits, group_size, act_order) + ) = marlin_quantize(b, quant_type, group_size, act_order) # Marlin_24 quant (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) + marlin_24_s) = marlin_24_quantize(b, quant_type, group_size) marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) # GPTQ quant (w_ref, q_w, s, g_idx, - rand_perm) = quantize_weights(b, num_bits, group_size, act_order) - q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order) + q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) # For act_order, sort the "weights" and "g_idx" # so that group ids are increasing @@ -75,10 +77,11 @@ def bench_run(results: List[benchmark.Measurement], model: str, marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL) + marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) globals = { # Gen params - "num_bits": num_bits, + "quant_type": quant_type, "group_size": group_size, "size_m": size_m, "size_n": size_n, @@ -128,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501 + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -138,19 +141,19 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501 + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_gemm_fp32", ).blocked_autorange(min_run_time=min_run_time)) - if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS + if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): results.append( benchmark.Timer( stmt= - "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501 + "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -160,7 +163,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 + "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -196,9 +199,10 @@ def main(args): ) > 0 and is_k_full not in args.limit_k_full: continue - for num_bits in MARLIN_SUPPORTED_NUM_BITS: - if len(args.limit_num_bits - ) > 0 and num_bits not in args.limit_num_bits: + for quant_type in query_marlin_supported_quant_types( + False): + if len(args.limit_num_bits) > 0 and \ + quant_type.size_bits not in args.limit_num_bits: continue for group_size in MARLIN_SUPPORTED_GROUP_SIZES: @@ -215,8 +219,8 @@ def main(args): for size_m in args.batch_sizes: bench_run(results, model, act_order, is_k_full, - num_bits, group_size, size_m, size_k, - size_n) + quant_type, group_size, size_m, + size_k, size_n) compare = benchmark.Compare(results) compare.print() diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 118f9b28e0ae..3ba3a2b6a93c 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -113,6 +113,5 @@ define_gpu_extension_target( WITH_SOABI ) -add_custom_target(default) message(STATUS "Enabling C extension.") add_dependencies(default _C) diff --git a/csrc/registration.h b/csrc/core/registration.h similarity index 100% rename from csrc/registration.h rename to csrc/core/registration.h diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp new file mode 100644 index 000000000000..9f78402eee2a --- /dev/null +++ b/csrc/core/scalar_type.hpp @@ -0,0 +1,382 @@ +#pragma once + +#include + +namespace vllm { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// ScalarTypeTorch is a subclass of ScalarType that is compatible with +// TORCH_LIBRARY, making it accessible from Python as well meaning this class +// can be used as a argument for custom operators, helping to simplify these +// interfaces. +// +// The type definitions on the Python side can be found in: vllm/_core_ext.pyi +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : int64_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa, + int64_t bias, bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + bias(bias), + signed_(signed_), + finite_values_only(finite_values_only), + nan_repr(nan_repr){}; + + static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) { + return ScalarType(true, 0, size_bits - 1, bias); + } + + static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) { + return ScalarType(false, 0, size_bits, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(int64_t exponent, + int64_t mantissa) { + TORCH_CHECK(mantissa > 0 && exponent > 0); + return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(int64_t exponent, int64_t mantissa, + bool finite_values_only, + NanRepr nan_repr) { + TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); + TORCH_CHECK(mantissa > 0 && exponent > 0); + TORCH_CHECK(nan_repr != NAN_IEEE_754, + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions"); + return ScalarType(true, exponent, mantissa, 0, finite_values_only, + nan_repr); + } + + int64_t const exponent; // size of the exponent field (0 for integer types) + int64_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + int64_t const bias; // stored values equal value + bias, + // used for quantized type + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + int64_t size_bits() const { return mantissa + exponent + is_signed(); } + bool is_signed() const { return signed_; } + bool is_integer() const { return exponent == 0; } + bool is_floating_point() const { return exponent > 0; } + bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && + nan_repr == NAN_IEEE_754; + } + bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; } + bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + bool has_bias() const { return bias != 0; } + + private: + double _floating_point_max() const { + TORCH_CHECK(mantissa <= 52 && exponent <= 11, + "Cannot represent max/min as a double for type ", str()); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + TORCH_CHECK(exponent < 11, + "Cannot represent max/min as a double for type ", str()); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = + max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = + (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), + "Cannot represent max as a int64_t"); + return {(int64_t(1) << mantissa) - 1}; + } + } + + std::variant _raw_min() const { + if (is_floating_point()) { + TORCH_CHECK(is_signed(), + "We currently assume all floating point types are signed"); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + TORCH_CHECK(!is_signed() || size_bits() <= 64, + "Cannot represent min as a int64_t"); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + std::variant max() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + std::variant min() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_min()); + } + + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && + bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && + nan_repr == other.nan_repr; + } +}; + +// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from +// torch::CustomClassHolder), we use multiple inheritance here since we cannot +// have ScalarType inherit from torch::CustomClassHolder and have a constexpr +// constructor at the same time (torch::CustomClassHolder does not have a +// constexpr destructor) +class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { + public: + ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias, + bool _signed) + : ScalarType(exponent, mantissa, bias, _signed){}; + + ScalarTypeTorch(ScalarType type) : ScalarType(type){}; + + using Base = ScalarType; + using Self = ScalarTypeTorch; + using SelfPtr = c10::intrusive_ptr; + + static SelfPtr int_(int64_t size_bits, c10::optional bias) { + return c10::make_intrusive( + ScalarType::int_(size_bits, bias.value_or(0))); + } + + static SelfPtr uint(int64_t size_bits, c10::optional bias) { + return c10::make_intrusive( + ScalarType::uint(size_bits, bias.value_or(0))); + } + + static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) { + return c10::make_intrusive( + ScalarType::float_IEEE754(exponent, mantissa)); + } + + static SelfPtr float_(int64_t exponent, int64_t mantissa, + bool finite_values_only, int64_t nan_repr) { + return c10::make_intrusive(ScalarType::float_( + exponent, mantissa, finite_values_only, NanRepr(nan_repr))); + } + + template + static void bind_readonly_property(torch::class_& cls, + std::string const& name, T Base::*field) { + auto getter_func = [field = std::move(field)](SelfPtr const& self) { + if constexpr (std::is_member_function_pointer_v) { + return (self.get()->*field)(); + } else { + return self.get()->*field; + } + }; + + cls.def_property(name, getter_func); + } + + template + static void bind_function(torch::class_& cls, const std::string& name, + MemberFunc Cls::*member) { + cls.def(name, [member = std::move(member)](SelfPtr const& self) { + return (self.get()->*member)(); + }); + } + + template + static void bind_function(torch::class_& cls, const std::string& name, + Func func) { + cls.def(name, func); + } + + template + static void bind_static_function(torch::class_& cls, + const std::string& name, Func func) { + cls.def_static(name, func); + } + + static void bind_class(torch::Library& lib) { + auto cls = lib.class_("ScalarType") + .def(torch::init()); + + // Bind Properties + bind_readonly_property(cls, "mantissa", &Base::mantissa); + bind_readonly_property(cls, "exponent", &Base::exponent); + bind_readonly_property(cls, "bias", &Base::bias); + bind_readonly_property(cls, "signed", &Base::is_signed); + bind_readonly_property(cls, "size_bits", &Base::size_bits); + + // Bind member functions + bind_function(cls, "is_signed", &Base::is_signed); + bind_function(cls, "is_integer", &Base::is_integer); + bind_function(cls, "is_floating_point", &Base::is_floating_point); + bind_function(cls, "is_ieee_754", &Base::is_ieee_754); + bind_function(cls, "has_nans", &Base::has_nans); + bind_function(cls, "has_infs", &Base::has_infs); + bind_function(cls, "has_bias", &Base::has_bias); + + bind_function(cls, "max", [](SelfPtr const& self) { + return std::visit([](auto arg) { return c10::IValue(arg); }, + self.get()->max()); + }); + bind_function(cls, "min", [](SelfPtr const& self) { + return std::visit([](auto arg) { return c10::IValue(arg); }, + self.get()->min()); + }); + + bind_function(cls, "__str__", &Base::str); + bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) { + return *self == *other; + }); + bind_function(cls, "__repr__", [](SelfPtr const& self) { + return "ScalarType." + self.get()->str(); + }); + + // Bind static functions (convenience constructors) + bind_static_function(cls, "int_", &ScalarTypeTorch::int_); + bind_static_function(cls, "uint", &ScalarTypeTorch::uint); + bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754); + bind_static_function(cls, "float_", &ScalarTypeTorch::float_); + } +}; + +using ScalarTypeTorchPtr = c10::intrusive_ptr; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE3M2f = + ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = + ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +}; // namespace vllm diff --git a/csrc/core/torch_bindings.cpp b/csrc/core/torch_bindings.cpp new file mode 100644 index 000000000000..f60254189a2f --- /dev/null +++ b/csrc/core/torch_bindings.cpp @@ -0,0 +1,16 @@ +#include + +#include "scalar_type.hpp" +#include "registration.h" + +// Note the CORE exstension will be built for (almost) all hardware targets so +// new additions must account for this. (currently not built for TPU and Neuron) + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) { + // ScalarType, a custom class for representing data types that supports + // quantized types, declared here so it can be used when creating interfaces + // for custom ops. + vllm::ScalarTypeTorch::bind_class(lib); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 7d549e271a30..cf7d977da7c1 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -1,6 +1,6 @@ #include "cache.h" #include "ops.h" -#include "registration.h" +#include "core/registration.h" #include diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 243752b9a9e8..86e42af44df1 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -1,4 +1,4 @@ -#include "registration.h" +#include "core/registration.h" #include "moe_ops.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { diff --git a/csrc/ops.h b/csrc/ops.h index f274a7e647b9..3bd4a9eda5ee 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -3,6 +3,8 @@ #include #include +#include "core/scalar_type.hpp" + void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, @@ -84,14 +86,16 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k); torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool has_zp, bool use_fp32_reduce); diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 26cc248e6ac5..edf19365c809 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -21,6 +21,7 @@ #include "marlin.cuh" #include "marlin_dtypes.cuh" +#include "core/scalar_type.hpp" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert(std::is_same::value || \ @@ -71,14 +72,15 @@ __global__ void Marlin( bool use_fp32_reduce // whether to use fp32 global reduce ) {} -} // namespace gptq_marlin +} // namespace marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full) { + bool is_k_full, bool has_zp) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -1963,18 +1965,29 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp, - void* s, void* zp, void* g_idx, void* perm, void* a_tmp, - int prob_m, int prob_n, int prob_k, void* workspace, - int num_bits, bool has_act_order, bool is_k_full, - bool has_zp, int num_groups, int group_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, int sms, - int max_par, bool use_fp32_reduce) { - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, + void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m, + int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, bool has_zp, int num_groups, int group_size, + int dev, cudaStream_t stream, int thread_k, int thread_n, + int sms, int max_par, bool use_fp32_reduce) { + if (has_zp) { + TORCH_CHECK( + q_type == vllm::kU4 || q_type == vllm::kU8, + "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + TORCH_CHECK( + q_type == vllm::kU4B8 || q_type == vllm::kU8B128, + "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + q_type.str()); + } + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); + // TODO: remove alias when we start supporting other 8bit types + int num_bits = q_type.size_bits(); int tot_m = prob_m; int tot_m_blocks = div_ceil(tot_m, 16); int pad = 16 * tot_m_blocks - tot_m; @@ -2126,19 +2139,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp, } } -} // namespace gptq_marlin +} // namespace marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool has_zp, bool use_fp32_reduce) { - // Verify num_bits - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; + if (has_zp) { + TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", + b_q_type->str()); + } else { + TORCH_CHECK( + *b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + b_q_type->str()); + } + + int pack_factor = 32 / b_q_type->size_bits(); // Verify A TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), @@ -2265,21 +2287,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { - marlin::marlin_mm_f16i4( + marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, + workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); } else if (a.scalar_type() == at::ScalarType::BFloat16) { - marlin::marlin_mm_f16i4( + marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, + workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); } else { diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 3c50f1786bc6..93445a386593 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -27,6 +27,7 @@ #include #include "common/base.h" +#include "core/scalar_type.hpp" #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 @@ -86,7 +87,8 @@ __global__ void Marlin_24( torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k) { TORCH_CHECK_NOT_IMPLEMENTED( @@ -1025,13 +1027,14 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k) { // Verify num_bits - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; + TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + int pack_factor = 32 / b_q_type->size_bits(); // Verify M TORCH_CHECK(size_m == a.size(0), @@ -1126,8 +1129,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, marlin_24::marlin_cuda_2_4( a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), - num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_m, sms, max_par); + b_q_type->size_bits(), groupsize, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par); return c; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index bf8cefa8d471..7c0d617fc8b3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -1,7 +1,7 @@ #include "cache.h" #include "cuda_utils.h" #include "ops.h" -#include "registration.h" +#include "core/registration.h" #include diff --git a/setup.py b/setup.py index 91307e8a9406..b146299f8269 100644 --- a/setup.py +++ b/setup.py @@ -271,6 +271,10 @@ def _build_custom_ops() -> bool: return _is_cuda() or _is_hip() or _is_cpu() +def _build_core_ext() -> bool: + return not _is_neuron() and not _is_tpu() + + def get_hipcc_rocm_version(): # Run the hipcc --version command result = subprocess.run(['hipcc', '--version'], @@ -433,6 +437,9 @@ def _read_requirements(filename: str) -> List[str]: ext_modules = [] +if _build_core_ext(): + ext_modules.append(CMakeExtension(name="vllm._core_C")) + if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) @@ -477,7 +484,7 @@ def _read_requirements(filename: str) -> List[str]: extras_require={ "tensorizer": ["tensorizer>=2.9.0"], }, - cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {}, + cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {}, package_data=package_data, entry_points={ "console_scripts": [ diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 03acbf7968ff..0b7ed26a39e1 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -1,8 +1,6 @@ import pytest import torch -# ruff: noqa: F401 -import vllm._C from tests.kernels.quant_utils import ref_dynamic_per_token_quant from vllm._custom_ops import scaled_int8_quant diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index a9e34ac8a7aa..2f58ffda2140 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -9,14 +9,14 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) from vllm.model_executor.layers.quantization.qqq import ( MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N, MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS, - marlin_make_empty_g_idx, marlin_permute_scales) + MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, + marlin_permute_scales, query_marlin_supported_quant_types) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( pack_fp8_to_int32) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -27,8 +27,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501 marlin_qqq_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp, - sort_weights) + awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] @@ -65,12 +64,13 @@ def rand_data(shape, dtype=torch.float16): reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", + query_marlin_supported_quant_types(False)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, - mnk_factors): +def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, + act_order, mnk_factors): m_factor, n_factor, k_factor = mnk_factors size_m = m_factor @@ -95,11 +95,11 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, b_weight = rand_data((size_k, size_n)) # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = quantize_weights(b_weight, num_bits, - group_size, act_order) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + b_weight, quant_type, group_size, act_order) # Pack to GPTQ format - q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing @@ -108,8 +108,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Pack to Marlin format - weight_perm = get_weight_perm(num_bits) - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, + weight_perm) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.gptq_marlin_repack( @@ -117,7 +118,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, sort_indices, size_k, size_n, - num_bits, + quant_type.size_bits, ) torch.cuda.synchronize() @@ -128,10 +129,11 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", + query_marlin_supported_quant_types(False)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, +def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors): m_factor, n_factor, k_factor = mnk_factors @@ -150,22 +152,25 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, b_weight = rand_data((size_k, size_n)) # Quantize - w_ref, q_w, s, zp = quantize_weights_with_zp(b_weight, num_bits, - group_size) + w_ref, q_w, s, zp = quantize_weights(b_weight, + quant_type, + group_size, + zero_points=True) # Pack to AWQ format - q_w_awq = awq_pack(q_w, num_bits, size_k, size_n) + q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n) # Pack to Marlin format - weight_perm = get_weight_perm(num_bits) - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, + weight_perm) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.awq_marlin_repack( q_w_awq, size_k, size_n, - num_bits, + quant_type.size_bits, ) torch.cuda.synchronize() @@ -176,7 +181,8 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", + query_marlin_supported_quant_types(False)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @@ -185,7 +191,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, def test_gptq_marlin_gemm( k_chunk, n_chunk, - num_bits, + quant_type, group_size, mnk_factors, act_order, @@ -211,7 +217,7 @@ def test_gptq_marlin_gemm( b_weight = rand_data((size_k, size_n)) w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( - b_weight, num_bits, group_size, act_order) + b_weight, quant_type, group_size, act_order) marlin_zp = marlin_make_empty_g_idx(marlin_s.device) @@ -226,7 +232,7 @@ def test_gptq_marlin_gemm( g_idx, sort_indices, workspace.scratch, - num_bits, + quant_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], @@ -248,10 +254,10 @@ def test_gptq_marlin_gemm( reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS) -@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) @pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, +def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors): m_factor, n_factor, k_factor = mnk_factors @@ -266,7 +272,7 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, b_weight = rand_data((size_k, size_n)) (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size) + marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size) workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL) @@ -279,7 +285,7 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, marlin_24_meta, marlin_24_s, workspace_24.scratch, - num_bits, + quant_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], @@ -371,14 +377,15 @@ def test_fp8_marlin_gemm( reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", + query_marlin_supported_quant_types(True)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) def test_awq_marlin_gemm( k_chunk, n_chunk, - num_bits, + quant_type, group_size, mnk_factors, use_fp32_reduce, @@ -396,7 +403,7 @@ def test_awq_marlin_gemm( b_weight = rand_data((size_k, size_n)) w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( - b_weight, num_bits, group_size) + b_weight, quant_type, group_size) g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) @@ -414,7 +421,7 @@ def test_awq_marlin_gemm( g_idx, sort_indices, workspace.scratch, - num_bits, + quant_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], diff --git a/tests/test_scalartype.py b/tests/test_scalartype.py new file mode 100644 index 000000000000..1201aaa92ea8 --- /dev/null +++ b/tests/test_scalartype.py @@ -0,0 +1,36 @@ +import pytest +import torch + +from vllm.scalar_type import scalar_types + + +@pytest.mark.parametrize("type_tuple", ( + (-8, 7, scalar_types.int4), + (0, 15, scalar_types.uint4), + (-8, 7, scalar_types.uint4b8), + (-128, 127, scalar_types.uint8b128), + (-28., 28., scalar_types.float6_e3m2f), + (torch.int8, scalar_types.int8), + (torch.uint8, scalar_types.uint8), + (torch.float8_e5m2, scalar_types.float8_e5m2), + (torch.float8_e4m3fn, scalar_types.float8_e4m3fn), + (torch.bfloat16, scalar_types.float16_e8m7), + (torch.float16, scalar_types.float16_e5m10), +), + ids=lambda x: str(x)) +def test_scalar_type_min_max(type_tuple): + print(type_tuple) + if len(type_tuple) == 3: + min, max, t = type_tuple + else: + torch_type, t = type_tuple + if torch_type.is_floating_point: + min = torch.finfo(torch_type).min + max = torch.finfo(torch_type).max + else: + min = torch.iinfo(torch_type).min + max = torch.iinfo(torch_type).max + + print(t, min, max, t.min(), t.max()) + assert min == t.min() + assert max == t.max() diff --git a/vllm/_core_ext.py b/vllm/_core_ext.py new file mode 100644 index 000000000000..e3b9fbb93891 --- /dev/null +++ b/vllm/_core_ext.py @@ -0,0 +1,177 @@ +import importlib.util +from enum import Enum +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) +core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +if TYPE_CHECKING or not core_C_available: + # On platforms were we cannot use/build the C++ core extension (i.e. namely + # neuron and tpu), we define the mock ScalarType class here that partially + # mimics the C++ ScalarType class. + # + # We also use this provide type signatures to the Python LSP for the methods + # in the C++ ScalarType class. So these type signatures should be kept + # in sync with csrc/core/scalar_type.hpp + + from dataclasses import dataclass + + @dataclass(frozen=True) + class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + _finite_values_only: bool = False + """ + Private: if NANs are supported, used `has_infs()` instead. + """ + + nan_repr: int = NanRepr.IEEE_754.value + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + @property + def size_bits(self): + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + raise NotImplementedError + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + raise NotImplementedError + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + ... + + def is_floating_point(self): + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self): + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self): + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self): + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self): + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + raise NotImplementedError + + def __repr__(self) -> str: + raise NotImplementedError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + return cls(size_bits - 1, size_bits, bias if bias else 0, True) + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + return cls(size_bits, size_bits, bias if bias else 0, False) + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + return cls(exponent, mantissa, 0, True) + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: int): + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + return cls(exponent, mantissa, 0, True, finite_values_only, + nan_repr) + +elif core_C_available: + try: + import vllm._core_C # noqa: F401 + except ImportError as e: + logger.warning("Failed to import from vllm._core_C with %r", e) + + ScalarType = torch.classes._core_C.ScalarType diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6cd77f75cae8..ad7e5bd19933 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -4,6 +4,7 @@ import torch +from vllm._core_ext import ScalarType from vllm.logger import init_logger logger = init_logger(__name__) @@ -220,10 +221,10 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # marlin_24 def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, num_bits: int, size_m: int, - size_n: int, size_k: int) -> torch.Tensor: + workspace: torch.Tensor, b_q_type: ScalarType, + size_m: int, size_n: int, size_k: int) -> torch.Tensor: return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, - workspace, num_bits, size_m, + workspace, b_q_type, size_m, size_n, size_k) @@ -279,14 +280,22 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) -def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, b_zeros: torch.Tensor, - g_idx: torch.Tensor, perm: torch.Tensor, - workspace: torch.Tensor, num_bits: int, size_m: int, - size_n: int, size_k: int, is_k_full: bool, has_zp: bool, - use_fp32_reduce: bool) -> torch.Tensor: +def gptq_marlin_gemm(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False) -> torch.Tensor: return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, - g_idx, perm, workspace, num_bits, + g_idx, perm, workspace, b_q_type, size_m, size_n, size_k, is_k_full, has_zp, use_fp32_reduce) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 5ffbb8e854e8..2cc080608c7a 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -10,11 +10,11 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_awq_marlin_linear, awq_to_marlin_zero_points, - check_awq_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_awq_marlin_supported, - verify_marlin_supports_shape) + apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, + replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -22,20 +22,31 @@ class AWQMarlinConfig(QuantizationConfig): """Config class for AWQ Marlin""" + # num_bits -> type + TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + def __init__(self, weight_bits: int, group_size: int, has_zp: bool, lm_head_quantized: bool) -> None: - self.weight_bits = weight_bits - self.pack_factor = 32 // self.weight_bits # packed into 32bits + self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size self.has_zp = has_zp self.lm_head_quantized = lm_head_quantized - verify_awq_marlin_supported(num_bits=self.weight_bits, - group_size=self.group_size, - has_zp=self.has_zp) + if weight_bits not in self.TYPE_MAP: + raise ValueError(f"Unsupported num_bits = {weight_bits}. " + f"Supported num_bits = {self.TYPE_MAP.keys()}") + + self.quant_type = self.TYPE_MAP[weight_bits] + + verify_marlin_supported(self.quant_type, + group_size=self.group_size, + has_zp=self.has_zp) def __repr__(self) -> str: - return (f"AWQMarlinConfig(weight_bits={self.weight_bits}, " + return (f"AWQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " f"has_zp={self.has_zp}, " f"lm_head_quantized={self.lm_head_quantized})") @@ -110,11 +121,13 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): if (num_bits is None or group_size is None or has_zp is None): return False - return check_awq_marlin_supported( - num_bits=num_bits, - group_size=group_size, - has_zp=has_zp, - min_capability=cls.get_min_capability()) + if num_bits not in cls.TYPE_MAP: + return False + + return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits], + group_size=group_size, + has_zp=has_zp, + min_capability=cls.get_min_capability()) class AWQMarlinLinearMethod(LinearMethodBase): @@ -226,7 +239,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.qweight, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.quant_type.size_bits) replace_tensor(layer, "qweight", marlin_qweight) # Permute scales from AWQ format to marlin format. @@ -242,7 +255,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.qzeros, size_k=layer.num_groups, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.quant_type.size_bits) replace_tensor(layer, "qzeros", marlin_zp) # Not-used @@ -263,7 +276,7 @@ def apply( g_idx=layer.g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, - num_bits=self.quant_config.weight_bits, + quant_type=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index b8ffb22d7a89..c1adfdb2980b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -9,9 +9,13 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) from vllm.model_executor.utils import set_weight_attrs +from vllm.scalar_type import scalar_types __all__ = ["CompressedTensorsW4A16Sparse24"] -W4A16SPARSE24_SUPPORTED_BITS = [4] +W4A16SPARSE24_SUPPORTED_TYPES_MAP = { + 4: scalar_types.uint4b8, +} +W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys()) class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): @@ -22,9 +26,15 @@ def __init__(self, group_size: Optional[int] = None): self.strategy = strategy self.group_size = group_size - self.num_bits = num_bits self.tile_size = 16 + if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}") + + self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits] + if self.strategy == "group" and self.group_size is None: raise ValueError( "group_size must be given when using strategy group") @@ -43,7 +53,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - pack_factor = 32 // self.num_bits + pack_factor = 32 // self.quant_type.size_bits output_size_per_partition = sum(output_partition_sizes) qweight = Parameter( @@ -138,7 +148,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, size_n = scales.shape[1] output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, - workspace, self.num_bits, size_m, + workspace, self.quant_type, size_m, size_n, size_k) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index a41962ccd66d..b8880f7ac136 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -8,12 +8,17 @@ CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_gptq_marlin_supported, + marlin_permute_scales, replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.utils import set_weight_attrs +from vllm.scalar_type import scalar_types __all__ = ["CompressedTensorsWNA16"] -WNA16_SUPPORTED_BITS = [4, 8] +WNA16_SUPPORTED_TYPES_MAP = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128, +} +WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) class CompressedTensorsWNA16(CompressedTensorsScheme): @@ -22,8 +27,8 @@ def __init__(self, strategy: str, num_bits: int, group_size: Optional[int] = None): - self.num_bits = num_bits - self.pack_factor = 32 // self.num_bits + + self.pack_factor = 32 // num_bits self.strategy = strategy self.group_size: int @@ -37,10 +42,16 @@ def __init__(self, else: self.group_size = group_size + if num_bits not in WNA16_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}") + + self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] + # Verify supported on platform. - verify_gptq_marlin_supported(num_bits=self.num_bits, - group_size=self.group_size, - is_sym=True) + verify_marlin_supported(quant_type=self.quant_type, + group_size=self.group_size) @classmethod def get_min_capability(cls) -> int: @@ -150,7 +161,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.num_bits) + num_bits=self.quant_type.size_bits) replace_tensor(layer, "weight_packed", marlin_qweight) # Permute scales from compressed-tensors format to marlin format. @@ -172,7 +183,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, g_idx=layer.g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, - num_bits=self.num_bits, + wtype=self.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, is_k_full=True, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index bdcc9c3b4f0c..4a11b1497107 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -10,11 +10,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full, + apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_gptq_marlin_supported, verify_marlin_supports_shape) + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -22,6 +23,12 @@ class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" + # (num_bits, is_sym) -> quant_type + TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + def __init__(self, weight_bits: int, group_size: int, desc_act: bool, is_sym: bool, lm_head_quantized: bool) -> None: if desc_act and group_size == -1: @@ -29,20 +36,23 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool, # (since we have only one group per output channel) desc_act = False - self.weight_bits = weight_bits - self.pack_factor = 32 // self.weight_bits # packed into int32 + self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size self.desc_act = desc_act - self.is_sym = is_sym self.lm_head_quantized = lm_head_quantized + if (weight_bits, is_sym) not in self.TYPE_MAP: + raise ValueError("Unsupported quantization config: " + f"bits={weight_bits}, sym={is_sym}") + + self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + # Verify supported on platform. - verify_gptq_marlin_supported(num_bits=self.weight_bits, - group_size=self.group_size, - is_sym=self.is_sym) + verify_marlin_supported(quant_type=self.quant_type, + group_size=self.group_size) def __repr__(self) -> str: - return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " + return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " f"desc_act={self.desc_act}, " f"lm_head_quantized={self.lm_head_quantized})") @@ -122,11 +132,12 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): or desc_act is None): return False - return check_gptq_marlin_supported( - num_bits=num_bits, - group_size=group_size, - is_sym=sym, - min_capability=cls.get_min_capability()) + if (num_bits, sym) not in cls.TYPE_MAP: + return False + + return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)], + group_size=group_size, + min_capability=cls.get_min_capability()) class GPTQMarlinLinearMethod(LinearMethodBase): @@ -293,7 +304,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.quant_type.size_bits) replace_tensor(layer, "qweight", marlin_qweight) # Permute scales from autogptq format to marlin format. @@ -319,7 +330,7 @@ def apply( g_idx=layer.g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, - num_bits=self.quant_config.weight_bits, + wtype=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, is_k_full=layer.is_k_full, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index e708c4da95af..cafd100a2f40 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -17,9 +18,10 @@ GPTQ_MARLIN_24_MIN_THREAD_K = 128 GPTQ_MARLIN_24_MAX_PARALLEL = 64 -GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [ + scalar_types.uint4b8, scalar_types.uint8b128 +] GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] -GPTQ_MARLIN_24_SUPPORTED_SYM = [True] class GPTQMarlin24Config(QuantizationConfig): @@ -31,14 +33,19 @@ def __init__( weight_bits: int, group_size: int, ) -> None: - self.weight_bits = weight_bits + quant_type = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128, + }.get(weight_bits) + self.group_size = group_size # Verify - if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: + if quant_type is None or \ + quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES: raise ValueError( - f"Marlin_24 does not support weight_bits = {self.weight_bits}. " - f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} " + f"Marlin_24 does not support quant_type = {quant_type}. " + f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} " "are supported.") if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: raise ValueError( @@ -46,8 +53,10 @@ def __init__( f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} " "are supported.") + self.quant_type = quant_type + # 4 Bits packed into 32 bit datatype. - self.pack_factor = 32 // self.weight_bits + self.pack_factor = 32 // self.quant_type.size_bits # Tile size used by marlin kernels. self.tile_size = 16 @@ -66,8 +75,8 @@ def __init__( self.perm_len = 1024 def __repr__(self) -> str: - return "Marlin24Config(weight_bits={}, group_size={})".format( - self.weight_bits, self.group_size) + return "Marlin24Config(quant_type={}, group_size={})".format( + self.quant_type, self.group_size) @classmethod def get_name(cls) -> str: @@ -279,7 +288,7 @@ def apply( output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, workspace, - self.quant_config.weight_bits, + self.quant_config.quant_type, size_m, size_n, size_k) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index b789ca20cadb..6e84d3621936 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -5,6 +5,7 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types from .quant_utils import pack_cols, unpack_cols @@ -13,7 +14,6 @@ GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MAX_PARALLEL = 16 -MARLIN_SUPPORTED_NUM_BITS = [4, 8] MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] # In case there is a performance issue with Marlin, the variable below can be @@ -22,76 +22,70 @@ USE_FP32_REDUCE_DEFAULT = True -def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool, - min_capability: Optional[int], - has_zp: bool) -> Tuple[bool, Optional[str]]: - if min_capability is not None: +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types(has_zp: bool, + min_capability: Optional[int] = None): + if min_capability is None: major, minor = current_platform.get_device_capability() - device_capability = major * 10 + minor - if device_capability < min_capability: - return (False, "Marlin does not support device_capability = {}" - ", the min_capability required is {}".format( - device_capability, min_capability)) - - if num_bits not in MARLIN_SUPPORTED_NUM_BITS: - return (False, "Marlin does not support weight_bits = {}. " - "Only weight_bits = {} are supported.".format( - num_bits, MARLIN_SUPPORTED_NUM_BITS)) - - if group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return (False, "Marlin does not support group_size = {}. Only " - "group_sizes = {} are supported.".format( - group_size, MARLIN_SUPPORTED_GROUP_SIZES)) - - if not has_zp and not is_sym: - return (False, - "Marlin without zero_points must have symmetric quantization") + min_capability = major * 10 + minor - return True, None + if min_capability < 80: + return [] + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] -def check_gptq_marlin_supported(num_bits: int, group_size: int, is_sym: bool, - min_capability: int) -> bool: - cond, _ = _check_marlin_supported(num_bits, - group_size, - is_sym, - min_capability, - has_zp=False) - return cond +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: -def check_awq_marlin_supported(num_bits: int, group_size: int, has_zp: bool, - min_capability: int) -> bool: - cond, _ = _check_marlin_supported(num_bits, - group_size, - False, - min_capability, - has_zp=has_zp) - return cond + if min_capability is None: + major, minor = current_platform.get_device_capability() + min_capability = major * 10 + minor + supported_types = query_marlin_supported_quant_types( + has_zp, min_capability) -def verify_gptq_marlin_supported(num_bits: int, group_size: int, - is_sym: bool) -> None: - cond, err_msg = _check_marlin_supported(num_bits, - group_size, - is_sym, - min_capability=None, - has_zp=False) - if not cond: - assert err_msg is not None - raise ValueError("GPTQ" + err_msg) + if quant_type not in supported_types: + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"min_capability = {min_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") + + return True, None + + +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + min_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + min_capability) + return cond -def verify_awq_marlin_supported(num_bits: int, group_size: int, - has_zp: bool) -> None: - cond, err_msg = _check_marlin_supported(num_bits, - group_size, - False, - min_capability=None, - has_zp=has_zp) +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None - raise ValueError("AWQ" + err_msg) + raise ValueError(err_msg) def verify_marlin_supports_shape(output_size_per_partition: int, @@ -245,7 +239,7 @@ def apply_gptq_marlin_linear( g_idx: torch.Tensor, g_idx_sort_indices: torch.Tensor, workspace: torch.Tensor, - num_bits: int, + wtype: ScalarType, output_size_per_partition: int, input_size_per_partition: int, is_k_full: bool, @@ -261,7 +255,7 @@ def apply_gptq_marlin_linear( g_idx, g_idx_sort_indices, workspace, - num_bits, + wtype, size_m=reshaped_x.shape[0], size_n=output_size_per_partition, size_k=input_size_per_partition, @@ -283,7 +277,7 @@ def apply_awq_marlin_linear( g_idx: torch.Tensor, g_idx_sort_indices: torch.Tensor, workspace: torch.Tensor, - num_bits: int, + quant_type: ScalarType, output_size_per_partition: int, input_size_per_partition: int, bias: Optional[torch.Tensor] = None, @@ -298,7 +292,7 @@ def apply_awq_marlin_linear( g_idx, g_idx_sort_indices, workspace, - num_bits, + quant_type, size_m=reshaped_x.shape[0], size_n=output_size_per_partition, size_k=input_size_per_partition, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index 541d148c761f..7d08ac6f8746 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -5,10 +5,12 @@ import numpy as np import torch +from vllm.scalar_type import ScalarType + from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points) -from .quant_utils import (get_pack_factor, quantize_weights, - quantize_weights_with_zp, sort_weights) +from .quant_utils import (get_pack_factor, gptq_quantize_weights, + quantize_weights, sort_weights) class MarlinWorkspace: @@ -90,9 +92,10 @@ def get_weight_perm(num_bits: int): return perm -def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, +def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, act_order: bool): size_k, size_n = w.shape + num_bits = quant_type.size_bits # Normalize group_size if group_size == -1: @@ -100,8 +103,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, assert group_size <= size_k # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, - act_order) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing @@ -122,7 +125,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, return res_list -def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int): +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, + group_size: int): size_k, size_n = w.shape # Normalize group_size @@ -135,13 +139,18 @@ def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int): num_groups = size_k // group_size # Quantize with zp - w_ref, q_w, s, zp = quantize_weights_with_zp(w, num_bits, group_size) + w_ref, q_w, s, zp = quantize_weights(w, + quant_type, + group_size, + zero_points=True) # Reformat to marlin - weight_perm = get_weight_perm(num_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, + weight_perm) marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - marlin_zp = marlin_zero_points(zp, num_groups, size_n, num_bits) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, + quant_type.size_bits) # Create result res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py index 648c32249a57..17d09055b1ea 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py @@ -6,8 +6,10 @@ import numpy import torch +from vllm.scalar_type import ScalarType + from .marlin_utils_test import marlin_weights -from .quant_utils import quantize_weights +from .quant_utils import gptq_quantize_weights # This is PyTorch implementation of main part of reorder_meta() @@ -348,13 +350,11 @@ def check_24(w, num_rows_to_sample=50, _verbose=False): print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") -def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): assert q_24.shape == (size_k, size_n) - # Remove zp to normalize over 0 - max_q_val = (1 << num_bits) - 1 - zp = (max_q_val + 1) // 2 - q_24_no_zp = q_24 - zp + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias # Compress q_24_no_zp = q_24_no_zp.t().contiguous() @@ -362,8 +362,8 @@ def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): q_24_no_zp) q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() - # Restore zp - q_24_comp = q_24_no_zp_comp + zp + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias # Resize meta to its actual shape (without moving any data) meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) @@ -427,7 +427,7 @@ def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int, def marlin_24_quantize( w: torch.Tensor, - num_bits: int, + quant_type: ScalarType, group_size: int, ): size_k, size_n = w.shape @@ -441,20 +441,18 @@ def marlin_24_quantize( w_24, mask_24 = inject_24(w, size_k, size_n) # Quantize - w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, - num_bits, - group_size, - act_order=False) + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False) # Compress quantized weight q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, - num_bits) + quant_type) size_k_comp = size_k // 2 # Reformat to marlin - weight_perm = get_weight_perm_24(num_bits) + weight_perm = get_weight_perm_24(quant_type.size_bits) marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, - num_bits, weight_perm) + quant_type.size_bits, weight_perm) marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) # Create result diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 7ade8bf664cc..7f9081b25770 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -4,7 +4,11 @@ import numpy import torch -SUPPORTED_NUM_BITS = [4, 8] +from vllm.model_executor.layers.quantization.qqq import ( + MARLIN_QQQ_SUPPORTED_NUM_BITS) +from vllm.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] # Note: this is a hack. We should update each model to register the @@ -45,7 +49,7 @@ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: def get_pack_factor(num_bits): - assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" return 32 // num_bits @@ -74,24 +78,23 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): ) -def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, - act_order: bool): +def quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + zero_points: bool = False): + assert quant_type.is_integer(), \ + "Floating point quantization may work but has not been tested" + orig_device = w.device + orig_type = w.dtype size_k, size_n = w.shape assert w.is_floating_point(), "w must be float" - assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" if group_size == -1: group_size = size_k assert group_size <= size_k - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - # Reshape to [groupsize, -1] if group_size < size_k: w = w.reshape((-1, group_size, size_n)) @@ -99,16 +102,34 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, w = w.reshape((group_size, -1)) # Compute scale for each group - s = torch.max(torch.abs(w), 0, keepdim=True)[0] - s *= 2 / max_q_val # 2 => symmetric + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ + .clamp(min_q_val, max_q_val).int() + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) + maybe_w_zp = None # Quantize - q_w = torch.round(w / s).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias # Restore original shapes if group_size < size_k: @@ -119,90 +140,48 @@ def reshape_w(w): w = w.reshape((size_k, size_n)).contiguous() return w - q_w = reshape_w(q_w) + w_q = reshape_w(w_q) w_ref = reshape_w(w_ref) - s = s.reshape((-1, size_n)).contiguous() + w_s = w_s.reshape((-1, size_n)).contiguous() - # Apply act_order - g_idx = torch.empty(0, dtype=torch.int, device=w.device) - rand_perm = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - assert ( - group_size < size_k - ), "For act_order, groupsize = {} must be less than size_k = {}".format( - group_size, size_k) - - w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size) + if zero_points: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) return ( w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s.to(device=orig_device), - g_idx.to(device=orig_device), - rand_perm.to(device=orig_device), + w_q.to(device=orig_device), + w_s.to(device=orig_device), + maybe_w_zp, ) -def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape +def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, + group_size: int, act_order: bool): + size_k, _ = w.shape assert w.is_floating_point(), "w must be float" - assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \ + f"Unsupported gptq type = {quant_type}" assert group_size in SUPPORTED_GROUP_SIZES + [ size_k ], f"Unsupported groupsize = {group_size}" - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - max_q_val = 2**num_bits - 1 - min_q_val = 0 + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) - # Reshape to [groupsize, -1] - if group_size < size_k: - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - # Compute scale for each group - max = torch.max(w, 0, keepdim=True)[0] - min = torch.min(w, 0, keepdim=True)[0] - s = (max - min).clamp(min=1e-5) / max_q_val - - # Compute zero-point for each group - zp = (-torch.round(min / s)).clamp(min_q_val, max_q_val).int() - - # Quantize - q_w = torch.round(w / s).int() + zp - q_w = torch.clamp(q_w, min_q_val, max_q_val) - - # Compute ref (dequantized) - w_ref = (q_w - zp).half() * s - - # Restore original shapes - if group_size < size_k: - - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k) - s = s.reshape((-1, size_n)).contiguous() - zp = zp.reshape((-1, size_n)).contiguous() + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size) - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s.to(device=orig_device), - zp.to(device=orig_device), - ) + return w_ref, w_q, w_s, g_idx, rand_perm # QQQ employs different quant schemes for per-group and @@ -212,7 +191,8 @@ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): size_k, size_n = w.shape assert w.is_floating_point(), "w must be float" - assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \ + f"Unsupported num_bits = {num_bits}" assert group_size in SUPPORTED_GROUP_SIZES + [ size_k ], f"Unsupported groupsize = {group_size}" diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py new file mode 100644 index 000000000000..eb491dd1554a --- /dev/null +++ b/vllm/scalar_type.py @@ -0,0 +1,35 @@ +from ._core_ext import NanRepr, ScalarType + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, + NanRepr.EXTD_RANGE_MAX_MIN.value) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value) + + # "gptq" types + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 From 05308891e203329a733bcf29a3452b15b75b5eb4 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:55:40 -0700 Subject: [PATCH 0038/3246] [Core] Pipeline parallel with Ray ADAG (#6837) Support pipeline-parallelism with Ray accelerated DAG. Signed-off-by: Rui Qiao --- Dockerfile | 2 + MANIFEST.in | 1 + requirements-adag.txt | 3 + requirements-test.txt | 3 + tests/distributed/test_pipeline_parallel.py | 51 +++++--- tests/utils.py | 31 ++++- vllm/envs.py | 12 +- vllm/executor/ray_gpu_executor.py | 137 +++++++++++++------- vllm/executor/ray_utils.py | 30 ++++- vllm/worker/worker_base.py | 6 +- 10 files changed, 199 insertions(+), 77 deletions(-) create mode 100644 requirements-adag.txt diff --git a/Dockerfile b/Dockerfile index 7294707046ab..49aaea2949ac 100644 --- a/Dockerfile +++ b/Dockerfile @@ -42,6 +42,7 @@ WORKDIR /workspace # install build and runtime dependencies COPY requirements-common.txt requirements-common.txt +COPY requirements-adag.txt requirements-adag.txt COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-cuda.txt @@ -78,6 +79,7 @@ COPY setup.py setup.py COPY cmake cmake COPY CMakeLists.txt CMakeLists.txt COPY requirements-common.txt requirements-common.txt +COPY requirements-adag.txt requirements-adag.txt COPY requirements-cuda.txt requirements-cuda.txt COPY pyproject.toml pyproject.toml COPY vllm vllm diff --git a/MANIFEST.in b/MANIFEST.in index 82be639ef4d7..5a41e5e71418 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include LICENSE +include requirements-adag.txt include requirements-common.txt include requirements-cuda.txt include requirements-rocm.txt diff --git a/requirements-adag.txt b/requirements-adag.txt new file mode 100644 index 000000000000..e77f90fb8f85 --- /dev/null +++ b/requirements-adag.txt @@ -0,0 +1,3 @@ +# Dependencies for Ray accelerated DAG +cupy-cuda12x +ray >= 2.32 \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index df247496be16..5f3fd15c7ee5 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,3 +1,6 @@ +# Needed for Ray accelerated DAG tests +-r requirements-adag.txt + # testing pytest tensorizer>=2.9.0 diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index f632caba9017..ab325e096692 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -15,22 +15,31 @@ @pytest.mark.parametrize( - "TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, DIST_BACKEND", - [ - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - ]) -@fork_new_process_for_each_test + ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " + "MODEL_NAME, DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL"), [ + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + ]) def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, - DIST_BACKEND): + DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL): if VLLM_MULTI_NODE and DIST_BACKEND == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") @@ -67,8 +76,18 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, if EAGER_MODE: pp_args.append("--enforce-eager") tp_args.append("--enforce-eager") + pp_env = None + if USE_RAY_ADAG: + assert DIST_BACKEND == "ray", ( + "Ray ADAG is only supported with Ray distributed backend") + pp_env = { + "VLLM_USE_RAY_COMPILED_DAG": "1", + "VLLM_USE_RAY_SPMD_WORKER": "1", + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": + str(int(USE_RAY_ADAG_NCCL)), + } - compare_two_settings(MODEL_NAME, pp_args, tp_args) + compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env) @pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [ diff --git a/tests/utils.py b/tests/utils.py index f3ee801ee774..dd8af8e3afe7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,7 +7,7 @@ import warnings from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import openai import ray @@ -57,6 +57,7 @@ def __init__( model: str, cli_args: List[str], *, + env_dict: Optional[Dict[str, str]] = None, auto_port: bool = True, ) -> None: if auto_port: @@ -77,6 +78,8 @@ def __init__( # the current process might initialize cuda, # to be safe, we should use spawn method env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + if env_dict is not None: + env.update(env_dict) self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args, env=env, stdout=sys.stdout, @@ -89,6 +92,11 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.proc.terminate() + try: + self.proc.wait(3) + except subprocess.TimeoutExpired: + # force kill if needed + self.proc.kill() def _wait_for_server(self, *, url: str, timeout: float): # run health check @@ -127,10 +135,21 @@ def get_async_client(self): ) -def compare_two_settings(model: str, arg1: List[str], arg2: List[str]): +def compare_two_settings(model: str, + arg1: List[str], + arg2: List[str], + env1: Optional[Dict[str, str]] = None, + env2: Optional[Dict[str, str]] = None): """ - Launch API server with two different sets of arguments and compare the - results of the API calls. The arguments are after the model name. + Launch API server with two different sets of arguments/environments + and compare the results of the API calls. + + Args: + model: The model to test. + arg1: The first set of arguments to pass to the API server. + arg2: The second set of arguments to pass to the API server. + env1: The first set of environment variables to pass to the API server. + env2: The second set of environment variables to pass to the API server. """ tokenizer = AutoTokenizer.from_pretrained(model) @@ -138,8 +157,8 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]): prompt = "Hello, my name is" token_ids = tokenizer(prompt)["input_ids"] results = [] - for args in (arg1, arg2): - with RemoteOpenAIServer(model, args) as server: + for args, env in ((arg1, env1), (arg2, env2)): + with RemoteOpenAIServer(model, args, env_dict=env) as server: client = server.get_client() # test models list diff --git a/vllm/envs.py b/vllm/envs.py index 9bcb26f8e5a6..5b8a65bd6545 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -38,6 +38,7 @@ VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False + VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: bool = True VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_IMAGE_FETCH_TIMEOUT: int = 5 @@ -273,13 +274,20 @@ def get_default_config_root(): # execution on all workers. # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. "VLLM_USE_RAY_SPMD_WORKER": - lambda: bool(os.getenv("VLLM_USE_RAY_SPMD_WORKER", 0)), + lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))), # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. "VLLM_USE_RAY_COMPILED_DAG": - lambda: bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)), + lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))), + + # If the env var is set, it uses NCCL for communication in + # Ray's compiled DAG. This flag is ignored if + # VLLM_USE_RAY_COMPILED_DAG is not set. + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": + lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1")) + ), # Use dedicated multiprocess context for workers. # Both spawn and fork work diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 14007e5518d4..46d216910a08 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -105,12 +105,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # The remaining workers are the actual ray actors. self.workers: List[RayWorkerWrapper] = [] + # Used in ray compiled DAG: indexed first by PP rank, + # and then TP rank. In other words, the inner list is + # the TP group of workers for a PP rank. + self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + if self.parallel_config.ray_workers_use_nsight: ray_remote_kwargs = self._configure_ray_workers_use_nsight( ray_remote_kwargs) + logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) # Create the workers. driver_ip = get_ip() + logger.info("driver_ip: %s", driver_ip) worker_wrapper_kwargs = self._get_worker_wrapper_args() for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get("GPU", 0): @@ -142,42 +149,49 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Else, added to the list of workers. self.workers.append(worker) + logger.debug("workers: %s", self.workers) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " "GPU node.") + worker_ips = [ + ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] + for worker in self.workers + ] + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = ray.get(worker.get_node_ip.remote()) + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + # Get the set of GPU IDs used on each node. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", use_dummy_driver=True) - # the order in `worker_node_and_gpu_ids` does not necessarily match - # the machine boundaries. We need to make sure that workers in the - # same node are assigned consecutive ranks. - # examples: - # [('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [1]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [2]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [3]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [1]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [2]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [3])] # noqa - - # initialize worker ranks with -1 (unassigned) - worker_ranks = [-1 for x in worker_node_and_gpu_ids] - current_rank = 0 - while -1 in worker_ranks: - # whenever we find an unassigned worker, find the node - index = worker_ranks.index(-1) - current_node_id = worker_node_and_gpu_ids[index][0] - # assign ranks to all workers in the same node - for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): - if node_id == current_node_id: - worker_ranks[i] = current_rank - current_rank += 1 - # with the above example, worker_ranks will be [0, 4, 5, 6, 7, 1, 2, 3] - node_workers = defaultdict(list) # node id -> list of worker ranks node_gpus = defaultdict(list) # node id -> list of gpu ids - for worker_rank, (node_id, gpu_ids) in zip(worker_ranks, - worker_node_and_gpu_ids): - node_workers[node_id].append(worker_rank) + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) # `gpu_ids` can be a list of strings or integers. # convert them to integers for consistency. # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), @@ -202,16 +216,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables) - if len(node_gpus) == 1: - # in single node case, we don't need to get the IP address. - # the loopback address is sufficient - # NOTE: a node may have several IP addresses, one for each - # network interface. `get_ip()` might return any of them, - # while they might not work for communication inside the node - # if the network setup is complicated. Using the loopback address - # solves this issue, as it always works for communication inside - # the node. - driver_ip = "127.0.0.1" distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) @@ -221,8 +225,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", local_rank=node_workers[node_id].index(rank), rank=rank, distributed_init_method=distributed_init_method, - ) for rank, (node_id, - _) in zip(worker_ranks, worker_node_and_gpu_ids) + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) ] self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) @@ -231,6 +234,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) + if self.use_ray_spmd_worker: + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range( + self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size + ) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + # This is the list of workers that are rank 0 of each TP group EXCEPT # global rank 0. These are the workers that will broadcast to the # rest of the workers. @@ -241,9 +257,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", self.non_driver_workers: List[RayWorkerWrapper] = [] # Enforce rank order for correct rank to return final output. - for rank, worker in sorted(zip(worker_ranks[1:], self.workers)): - # We need to skip the driver worker, which we - # do by skipping worker_ranks[0] which is always 0. + for index, worker in enumerate(self.workers): + # The driver worker is rank 0 and not in self.workers. + rank = index + 1 if rank % self.parallel_config.tensor_parallel_size == 0: self.tp_driver_workers.append(worker) else: @@ -376,16 +392,47 @@ def _compiled_ray_dag(self, enable_asyncio: bool): raise ValueError(f"Ray version {required_version} or greater is " f"required, but found {current_version}") - from ray.dag import InputNode, MultiOutputNode assert self.parallel_config.use_ray + from ray.dag import InputNode, MultiOutputNode + from ray.experimental.channel.torch_tensor_type import TorchTensorType - # Right now, compiled DAG requires at least 1 arg. We send - # a dummy value for now. It will be fixed soon. + logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s", + envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL) with InputNode() as input_data: - forward_dag = MultiOutputNode([ - worker.execute_model_spmd.bind( # type: ignore[attr-defined] - input_data) for worker in self.workers - ]) + # Example DAG: PP=2, TP=4 + # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501 + # -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501 + # -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501 + # -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501 + + # All workers in the first TP group will take in the + # ExecuteModelRequest as input. + outputs = [input_data for _ in self.pp_tp_workers[0]] + for pp_rank, tp_group in enumerate(self.pp_tp_workers): + # Each PP worker takes in the output of the previous PP worker, + # and the TP group executes in SPMD fashion. + outputs = [ + worker.execute_model_spmd. + bind( # type: ignore[attr-defined] + outputs[i]) for i, worker in enumerate(tp_group) + ] + + last_pp_rank = len(self.pp_tp_workers) - 1 + if pp_rank < last_pp_rank: + # Specify how intermediate tensors should be passed + # between pp stages, no need to specify for the last + # pp stage. + transport = "nccl" \ + if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \ + else "auto" + outputs = [ + output.with_type_hint( + TorchTensorType(transport=transport)) + for output in outputs + ] + + forward_dag = MultiOutputNode(outputs) + return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) def __del__(self): diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 58b864070f72..ac948331e81e 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -1,8 +1,8 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest +from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import get_ip, is_hip, is_tpu, is_xpu from vllm.worker.worker_base import WorkerWrapperBase @@ -31,9 +31,17 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: gpu_ids = ray.get_gpu_ids() return node_id, gpu_ids - def execute_model_spmd(self, execute_model_req: ExecuteModelRequest): - """Used only when SPMD worker and compiled DAG are both - enabled.""" + def execute_model_spmd( + self, req_or_tuple: Union[ExecuteModelRequest, + Tuple[ExecuteModelRequest, + IntermediateTensors]]): + """Execute model in SPMD fashion: used only when SPMD worker and + compiled DAG are both enabled. + + Args: + req_or_tuple: The request to execute the model, or a tuple + containing the request and intermediate tensors. + """ # TODO(swang): This is needed right now because Ray aDAG executes # on a background thread, so we need to reset torch's current # device. @@ -42,7 +50,17 @@ def execute_model_spmd(self, execute_model_req: ExecuteModelRequest): torch.cuda.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True - return self.worker._execute_model_spmd(execute_model_req) + if isinstance(req_or_tuple, tuple): + execute_model_req, intermediate_tensors = req_or_tuple + else: + execute_model_req = req_or_tuple + intermediate_tensors = None + + output = self.worker._execute_model_spmd(execute_model_req, + intermediate_tensors) + if isinstance(output, IntermediateTensors): + return execute_model_req, output + return output ray_import_err = None diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 8a4d1958c65a..e56440693b89 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -285,7 +285,9 @@ def execute_model( return output def _execute_model_spmd( - self, execute_model_req: ExecuteModelRequest + self, + execute_model_req: ExecuteModelRequest, + intermediate_tensors: Optional[IntermediateTensors] = None ) -> Optional[List[SamplerOutput]]: """ Execute model in Single Program Multiple Data (SPMD) fashion. @@ -309,7 +311,7 @@ def _execute_model_spmd( return self.model_runner.execute_model( model_input, self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None) + if self.kv_cache is not None else None, intermediate_tensors) class WorkerWrapperBase: From 22e718ff1a51930231d87c89d6c43676af59860b Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:50:00 -0700 Subject: [PATCH 0039/3246] [Misc] Revive to use loopback address for driver IP (#7091) Signed-off-by: Rui Qiao --- vllm/executor/ray_gpu_executor.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 46d216910a08..4a6825c01fcf 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -216,6 +216,16 @@ def sort_by_driver_then_worker_ip(worker): self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables) + if len(node_gpus) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) From 708989341ef6361a5981d890a0e2f1b794323458 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 2 Aug 2024 16:18:45 -0700 Subject: [PATCH 0040/3246] [misc] add a flag to enable compile (#7092) --- vllm/envs.py | 4 ++++ vllm/worker/model_runner.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index 5b8a65bd6545..595058bcbb02 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -174,6 +174,10 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), + # Internal flag to enable Dynamo graph capture + "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": + lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 777344289958..f9c26e0c318b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,6 +23,7 @@ BatchPrefillWithPagedKVCacheWrapper = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 +import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, @@ -786,6 +787,11 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") + if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE: + self.model = torch.compile(self.model, + fullgraph=True, + backend="eager") + def save_sharded_state( self, path: str, From ed812a73fae77bb520b739cfeaad36dbd61e2b03 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Fri, 2 Aug 2024 21:27:28 -0400 Subject: [PATCH 0041/3246] [ Frontend ] Multiprocessing for OpenAI Server with `zeromq` (#6883) Signed-off-by: Joe Runde Co-authored-by: Joe Runde Co-authored-by: Joe Runde Co-authored-by: Nick Hill Co-authored-by: Simon Mo --- tests/entrypoints/openai/test_disable_mp.py | 715 ++++++++++++++++++ vllm/engine/async_llm_engine.py | 27 +- vllm/engine/llm_engine.py | 36 +- vllm/engine/protocol.py | 84 ++ vllm/entrypoints/openai/api_server.py | 132 +++- vllm/entrypoints/openai/cli_args.py | 9 +- vllm/entrypoints/openai/logits_processors.py | 19 +- vllm/entrypoints/openai/rpc/__init__.py | 42 + vllm/entrypoints/openai/rpc/client.py | 248 ++++++ vllm/entrypoints/openai/rpc/server.py | 216 ++++++ vllm/entrypoints/openai/serving_chat.py | 16 +- vllm/entrypoints/openai/serving_completion.py | 19 +- vllm/entrypoints/openai/serving_embedding.py | 13 +- vllm/entrypoints/openai/serving_engine.py | 8 +- .../openai/serving_tokenization.py | 10 +- vllm/envs.py | 6 + .../outlines_logits_processors.py | 19 + vllm/tracing.py | 2 +- .../tokenizer_group/__init__.py | 19 +- vllm/utils.py | 28 +- 20 files changed, 1567 insertions(+), 101 deletions(-) create mode 100644 tests/entrypoints/openai/test_disable_mp.py create mode 100644 vllm/engine/protocol.py create mode 100644 vllm/entrypoints/openai/rpc/__init__.py create mode 100644 vllm/entrypoints/openai/rpc/client.py create mode 100644 vllm/entrypoints/openai/rpc/server.py diff --git a/tests/entrypoints/openai/test_disable_mp.py b/tests/entrypoints/openai/test_disable_mp.py new file mode 100644 index 000000000000..12c805413311 --- /dev/null +++ b/tests/entrypoints/openai/test_disable_mp.py @@ -0,0 +1,715 @@ +""" +Repeat of tests in test_completion.py with the non-mp backend. +""" + +# imports for guided decoding tests +import json +import re +import shutil +from tempfile import TemporaryDirectory +from typing import List + +import jsonschema +import openai # use the official client for correctness check +import pytest +# downloading lora to test lora requests +from huggingface_hub import snapshot_download +from openai import BadRequestError +from transformers import AutoTokenizer + +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +# technically these adapters use a different base model, +# but we're not testing generation quality here +LORA_NAME = "typeof/zephyr-7b-beta-lora" +PA_NAME = "swapnilbp/llama_tweet_ptune" +# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also +# need to change to match the prompt adapter +PA_NUM_VIRTUAL_TOKENS = 8 + + +@pytest.fixture(scope="module") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="module") +def zephyr_lora_added_tokens_files(zephyr_lora_files): + tmp_dir = TemporaryDirectory() + tmp_model_dir = f"{tmp_dir.name}/zephyr" + shutil.copytree(zephyr_lora_files, tmp_model_dir) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + # Copy tokenizer to adapter and add some unique tokens + # 32000, 32001, 32002 + added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], + special_tokens=True) + assert added == 3 + tokenizer.save_pretrained(tmp_model_dir) + yield tmp_model_dir + tmp_dir.cleanup() + + +@pytest.fixture(scope="module") +def zephyr_pa_files(): + return snapshot_download(repo_id=PA_NAME) + + +@pytest.fixture(scope="module") +def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, + zephyr_pa_files): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--enforce-eager", + # lora config + "--enable-lora", + "--lora-modules", + f"zephyr-lora={zephyr_lora_files}", + f"zephyr-lora2={zephyr_lora_added_tokens_files}", + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + # pa config + "--enable-prompt-adapter", + "--prompt-adapters", + f"zephyr-pa={zephyr_pa_files}", + f"zephyr-pa2={zephyr_pa_files}", + "--max-prompt-adapters", + "2", + "--max-prompt-adapter-token", + "128", + "--disable-frontend-multiprocessing" + ] + + +@pytest.fixture(scope="module") +def server(default_server_args): + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server): + return server.get_async_client() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras, then test prompt adapters + "model_name,num_virtual_tokens", + [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), + ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS), + ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)], +) +async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, + num_virtual_tokens: int): + completion = await client.completions.create(model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, + prompt_tokens=6 + num_virtual_tokens, + total_tokens=11 + num_virtual_tokens) + + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 1 + + +@pytest.mark.asyncio +async def test_added_lora_tokens(client: openai.AsyncOpenAI): + # test using token IDs + completion = await client.completions.create( + model="zephyr-lora2", + prompt=[0, 0, 32000, 32001, 32002], + echo=True, + max_tokens=5, + temperature=0.0, + ) + # Added tokens should appear in tokenized prompt + assert completion.choices[0].text.startswith("vllm1vllm2vllm3") + + +@pytest.mark.asyncio +async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 32000, 32001, 32002], + echo=True, + max_tokens=5, + temperature=0.0, + ) + # Added tokens should not appear in tokenized prompt + assert "vllm" not in completion.choices[0].text + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras, then test prompt adapters + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"], +) +async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=None, + ) + choice = completion.choices[0] + assert choice.logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # just test 1 lora and 1 pa hereafter + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=0, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=5, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, + model_name: str): + + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=21, + ) + ... + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + stream = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=30, + stream=True, + ) + async for chunk in stream: + ... + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_completion_streaming(client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is an LLM?" + + single_completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: List[str] = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_completion_stream_options(client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is the capital of France?" + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + False, + }) + + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + False, + }) + async for chunk in stream: + if chunk.choices[0].finish_reason is None: + assert chunk.usage is None + else: + assert chunk.usage is None + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is not None + assert chunk.usage.prompt_tokens > 0 + assert chunk.usage.completion_tokens > 0 + assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + + chunk.usage.completion_tokens) + if chunk.choices[0].finish_reason is not None: + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=False, stream_options= + # {"include_usage": None} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": None}) + + # Test stream=False, stream_options= + # {"include_usage": True} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": True}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": None} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": None}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": True} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": True}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): + # test both text and token IDs + for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2): + # test simple list + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + ) + assert len(batch.choices) == 2 + assert batch.choices[0].text == batch.choices[1].text + + # test n = 2 + batch = await client.completions.create( + model=model_name, + prompt=prompts, + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but not necessary + # for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" + + # test streaming + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + stream=True, + ) + texts = [""] * 2 + async for chunk in batch: + assert len(chunk.choices) == 1 + choice = chunk.choices[0] + texts[choice.index] += choice.text + assert texts[0] == texts[1] + + +@pytest.mark.asyncio +async def test_logits_bias(client: openai.AsyncOpenAI): + prompt = "Hello, my name is" + max_tokens = 5 + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Test exclusive selection + token_id = 1000 + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + logit_bias={str(token_id): 100}, + seed=42, + ) + assert len(completion.choices[0].text) >= 5 + response_tokens = tokenizer(completion.choices[0].text, + add_special_tokens=False)["input_ids"] + expected_tokens = tokenizer(tokenizer.decode([token_id] * 5), + add_special_tokens=False)["input_ids"] + assert all([ + response == expected + for response, expected in zip(response_tokens, expected_tokens) + ]) + + # Test ban + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + ) + response_tokens = tokenizer(completion.choices[0].text, + add_special_tokens=False)["input_ids"] + first_response = completion.choices[0].text + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + logit_bias={str(token): -100 + for token in response_tokens}, + ) + assert first_response != completion.choices[0].text + + +@pytest.mark.asyncio +async def test_allowed_token_ids(client: openai.AsyncOpenAI): + prompt = "Hello, my name is" + max_tokens = 1 + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Test exclusive selection + allowed_ids = [21555, 21557, 21558] + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + seed=42, + extra_body=dict(allowed_token_ids=allowed_ids), + logprobs=1, + ) + response_tokens = completion.choices[0].logprobs.tokens + assert len(response_tokens) == 1 + assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_json_schema): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}", + n=3, + temperature=1.0, + max_tokens=500, + extra_body=dict(guided_json=sample_json_schema, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 3 + for i in range(3): + output_json = json.loads(completion.choices[i].text) + jsonschema.validate(instance=output_json, schema=sample_json_schema) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_regex): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example IPv4 address with this regex: {sample_regex}", + n=3, + temperature=1.0, + max_tokens=20, + extra_body=dict(guided_regex=sample_regex, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 3 + for i in range(3): + assert re.fullmatch(sample_regex, + completion.choices[i].text) is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_guided_choice): + completion = await client.completions.create( + model=MODEL_NAME, + prompt="The best language for type-safe systems programming is ", + n=2, + temperature=1.0, + max_tokens=10, + extra_body=dict(guided_choice=sample_guided_choice, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 2 + for i in range(2): + assert completion.choices[i].text in sample_guided_choice + + +@pytest.mark.asyncio +async def test_guided_grammar(client: openai.AsyncOpenAI, + sample_sql_statements): + + completion = await client.completions.create( + model=MODEL_NAME, + prompt=("Generate a sql state that select col_1 from " + "table_1 where it is equals to 1"), + temperature=1.0, + max_tokens=500, + extra_body=dict(guided_grammar=sample_sql_statements)) + + content = completion.choices[0].text + + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(sample_sql_statements) + parser.parse(content) + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") + + assert content.strip() == ground_truth + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +async def test_echo_logprob_completion(client: openai.AsyncOpenAI, + model_name: str, logprobs_arg: int): + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # test using text and token IDs + for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=logprobs_arg) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, + list) else prompt + assert re.search(r"^" + prompt_text, completion.choices[0].text) + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) > 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_json_schema, sample_regex): + with pytest.raises(openai.BadRequestError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example JSON that fits this schema: 42", + extra_body=dict(guided_json=42, + guided_decoding_backend=guided_decoding_backend)) + + with pytest.raises(openai.BadRequestError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example string that fits this regex", + extra_body=dict(guided_regex=sample_regex, + guided_json=sample_json_schema)) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d3f9a0ab00f1..c39caca25cc7 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,7 +7,8 @@ from transformers import PreTrainedTokenizer import vllm.envs as envs -from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout @@ -928,6 +929,14 @@ async def get_model_config(self) -> ModelConfig: else: return self.engine.get_model_config() + async def get_parallel_config(self) -> ParallelConfig: + """Get the parallel configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_parallel_config.remote( # type: ignore + ) + else: + return self.engine.get_parallel_config() + async def get_decoding_config(self) -> DecodingConfig: """Get the decoding configuration of the vLLM engine.""" if self.engine_use_ray: @@ -936,6 +945,22 @@ async def get_decoding_config(self) -> DecodingConfig: else: return self.engine.get_decoding_config() + async def get_scheduler_config(self) -> SchedulerConfig: + """Get the scheduling configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_scheduler_config.remote( # type: ignore + ) + else: + return self.engine.get_scheduler_config() + + async def get_lora_config(self) -> LoRAConfig: + """Get the lora configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_lora_config.remote( # type: ignore + ) + else: + return self.engine.get_lora_config() + async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1efe2206abe8..3747f93b16cd 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -38,9 +38,8 @@ init_tracer) from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import (AnyTokenizer, - BaseTokenizerGroup, - get_tokenizer_group) +from vllm.transformers_utils.tokenizer_group import ( + AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter @@ -485,19 +484,12 @@ def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: return self.get_tokenizer_group().get_lora_tokenizer( sequence.lora_request) - def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: - init_kwargs = dict( - tokenizer_id=self.model_config.tokenizer, - enable_lora=bool(self.lora_config), - max_num_seqs=self.scheduler_config.max_num_seqs, - max_input_length=None, - tokenizer_mode=self.model_config.tokenizer_mode, - trust_remote_code=self.model_config.trust_remote_code, - revision=self.model_config.tokenizer_revision) - init_kwargs.update(tokenizer_init_kwargs) - - return get_tokenizer_group(self.parallel_config.tokenizer_pool_config, - **init_kwargs) + def _init_tokenizer(self) -> BaseTokenizerGroup: + return init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + parallel_config=self.parallel_config, + enable_lora=bool(self.lora_config)) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -759,10 +751,22 @@ def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" return self.model_config + def get_parallel_config(self) -> ParallelConfig: + """Gets the parallel configuration.""" + return self.parallel_config + def get_decoding_config(self) -> DecodingConfig: """Gets the decoding configuration.""" return self.decoding_config + def get_scheduler_config(self) -> SchedulerConfig: + """Gets the scheduler configuration.""" + return self.scheduler_config + + def get_lora_config(self) -> LoRAConfig: + """Gets the LoRA configuration.""" + return self.lora_config + def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" return sum(scheduler.get_num_unfinished_seq_groups() diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py new file mode 100644 index 000000000000..fc94ef6662e0 --- /dev/null +++ b/vllm/engine/protocol.py @@ -0,0 +1,84 @@ +from typing import (AsyncIterator, List, Mapping, Optional, Protocol, + runtime_checkable) + +from transformers import PreTrainedTokenizer + +from vllm.config import DecodingConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs +from vllm.inputs.data import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import SamplerOutput + + +@runtime_checkable +class AsyncEngineClient(Protocol): + """Protocol class for Clients to AsyncLLMEngine""" + + @property + def is_running(self) -> bool: + ... + + @property + def is_stopped(self) -> bool: + ... + + @property + def errored(self) -> bool: + ... + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncIterator[RequestOutput]: + """Generates outputs for a request""" + + async def encode( + self, + inputs: PromptInputs, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncIterator[EmbeddingRequestOutput]: + """Generate outputs for a request from an embedding model.""" + + async def abort(self, request_id: str) -> None: + """Abort a request. + + Args: + request_id: The unique id of the request. + """ + + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> PreTrainedTokenizer: + """Get the appropriate Tokenizer for the request""" + + async def is_tracing_enabled(self) -> bool: + pass + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + ) -> None: + pass + + async def check_health(self) -> None: + """Raise if unhealthy""" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0fe4dd245b5e..e330ee81f7e4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -5,7 +5,8 @@ import signal from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Optional, Set +from multiprocessing import Process +from typing import AsyncIterator, Set import fastapi import uvicorn @@ -17,8 +18,10 @@ from starlette.routing import Mount import vllm.envs as envs +from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block @@ -31,6 +34,8 @@ EmbeddingRequest, ErrorResponse, TokenizeRequest, TokenizeResponse) +from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient +from vllm.entrypoints.openai.rpc.server import run_rpc_server # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -39,12 +44,12 @@ OpenAIServingTokenization) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, get_open_port from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds -engine: AsyncLLMEngine +async_engine_client: AsyncEngineClient engine_args: AsyncEngineArgs openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion @@ -56,13 +61,22 @@ _running_tasks: Set[asyncio.Task] = set() +def model_is_embedding(model_name: str) -> bool: + return ModelConfig(model=model_name, + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16").embedding_mode + + @asynccontextmanager async def lifespan(app: fastapi.FastAPI): async def _force_log(): while True: await asyncio.sleep(10) - await engine.do_log_stats() + await async_engine_client.do_log_stats() if not engine_args.disable_log_stats: task = asyncio.create_task(_force_log()) @@ -72,6 +86,52 @@ async def _force_log(): yield +@asynccontextmanager +async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: + # Context manager to handle async_engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + global engine_args + engine_args = AsyncEngineArgs.from_cli_args(args) + + # Backend itself still global for the silly lil' health handler + global async_engine_client + + # If manually triggered or embedding model, use AsyncLLMEngine in process. + # TODO: support embedding model via RPC. + if (model_is_embedding(args.model) + or args.disable_frontend_multiprocessing): + async_engine_client = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + yield async_engine_client + return + + # Otherwise, use the multiprocessing AsyncLLMEngine. + else: + # Start RPCServer in separate process (holds the AsyncLLMEngine). + port = get_open_port(envs.VLLM_RPC_PORT) + rpc_server_process = Process(target=run_rpc_server, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + port)) + rpc_server_process.start() + + # Build RPCClient, which conforms to AsyncEngineClient Protocol. + async_engine_client = AsyncEngineRPCClient(port) + await async_engine_client.setup() + + try: + yield async_engine_client + finally: + # Ensure rpc server process was terminated + rpc_server_process.terminate() + + # Close all open connections to the backend + async_engine_client.close() + + # Wait for server process to join + rpc_server_process.join() + + router = APIRouter() @@ -86,7 +146,7 @@ def mount_metrics(app: fastapi.FastAPI): @router.get("/health") async def health() -> Response: """Health check.""" - await openai_serving_chat.engine.check_health() + await async_engine_client.check_health() return Response(status_code=200) @@ -215,8 +275,8 @@ async def authentication(request: Request, call_next): async def build_server( + async_engine_client: AsyncEngineClient, args, - llm_engine: Optional[AsyncLLMEngine] = None, **uvicorn_kwargs, ) -> uvicorn.Server: app = build_app(args) @@ -226,14 +286,7 @@ async def build_server( else: served_model_names = [args.model] - global engine, engine_args - - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = (llm_engine - if llm_engine is not None else AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER)) - - model_config = await engine.get_model_config() + model_config = await async_engine_client.get_model_config() if args.disable_log_requests: request_logger = None @@ -246,7 +299,7 @@ async def build_server( global openai_serving_tokenization openai_serving_chat = OpenAIServingChat( - engine, + async_engine_client, model_config, served_model_names, args.response_role, @@ -257,7 +310,7 @@ async def build_server( return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) openai_serving_completion = OpenAIServingCompletion( - engine, + async_engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -266,13 +319,13 @@ async def build_server( return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) openai_serving_embedding = OpenAIServingEmbedding( - engine, + async_engine_client, model_config, served_model_names, request_logger=request_logger, ) openai_serving_tokenization = OpenAIServingTokenization( - engine, + async_engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -304,32 +357,39 @@ async def build_server( return uvicorn.Server(config) -async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None: +async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - server = await build_server( - args, - llm_engine, - **uvicorn_kwargs, - ) + shutdown_task = None + async with build_async_engine_client(args) as async_engine_client: + + server = await build_server( + async_engine_client, + args, + **uvicorn_kwargs, + ) + + loop = asyncio.get_running_loop() - loop = asyncio.get_running_loop() + server_task = loop.create_task(server.serve()) - server_task = loop.create_task(server.serve()) + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() - def signal_handler() -> None: - # prevents the uvicorn signal handler to exit early - server_task.cancel() + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) + try: + await server_task + except asyncio.CancelledError: + logger.info("Gracefully stopping http server") + shutdown_task = server.shutdown() - try: - await server_task - except asyncio.CancelledError: - print("Gracefully stopping http server") - await server.shutdown() + if shutdown_task: + # NB: Await server shutdown only after the backend context is exited + await shutdown_task if __name__ == "__main__": diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index a4192937980f..1facedac72ca 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -131,9 +131,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( "--return-tokens-as-token-ids", action="store_true", - help="When --max-logprobs is specified, represents single tokens as" - "strings of the form 'token_id:{token_id}' so that tokens that" + help="When --max-logprobs is specified, represents single tokens as " + "strings of the form 'token_id:{token_id}' so that tokens that " "are not JSON-encodable can be identified.") + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + help="If specified, will run the OpenAI frontend server in the same " + "process as the model serving engine.") parser = AsyncEngineArgs.add_cli_args(parser) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index f8e04e7f18e0..84871fc83ef5 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -1,4 +1,4 @@ -from functools import lru_cache +from functools import lru_cache, partial from typing import Dict, FrozenSet, Iterable, List, Optional, Union import torch @@ -40,6 +40,14 @@ def _get_allowed_token_ids_logits_processor( return AllowedTokenIdsLogitsProcessor(allowed_token_ids) +def logit_bias_logits_processor(logit_bias: Dict[str, + float], token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + for token_id, bias in logit_bias.items(): + logits[token_id] += bias + return logits + + def get_logits_processors( logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], allowed_token_ids: Optional[List[int]], @@ -64,13 +72,8 @@ def get_logits_processors( raise ValueError("token_id in logit_bias contains " "out-of-vocab token id") - def logit_bias_logits_processor(token_ids: List[int], - logits: torch.Tensor) -> torch.Tensor: - for token_id, bias in clamped_logit_bias.items(): - logits[token_id] += bias - return logits - - logits_processors.append(logit_bias_logits_processor) + logits_processors.append( + partial(logit_bias_logits_processor, clamped_logit_bias)) if allowed_token_ids is not None: logits_processors.append( diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py new file mode 100644 index 000000000000..8a7b12201cab --- /dev/null +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Mapping, Optional, Union + +from vllm.inputs import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams + +VLLM_RPC_SUCCESS_STR = "SUCCESS" +VLLM_RPC_HEALTHY_STR = "HEALTHY" + + +@dataclass +class RPCGenerateRequest: + inputs: PromptInputs + sampling_params: SamplingParams + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + +@dataclass +class RPCAbortRequest: + request_id: str + + +class RPCUtilityRequest(Enum): + IS_SERVER_READY = 1 + GET_MODEL_CONFIG = 2 + GET_DECODING_CONFIG = 3 + GET_PARALLEL_CONFIG = 4 + GET_SCHEDULER_CONFIG = 5 + GET_LORA_CONFIG = 6 + DO_LOG_STATS = 7 + CHECK_HEALTH = 8 + IS_TRACING_ENABLED = 9 + + +RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, + RPCUtilityRequest] diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py new file mode 100644 index 000000000000..45bf88b5bf57 --- /dev/null +++ b/vllm/entrypoints/openai/rpc/client.py @@ -0,0 +1,248 @@ +from contextlib import contextmanager +from typing import Any, AsyncIterator, Mapping, Optional + +import cloudpickle +import zmq +import zmq.asyncio + +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, + VLLM_RPC_HEALTHY_STR, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCUtilityRequest) +from vllm.inputs import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + + +class AsyncEngineRPCClient: + + def __init__(self, port: int): + self.context = zmq.asyncio.Context() + self.path = f"tcp://localhost:{port}" + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + # Wait until server is ready. + await self.wait_for_server() + + # Get the configs. + self.model_config = await self._get_model_config_rpc() + self.decoding_config = await self._get_decoding_config_rpc() + self.tracing_flag = await self._is_tracing_enabled_rpc() + + # Create the tokenizer group. + # TODO: refactor OAI server to avoid needing this info. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=(await self._get_scheduler_config_rpc()), + parallel_config=(await self._get_parallel_config_rpc()), + enable_lora=bool(await self._get_lora_config_rpc()), + ) + + def close(self): + """Destroy the ZeroMQ Context.""" + self.context.destroy() + + @contextmanager + def socket(self): + # Ensure client sockets are always closed after use + + # Connect to RPC socket for Request-Reply pattern, + # Note that we use DEALER to enable asynchronous communication + # to enable streaming. + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.path) + yield socket + finally: + socket.close() + + async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, + expected_type: Any, + error_message: str) -> Any: + """Send an RPC request that is expecting data back.""" + + with self.socket() as socket: + + # Ping RPCServer with a request. + await socket.send(cloudpickle.dumps(request)) + + # Await the data from the Server. + data = cloudpickle.loads(await socket.recv()) + + if not isinstance(data, expected_type): + # LoRAConfig can be None. + if expected_type == LoRAConfig and data is None: + pass + else: + raise ValueError(error_message) + + return data + + async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, + error_message: str): + """Send one-way RPC request to trigger an action.""" + with self.socket() as socket: + # Ping RPC Server with request. + await socket.send(cloudpickle.dumps(request)) + + # Await acknowledgement from RPCServer. + response = cloudpickle.loads(await socket.recv()) + + if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: + raise ValueError(error_message) + + return response + + async def get_tokenizer(self, lora_request: LoRARequest): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def wait_for_server(self): + """Wait for the RPCServer to start up.""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.IS_SERVER_READY, + error_message="Unable to start RPC Server.") + + async def _get_model_config_rpc(self) -> ModelConfig: + """Get the ModelConfig object from the RPC Server""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_MODEL_CONFIG, + expected_type=ModelConfig, + error_message="Could not get ModelConfig from RPC Server") + + async def _get_decoding_config_rpc(self) -> DecodingConfig: + """Get DecodingConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_DECODING_CONFIG, + expected_type=DecodingConfig, + error_message="Could not get DecodingConfig from RPC Server") + + async def _get_parallel_config_rpc(self) -> ParallelConfig: + """Get ParallelConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_PARALLEL_CONFIG, + expected_type=ParallelConfig, + error_message="Could not get ParallelConfig from RPC Server") + + async def _get_scheduler_config_rpc(self) -> SchedulerConfig: + """Get SchedulerConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + expected_type=SchedulerConfig, + error_message="Could not get SchedulerConfig from RPC Server") + + async def _get_lora_config_rpc(self): + """Get LoRAConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_LORA_CONFIG, + expected_type=LoRAConfig, + error_message="Could not get LoRAConfig from RPC Server") + + async def _is_tracing_enabled_rpc(self) -> ParallelConfig: + """Get is_tracing_enabled flag from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.IS_TRACING_ENABLED, + expected_type=bool, + error_message="Could not get is_tracing_enabled flag from RPC " + "Server") + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), + error_message=f"RPCAbortRequest {request_id} failed") + + async def do_log_stats(self): + """Send a DO_LOG_STATS signal to the RPC Server""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.DO_LOG_STATS, + error_message="RPCRequest DO_LOG_STATS failed.") + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncIterator[RequestOutput]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + with self.socket() as socket: + + # Send RPCGenerateRequest to the RPCServer. + await socket.send_multipart([ + cloudpickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)) + ]) + + # Stream back the results from the RPC Server. + while True: + message = await socket.recv() + request_output = cloudpickle.loads(message) + + if isinstance(request_output, Exception): + raise request_output + + if request_output.finished: + break + yield request_output + + yield request_output + + async def check_health(self) -> None: + """Raise if unhealthy""" + + with self.socket() as socket: + + # Ping RPCServer with CHECK_HEALTH request. + await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH) + ) + + # Await the reply from the server. + # TODO: do we need an internal timeout here? + # Or do we expect the external probe to timeout and let this chill? + health_message = cloudpickle.loads(await socket.recv()) + + if isinstance(health_message, Exception): + raise health_message + + if health_message != VLLM_RPC_HEALTHY_STR: + raise ValueError("Expected healthy response from backend but got " + "f{health_message}") + + async def encode(self, *args, + **kwargs) -> AsyncIterator[EmbeddingRequestOutput]: + raise NotImplementedError( + "Embeddings not supported with multiprocessing backend") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py new file mode 100644 index 000000000000..7a72a6f732c9 --- /dev/null +++ b/vllm/entrypoints/openai/rpc/server.py @@ -0,0 +1,216 @@ +import asyncio +import signal +from typing import Any, Coroutine + +import cloudpickle +import zmq +import zmq.asyncio +from typing_extensions import Never + +from vllm import AsyncEngineArgs, AsyncLLMEngine +from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCUtilityRequest) +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext + +logger = init_logger(__name__) + + +class AsyncEngineRPCServer: + + def __init__(self, async_engine_args: AsyncEngineArgs, + usage_context: UsageContext, port: int): + # Initialize engine first. + self.engine = AsyncLLMEngine.from_engine_args(async_engine_args, + usage_context) + + # Initialize context. + self.context = zmq.asyncio.Context() + + # Init socket for readiness state. + self.socket = self.context.socket(zmq.constants.ROUTER) + self.socket.bind(f"tcp://localhost:{port}") + + def cleanup(self): + """Cleanup all resources.""" + self.socket.close() + self.context.destroy() + + async def get_model_config(self, identity): + """Send the ModelConfig""" + model_config = await self.engine.get_model_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(model_config)]) + + async def get_decoding_config(self, identity): + """Send the DecodingConfig""" + decoding_config = await self.engine.get_decoding_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(decoding_config)]) + + async def get_lora_config(self, identity): + lora_config = await self.engine.get_lora_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(lora_config)]) + + async def get_scheduler_config(self, identity): + """Send the SchedulerConfig""" + parallel_config = await self.engine.get_scheduler_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(parallel_config)]) + + async def get_parallel_config(self, identity): + """Send the ParallelConfig""" + parallel_config = await self.engine.get_parallel_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(parallel_config)]) + + async def is_tracing_enabled(self, identity): + """Send the is_tracing_enabled flag""" + tracing_flag = await self.engine.is_tracing_enabled() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(tracing_flag)]) + + async def do_log_stats(self, identity): + """Log stats and confirm success.""" + await self.engine.do_log_stats() + + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + + async def is_server_ready(self, identity): + """Notify the client that we are ready.""" + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + + async def abort(self, identity, request: RPCAbortRequest): + """Abort request and notify the client of success.""" + # Abort the request in the llm engine. + await self.engine.abort(request.request_id) + + # Send confirmation to the client. + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + + async def generate(self, identity, generate_request: RPCGenerateRequest): + try: + results_generator = self.engine.generate( + generate_request.inputs, + sampling_params=generate_request.sampling_params, + request_id=generate_request.request_id, + lora_request=generate_request.lora_request, + trace_headers=generate_request.trace_headers, + prompt_adapter_request=generate_request.prompt_adapter_request) + + async for request_output in results_generator: + await self.socket.send_multipart( + [identity, cloudpickle.dumps(request_output)]) + + except Exception as e: + ### Notify client of all failures + await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + + async def check_health(self, identity): + try: + await self.engine.check_health() + await self.socket.send_multipart( + [identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)]) + except Exception as e: + await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + + def _make_handler_coro(self, identity, + message) -> Coroutine[Any, Any, Never]: + """Route the zmq message to the handler coroutine.""" + + request = cloudpickle.loads(message) + + if isinstance(request, RPCGenerateRequest): + return self.generate(identity, request) + + elif isinstance(request, RPCAbortRequest): + return self.abort(identity, request) + + elif isinstance(request, RPCUtilityRequest): + if request == RPCUtilityRequest.GET_MODEL_CONFIG: + return self.get_model_config(identity) + elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + return self.get_parallel_config(identity) + elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + return self.get_decoding_config(identity) + elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + return self.get_scheduler_config(identity) + elif request == RPCUtilityRequest.GET_LORA_CONFIG: + return self.get_lora_config(identity) + elif request == RPCUtilityRequest.DO_LOG_STATS: + return self.do_log_stats(identity) + elif request == RPCUtilityRequest.IS_SERVER_READY: + return self.is_server_ready(identity) + elif request == RPCUtilityRequest.CHECK_HEALTH: + return self.check_health(identity) + elif request == RPCUtilityRequest.IS_TRACING_ENABLED: + return self.is_tracing_enabled(identity) + else: + raise ValueError(f"Unknown RPCUtilityRequest type: {request}") + + else: + raise ValueError(f"Unknown RPCRequest type: {request}") + + async def run_server_loop(self): + """Inner RPC Server Loop""" + + running_tasks = set() + while True: + # Wait for a request. + identity, message = await self.socket.recv_multipart() + + # Process the request async. + task = asyncio.create_task( + self._make_handler_coro(identity, message)) + + # We need to keep around a strong reference to the task, + # to avoid the task disappearing mid-execution as running tasks + # can be GC'ed. Below is a common "fire-and-forget" tasks + # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + running_tasks.add(task) + task.add_done_callback(running_tasks.discard) + + +async def run_server(server: AsyncEngineRPCServer): + # Put the server task into the asyncio loop. + loop = asyncio.get_running_loop() + server_task = loop.create_task(server.run_server_loop()) + + # Interruption handling. + def signal_handler() -> None: + # Kill the server on interrupt / terminate + server_task.cancel() + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + except asyncio.CancelledError: + logger.info("vLLM ZMQ RPC Server was interrupted.") + finally: + # Clean up all resources. + server.cleanup() + + +def run_rpc_server(async_engine_args: AsyncEngineArgs, + usage_context: UsageContext, port: int): + server = AsyncEngineRPCServer(async_engine_args, usage_context, port) + asyncio.run(run_server(server)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c832cf2a24b5..ebb1d57fbb9a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -8,7 +8,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) @@ -39,7 +39,7 @@ class OpenAIServingChat(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], response_role: str, @@ -50,7 +50,7 @@ def __init__( chat_template: Optional[str], return_tokens_as_token_ids: bool = False, ): - super().__init__(engine=engine, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -89,7 +89,8 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) model_config = self.model_config - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer( + lora_request) conversation: List[ConversationMessage] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] @@ -161,7 +162,8 @@ async def create_chat_completion( if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data - is_tracing_enabled = await self.engine.is_tracing_enabled() + is_tracing_enabled = ( + await self.async_engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled and raw_request: trace_headers = extract_trace_headers(raw_request.headers) @@ -169,7 +171,7 @@ async def create_chat_completion( and contains_trace_headers(raw_request.headers)): log_tracing_disabled_warning() - result_generator = self.engine.generate( + result_generator = self.async_engine_client.generate( engine_inputs, sampling_params, request_id, @@ -441,7 +443,7 @@ async def chat_completion_full_generator( async for res in result_generator: if raw_request is not None and await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.engine.abort(request_id) + await self.async_engine_client.abort(request_id) return self.create_error_response("Client disconnected") final_res = res assert final_res is not None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7765c5903f34..edc83d83fbba 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,7 +8,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -42,7 +42,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -51,7 +51,7 @@ def __init__( request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): - super().__init__(engine=engine, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -91,7 +91,8 @@ async def create_completion(self, request: CompletionRequest, prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer( + lora_request) guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -119,7 +120,8 @@ async def create_completion(self, request: CompletionRequest, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - is_tracing_enabled = await self.engine.is_tracing_enabled() + is_tracing_enabled = ( + await self.async_engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled: trace_headers = extract_trace_headers(raw_request.headers) @@ -127,7 +129,7 @@ async def create_completion(self, request: CompletionRequest, raw_request.headers): log_tracing_disabled_warning() - generator = self.engine.generate( + generator = self.async_engine_client.generate( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, request_id_item, @@ -168,7 +170,7 @@ async def create_completion(self, request: CompletionRequest, async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") + await self.async_engine_client.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res @@ -230,7 +232,8 @@ async def completion_stream_generator( # Abort the request if the client disconnects. if await raw_request.is_disconnected(): - await self.engine.abort(f"{request_id}-{prompt_idx}") + await self.async_engine_client.abort( + f"{request_id}-{prompt_idx}") raise StopAsyncIteration() for output in res.outputs: diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index bccc90894e79..e61c82f9a8a6 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -6,7 +6,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, @@ -56,13 +56,13 @@ class OpenAIServingEmbedding(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, request_logger: Optional[RequestLogger], ): - super().__init__(engine=engine, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=None, @@ -99,7 +99,8 @@ async def create_embedding(self, request: EmbeddingRequest, prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer( + lora_request) pooling_params = request.to_pooling_params() @@ -124,7 +125,7 @@ async def create_embedding(self, request: EmbeddingRequest, "Prompt adapter is not supported " "for embedding models") - generator = self.engine.encode( + generator = self.async_engine_client.encode( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, request_id_item, @@ -146,7 +147,7 @@ async def create_embedding(self, request: EmbeddingRequest, async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") + await self.async_engine_client.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8c7929a12e9a..df4932d8fe18 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -8,7 +8,7 @@ from typing_extensions import Annotated from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -61,7 +61,7 @@ class OpenAIServing: def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -72,7 +72,7 @@ def __init__( ): super().__init__() - self.engine = engine + self.async_engine_client = async_engine_client self.model_config = model_config self.max_model_len = model_config.max_model_len @@ -155,7 +155,7 @@ def create_streaming_error_response( async def _guided_decode_logits_processor( self, request: Union[ChatCompletionRequest, CompletionRequest], tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: - decoding_config = await self.engine.get_decoding_config() + decoding_config = await self.async_engine_client.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend return await get_guided_decoding_logits_processor( diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 94e1b03ed403..c4350881a27a 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,9 +1,9 @@ from typing import List, Optional, Union from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine # yapf conflicts with isort for this block # yapf: disable +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) @@ -24,7 +24,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -32,7 +32,7 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], ): - super().__init__(engine=engine, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -57,7 +57,7 @@ async def create_tokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer(lora_request) if isinstance(request, TokenizeChatRequest): model_config = self.model_config @@ -113,7 +113,7 @@ async def create_detokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer(lora_request) self._log_inputs(request_id, request.tokens, diff --git a/vllm/envs.py b/vllm/envs.py index 595058bcbb02..a78bad6a2b27 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -4,6 +4,7 @@ if TYPE_CHECKING: VLLM_HOST_IP: str = "" VLLM_PORT: Optional[int] = None + VLLM_RPC_PORT: int = 5570 VLLM_USE_MODELSCOPE: bool = False VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 VLLM_INSTANCE_ID: Optional[str] = None @@ -140,6 +141,11 @@ def get_default_config_root(): lambda: int(os.getenv('VLLM_PORT', '0')) if 'VLLM_PORT' in os.environ else None, + # used when the frontend api server is running in multi-processing mode, + # to communicate with the backend engine process over ZMQ. + 'VLLM_RPC_PORT': + lambda: int(os.getenv('VLLM_PORT', '5570')), + # If true, will load models from ModelScope instead of Hugging Face Hub. # note that the value is true or false, not numbers "VLLM_USE_MODELSCOPE": diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 1c8f6cccb3e9..554dcc0ed43e 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -21,6 +21,8 @@ from typing import Callable, DefaultDict, Dict, List, Union import torch +from lark import Lark +from outlines import grammars from outlines.caching import cache from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write from outlines.fsm.json_schema import build_regex_from_schema @@ -44,6 +46,23 @@ def __call__(self, input_ids: List[int], last_seq_id = hash(tuple(input_ids[:-1])) self._fsm_state[seq_id] = self._guide.get_next_state( state=self._fsm_state[last_seq_id], token_id=last_token) + else: + # Note: this is a hack. + # Lark pickling does not work properly (silent failure), + # which breaks the RPC (which uses python pickleing). + # We need to find a better solution. + # On the first time this is called, we simply re-create + # the Lark object. + if isinstance(self._guide, CFGGuide): + self._guide.parser = Lark( + self._guide.cfg_string, + parser="lalr", + lexer="contextual", + propagate_positions=False, + maybe_placeholders=False, + regex=True, + import_paths=[grammars.GRAMMAR_PATH], + ) instruction = self._guide.get_next_instruction( state=self._fsm_state[seq_id]) diff --git a/vllm/tracing.py b/vllm/tracing.py index dc8377f2396f..7ac38e6a0f66 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -60,7 +60,7 @@ def get_span_exporter(endpoint): OTLPSpanExporter) elif protocol == "http/protobuf": from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( - OTLPSpanExporter) + OTLPSpanExporter) # type: ignore else: raise ValueError( f"Unsupported OTLP protocol '{protocol}' is configured") diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 7a0436dd1fb1..eeab19899b02 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -1,6 +1,7 @@ from typing import Optional, Type -from vllm.config import TokenizerPoolConfig +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + TokenizerPoolConfig) from vllm.executor.ray_utils import ray from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup @@ -13,6 +14,22 @@ RayTokenizerGroupPool = None # type: ignore +def init_tokenizer_from_configs(model_config: ModelConfig, + scheduler_config: SchedulerConfig, + parallel_config: ParallelConfig, + enable_lora: bool): + init_kwargs = dict(tokenizer_id=model_config.tokenizer, + enable_lora=enable_lora, + max_num_seqs=scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision) + + return get_tokenizer_group(parallel_config.tokenizer_pool_config, + **init_kwargs) + + def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], **init_kwargs) -> BaseTokenizerGroup: tokenizer_cls: Type[BaseTokenizerGroup] diff --git a/vllm/utils.py b/vllm/utils.py index c4c17bfbefc6..51bd72977a22 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -290,6 +290,10 @@ def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: return _async_wrapper +class ProducerFinished: + pass + + def merge_async_iterators( *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]: """Merge multiple asynchronous iterators into a single iterator. @@ -298,9 +302,10 @@ def merge_async_iterators( When it yields, it yields a tuple (i, item) where i is the index of the iterator that yields the item. """ - queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue() + queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished, + Exception]] = asyncio.Queue() - finished = [False] * len(iterators) + producers = len(iterators) async def producer(i: int, iterator: AsyncIterator[T]): try: @@ -308,7 +313,8 @@ async def producer(i: int, iterator: AsyncIterator[T]): await queue.put((i, item)) except Exception as e: await queue.put(e) - finished[i] = True + # Signal to the consumer that we've finished + await queue.put(ProducerFinished()) _tasks = [ asyncio.create_task(producer(i, iterator)) @@ -316,9 +322,17 @@ async def producer(i: int, iterator: AsyncIterator[T]): ] async def consumer(): + remaining = producers try: - while not all(finished) or not queue.empty(): + while remaining or not queue.empty(): + # we think there is a race condition here item = await queue.get() + + if isinstance(item, ProducerFinished): + # Signal that a producer finished- not a real item + remaining -= 1 + continue + if isinstance(item, Exception): raise item yield item @@ -374,8 +388,10 @@ def get_distributed_init_method(ip: str, port: int) -> str: return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" -def get_open_port() -> int: - port = envs.VLLM_PORT +def get_open_port(port: Optional[int] = None) -> int: + if port is None: + # Default behavior here is to return a port for multi-gpu communication + port = envs.VLLM_PORT if port is not None: while True: try: From 69ea15e5cc823b2bc040921ce516807fb7357dd1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 2 Aug 2024 21:05:16 -0700 Subject: [PATCH 0042/3246] [ci][distributed] shorten wait time if server hangs (#7098) --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index dd8af8e3afe7..974fece49f4b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -50,7 +50,7 @@ def _nvml(): class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key - MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds + MAX_SERVER_START_WAIT_S = 120 # wait for server to start for 120 seconds def __init__( self, From 8c025fa7030350a81bfeb665c99ad622667bdac0 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 3 Aug 2024 12:31:27 +0800 Subject: [PATCH 0043/3246] [Frontend] Factor out chat message parsing (#7055) --- vllm/entrypoints/chat_utils.py | 28 +++++++++++++++---- vllm/entrypoints/openai/serving_chat.py | 17 ++++------- .../openai/serving_tokenization.py | 21 +++++++------- 3 files changed, 39 insertions(+), 27 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index fbb7f70b55e1..072450a6146e 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,7 +1,8 @@ import codecs -from dataclasses import dataclass, field +from dataclasses import dataclass from functools import lru_cache -from typing import Awaitable, Iterable, List, Optional, Union, cast, final +from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast, + final) # yapf conflicts with isort for this block # yapf: disable @@ -65,8 +66,7 @@ class ConversationMessage(TypedDict): @dataclass(frozen=True) class ChatMessageParseResult: messages: List[ConversationMessage] - mm_futures: List[Awaitable[MultiModalDataDict]] = field( - default_factory=list) + mm_futures: List[Awaitable[MultiModalDataDict]] def load_chat_template(chat_template: Optional[str]) -> Optional[str]: @@ -174,7 +174,7 @@ def _parse_chat_message_content_parts( return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) -def parse_chat_message_content( +def _parse_chat_message_content( message: ChatCompletionMessageParam, model_config: ModelConfig, tokenizer: PreTrainedTokenizer, @@ -190,3 +190,21 @@ def parse_chat_message_content( return _parse_chat_message_content_parts(role, content, model_config, tokenizer) + + +def parse_chat_messages( + messages: List[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: PreTrainedTokenizer, +) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]: + conversation: List[ConversationMessage] = [] + mm_futures: List[Awaitable[MultiModalDataDict]] = [] + + for msg in messages: + parse_result = _parse_chat_message_content(msg, model_config, + tokenizer) + + conversation.extend(parse_result.messages) + mm_futures.extend(parse_result.mm_futures) + + return conversation, mm_futures diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index ebb1d57fbb9a..d215754993e8 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,6 +1,5 @@ import time -from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List, - Optional) +from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional from typing import Sequence as GenericSequence from typing import Union @@ -11,7 +10,7 @@ from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, - parse_chat_message_content) + parse_chat_messages) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -92,15 +91,8 @@ async def create_chat_completion( tokenizer = await self.async_engine_client.get_tokenizer( lora_request) - conversation: List[ConversationMessage] = [] - mm_futures: List[Awaitable[MultiModalDataDict]] = [] - - for msg in request.messages: - chat_parsed_result = parse_chat_message_content( - msg, model_config, tokenizer) - - conversation.extend(chat_parsed_result.messages) - mm_futures.extend(chat_parsed_result.mm_futures) + conversation, mm_futures = parse_chat_messages( + request.messages, model_config, tokenizer) tool_dicts = None if request.tools is None else [ tool.model_dump() for tool in request.tools @@ -115,6 +107,7 @@ async def create_chat_completion( chat_template=request.chat_template or self.chat_template, **(request.chat_template_kwargs or {}), ) + assert isinstance(prompt, str) except Exception as e: logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index c4350881a27a..5b6b979b9b9e 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,13 +1,11 @@ from typing import List, Optional, Union from vllm.config import ModelConfig -# yapf conflicts with isort for this block -# yapf: disable from vllm.engine.protocol import AsyncEngineClient -from vllm.entrypoints.chat_utils import (ConversationMessage, - load_chat_template, - parse_chat_message_content) +from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable from vllm.entrypoints.openai.protocol import (DetokenizeRequest, DetokenizeResponse, ErrorResponse, @@ -17,8 +15,11 @@ # yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) +from vllm.logger import init_logger from vllm.utils import random_uuid +logger = init_logger(__name__) + class OpenAIServingTokenization(OpenAIServing): @@ -62,12 +63,12 @@ async def create_tokenize( if isinstance(request, TokenizeChatRequest): model_config = self.model_config - conversation: List[ConversationMessage] = [] + conversation, mm_futures = parse_chat_messages( + request.messages, model_config, tokenizer) - for message in request.messages: - result = parse_chat_message_content(message, model_config, - tokenizer) - conversation.extend(result.messages) + if mm_futures: + logger.warning( + "Multi-modal inputs are ignored during tokenization") prompt = tokenizer.apply_chat_template( add_generation_prompt=request.add_generation_prompt, From 04e55834254bf11770d544bbeebdbdb7731d9bbd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 2 Aug 2024 21:33:53 -0700 Subject: [PATCH 0044/3246] [ci][distributed] merge distributed test commands (#7097) Co-authored-by: Cyrus Leung --- .buildkite/test-pipeline.yaml | 27 ++------- .../test_basic_distributed_correctness.py | 50 ++++++++++------ .../test_chunked_prefill_distributed.py | 35 +++++------- .../distributed/test_multimodal_broadcast.py | 57 +++++++++---------- 4 files changed, 78 insertions(+), 91 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 573c3740f0bb..93b3e3fe9166 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -82,20 +82,9 @@ steps: num_gpus: 2 commands: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=llava-hf/llava-v1.6-mistral-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=llava-hf/llava-v1.6-mistral-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py + - TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py + - pytest -v -s distributed/test_chunked_prefill_distributed.py + - pytest -v -s distributed/test_multimodal_broadcast.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py @@ -107,11 +96,6 @@ steps: fast_check: true commands: - pytest -v -s distributed/test_pynccl.py - # We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here. - # See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context. - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py - label: Pipeline Parallelism Test @@ -279,9 +263,6 @@ steps: # NOTE: don't test llama model here, it seems hf implementation is buggy # see https://github.com/vllm-project/vllm/pull/5689 for details - pytest -v -s distributed/test_custom_all_reduce.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s -x lora/test_mixtral.py diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 7a0e5673b2cc..1de2ebab22db 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -1,15 +1,10 @@ """Compare the outputs of HF and distributed vLLM when using greedy sampling. -vLLM will allocate all the available memory, so we need to run the tests one -by one. The solution is to pass arguments (model name) by environment -variables. + Run: ```sh cd $VLLM_PATH/tests -TEST_DIST_MODEL=facebook/opt-125m pytest \ - distributed/test_basic_distributed_correctness.py -TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ - distributed/test_basic_distributed_correctness.py +pytest distributed/test_basic_distributed_correctness.py ``` """ import os @@ -19,27 +14,48 @@ from vllm.utils import cuda_device_count_stateless from ..models.utils import check_outputs_equal +from ..utils import fork_new_process_for_each_test -MODELS = [ - os.environ["TEST_DIST_MODEL"], -] -DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" +TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") @pytest.mark.skipif(cuda_device_count_stateless() < 2, reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize( + "model, distributed_executor_backend, attention_backend, test_suite", [ + ("facebook/opt-125m", "ray", "", "L4"), + ("facebook/opt-125m", "mp", "", "L4"), + ("meta-llama/Llama-2-7b-hf", "ray", "", "L4"), + ("meta-llama/Llama-2-7b-hf", "mp", "", "L4"), + ("facebook/opt-125m", "ray", "", "A100"), + ("facebook/opt-125m", "mp", "", "A100"), + ("facebook/opt-125m", "mp", "FLASHINFER", "A100"), + ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"), + ]) +@fork_new_process_for_each_test def test_models( hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, - max_tokens: int, + distributed_executor_backend: str, + attention_backend: str, + test_suite: str, ) -> None: - distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + + if test_suite != TARGET_TEST_SUITE: + pytest.skip(f"Skip test for {test_suite}") + + if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa + # test ray adag + os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1" + os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1" + + if attention_backend: + os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend + + dtype = "half" + max_tokens = 5 # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index 1ef085b93379..10921a3852f8 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -1,46 +1,39 @@ """Compare the outputs of HF and distributed vLLM when using greedy sampling. -vLLM will allocate all the available memory, so we need to run the tests one -by one. The solution is to pass arguments (model name) by environment -variables. Run: ```sh -TEST_DIST_MODEL=facebook/opt-125m pytest \ - test_chunked_prefill_distributed.py -TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ - test_chunked_prefill_distributed.py +pytest test_chunked_prefill_distributed.py ``` """ -import os import pytest from vllm.utils import cuda_device_count_stateless from ..models.utils import check_outputs_equal - -MODELS = [ - os.environ["TEST_DIST_MODEL"], -] -DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" +from ..utils import fork_new_process_for_each_test @pytest.mark.skipif(cuda_device_count_stateless() < 2, reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +@pytest.mark.parametrize("model, distributed_executor_backend", [ + ("facebook/opt-125m", "ray"), + ("meta-llama/Llama-2-7b-hf", "ray"), + ("facebook/opt-125m", "mp"), + ("meta-llama/Llama-2-7b-hf", "mp"), +]) +@fork_new_process_for_each_test def test_models( hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, - max_tokens: int, - chunked_prefill_token_size: int, + distributed_executor_backend: str, ) -> None: - distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + + dtype = "half" + max_tokens = 5 + chunked_prefill_token_size = 16 # Add a chunked prefill config. max_num_seqs = min(chunked_prefill_token_size, 256) diff --git a/tests/distributed/test_multimodal_broadcast.py b/tests/distributed/test_multimodal_broadcast.py index a99917f58694..2c96358e2e6f 100644 --- a/tests/distributed/test_multimodal_broadcast.py +++ b/tests/distributed/test_multimodal_broadcast.py @@ -1,44 +1,41 @@ """Compare the outputs of HF and distributed vLLM when using greedy sampling. -The second test will hang if more than one test is run per command, so we need -to run the tests one by one. The solution is to pass arguments (model name) by -environment variables. Run: ```sh -TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf \ - test_multimodal_broadcast.py -TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct \ - test_multimodal_broadcast.py +pytest -s -v test_multimodal_broadcast.py ``` """ -import os import pytest from vllm.utils import cuda_device_count_stateless -model = os.environ["TEST_DIST_MODEL"] - -if model.startswith("llava-hf/llava-1.5"): - from ..models.test_llava import models, run_test -elif model.startswith("llava-hf/llava-v1.6"): - from ..models.test_llava_next import models, run_test -else: - raise NotImplementedError(f"Unsupported model: {model}") - - -@pytest.mark.parametrize("tensor_parallel_size", [2]) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, image_assets, - tensor_parallel_size: int, dtype: str, max_tokens: int, - num_logprobs: int) -> None: - if cuda_device_count_stateless() < tensor_parallel_size: - pytest.skip( - f"Need at least {tensor_parallel_size} GPUs to run the test.") - - distributed_executor_backend = os.getenv("DISTRIBUTED_EXECUTOR_BACKEND") +from ..utils import fork_new_process_for_each_test + + +@pytest.mark.skipif(cuda_device_count_stateless() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("model, distributed_executor_backend", [ + ("llava-hf/llava-1.5-7b-hf", "ray"), + ("llava-hf/llava-v1.6-mistral-7b-hf", "ray"), + ("llava-hf/llava-1.5-7b-hf", "mp"), + ("llava-hf/llava-v1.6-mistral-7b-hf", "mp"), +]) +@fork_new_process_for_each_test +def test_models(hf_runner, vllm_runner, image_assets, model: str, + distributed_executor_backend: str) -> None: + + dtype = "half" + max_tokens = 5 + num_logprobs = 5 + tensor_parallel_size = 2 + + if model.startswith("llava-hf/llava-1.5"): + from ..models.test_llava import models, run_test + elif model.startswith("llava-hf/llava-v1.6"): + from ..models.test_llava_next import models, run_test + else: + raise NotImplementedError(f"Unsupported model: {model}") run_test( hf_runner, From a0d164567cd2a82d827c81a49a21e3f2c75a522d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 2 Aug 2024 22:32:04 -0700 Subject: [PATCH 0045/3246] [ci][distributed] disable ray dag tests (#7099) --- tests/distributed/test_pipeline_parallel.py | 43 +++++++++------------ 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index ab325e096692..8eb5ca9461c7 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -14,36 +14,29 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" -@pytest.mark.parametrize( - ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " - "MODEL_NAME, DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL"), [ - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - ]) +@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " + "MODEL_NAME, DIST_BACKEND"), + [ + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + ]) def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, - DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL): + DIST_BACKEND): if VLLM_MULTI_NODE and DIST_BACKEND == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") + USE_RAY_ADAG_NCCL = 0 + USE_RAY_ADAG = 0 + pp_args = [ # use half precision for speed and memory savings in CI environment "--dtype", From 0c25435daa0a399460a676e7c9b604bd23ea2d22 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 3 Aug 2024 13:36:14 +0800 Subject: [PATCH 0046/3246] [Model] Refactor and decouple weight loading logic for InternVL2 model (#7067) --- vllm/model_executor/models/intern_vit.py | 11 +++- vllm/model_executor/models/internvl.py | 82 ++++++++---------------- 2 files changed, 38 insertions(+), 55 deletions(-) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index c6c692deca2e..54c933e3e495 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -4,7 +4,7 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- -from typing import Optional +from typing import Iterable, Optional, Tuple import torch import torch.nn as nn @@ -16,6 +16,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader NORM2FN = { 'rms_norm': RMSNorm, @@ -268,3 +269,11 @@ def forward( encoder_outputs = self.encoder(inputs_embeds=hidden_states) return encoder_outputs + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index eabc283b1efd..474925127148 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -4,6 +4,7 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- +import itertools from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch @@ -414,58 +415,31 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - (".gate_up_proj", ".w1", 0), - (".gate_up_proj", ".w3", 1), - ] - params_dict = dict(self.named_parameters()) + def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], + prefix: str): for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if self.config.text_config.tie_word_embeddings \ - and "lm_head.weight" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - # We only do sharding for language model - # and not vision model for now. - if "vision_embed_tokens" in name and self.vision_embed_tokens: - continue - if weight_name not in name: - continue - param = params_dict[name.replace(weight_name, param_name)] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - if "wqkv" in name: - config = self.config.text_config - kv_groups = (config.num_attention_heads // - config.num_key_value_heads) - head_dim = config.hidden_size // config.num_attention_heads - loaded_weight = loaded_weight.view(-1, 2 + kv_groups, - head_dim, - loaded_weight.shape[-1]) - wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], - dim=1) - wq = wq.reshape(-1, wq.shape[-1]) - wk = wk.reshape(-1, wk.shape[-1]) - wv = wv.reshape(-1, wv.shape[-1]) - weight_loader = param.weight_loader - weight_loader(param, wq, 'q') - weight_loader(param, wk, 'k') - weight_loader(param, wv, 'v') - continue - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + name = name.split(".") + if prefix == name.pop(0): + name = ".".join(name) + yield name, loaded_weight + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # prepare weight iterators for components + vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + + # load vision encoder + vit_weights = self._filter_weights(vit_weights, "vision_model") + self.vision_model.load_weights(vit_weights) + + # load mlp projector + mlp_weights = self._filter_weights(mlp_weights, "mlp1") + mlp_params_dict = dict(self.mlp1.named_parameters()) + for name, loaded_weight in mlp_weights: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + llm_weights = self._filter_weights(llm_weights, "language_model") + self.language_model.load_weights(llm_weights) From fb2c1c86c196aa1531435d0c445fbea4c9dd4aa5 Mon Sep 17 00:00:00 2001 From: Zach Zheng Date: Fri, 2 Aug 2024 22:38:15 -0700 Subject: [PATCH 0047/3246] [Bugfix] Fix block table for seqs that have prefix cache hits (#7018) --- tests/prefix_caching/test_prefix_caching.py | 56 +++++++++++++++++++++ vllm/attention/backends/flash_attn.py | 12 +++-- 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 7985001d34eb..9821dbd066a5 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -6,10 +6,17 @@ import pytest +from tests.kernels.utils import override_backend_env_variable from vllm.block import PhysicalTokenBlock from vllm.core.block_manager_v1 import CachedBlockAllocator from vllm.utils import Device +from ..models.utils import check_outputs_equal + +MODELS = [ + "facebook/opt-125m", +] + @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("num_blocks", [16]) @@ -76,3 +83,52 @@ def test_eviction(num_blocks: int, ): assert (realloc_block != new_block) assert (new_block.block_hash == new_block_hash) assert (new_block.block_number == 2) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("cached_position", [0, 1]) +@pytest.mark.parametrize("use_v2_block_manager", [False, True]) +def test_mixed_requests( + hf_runner, + vllm_runner, + example_prompts, + model: str, + backend: str, + dtype: str, + max_tokens: int, + cached_position: int, + use_v2_block_manager: bool, + monkeypatch, +) -> None: + """ + Test the case when some sequences have the prefix cache hit + and the others don't. The cached position determines where + the sequence is at among the batch of prefills. + """ + override_backend_env_variable(monkeypatch, backend) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + cached_prompt = example_prompts[cached_position] + with vllm_runner( + model, + dtype=dtype, + enable_prefix_caching=True, + use_v2_block_manager=use_v2_block_manager, + ) as vllm_model: + # Run the first prompt so the cache is populated + vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) + + # Run all the promopts + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 00654dca2adf..26b3159682b3 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -209,6 +209,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False self.input_builder = input_builder self.runner = input_builder.runner @@ -219,7 +220,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool): + chunked_prefill_enabled: bool, prefix_cache_hit: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. @@ -252,7 +253,7 @@ def _add_seq_group( # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. block_table = [] - if inter_data.prefix_cache_hit: + if prefix_cache_hit: # NOTE(woosuk): For flash-attn, the block table should # include the entries for the incoming prefill tokens. block_table = block_tables[seq_id] @@ -281,9 +282,14 @@ def build(self, seq_lens: List[int], query_lens: List[int], -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 From 99d7cabd7b8b789e837a0682982fd7ec94a843b1 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 3 Aug 2024 13:40:19 +0800 Subject: [PATCH 0048/3246] [LoRA] ReplicatedLinear support LoRA (#7081) --- tests/lora/test_layers.py | 103 ++++++++++++++++++++++++++++++++++++++ vllm/lora/layers.py | 94 ++++++++++++++++++++++++++++++++++ vllm/lora/utils.py | 2 + 3 files changed, 199 insertions(+) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 6f33f56616fc..d8cc68d5e959 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -22,6 +22,7 @@ MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLora, QKVParallelLinearWithLora, + ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) # yapf: enable @@ -31,6 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope @@ -545,6 +547,107 @@ def _pretest(): atol=atol) +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("stage", STAGES) +def test_linear_replicated(dist_init, num_loras, device, stage) -> None: + + torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_linear_replicated_layer(): + + linear = ReplicatedLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = ReplicatedLinearWithLoRA(linear) + + lora_linear.create_lora_weights(max_loras, lora_config) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, lora_linear = create_random_linear_replicated_layer() + lora_linear.set_mapping(punica_wrapper) + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + ) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results: List[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = linear(input_)[0] + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + ) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("orientation", ["row", "column"]) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 3176badabbc7..42ec99e6ea2c 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -21,6 +21,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import ( @@ -262,6 +263,99 @@ def can_replace_layer( return type(source_layer) is VocabParallelEmbedding +class ReplicatedLinearWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: ReplicatedLinear) -> None: + super().__init__() + self.base_layer = base_layer + self.input_size = self.base_layer.input_size + self.output_size = self.base_layer.output_size + self.device = _get_lora_device(self.base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + lora_a_output_size = lora_config.max_lora_rank + self.lora_a_stacked = torch.zeros( + max_loras, + 1, + lora_a_output_size, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + max_loras, + 1, + self.output_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, + self.lora_b_stacked, 1.0) + return output + + def forward(self, input_): + """Forward of ReplicatedLinearWithLoRA + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output = self.apply(input_, bias) + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ReplicatedLinear + + class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): """ LoRA on top of ColumnParallelLinear layer. diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 4513337299e1..ee983328e2c5 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -23,6 +23,7 @@ MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLora, QKVParallelLinearWithLora, + ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) # yapf: enable @@ -38,6 +39,7 @@ QKVParallelLinearWithLora, MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA, + ReplicatedLinearWithLoRA, LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLora, From 67d745cc68d9ad31bf683a88f00a1aee9782f541 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 2 Aug 2024 23:52:44 -0700 Subject: [PATCH 0049/3246] [CI] Temporarily turn off H100 performance benchmark (#7104) --- .../benchmark-pipeline.yaml | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 02c0ee534d72..8490c9f1da22 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -42,20 +42,20 @@ steps: - name: devshm emptyDir: medium: Memory - - label: "H100" - agents: - queue: H100 - plugins: - - docker#v5.11.0: - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - command: - - bash - - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh - mount-buildkite-agent: true - propagate-environment: true - ipc: host - gpus: all - environment: - - VLLM_USAGE_SOURCE - - HF_TOKEN + # - label: "H100" + # agents: + # queue: H100 + # plugins: + # - docker#v5.11.0: + # image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + # command: + # - bash + # - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh + # mount-buildkite-agent: true + # propagate-environment: true + # ipc: host + # gpus: all + # environment: + # - VLLM_USAGE_SOURCE + # - HF_TOKEN From 44dcb52e39ee6b2c9ef9e6497525e1e183c9d24b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 3 Aug 2024 10:44:53 -0700 Subject: [PATCH 0050/3246] [ci][test] finalize fork_new_process_for_each_test (#7114) --- tests/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/utils.py b/tests/utils.py index 974fece49f4b..666694299d39 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -360,6 +360,9 @@ def wait_for_gpu_memory_to_clear(devices: List[int], def fork_new_process_for_each_test(f): + """Decorator to fork a new process for each test function. + See https://github.com/vllm-project/vllm/issues/7053 for more details. + """ @functools.wraps(f) def wrapper(*args, **kwargs): From 825b044863a8e3af82a82a80cd2617486cc829ca Mon Sep 17 00:00:00 2001 From: Jeff Fialho Date: Sat, 3 Aug 2024 20:01:38 -0300 Subject: [PATCH 0051/3246] [Frontend] Warn if user `max_model_len` is greater than derived `max_model_len` (#7080) Signed-off-by: Jefferson Fialho Co-authored-by: Nick Hill --- vllm/config.py | 19 +++++++++++++------ vllm/envs.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ef56e2b6395b..028f4eed8f4a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -6,6 +6,7 @@ import torch from transformers import PretrainedConfig +import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry @@ -1541,15 +1542,21 @@ def _get_and_verify_max_len( "Disabling sliding window is not supported for models " "model_max_length in the config. Please raise an issue " "so we can investigate.") - pass else: - raise ValueError( + msg = ( f"User-specified max_model_len ({max_model_len}) is greater " - "than the derived max_model_len " - f"({max_len_key}={derived_max_model_len} or model_max_length=" + f"than the derived max_model_len ({max_len_key}=" + f"{derived_max_model_len} or model_max_length=" f"{model_max_length} in model's config.json). This may lead " - "to incorrect model outputs or CUDA errors. Make sure the " - "value is correct and within the model context size.") + "to incorrect model outputs or CUDA errors.") + if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: + logger.warning( + "%s Make sure the value is correct and within the " + "model context size.", msg) + else: + raise ValueError( + f"{msg} To allow overriding this maximum, set " + "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1") return int(max_model_len) diff --git a/vllm/envs.py b/vllm/envs.py index a78bad6a2b27..089a39d8e029 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -50,6 +50,7 @@ VLLM_NO_DEPRECATION_WARNING: bool = False CMAKE_BUILD_TYPE: Optional[str] = None VERBOSE: bool = False + VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False def get_default_cache_root(): @@ -331,6 +332,15 @@ def get_default_config_root(): # If set, vllm will skip the deprecation warnings. "VLLM_NO_DEPRECATION_WARNING": lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))), + + # If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows + # the user to specify a max sequence length greater than + # the max length derived from the model's config.json. + # To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1. + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": + lambda: + (os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in + ("1", "true")), } # end-env-vars-definition From 654bc5ca49bde0969bc95e4b1dbe7fabbb8f631c Mon Sep 17 00:00:00 2001 From: Yihuan Bu <88394319+kevinbu233@users.noreply.github.com> Date: Sat, 3 Aug 2024 23:12:09 -0400 Subject: [PATCH 0052/3246] Support for guided decoding for offline LLM (#6878) Co-authored-by: Cyrus Leung --- docs/source/conf.py | 1 + tests/entrypoints/{openai => }/conftest.py | 22 ++- tests/entrypoints/llm/test_guided_generate.py | 142 ++++++++++++++++++ vllm/entrypoints/llm.py | 44 +++++- vllm/entrypoints/openai/protocol.py | 26 +++- .../guided_decoding/__init__.py | 26 +++- .../guided_decoding/guided_fields.py | 38 +++++ .../lm_format_enforcer_decoding.py | 39 +++++ .../guided_decoding/outlines_decoding.py | 26 +++- 9 files changed, 352 insertions(+), 12 deletions(-) rename tests/entrypoints/{openai => }/conftest.py (83%) create mode 100644 tests/entrypoints/llm/test_guided_generate.py create mode 100644 vllm/model_executor/guided_decoding/guided_fields.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 1093b30bca11..f1eb8524d4e9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -111,6 +111,7 @@ def setup(app): "tqdm", "tensorizer", "pynvml", + "outlines", ] for mock_target in autodoc_mock_imports: diff --git a/tests/entrypoints/openai/conftest.py b/tests/entrypoints/conftest.py similarity index 83% rename from tests/entrypoints/openai/conftest.py rename to tests/entrypoints/conftest.py index 0837644f26bd..e7ef5637c8cc 100644 --- a/tests/entrypoints/openai/conftest.py +++ b/tests/entrypoints/conftest.py @@ -1,6 +1,26 @@ import pytest +@pytest.fixture +def sample_prompts(): + return [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + +@pytest.fixture +def sample_token_ids(): + return [ + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], + ] + + @pytest.fixture def sample_regex(): return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" @@ -66,4 +86,4 @@ def sample_sql_statements(): table: "table_1" | "table_2" condition: column "=" number number: "1" | "2" -""") \ No newline at end of file +""") diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py new file mode 100644 index 000000000000..873e11542125 --- /dev/null +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -0,0 +1,142 @@ +import json +import re +import weakref + +import jsonschema +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams + +from ...conftest import cleanup + +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, max_model_len=1024) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + del llm + cleanup() + + +@pytest.mark.skip_global_cleanup +def test_guided_regex(sample_regex, llm): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + ) + outputs = llm.generate( + prompts=[ + f"Give an example IPv4 address with this regex: {sample_regex}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_regex=sample_regex)) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + print(generated_text) + assert generated_text is not None + assert re.fullmatch(sample_regex, generated_text) is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +def test_guided_json_completion(sample_json_schema, llm): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + ) + outputs = llm.generate( + prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_json=sample_json_schema)) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, schema=sample_json_schema) + + +@pytest.mark.skip_global_cleanup +def test_guided_choice_completion(sample_guided_choice, llm): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + ) + outputs = llm.generate( + prompts="The best language for type-safe systems programming is ", + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_choice=sample_guided_choice)) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + print(generated_text) + assert generated_text is not None + assert generated_text in sample_guided_choice + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +def test_guided_grammar(sample_sql_statements, llm): + + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=1000, + ) + outputs = llm.generate( + prompts=("Generate a sql state that select col_1 from " + "table_1 where it is equals to 1"), + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_grammar=sample_sql_statements)) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(sample_sql_statements) + parser.parse(generated_text) + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( + " ", "") + + assert generated_text.strip() == ground_truth + + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 62309ed345b1..262cba79e571 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -10,6 +10,9 @@ parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + GuidedDecodingRequest, get_local_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -262,6 +265,8 @@ def generate( use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -303,6 +308,14 @@ def generate( else: inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + if isinstance(guided_options_request, dict): + if len(guided_options_request) > 1: + raise ValueError( + "You can only use one guided decoding but multiple is " + f"specified: {guided_options_request}") + guided_options_request = GuidedDecodingRequest( + **guided_options_request) + if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() @@ -311,7 +324,8 @@ def generate( inputs=inputs, params=sampling_params, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + guided_options=guided_options_request) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) @@ -508,6 +522,7 @@ def _validate_and_add_requests( Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], + guided_options: Optional[GuidedDecodingRequest] = None, ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. @@ -523,6 +538,15 @@ def _validate_and_add_requests( raise ValueError("The lengths of prompts and lora_request " "must be the same.") + if isinstance(params, list): + params = [ + self._add_guided_processor(param, guided_options) + if isinstance(param, SamplingParams) else param + for param in params + ] + elif isinstance(params, SamplingParams): + params = self._add_guided_processor(params, guided_options) + # Add requests to the engine. for i, request_inputs in enumerate(inputs): self._add_request( @@ -548,6 +572,24 @@ def _add_request( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) + def _add_guided_processor( + self, + params: SamplingParams, + guided_options: Optional[GuidedDecodingRequest] = None): + if guided_options: + if guided_options.guided_decoding_backend is None: + decoding_config = self.llm_engine.get_decoding_config() + guided_options.guided_decoding_backend = ( + decoding_config.guided_decoding_backend) + guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa + guided_options.guided_decoding_backend, guided_options, + self.get_tokenizer()) + if guided_logits_processor: + if params.logits_processors is None: + params.logits_processors = [] + params.logits_processors.append(guided_logits_processor) + return params + def _run_engine( self, *, use_tqdm: bool ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3b35ae1ebd70..76318a127122 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,6 +1,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time +from argparse import Namespace from typing import Any, Dict, List, Literal, Optional, Union import torch @@ -14,6 +15,23 @@ from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.utils import random_uuid +# torch is mocked during docs generation, +# so we have to provide the values as literals +_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) + +try: + from sphinx.ext.autodoc.mock import _MockModule + + if isinstance(torch, _MockModule): + _LONG_INFO = _MOCK_LONG_INFO + else: + _LONG_INFO = torch.iinfo(torch.long) +except ModuleNotFoundError: + _LONG_INFO = torch.iinfo(torch.long) + +assert _LONG_INFO.min == _MOCK_LONG_INFO.min +assert _LONG_INFO.max == _MOCK_LONG_INFO.max + class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields @@ -108,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel): n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 response_format: Optional[ResponseFormat] = None - seed: Optional[int] = Field(None, - ge=torch.iinfo(torch.long).min, - le=torch.iinfo(torch.long).max) + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None @@ -327,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel): max_tokens: Optional[int] = 16 n: int = 1 presence_penalty: Optional[float] = 0.0 - seed: Optional[int] = Field(None, - ge=torch.iinfo(torch.long).min, - le=torch.iinfo(torch.long).max) + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 50aa3ec379f4..4a2476dd6314 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -3,9 +3,10 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, CompletionRequest) -from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( - get_lm_format_enforcer_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) from vllm.sampling_params import LogitsProcessor @@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor( return await get_outlines_guided_decoding_logits_processor( request, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_lm_format_enforcer_guided_decoding_logits_processor) return await get_lm_format_enforcer_guided_decoding_logits_processor( request, tokenizer) @@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor( "Must be one of 'outlines, 'lm-format-enforcer'") +def get_local_guided_decoding_logits_processor( + guided_decoding_backend: str, guided_options: GuidedDecodingRequest, + tokenizer) -> Optional[LogitsProcessor]: + # request = _adapt_request_for_tool_use(request) + + if guided_decoding_backend == 'outlines': + return get_local_outlines_guided_decoding_logits_processor( + guided_options, tokenizer) + if guided_decoding_backend == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_local_lm_format_enforcer_guided_decoding_logits_processor) + return get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_options, tokenizer) + + raise ValueError( + f"Unknown guided decoding backend '{guided_decoding_backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer'") + + def _adapt_request_for_tool_use(request: Union[CompletionRequest, ChatCompletionRequest]): # the legacy completion API does not support tool use diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py new file mode 100644 index 000000000000..3082ac1510cc --- /dev/null +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, TypedDict, Union + +from pydantic import BaseModel + + +class LLMGuidedOptions(TypedDict, total=False): + guided_json: Union[Dict, BaseModel, str] + guided_regex: str + guided_choice: List[str] + guided_grammar: str + guided_decoding_backend: str + guided_whitespace_pattern: str + guided_json_object: bool + + +@dataclass +class GuidedDecodingRequest: + """One of the fields will be used to retrieve the logit processor.""" + guided_json: Optional[Union[Dict, BaseModel, str]] = None + guided_regex: Optional[str] = None + guided_choice: Optional[List[str]] = None + guided_grammar: Optional[str] = None + guided_decoding_backend: Optional[str] = None + guided_whitespace_pattern: Optional[str] = None + guided_json_object: Optional[bool] = None + + def __post_init__(self): + """Validate that some fields are mutually exclusive.""" + guide_count = sum([ + self.guided_json is not None, self.guided_regex is not None, + self.guided_choice is not None, self.guided_grammar is not None, + self.guided_json_object is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding but multiple are " + f"specified: {self.__dict__}") diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index d0a5ca5592f9..b2188c9cbc2b 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -12,7 +12,10 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) from vllm.sampling_params import LogitsProcessor @@ -54,6 +57,42 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( return logits_processor +def get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_options: GuidedDecodingRequest, + tokenizer) -> Optional[LogitsProcessor]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + + tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer) + character_level_parser: CharacterLevelParser + if guided_options.guided_json: + schema = _normalize_json_schema_object(guided_options.guided_json) + character_level_parser = JsonSchemaParser(schema) + elif guided_options.guided_choice: + character_level_parser = UnionParser( + [StringParser(choice) for choice in guided_options.guided_choice]) + elif guided_options.guided_regex: + character_level_parser = RegexParser(guided_options.guided_regex) + elif guided_options.guided_grammar: + # CFG grammar not supported by LMFE, revert to outlines + return get_local_outlines_guided_decoding_logits_processor( + guided_options, tokenizer) + elif guided_options.guided_json_object: + # None means any json object + character_level_parser = JsonSchemaParser(None) + else: + return None + + logits_processor = build_vllm_logits_processor(tokenizer_data, + character_level_parser) + return logits_processor + + def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: if isinstance(schema, str): return json_loads(schema) diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 721f7e0530cb..bc62224dabec 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -10,6 +10,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) @@ -77,8 +79,27 @@ async def get_outlines_guided_decoding_logits_processor( mode, request.guided_whitespace_pattern) +def get_local_outlines_guided_decoding_logits_processor( + guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, + None]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + guide, mode = _get_guide_and_mode(guided_options) + if not guide or not mode: + return None + + return _get_logits_processor(guide, tokenizer, mode, + guided_options.guided_whitespace_pattern) + + def _get_guide_and_mode( - request: Union[CompletionRequest, ChatCompletionRequest] + request: Union[CompletionRequest, ChatCompletionRequest, + GuidedDecodingRequest] ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: if request.guided_json: @@ -102,7 +123,8 @@ def _get_guide_and_mode( return choices_regex, GuidedDecodingMode.CHOICE elif request.guided_grammar: return request.guided_grammar, GuidedDecodingMode.GRAMMAR - elif (request.response_format is not None + elif (not isinstance(request, GuidedDecodingRequest) + and request.response_format is not None and request.response_format.type == "json_object"): return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR else: From 9fadc7b7a03f798036d0e8710587870e13bae759 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 3 Aug 2024 22:03:46 -0700 Subject: [PATCH 0053/3246] [misc] add zmq in collect env (#7119) --- collect_env.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/collect_env.py b/collect_env.py index 083cb768f539..244e4ddd5aed 100644 --- a/collect_env.py +++ b/collect_env.py @@ -65,6 +65,7 @@ "optree", "nccl", "transformers", + "zmq", } DEFAULT_PIP_PATTERNS = { @@ -77,6 +78,7 @@ "onnx", "nccl", "transformers", + "zmq", } From 83c644fe7ecee05d3ebe5057acb6e008d7e81eb8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 4 Aug 2024 00:22:19 -0700 Subject: [PATCH 0054/3246] [core][misc] simply output processing with shortcut code path (#7117) --- vllm/engine/output_processor/single_step.py | 39 ++++++++++++++++----- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 59eb4bc439d1..4a46c93f8425 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -81,6 +81,29 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: + sampling_params = seq_group.sampling_params + if sampling_params.n == 1 and not sampling_params.use_beam_search: + # only have one output sample + sample = outputs.samples[0] + # only have one sequence + seq = seq_group.seqs[0] + seq.append_token_id(sample.output_token, sample.logprobs) + if sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + else: + new_char_count = 0 + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count, + sampling_params, + lora_req=seq_group.lora_request, + ) + if seq.is_finished(): + for scheduler in self.scheduler: + scheduler.free_seq(seq) + return + # Process samples samples = outputs.samples parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) @@ -127,20 +150,20 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, child_seqs.append((parent, parent)) for seq, _ in child_seqs: - if seq_group.sampling_params.detokenize and self.detokenizer: + if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( - seq, seq_group.sampling_params) + seq, sampling_params) else: new_char_count = 0 self.stop_checker.maybe_stop_sequence( seq, new_char_count, - seq_group.sampling_params, + sampling_params, lora_req=seq_group.lora_request, ) # Non-beam search case - if not seq_group.sampling_params.use_beam_search: + if not sampling_params.use_beam_search: # For newly created child sequences, add them to the sequence group # and fork them in block manager if they are not finished. for seq, parent in child_seqs: @@ -164,8 +187,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Select the child sequences to keep in the sequence group. selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = [] unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = [] - beam_width = seq_group.sampling_params.best_of - length_penalty = seq_group.sampling_params.length_penalty + beam_width = sampling_params.best_of + length_penalty = sampling_params.length_penalty # Select the newly finished sequences with the highest scores # to replace existing finished sequences. @@ -219,8 +242,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, best_running_seq = running_child_seqs[0][0] current_worst_seq = all_finished_seqs[beam_width - 1][0] stop_beam_search = self._check_beam_search_early_stopping( - seq_group.sampling_params.early_stopping, - seq_group.sampling_params, best_running_seq, current_worst_seq) + sampling_params.early_stopping, sampling_params, + best_running_seq, current_worst_seq) if stop_beam_search: # Stop the beam search and remove all the running sequences from From 179a6a36f2a585df49ce9c26701b1b9d894bd00e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 4 Aug 2024 16:12:41 +0800 Subject: [PATCH 0055/3246] [Model]Refactor MiniCPMV (#7020) Co-authored-by: Cyrus Leung --- docs/source/models/supported_models.rst | 2 +- .../models/idefics2_vision_model.py | 296 +++++ vllm/model_executor/models/minicpmv.py | 1023 ++++++++++------- vllm/model_executor/models/na_vit.py | 2 +- 4 files changed, 937 insertions(+), 386 deletions(-) create mode 100644 vllm/model_executor/models/idefics2_vision_model.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index a1ea366b82b0..fd5d154006ae 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -220,7 +220,7 @@ Vision Language Models - Phi-3-Vision - :code:`microsoft/Phi-3-vision-128k-instruct`, etc. - - * - :code:`MiniCPM-V` + * - :code:`MiniCPMV` - MiniCPM-V - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. - diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py new file mode 100644 index 000000000000..cc448ed28d2d --- /dev/null +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -0,0 +1,296 @@ +# coding=utf-8 + +# adapted from https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/models/idefics2/modeling_idefics2.py +# Copyright 2024 The vLLM team. +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Idefics2 model.""" + +from typing import Optional + +import torch +from torch import nn +from transformers.models.idefics2.configuration_idefics2 import ( + Idefics2Config, Idefics2VisionConfig) +from xformers import ops as xops + +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig + + +class Idefics2VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings + ` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision + Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the + need to resize them to the same fixed size. In particular, we start from the + original pre-trained SigLIP model(which uses images of fixed-size square + images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, config: Idefics2VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + + def forward( + self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + ) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, + 1 / self.num_patches_per_side) + position_ids = torch.full(size=(batch_size, + max_nb_patches_h * max_nb_patches_w), + fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + bucket_coords_h = torch.bucketize(fractional_coords_h, + boundaries, + right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, + boundaries, + right=True) + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics2VisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Idefics2Config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501 + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + ) + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + batch_size, q_len, _ = hidden_states.size() + qkv, _ = self.qkv_proj( + hidden_states + ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim + query_states, key_states, value_states = qkv.chunk(3, dim=-1) + query_states = query_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + # see: https://facebookresearch.github.io/xformers/components/ops.html + out = xops.memory_efficient_attention_forward( + query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale, + ) + out = out.view(batch_size, q_len, -1) + attn_output, _ = self.out_proj(out) + return attn_output + + +class Idefics2VisionMLP(nn.Module): + + def __init__( + self, + config: Idefics2Config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Idefics2EncoderLayer(nn.Module): + + def __init__(self, config: Idefics2Config): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Idefics2VisionAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = Idefics2VisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + + """ + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Idefics2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention + layers. Each layer is a + [`Idefics2EncoderLayer`]. + + Args: + config: Idefics2Config + """ + + def __init__(self, config: Idefics2Config): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + Idefics2EncoderLayer(config) + for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor: + r""" + Args: + inputs_embeds (torch.Tensor): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. + This is useful if you want more control over how to convert + `input_ids` indices into associated vectorsthan the model's + internal embedding lookup matrix. + """ + hidden_states = inputs_embeds + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states) + hidden_states = layer_outputs + return hidden_states + + +class Idefics2VisionTransformer(nn.Module): + + def __init__(self, config: Idefics2VisionConfig): + super().__init__() + embed_dim = config.hidden_size + self.config = config + self.embeddings = Idefics2VisionEmbeddings(config) + self.encoder = Idefics2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + ) -> torch.tensor: + hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask) + encoder_outputs = self.encoder(hidden_states) + last_hidden_state = self.post_layernorm(encoder_outputs) + return last_hidden_state diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 2a7fe7ba0eba..095bb49f6ba7 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -24,7 +24,8 @@ import math import re from functools import partial -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import (Any, Callable, Iterable, List, Optional, Tuple, TypedDict, + Union) import numpy as np import torch @@ -38,11 +39,14 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.llama import LlamaModel @@ -54,12 +58,45 @@ cached_get_tokenizer) from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from .idefics2_vision_model import Idefics2VisionTransformer + +logger = init_logger(__name__) + _KEYS_TO_MODIFY_MAPPING = { "llm.lm_head": "lm_head", "llm.model": "llm", } +class MiniCPMVImagePixelInputs(TypedDict): + pixel_values: List[torch.Tensor] + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + + Note that the image size may vary, so we pass it as a list + instead of a batched tensor. + """ + + image_bounds: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(start, stop)` format. + """ + + tgt_sizes: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(height, width)` format. + """ + + +MiniCPMVImageInputs = MiniCPMVImagePixelInputs + +DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) + + def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): # abs_pos: L, C # tgt_size: (H, W) @@ -68,23 +105,25 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): # tgt_size = int(math.sqrt(tgt_size)) dtype = abs_pos.dtype - return F.interpolate( + return (F.interpolate( abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), size=(tgt_size[0], tgt_size[1]), mode="bicubic", align_corners=False, - ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)) # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 -def get_2d_sincos_pos_embed(embed_dim: int, - grid_size: Union[int, Tuple[int, int]], - cls_token: bool = False, - version: Tuple[int, int] = (2, 0)): +def get_2d_sincos_pos_embed( + embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: Tuple[int, int] = (2, 0), +): """ grid_size: int of the grid height and width return: - pos_embed: [grid_size*grid_size, embed_dim] or + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ if isinstance(grid_size, int): @@ -109,7 +148,7 @@ def get_2d_sincos_pos_embed(embed_dim: int, def get_2d_sincos_pos_embed_from_grid(embed_dim: int, - grid: Union[int, Tuple[int, int]], + grid: np.ndarray, version: Tuple[int, int] = (2, 0)): assert embed_dim % 2 == 0 @@ -127,7 +166,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, def get_1d_sincos_pos_embed_from_grid(embed_dim: int, - pos: int, + pos: np.ndarray, version: Tuple[int, int] = (2, 0)): """ embed_dim: output dimension for each position @@ -136,24 +175,24 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2. - omega = 1. / 10000**omega # (D/2,) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) if version == (2, 0): pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) else: - out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product + out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product emb_sin = np.sin(out) # (H, W, D/2) emb_cos = np.cos(out) # (H, W, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) return emb -class Resampler(nn.Module): +class BaseResampler(nn.Module): """ A 2D perceiver-resampler network with one cross attention layers by (grid_size**2) learnable queries and 2d sincos pos_emb @@ -161,89 +200,151 @@ class Resampler(nn.Module): A tensor with the shape of (grid_size**2, embed_dim) """ - default_norm_layer = partial(nn.LayerNorm, eps=1e-6) - - def __init__(self, - num_queries: int, - grid_size: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: nn.Module = default_norm_layer, - adaptive: bool = False, - max_size: Tuple[int, int] = (70, 70), - version: Tuple[int, int] = (2, 0)): + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + ) -> None: super().__init__() - self.version = version - if self.version == (2, 0): - self.num_queries = grid_size**2 - else: - self.num_queries = num_queries - self.max_size = max_size + self.num_queries = num_queries self.embed_dim = embed_dim self.num_heads = num_heads - self.adaptive = adaptive self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) - trunc_normal_(self.query, std=.02) + trunc_normal_(self.query, std=0.02) if kv_dim is not None and kv_dim != embed_dim: - self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) + self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False) else: - self.kv_proj = nn.Identity() + # Maintain the same return value with ReplicatedLinear.forward + self.kv_proj = lambda *args, **kwargs: ( + nn.Identity()(*args, **kwargs), + None, + ) self.attn = nn.MultiheadAttention(embed_dim, num_heads) self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) - self.ln_post = norm_layer(embed_dim) self.proj = nn.Parameter( (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim)) - if self.version == (2, 0): - self.pos_embed = nn.Parameter( - torch.from_numpy( - get_2d_sincos_pos_embed( - embed_dim, grid_size, - version=self.version)).float()).requires_grad_(False) + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class Resampler2(BaseResampler): + + def __init__( + self, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + adaptive: bool = False, + ) -> None: + super().__init__(grid_size**2, embed_dim, num_heads, kv_dim, + norm_layer) + + self.adaptive = adaptive + + pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, + grid_size, + version=(2, 0)) + self.pos_embed = nn.Parameter( + torch.from_numpy(pos_embed_arr).float()).requires_grad_(False) + + self.apply(self._init_weights) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ): + if self.adaptive: + pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, + tgt_sizes, + version=(2, 0)) + pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, + dtype=x.dtype) else: - self._set_2d_pos_cache(self.max_size) + pos_embed = get_abs_pos(self.pos_embed, tgt_sizes) + + x, _ = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn( + self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask, + )[0] + x = out.permute(1, 0, 2) + + x = self.ln_post(x) + x = x @ self.proj + return x + + +class Resampler2_5(BaseResampler): + + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: Tuple[int, int] = (70, 70), + ) -> None: + super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer) + + self.max_size = max_size + self._set_2d_pos_cache(self.max_size) self.apply(self._init_weights) def _set_2d_pos_cache(self, max_size: Tuple[int, int], - device: torch.types.Device = 'cpu'): - pos_embed = torch.from_numpy( - get_2d_sincos_pos_embed(self.embed_dim, - max_size, - version=self.version)).float().to(device) + device: torch.types.Device = "cpu") -> None: + pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, + max_size, + version=(2, 5)) + pos_embed = torch.from_numpy(pos_embed_arr).float().to(device) self.register_buffer("pos_embed", pos_embed, persistent=False) def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, - device: torch.types.Device): - max_h = torch.max(tgt_sizes[:, 0]) - max_w = torch.max(tgt_sizes[:, 1]) + device: torch.types.Device) -> None: + max_h = tgt_sizes[:, 0].max().item() + max_w = tgt_sizes[:, 1].max().item() + assert isinstance(max_h, int) and isinstance(max_w, int) + if max_h > self.max_size[0] or max_w > self.max_size[1]: - self.max_size = [ + self.max_size = ( max(max_h, self.max_size[0]), - max(max_w, self.max_size[1]) - ] + max(max_w, self.max_size[1]), + ) self._set_2d_pos_cache(self.max_size, device) - def _init_weights(self, m: nn.Module): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward_2_5(self, - x: torch.Tensor, - tgt_sizes: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, + tgt_sizes: torch.Tensor) -> torch.Tensor: assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] @@ -254,25 +355,25 @@ def forward_2_5(self, self._adjust_pos_cache(tgt_sizes, device=device) - max_patch_len = torch.max(patch_len) + max_patch_len = patch_len.max().item() + assert isinstance(max_patch_len, int) + key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device) pos_embed = [] for i in range(bs): - tgt_h, tgt_w = tgt_sizes[i] + tgt_h, tgt_w = tgt_sizes[i].tolist() pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( (tgt_h * tgt_w, -1)).to(dtype)) # patches * D key_padding_mask[i, patch_len[i]:] = True - pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute( 1, 0, 2) # BLD => L * B * D - - x = self.kv_proj(x) # B * L * D + x, _ = self.kv_proj(x) # B * L * D x = self.ln_kv(x).permute(1, 0, 2) # L * B * D q = self.ln_q(self.query) # Q * D @@ -281,7 +382,8 @@ def forward_2_5(self, self._repeat(q, bs), # Q * B * D x + pos_embed, # L * B * D + L * B * D x, - key_padding_mask=key_padding_mask)[0] + key_padding_mask=key_padding_mask, + )[0] # out: Q * B * D x = out.permute(1, 0, 2) # B * Q * D @@ -289,45 +391,6 @@ def forward_2_5(self, x = x @ self.proj return x - def forward_2(self, - x: torch.Tensor, - tgt_sizes: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None): - if self.adaptive: - pos_embed = torch.Tensor( - get_2d_sincos_pos_embed(self.embed_dim, - tgt_sizes)).float().to(device=x.device, - dtype=x.dtype) - else: - pos_embed = get_abs_pos(self.pos_embed, tgt_sizes) - - x = self.kv_proj(x) - x = self.ln_kv(x).permute(1, 0, 2) - - N = x.shape[1] - q = self.ln_q(self.query) - out = self.attn(self._repeat(q, N) + self.pos_embed.unsqueeze(1), - x + pos_embed.unsqueeze(1), - x, - attn_mask=attn_mask)[0] - x = out.permute(1, 0, 2) - - x = self.ln_post(x) - x = x @ self.proj - return x - - def forward(self, - x: torch.Tensor, - tgt_sizes: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None): - if self.version == (2, 0): - return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask) - else: - return self.forward_2_5(x, tgt_sizes=tgt_sizes) - - def _repeat(self, query, N: int): - return query.unsqueeze(1).repeat(1, N, 1) - def get_max_minicpmv_image_tokens(ctx: InputContext): hf_config = ctx.get_hf_config(PretrainedConfig) @@ -348,10 +411,7 @@ def dummy_image_for_minicpmv(hf_config: PretrainedConfig): def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int): hf_config = ctx.get_hf_config(PretrainedConfig) - # image_feature_size = get_max_minicpmv_image_tokens(ctx) - seq_data = dummy_seq_data_for_minicpmv(seq_len) - mm_data = dummy_image_for_minicpmv(hf_config) return seq_data, mm_data @@ -376,25 +436,36 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): pattern = "(./)" image = multi_modal_data["image"] image_tags = re.findall(pattern, prompt) - assert len(image_tags) <= 1 - text_chunks = prompt.split(pattern) - new_prompt = text_chunks[0] \ - + image_processor.get_slice_image_placeholder(image.size) \ - + text_chunks[1] - new_token_ids = tokenizer.encode(new_prompt) - - llm_inputs = LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + if len(image_tags) == 0: + new_token_ids = token_ids + new_prompt = prompt + else: + if len(image_tags) > 1: + logger.warning("Multiple image input is not supported yet, " + "so any extra image tokens will be treated " + "as plain text.") + + text_chunks = prompt.split(pattern) + new_prompt = (text_chunks[0] + + image_processor.get_slice_image_placeholder(image.size) + + "".join(text_chunks[1:])) + + new_token_ids = tokenizer.encode(new_prompt) + + llm_inputs = LLMInputs( + prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data, + ) return llm_inputs -@MULTIMODAL_REGISTRY.register_image_input_mapper() -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) -@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) -class MiniCPMV(nn.Module, SupportsVision): +class MiniCPMVBaseModel(nn.Module, SupportsVision): + """ + The abstract class of MiniCPMV can only be inherited, but cannot be + instantiated. + """ def __init__( self, @@ -419,8 +490,8 @@ def __init__( self.vpm = self.init_vision_module() param_dtype = torch.get_default_dtype() self.vpm.to(dtype=param_dtype) - self.vision_dim = self.vpm.embed_dim if self.version == (2, 0) \ - else self.vpm.embeddings.embed_dim + self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else + self.vpm.embeddings.embed_dim) self.embed_dim = self.config.hidden_size self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) self.resampler.to(device="cuda", dtype=param_dtype) @@ -430,248 +501,100 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() - def init_llm(self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): - if self.version == (2, 0): - return MiniCPMModel(config, - cache_config=cache_config, - quant_config=quant_config) - elif self.version == (2, 5): - return LlamaModel(config, - cache_config=cache_config, - quant_config=quant_config) - else: - return Qwen2Model(config, - cache_config=cache_config, - quant_config=quant_config) - - def init_vision_module(self): - if self.version == (2, 0): - try: - import timm - except ImportError: - raise ImportError( - 'Please install timm==0.9.10') from ImportError - default_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.float16) - model = timm.create_model('vit_so400m_patch14_siglip_384.webli', - pretrained=False, - num_classes=0, - dynamic_img_size=True, - dynamic_img_pad=True) - torch.set_default_dtype(default_dtype) - if isinstance(model, timm.models.VisionTransformer - ) and model.attn_pool is not None: - model.attn_pool = torch.nn.Identity() - - if self.config.drop_vision_last_layer: - model.blocks = model.blocks[:-1] - elif self.version == (2, 5): - from transformers.models.idefics2.modeling_idefics2 import ( - Idefics2VisionTransformer) - model = Idefics2VisionTransformer(self.config.vision_config) - if self.config.drop_vision_last_layer: - model.encoder.layers = model.encoder.layers[:-1] - else: - from vllm.model_executor.models.na_vit import ( - SiglipVisionTransformer) - if self.config._attn_implementation == 'flash_attention_2': - self.config.vision_config._attn_implementation \ - = 'flash_attention_2' - else: - # not support sdpa - self.config.vision_config._attn_implementation = 'eager' - model = SiglipVisionTransformer(self.config.vision_config) - if self.config.drop_vision_last_layer: - model.encoder.layers = model.encoder.layers[:-1] - return model - - def init_resampler(self, embed_dim: int, vision_dim: int): - default_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.float16) - if self.version == (2, 0): - resampler = Resampler(grid_size=int( - math.sqrt(self.config.query_num)), - num_queries=None, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - adaptive=True, - version=self.version) + def get_embedding( + self, + input_ids: torch.Tensor, + image_inputs: Optional[MiniCPMVImageInputs], + ) -> Tuple[torch.Tensor, torch.Tensor]: + vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) + if hasattr(self.config, "scale_emb"): + vlm_embedding *= self.config.scale_emb + + if image_inputs is None: # No image + vision_hidden_states = torch.tensor([], device=input_ids.device) else: - resampler = Resampler(num_queries=self.config.query_num, - grid_size=None, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - adaptive=True, - version=self.version) - torch.set_default_dtype(default_dtype) - return resampler + vision_hidden_states = self.get_vision_hidden_states(image_inputs) + + # See NOTE in _parse_and_validate_inputs + image_bounds = image_inputs["image_bounds"] + if len(image_bounds) > 0: + image_indices = torch.stack([ + torch.arange(start, end, dtype=torch.long) + for start, end in image_bounds.tolist() + ]).to(vlm_embedding.device) + vlm_embedding.scatter_( + 0, + image_indices.view(-1, 1).repeat(1, + vlm_embedding.shape[-1]), + vision_hidden_states.view(-1, + vision_hidden_states.shape[-1]), + ) - def get_vision_embedding(self, - pixel_values: List[List[torch.Tensor]], - patch_attn_mask: Optional[torch.Tensor] = None, - tgt_sizes: Optional[torch.Tensor] = None, - version: Tuple[int, int] = (2, 0)): - if version == (2, 0): - res = [] - dtype = self.vpm.pos_embed.data.dtype - for pixel_value in pixel_values: - # V2.0 start - H, W = pixel_value[0].shape[-2:] - tgt_size = (math.ceil(H / self.vpm.patch_embed.patch_size[0]), - math.ceil(W / self.vpm.patch_embed.patch_size[0])) - # V2.0 end - vision_embedding = self.vpm.forward_features( - pixel_value.unsqueeze(0).type(dtype)) - if hasattr(self.vpm, 'num_prefix_tokens' - ) and self.vpm.num_prefix_tokens > 0: - vision_embedding = vision_embedding[:, self.vpm. - num_prefix_tokens:] - res.append(self.resampler(vision_embedding, tgt_size)) - return torch.vstack(res) - elif version == (2, 5): - vision_embedding = self.vpm( - pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask).last_hidden_state - vision_embedding = self.resampler(vision_embedding, tgt_sizes) - else: - vision_embedding = self.vpm(pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask, - tgt_sizes=tgt_sizes).last_hidden_state + return vlm_embedding, vision_hidden_states - def get_image_bounds(self, input_ids: torch.Tensor): + def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor: tokenizer = cached_get_tokenizer(self.config._name_or_path, trust_remote_code=True) - if not hasattr(tokenizer, "slice_start_id"): - start_cond = input_ids == tokenizer.im_start_id - end_cond = input_ids == tokenizer.im_end_id - else: - start_cond = (input_ids == tokenizer.im_start_id) | ( - input_ids == tokenizer.slice_start_id) - end_cond = (input_ids == tokenizer.im_end_id) | ( - input_ids == tokenizer.slice_end_id) + start_cond = input_ids == tokenizer.im_start_id + end_cond = input_ids == tokenizer.im_end_id + if hasattr(tokenizer, "slice_start_id"): + start_cond |= (input_ids == tokenizer.slice_start_id) + end_cond |= (input_ids == tokenizer.slice_end_id) - image_start_tokens = torch.where(start_cond)[0] + image_start_tokens, = torch.where(start_cond) image_start_tokens += 1 - image_end_tokens = torch.where(end_cond)[0] + image_end_tokens, = torch.where(end_cond) valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) + if valid_image_nums == 0: - return [] - image_bound = torch.hstack([ + return torch.zeros((0, 2), device=input_ids.device) + + return torch.hstack([ image_start_tokens[:valid_image_nums].unsqueeze(-1), image_end_tokens[:valid_image_nums].unsqueeze(-1), ]) - return image_bound - - def get_vision_hidden_states(self, data: Dict[str, - Union[List[torch.Tensor], - torch.Tensor]]): - if "vision_hidden_states" not in data: - pixel_values = data["pixel_values"] - tgt_sizes = data["tgt_sizes"] - vision_hidden_states = [] - if self.version == (2, 0): - if pixel_values is not None and len(pixel_values) > 0: - vision_hidden_states = self.get_vision_embedding( - pixel_values) - else: - vision_hidden_states = torch.tensor([]).to( - data["input_ids"].device) - else: - device = self.vpm.embeddings.position_embedding.weight.device - dtype = self.vpm.embeddings.position_embedding.weight.dtype - all_pixel_values = [ - i.flatten(end_dim=1).permute(1, 0) for i in pixel_values - ] - if all_pixel_values: - tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) - max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) - all_pixel_values = torch.nn.utils.rnn.pad_sequence( - all_pixel_values, batch_first=True, padding_value=0.0) - B, L, _ = all_pixel_values.shape - all_pixel_values = all_pixel_values.permute( - 0, 2, 1).reshape(B, 3, -1, L) - patch_attn_mask = torch.zeros((B, 1, max_patches), - dtype=torch.bool, - device=device) - if self.version == (2, 5): - for i in range(B): - patch_attn_mask[i, :tgt_sizes[i][0] * - tgt_sizes[i][1]] = True - vision_embedding = self.vpm( - all_pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask - ).last_hidden_state - else: - for i in range(B): - patch_attn_mask[i, 0, :tgt_sizes[i][0] * - tgt_sizes[i][1]] = True - vision_embedding = self.vpm( - all_pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask, - tgt_sizes=tgt_sizes).last_hidden_state - - vision_hidden_states = self.resampler( - vision_embedding, tgt_sizes) - - else: # no image - dummy_feature = [] - vision_hidden_states = dummy_feature - else: - vision_hidden_states = data["vision_hidden_states"] - - return vision_hidden_states - - def get_embedding(self, data: Dict[str, Union[List[torch.Tensor], - torch.Tensor]]): - input_ids = data["input_ids"] - - vision_hidden_states = self.get_vision_hidden_states(data) - if vision_hidden_states is not None and len(vision_hidden_states) > 0: - image_bounds = self.get_image_bounds(input_ids) - else: - image_bounds = [] - - if hasattr(self.config, 'scale_emb'): - vlm_embedding = self.llm.embed_tokens( - input_ids) * self.config.scale_emb - else: - vlm_embedding = self.llm.embed_tokens(input_ids) - vision_hidden_states = [ - i.type(vlm_embedding.dtype) if isinstance(i, torch.Tensor) else i - for i in vision_hidden_states - ] - - if len(vision_hidden_states) > 0 and len(image_bounds) > 0: - vision_hidden_states = torch.cat(vision_hidden_states, dim=0) - image_indices = torch.stack([ - torch.arange(r[0], r[1], dtype=torch.long) - for r in image_bounds - ]).to(vlm_embedding.device) - vlm_embedding.scatter_( - 0, - image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]), - vision_hidden_states.view(-1, vision_hidden_states.shape[-1])) - return vlm_embedding, vision_hidden_states - - def process_multimodal_inputs(self, inputs: Dict[str, - Union[List[torch.Tensor], - torch.Tensor]]): - pixel_values = [] - tgt_sizes = [] - for b in range(len(inputs["pixel_values"])): - pixel_values += inputs["pixel_values"][b] - tgt_sizes += inputs["tgt_sizes"][b] - return { - "pixel_values": pixel_values, - "input_ids": inputs["input_ids"], - "tgt_sizes": tgt_sizes - } + def _parse_and_validate_inputs( + self, + input_ids: torch.Tensor, + **kwargs: object, + ) -> Optional[MiniCPMVImageInputs]: + pixel_values = kwargs.pop("pixel_values", []) + tgt_sizes = kwargs.pop("tgt_sizes", []) + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(tgt_sizes, (torch.Tensor, list)): + raise ValueError("Incorrect type of target sizes. " + f"Got type: {type(tgt_sizes)}") + + if len(pixel_values) != len(tgt_sizes): + raise ValueError("Inconsistent batch lengths, found: " + f"{len(pixel_values)} vs. {len(tgt_sizes)}") + + pixel_values_flat: List[torch.Tensor] = [] + tgt_sizes_flat: List[torch.Tensor] = [] + for b in range(len(pixel_values)): + pixel_values_flat += pixel_values[b] + tgt_sizes_flat += tgt_sizes[b] + + # NOTE: Input IDs does not contain image tokens during memory profiling, + # so we allow it to be empty + if len(pixel_values_flat) != len(tgt_sizes_flat): + raise ValueError("Inconsistent flattened lengths, found: " + f"{len(pixel_values_flat)} vs. " + f"{len(tgt_sizes_flat)}") + + if len(pixel_values_flat) == 0: + return None + + return MiniCPMVImageInputs( + image_bounds=self._get_image_bounds(input_ids), + pixel_values=pixel_values_flat, + tgt_sizes=torch.stack(tgt_sizes_flat), + ) def forward( self, @@ -680,23 +603,20 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - **kwargs: object, - ): - inputs = { - "pixel_values": kwargs.pop("pixel_values", []), - "input_ids": input_ids, - "tgt_sizes": kwargs.pop("tgt_sizes", None), - } - inputs = self.process_multimodal_inputs(inputs) - - vlm_embeddings, vision_hidden_states = self.get_embedding(inputs) - - output = self.llm(input_ids=None, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=vlm_embeddings) + **kwargs: Any, + ) -> torch.Tensor: + image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) + + vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) + + output = self.llm( + input_ids=None, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=vlm_embeddings, + ) return output def compute_logits(self, hidden_states: torch.Tensor, @@ -735,13 +655,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # the checkpoint. Skip them. continue use_default_weight_loading = False - if "vpm" in name or 'resampler' in name: - # We only do sharding for language model and - # not vision model for now. + if self.is_default_weight_loading(name): use_default_weight_loading = True else: - for (param_name, weight_name, - shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue param = params_dict[name.replace(weight_name, param_name)] @@ -755,3 +672,341 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + raise NotImplementedError + + def init_vision_module(self) -> nn.Module: + raise NotImplementedError + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + raise NotImplementedError + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + raise NotImplementedError + + def is_default_weight_loading(self, name: str) -> bool: + raise NotImplementedError + + +class MiniCPMV2(MiniCPMVBaseModel): + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config, multimodal_config, cache_config, quant_config) + assert self.version == (2, 0) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + return MiniCPMModel(config, + cache_config=cache_config, + quant_config=quant_config) + + def init_vision_module(self) -> nn.Module: + # TODO :refactor this vision model + try: + import timm + except ImportError: + raise ImportError("Please install timm==0.9.10") from ImportError + with set_default_torch_dtype(torch.float16): + model = timm.create_model( + "vit_so400m_patch14_siglip_384.webli", + pretrained=False, + num_classes=0, + dynamic_img_size=True, + dynamic_img_pad=True, + ) + + if (isinstance(model, timm.models.VisionTransformer) + and model.attn_pool is not None): + model.attn_pool = torch.nn.Identity() + + if self.config.drop_vision_last_layer: + model.blocks = model.blocks[:-1] + + return model + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + with set_default_torch_dtype(torch.float16): + resampler = Resampler2( + embed_dim=embed_dim, + num_heads=embed_dim // 128, + grid_size=int(math.sqrt(self.config.query_num)), + kv_dim=vision_dim, + adaptive=True, + ) + + return resampler + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + res = [] + dtype = self.vpm.pos_embed.data.dtype + for pixel_value in pixel_values: + H, W = pixel_value[0].shape[-2:] + tgt_size = ( + math.ceil(H / self.vpm.patch_embed.patch_size[0]), + math.ceil(W / self.vpm.patch_embed.patch_size[0]), + ) + vision_embedding = self.vpm.forward_features( + pixel_value.unsqueeze(0).type(dtype)) + if (hasattr(self.vpm, "num_prefix_tokens") + and self.vpm.num_prefix_tokens > 0): + vision_embedding = vision_embedding[:, self.vpm. + num_prefix_tokens:] + res.append(self.resampler(vision_embedding, tgt_size)) + return torch.vstack(res) + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + + return self.get_vision_embedding(pixel_values) + + def is_default_weight_loading(self, name: str) -> bool: + return "resampler" in name or "vpm" in name + + +class MiniCPMV2_5(MiniCPMVBaseModel): + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config, multimodal_config, cache_config, quant_config) + assert self.version == (2, 5) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + return LlamaModel(config, + cache_config=cache_config, + quant_config=quant_config) + + def init_vision_module(self) -> nn.Module: + model = Idefics2VisionTransformer(self.config.vision_config) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + with set_default_torch_dtype(torch.float16): + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + ) + return resampler + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + vision_embedding = self.vpm(pixel_values, + patch_attention_mask=patch_attn_mask) + vision_embedding = self.resampler(vision_embedding, tgt_sizes) + return vision_embedding + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + + device = self.vpm.embeddings.position_embedding.weight.device + dtype = self.vpm.embeddings.position_embedding.weight.dtype + all_pixel_values_lst = [ + i.flatten(end_dim=1).permute(1, 0) for i in pixel_values + ] + + max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() + assert isinstance(max_patches, int) + + all_pixel_values = torch.nn.utils.rnn.pad_sequence( + all_pixel_values_lst, batch_first=True, padding_value=0.0) + B, L, _ = all_pixel_values.shape + all_pixel_values = all_pixel_values.permute(0, 2, + 1).reshape(B, 3, -1, L) + + patch_attn_mask = torch.zeros((B, 1, max_patches), + dtype=torch.bool, + device=device) + for i in range(B): + patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True + + return self.get_vision_embedding(all_pixel_values.type(dtype), + patch_attn_mask, tgt_sizes) + + def is_default_weight_loading(self, name: str) -> bool: + return "resampler" in name + + +# NOTE: Currently, information about this model is unavailable. We are +# temporarily using `MiniCPMVQwen2` as it's name. The name may need +# to be modified in the future. +class MiniCPMVQwen2(MiniCPMVBaseModel): + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config, multimodal_config, cache_config, quant_config) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + return Qwen2Model(config, + cache_config=cache_config, + quant_config=quant_config) + + def init_vision_module(self) -> nn.Module: + # A custom version of SiglipVisionTransformer, won't work with TP + from vllm.model_executor.models.na_vit import SiglipVisionTransformer + + if self.config._attn_implementation == "flash_attention_2": + self.config.vision_config._attn_implementation = "flash_attention_2" + else: + # not support sdpa + self.config.vision_config._attn_implementation = "eager" + model = SiglipVisionTransformer(self.config.vision_config) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + with set_default_torch_dtype(torch.float16): + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + ) + + return resampler + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + vision_embedding = self.vpm( + pixel_values, + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes, + ).last_hidden_state + return vision_embedding + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + + device = self.vpm.embeddings.position_embedding.weight.device + dtype = self.vpm.embeddings.position_embedding.weight.dtype + all_pixel_values_lst = [ + i.flatten(end_dim=1).permute(1, 0) for i in pixel_values + ] + + max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() + assert isinstance(max_patches, int) + + all_pixel_values = torch.nn.utils.rnn.pad_sequence( + all_pixel_values_lst, batch_first=True, padding_value=0.0) + B, L, _ = all_pixel_values.shape + all_pixel_values = all_pixel_values.permute(0, 2, + 1).reshape(B, 3, -1, L) + + patch_attn_mask = torch.zeros((B, 1, max_patches), + dtype=torch.bool, + device=device) + for i in range(B): + patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True + vision_embedding = self.vpm( + all_pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes, + ).last_hidden_state + + return self.resampler(vision_embedding, tgt_sizes) + + def is_default_weight_loading(self, name: str) -> bool: + return "resampler" in name or "vpm" in name + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) +@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) +class MiniCPMV(MiniCPMVBaseModel): + """ + Different versions of MiniCPMV use different visual encoders and LLMs, + which is not conducive to the current integration logic of LoRA and + bitsandbytes in vLLM. Therefore, it is necessary to separate them. + """ + + def __new__( + cls, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + if not hasattr(config, "version"): + if config.hidden_size == 2304 and config.query_num == 64: + version = (2, 0) + else: + version = (2, 5) + else: + version = str(config.version).split(".") + version = tuple([int(x) for x in version]) + # Dispatch class based on version + if version == (2, 0): + instance_class = MiniCPMV2 + elif version == (2, 5): + instance_class = MiniCPMV2_5 + else: + instance_class = MiniCPMVQwen2 + return instance_class(config, multimodal_config, cache_config, + quant_config) diff --git a/vllm/model_executor/models/na_vit.py b/vllm/model_executor/models/na_vit.py index 871e4128b66e..1d6f26f0d4fb 100644 --- a/vllm/model_executor/models/na_vit.py +++ b/vllm/model_executor/models/na_vit.py @@ -100,7 +100,7 @@ def _get_unpad_data(attention_mask): indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, From b1c9aa3daa7dcd981f0f77231b46883624b72dd0 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sun, 4 Aug 2024 16:13:18 +0200 Subject: [PATCH 0056/3246] [Bugfix] [SpecDecode] Default speculative_draft_tensor_parallel_size to 1 when using MLPSpeculator (#7105) Signed-off-by: Thomas Parnell --- vllm/config.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 028f4eed8f4a..0524514f6633 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1068,7 +1068,7 @@ def maybe_create_spec_config( draft_parallel_config = ( SpeculativeConfig.create_draft_parallel_config( target_parallel_config, - speculative_draft_tensor_parallel_size)) + speculative_draft_tensor_parallel_size, draft_hf_config)) if num_speculative_tokens is None: raise ValueError( @@ -1136,15 +1136,23 @@ def _maybe_override_draft_max_model_len( @staticmethod def create_draft_parallel_config( target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: Optional[int] + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig, ) -> ParallelConfig: """Create a parallel config for use by the draft worker. This is mostly a copy of the target parallel config, except the tp_size. """ if speculative_draft_tensor_parallel_size is None: - speculative_draft_tensor_parallel_size = \ - target_parallel_config.tensor_parallel_size + if draft_hf_config.model_type == "mlp_speculator": + speculative_draft_tensor_parallel_size = 1 + if target_parallel_config.tensor_parallel_size > 1: + logger.warning( + "MLPSpeculator cannot currently be run with tp>1; " + "setting speculative_draft_tensor_parallel_size=1") + else: + speculative_draft_tensor_parallel_size = \ + target_parallel_config.tensor_parallel_size elif speculative_draft_tensor_parallel_size != 1: # TODO(wooyeon): allow tp values larger than 1 raise ValueError( From 16a1cc9bb2b4bba82d78f329e5a89b44a5523ac8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 4 Aug 2024 11:31:51 -0700 Subject: [PATCH 0057/3246] [misc][distributed] improve libcudart.so finding (#7127) --- .../device_communicators/cuda_wrapper.py | 44 +++++++++---------- .../custom_all_reduce_utils.py | 4 +- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py index 5cac3c1d57bc..9c7f41a1f9d6 100644 --- a/vllm/distributed/device_communicators/cuda_wrapper.py +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -4,9 +4,6 @@ """ import ctypes -import glob -import os -import sys from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -36,24 +33,25 @@ class Function: argtypes: List[Any] -def get_pytorch_default_cudart_library_path() -> str: - # code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa - lib_folder = "cuda_runtime" - lib_name = "libcudart.so.*[0-9]" - lib_path = None - for path in sys.path: - nvidia_path = os.path.join(path, "nvidia") - if not os.path.exists(nvidia_path): - continue - candidate_lib_paths = glob.glob( - os.path.join(nvidia_path, lib_folder, "lib", lib_name)) - if candidate_lib_paths and not lib_path: - lib_path = candidate_lib_paths[0] - if lib_path: - break - if not lib_path: - raise ValueError(f"{lib_name} not found in the system path {sys.path}") - return lib_path +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found = False + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found = True + break + if not found: + # the library is not loaded in the current process + return None + start = line.index("/") + path = line[start:].strip() + return path class CudaRTLibrary: @@ -100,7 +98,9 @@ class CudaRTLibrary: def __init__(self, so_file: Optional[str] = None): if so_file is None: - so_file = get_pytorch_default_cudart_library_path() + so_file = find_loaded_library("libcudart.so") + assert so_file is not None, \ + "libcudart.so is not loaded in the current process" if so_file not in CudaRTLibrary.path_to_library_cache: lib = ctypes.CDLL(so_file) CudaRTLibrary.path_to_library_cache[so_file] = lib diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py index d27d7ee9a249..37ae94c671e3 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -145,6 +145,7 @@ def can_actually_p2p( p_tgt.start() p_src.join() p_tgt.join() + assert p_src.exitcode == 0 and p_tgt.exitcode == 0 result: List[bool] = [] for src, tgt in zip(batch_src, batch_tgt): a = result_queue.get() @@ -221,7 +222,8 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: # wrap raised exception to provide more information raise RuntimeError( f"Error happened when batch testing " - f"peer-to-peer access from {batch_src} to {batch_tgt}") from e + f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" + f"{returned.stderr.decode()}") from e result = pickle.loads(returned.stdout) for _i, _j, r in zip(batch_src, batch_tgt, result): cache[f"{_i}->{_j}"] = r From f80ab3521ca2aa74e121e26a27b87da7a1065939 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 5 Aug 2024 06:37:08 +0800 Subject: [PATCH 0058/3246] Clean up remaining Punica C information (#7027) --- .github/workflows/clang-format.yml | 6 ------ cmake/utils.cmake | 2 +- format.sh | 6 ------ vllm/config.py | 2 +- vllm/lora/layers.py | 2 +- 5 files changed, 3 insertions(+), 15 deletions(-) diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml index e9b6e28fa6bc..79b85d8cad0d 100644 --- a/.github/workflows/clang-format.yml +++ b/.github/workflows/clang-format.yml @@ -30,12 +30,6 @@ jobs: run: | EXCLUDES=( 'csrc/moe/topk_softmax_kernels.cu' - 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' - 'csrc/punica/bgmv/bgmv_config.h' - 'csrc/punica/bgmv/bgmv_impl.cuh' - 'csrc/punica/bgmv/vec_dtypes.cuh' - 'csrc/punica/punica_ops.cu' - 'csrc/punica/type_convert.h' ) find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 4869cad54113..69998b45be70 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -181,7 +181,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) # # The torch cmake setup hardcodes the detected architecture flags in # `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it - # can't modified on a per-target basis, e.g. for the `punica` extension. + # can't modified on a per-target basis. # So, all the `-gencode` flags need to be extracted and removed from # `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method. # Since it's not possible to use `target_compiler_options` for adding target diff --git a/format.sh b/format.sh index abc688c702aa..baaebc811d40 100755 --- a/format.sh +++ b/format.sh @@ -242,12 +242,6 @@ echo 'vLLM isort: Done' # NOTE: Keep up to date with .github/workflows/clang-format.yml CLANG_FORMAT_EXCLUDES=( 'csrc/moe/topk_softmax_kernels.cu' - 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' - 'csrc/punica/bgmv/bgmv_config.h' - 'csrc/punica/bgmv/bgmv_impl.cuh' - 'csrc/punica/bgmv/vec_dtypes.cuh' - 'csrc/punica/punica_ops.cu' - 'csrc/punica/type_convert.h' ) # Format specified files with clang-format diff --git a/vllm/config.py b/vllm/config.py index 0524514f6633..35945e34452d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1304,7 +1304,7 @@ class LoRAConfig: long_lora_scaling_factors: Optional[Tuple[float]] = None def __post_init__(self): - # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + # TODO: Increase the range of rank possible_max_ranks = (8, 16, 32, 64) possible_lora_extra_vocab_size = (0, 256, 512) if self.max_lora_rank not in possible_max_ranks: diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 42ec99e6ea2c..d3978ff6f4ff 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1073,7 +1073,7 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: - # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + # TODO: Verify if this condition can be relaxed if 32000 < self.base_layer.vocab_size > 128512: raise ValueError("When using LoRA, vocab size must be " "32000 >= vocab_size <= 128512") From 7b86e7c9cd6541abdf5d083b0a8a98ee667a91d1 Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Mon, 5 Aug 2024 09:23:17 +0800 Subject: [PATCH 0059/3246] [Model] Add multi-image support for minicpmv (#7122) Co-authored-by: hezhihui Co-authored-by: Cyrus Leung --- tests/conftest.py | 5 +- tests/models/test_minicpmv.py | 146 ++++++++++++++++++++++--- vllm/model_executor/models/minicpmv.py | 56 ++++++---- vllm/multimodal/image.py | 2 +- 4 files changed, 172 insertions(+), 37 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 999ca60d07a4..c7a349f1e9e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import os import sys from collections import UserList -from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar +from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union import pytest import torch @@ -508,7 +508,8 @@ def generate_greedy_logprobs( prompts: List[str], max_tokens: int, num_logprobs: int, - images: Optional[List[Image.Image]] = None, + images: Optional[Union[List[Image.Image], + List[List[Image.Image]]]] = None, stop_token_ids: Optional[List[int]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: greedy_logprobs_params = SamplingParams(temperature=0.0, diff --git a/tests/models/test_minicpmv.py b/tests/models/test_minicpmv.py index c57f0f8c0854..c3b2a7bcbaaf 100644 --- a/tests/models/test_minicpmv.py +++ b/tests/models/test_minicpmv.py @@ -14,6 +14,18 @@ pytestmark = pytest.mark.vlm + +class NestedInputs(UserDict): + + def __init__(self, model_inputs: BatchFeature): + super().__init__({"model_inputs": model_inputs}) + + self.model_inputs = model_inputs + + def to(self, device: torch.types.Device): + return NestedInputs(self.model_inputs.to(device)) + + # The image token is placed before "user" on purpose so that the test can pass HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -23,7 +35,7 @@ "cherry_blossom": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \ "(./)\nWhat is the season?<|eot_id|>" \ - "<|start_header_id|>assistant<|end_header_id|>\n\n" + "<|start_header_id|>assistant<|end_header_id|>\n\n", }) models = ["openbmb/MiniCPM-Llama3-V-2_5"] @@ -94,22 +106,10 @@ def run_test( ] with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad(): - - class NestedInputs(UserDict): - - def __init__(self, model_inputs: BatchFeature): - super().__init__({"model_inputs": model_inputs}) - - self.model_inputs = model_inputs - - def to(self, device: torch.types.Device): - return NestedInputs(self.model_inputs.to(device)) - hf_processor = hf_model.processor hf_model.processor = lambda **kw: NestedInputs( hf_processor(**kw) # type: ignore ) - hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, @@ -161,3 +161,123 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, num_logprobs=num_logprobs, tensor_parallel_size=1, ) + + +HF_MULTIIMAGE_IMAGE_PROMPT = \ + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \ + "(./)\n(./)\n" \ + "Describe these images.<|eot_id|>" \ + "<|start_header_id|>assistant<|end_header_id|>\n\n" + + +def run_multi_image_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + images = [asset.pil_image for asset in image_assets] + + inputs_per_case = [ + ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors]) + ] + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + max_model_len=4096, + max_num_seqs=1, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + tokenizer = vllm_model.model.get_tokenizer() + stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id] + vllm_outputs_per_case = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + stop_token_ids=stop_token_ids) + for prompts, images in inputs_per_case + ] + + with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad(): + hf_processor = hf_model.processor + hf_model.processor = lambda **kw: NestedInputs( + hf_processor(**kw) # type: ignore + ) + hf_outputs_per_case = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + tokenizer=tokenizer) + for prompts, images in inputs_per_case + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, + vllm_outputs_per_case): + check_logprobs_close( + outputs_0_lst=[ + trunc_hf_output(hf_output) for hf_output in hf_outputs + ], + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, + size_factors, dtype: str, max_tokens: int, + num_logprobs: int) -> None: + run_multi_image_test( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 095bb49f6ba7..038825959562 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -392,6 +392,20 @@ def forward(self, x: torch.Tensor, return x +def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: + version_float = getattr(config, "version", None) + + # The old configs do not include version number + # TODO: Remove this after the HF repos are updated + if version_float is None: + if config.hidden_size == 2304 and config.query_num == 64: + return (2, 0) + return (2, 5) + + version_str = str(version_float) + return tuple(int(x) for x in version_str.split(".")) + + def get_max_minicpmv_image_tokens(ctx: InputContext): hf_config = ctx.get_hf_config(PretrainedConfig) return getattr(hf_config, "query_num", 64) @@ -421,36 +435,43 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs - model_config = ctx.model_config - + version = get_version_by_config(model_config.hf_config) tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) + image_processor = cached_get_image_processor(model_config.tokenizer) + + def get_placeholder(image_size: Tuple[int, int], num_image: int): + if version == (2, 0) or version == (2, 5): + return image_processor. \ + get_slice_image_placeholder(image_size) + return image_processor. \ + get_slice_image_placeholder(image_size, num_image) prompt = llm_inputs.get("prompt") if prompt is None: token_ids = llm_inputs.get("prompt_token_ids") prompt = tokenizer.decode(token_ids) - image_processor = cached_get_image_processor(model_config.tokenizer) pattern = "(./)" - image = multi_modal_data["image"] + images = multi_modal_data["image"] + if isinstance(images, Image.Image): + images = [images] image_tags = re.findall(pattern, prompt) if len(image_tags) == 0: new_token_ids = token_ids new_prompt = prompt else: - if len(image_tags) > 1: - logger.warning("Multiple image input is not supported yet, " - "so any extra image tokens will be treated " - "as plain text.") - text_chunks = prompt.split(pattern) - new_prompt = (text_chunks[0] + - image_processor.get_slice_image_placeholder(image.size) + - "".join(text_chunks[1:])) - + new_prompt_chunks: List[str] = [] + for i in range(len(images)): + new_prompt_chunks += [ + text_chunks[i], + get_placeholder(images[i].size, i) + ] + new_prompt_chunks.append(text_chunks[-1]) + new_prompt = "".join(new_prompt_chunks) new_token_ids = tokenizer.encode(new_prompt) llm_inputs = LLMInputs( @@ -478,14 +499,7 @@ def __init__( self.config = config self.multimodal_config = multimodal_config - if not hasattr(self.config, "version"): - if self.config.hidden_size == 2304 and self.config.query_num == 64: - self.version = (2, 0) - else: - self.version = (2, 5) - else: - self.version = str(self.config.version).split(".") - self.version = tuple([int(x) for x in self.version]) + self.version = get_version_by_config(self.config) self.llm = self.init_llm(config, cache_config, quant_config) self.vpm = self.init_vision_module() param_dtype = torch.get_default_dtype() diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 3b37ce9149fb..b6a3909e9563 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -113,7 +113,7 @@ def _get_hf_image_processor(self, model_config: ModelConfig): def _default_input_mapper(self, ctx: InputContext, data: object) -> MultiModalInputs: model_config = ctx.model_config - if isinstance(data, Image.Image): + if isinstance(data, (Image.Image, list)): image_processor = self._get_hf_image_processor(model_config) if image_processor is None: raise RuntimeError("No HuggingFace processor is available " From cc08fc7225616aeb6709a2e75e5ac47ace124985 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 5 Aug 2024 11:40:51 +0800 Subject: [PATCH 0060/3246] [Frontend] Reapply "Factor out code for running uvicorn" (#7095) --- vllm/entrypoints/api_server.py | 77 ++++++++++++++++-------- vllm/entrypoints/launcher.py | 46 +++++++++++++++ vllm/entrypoints/openai/api_server.py | 84 +++++++++------------------ 3 files changed, 125 insertions(+), 82 deletions(-) create mode 100644 vllm/entrypoints/launcher.py diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 66941442c8c9..672382717d11 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -5,21 +5,23 @@ We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead. """ - +import asyncio import json import ssl -from typing import AsyncGenerator +from argparse import Namespace +from typing import Any, AsyncGenerator, Optional -import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.launcher import serve_http from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid +from vllm.version import __version__ as VLLM_VERSION logger = init_logger("vllm.entrypoints.api_server") @@ -81,6 +83,53 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return JSONResponse(ret) +def build_app(args: Namespace) -> FastAPI: + global app + + app.root_path = args.root_path + return app + + +async def init_app( + args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, +) -> FastAPI: + app = build_app(args) + + global engine + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = (llm_engine + if llm_engine is not None else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER)) + + return app + + +async def run_server(args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, + **uvicorn_kwargs: Any) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + app = await init_app(args, llm_engine) + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + await shutdown_task + + if __name__ == "__main__": parser = FlexibleArgumentParser() parser.add_argument("--host", type=str, default=None) @@ -105,25 +154,5 @@ async def stream_results() -> AsyncGenerator[bytes, None]: parser.add_argument("--log-level", type=str, default="debug") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.API_SERVER) - - app.root_path = args.root_path - logger.info("Available routes are:") - for route in app.routes: - if not hasattr(route, 'methods'): - continue - methods = ', '.join(route.methods) - logger.info("Route: %s, Methods: %s", route.path, methods) - - uvicorn.run(app, - host=args.host, - port=args.port, - log_level=args.log_level, - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs) + asyncio.run(run_server(args)) diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py new file mode 100644 index 000000000000..00826762f76a --- /dev/null +++ b/vllm/entrypoints/launcher.py @@ -0,0 +1,46 @@ +import asyncio +import signal +from typing import Any + +import uvicorn +from fastapi import FastAPI + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): + logger.info("Available routes are:") + for route in app.routes: + methods = getattr(route, "methods", None) + path = getattr(route, "path", None) + + if methods is None or path is None: + continue + + logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + + config = uvicorn.Config(app, **uvicorn_kwargs) + server = uvicorn.Server(config) + + loop = asyncio.get_running_loop() + + server_task = loop.create_task(server.serve()) + + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() + + async def dummy_shutdown() -> None: + pass + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + return dummy_shutdown() + except asyncio.CancelledError: + logger.info("Gracefully stopping http server") + return server.shutdown() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e330ee81f7e4..a0190f3d66b1 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -2,15 +2,13 @@ import importlib import inspect import re -import signal +from argparse import Namespace from contextlib import asynccontextmanager from http import HTTPStatus from multiprocessing import Process from typing import AsyncIterator, Set -import fastapi -import uvicorn -from fastapi import APIRouter, Request +from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -22,6 +20,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.protocol import AsyncEngineClient +from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block @@ -71,7 +70,7 @@ def model_is_embedding(model_name: str) -> bool: @asynccontextmanager -async def lifespan(app: fastapi.FastAPI): +async def lifespan(app: FastAPI): async def _force_log(): while True: @@ -135,7 +134,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: router = APIRouter() -def mount_metrics(app: fastapi.FastAPI): +def mount_metrics(app: FastAPI): # Add prometheus asgi middleware to route /metrics requests metrics_route = Mount("/metrics", make_asgi_app()) # Workaround for 307 Redirect for /metrics @@ -225,8 +224,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) -def build_app(args): - app = fastapi.FastAPI(lifespan=lifespan) +def build_app(args: Namespace) -> FastAPI: + app = FastAPI(lifespan=lifespan) app.include_router(router) app.root_path = args.root_path @@ -274,11 +273,10 @@ async def authentication(request: Request, call_next): return app -async def build_server( +async def init_app( async_engine_client: AsyncEngineClient, - args, - **uvicorn_kwargs, -) -> uvicorn.Server: + args: Namespace, +) -> FastAPI: app = build_app(args) if args.served_model_name is not None: @@ -334,62 +332,31 @@ async def build_server( ) app.root_path = args.root_path - logger.info("Available routes are:") - for route in app.routes: - if not hasattr(route, 'methods'): - continue - methods = ', '.join(route.methods) - logger.info("Route: %s, Methods: %s", route.path, methods) - - config = uvicorn.Config( - app, - host=args.host, - port=args.port, - log_level=args.uvicorn_log_level, - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs, - **uvicorn_kwargs, - ) - - return uvicorn.Server(config) + return app async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - shutdown_task = None async with build_async_engine_client(args) as async_engine_client: - - server = await build_server( - async_engine_client, - args, + app = await init_app(async_engine_client, args) + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, **uvicorn_kwargs, ) - loop = asyncio.get_running_loop() - - server_task = loop.create_task(server.serve()) - - def signal_handler() -> None: - # prevents the uvicorn signal handler to exit early - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - except asyncio.CancelledError: - logger.info("Gracefully stopping http server") - shutdown_task = server.shutdown() - - if shutdown_task: - # NB: Await server shutdown only after the backend context is exited - await shutdown_task + # NB: Await server shutdown only after the backend context is exited + await shutdown_task if __name__ == "__main__": @@ -399,4 +366,5 @@ def signal_handler() -> None: description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) args = parser.parse_args() + asyncio.run(run_server(args)) From c0d8f1636c58f5464e512eaabfed5aa29f2c5b7d Mon Sep 17 00:00:00 2001 From: Jungho Christopher Cho Date: Mon, 5 Aug 2024 15:22:12 +0900 Subject: [PATCH 0061/3246] [Model] SiglipVisionModel ported from transformers (#6942) Co-authored-by: Roger Wang --- examples/offline_inference_vision_language.py | 3 +- vllm/model_executor/models/paligemma.py | 79 +-- vllm/model_executor/models/siglip.py | 621 ++++++++++++++++++ 3 files changed, 650 insertions(+), 53 deletions(-) create mode 100644 vllm/model_executor/models/siglip.py diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 846246a2062a..ce9dc9e457c0 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -65,7 +65,8 @@ def run_phi3v(question): # PaliGemma def run_paligemma(question): - prompt = question + # PaliGemma has special prompt format for VQA + prompt = "caption en" llm = LLM(model="google/paligemma-3b-mix-224") return llm, prompt diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index fe91611cd30f..9ba53b8b59a2 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,9 +1,8 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict import torch -from PIL import Image from torch import nn -from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel +from transformers import PaliGemmaConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -18,9 +17,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsVision +from .siglip import (SiglipVisionModel, dummy_image_for_siglip, + dummy_seq_data_for_siglip, get_max_siglip_image_tokens) from .utils import merge_vision_embeddings logger = init_logger(__name__) @@ -32,55 +33,22 @@ def get_max_paligemma_image_tokens(ctx: InputContext): hf_config = ctx.get_hf_config(PaliGemmaConfig) - text_config = hf_config.text_config - - return text_config.num_image_tokens - - -def dummy_seq_data_for_paligemma( - hf_config: PaliGemmaConfig, - seq_len: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): - if image_feature_size_override is None: - image_feature_size = hf_config.text_config.num_image_tokens - else: - image_feature_size = image_feature_size_override - - token_ids = [image_token_id] * image_feature_size - token_ids += [0] * (seq_len - image_feature_size) - return SequenceData(token_ids) - - -def dummy_image_for_paligemma( - hf_config: SiglipVisionConfig, - *, - image_width_override: Optional[int] = None, - image_height_override: Optional[int] = None, -): - width = height = hf_config.image_size - if image_width_override is not None: - width = image_width_override - if image_height_override is not None: - height = image_height_override + vision_config = hf_config.vision_config - image = Image.new("RGB", (width, height), color=0) - return {"image": image} + return get_max_siglip_image_tokens(vision_config) def dummy_data_for_paligemma(ctx: InputContext, seq_len: int): hf_config = ctx.get_hf_config(PaliGemmaConfig) vision_config = hf_config.vision_config - seq_data = dummy_seq_data_for_paligemma( - hf_config, + seq_data = dummy_seq_data_for_siglip( + vision_config, seq_len, image_token_id=hf_config.image_token_index, ) - mm_data = dummy_image_for_paligemma(vision_config) + mm_data = dummy_image_for_siglip(vision_config) return seq_data, mm_data @@ -208,30 +176,37 @@ def _parse_and_validate_image_input( data=self._validate_pixel_values(pixel_values), ) - def _image_pixels_to_features(self, vision_tower: SiglipVisionModel, - pixel_values: torch.Tensor) -> torch.Tensor: + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + ) -> torch.Tensor: target_dtype = vision_tower.get_input_embeddings().weight.dtype - image_outputs = vision_tower(pixel_values.to(dtype=target_dtype), - output_hidden_states=True) - - selected_image_features = image_outputs.last_hidden_state + image_features = vision_tower(pixel_values.to(dtype=target_dtype)) - return selected_image_features + return image_features def _process_image_pixels( - self, inputs: PaliGemmaImagePixelInputs) -> torch.Tensor: + self, + inputs: PaliGemmaImagePixelInputs, + ) -> torch.Tensor: assert self.vision_tower is not None pixel_values = inputs["data"] - return self._image_pixels_to_features(self.vision_tower, pixel_values) + return self._image_pixels_to_features( + self.vision_tower, + pixel_values, + ) def _process_image_input( - self, image_input: PaliGemmaImageInputs) -> torch.Tensor: + self, + image_input: PaliGemmaImageInputs, + ) -> torch.Tensor: assert self.vision_tower is not None - image_features = self._process_image_pixels(image_input) + image_features = self._process_image_pixels(image_input, ) return self.multi_modal_projector(image_features) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py new file mode 100644 index 000000000000..6faef45c9a6d --- /dev/null +++ b/vllm/model_executor/models/siglip.py @@ -0,0 +1,621 @@ +"""Implementation of SiglipVisionModel intended to be only used +within a vision language model.""" + +import math +from typing import Optional, Tuple + +import torch +from PIL import Image +from torch import nn +from transformers import SiglipConfig, SiglipVisionConfig +from transformers.models.siglip.modeling_siglip import SiglipAttention +from vllm_flash_attn import flash_attn_func +from xformers.ops import memory_efficient_attention + +from vllm.config import ModelConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import LLMInputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.multimodal.image import (cached_get_tokenizer, + repeat_and_pad_image_tokens) +from vllm.sequence import SequenceData + + +def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: + assert image_size % patch_size == 0 + return image_size // patch_size + + +def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int: + grid_length = get_siglip_patch_grid_length(image_size=image_size, + patch_size=patch_size) + return grid_length * grid_length + + +def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int: + return get_siglip_num_patches(image_size=hf_config.image_size, + patch_size=hf_config.patch_size) + + +def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int: + return get_siglip_image_feature_size(hf_config) + + +def dummy_seq_data_for_siglip( + hf_config: SiglipVisionConfig, + seq_len: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = get_siglip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + token_ids = [image_token_id] * image_feature_size + token_ids += [0] * (seq_len - image_feature_size) + return SequenceData(token_ids) + + +def dummy_image_for_siglip( + hf_config: SiglipVisionConfig, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + width = height = hf_config.image_size + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return {"image": image} + + +def input_processor_for_siglip( + model_config: ModelConfig, + hf_config: SiglipVisionConfig, + llm_inputs: LLMInputs, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + if image_feature_size_override is None: + image_feature_size = get_siglip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + new_prompt, new_token_ids = repeat_and_pad_image_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + image_token_id=image_token_id, + repeat_count=image_feature_size, + ) + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs( + prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data, + ) + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa +class SiglipVisionEmbeddings(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + self.position_embedding = VocabParallelEmbedding( + self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions, dtype=torch.int64).expand( + (1, -1)), + persistent=False, + ) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, + width: int) -> torch.Tensor: + """ + This method is an adapted method for SigLIP (due to SigLIP not having + class embedding unlike other ViTs) that allows the model to interpolate + the pre-trained position encodings such that it can be usable on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_patches = embeddings.shape[1] + num_positions = position_embeddings.shape[1] + if num_patches == num_positions and height == width: + return position_embeddings + + dim = embeddings.shape[-1] + height = height // self.patch_size + width = width // self.patch_size + # we add a small number to avoid floating point error + # in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + + patch_pos_embed = position_embeddings.reshape( + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), + dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=( + height / math.sqrt(num_positions), + width / math.sqrt(num_positions), + ), + mode="bicubic", + align_corners=False, + ) + if (int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1]): + raise ValueError("Width or height does not match with " + "the interpolated position embeddings") + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False) -> torch.Tensor: + _, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding( + self.position_ids) + return embeddings + + +# NOTE: Not used - kept for later when we TP the ViT +# TODO(ChristopherCho): Implement TP version of Attention +class SiglipTPAttention(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + if self.total_num_heads % tp_size != 0: + raise ValueError( + f"Number of attention heads ({self.total_num_heads}) " + "must be divisible by the tensor model parallel size" + f" ({tp_size}).") + + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.embed_dim // self.total_num_heads + if self.head_dim * self.total_num_heads != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads (got " + "`embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.qkv_size = self.num_heads * self.head_dim + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + ) + + self.attn_fn = self._basic_attention_forward + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + batch_size, q_len, _ = hidden_states.size() + + qkv_states, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.split( + [self.qkv_size] * 3, dim=-1) + + attn_output = self.attn_fn( + q=query_states, + k=key_states, + v=value_states, + batch_size=batch_size, + q_len=q_len, + ) + + attn_output, _ = self.out_proj(attn_output) + return attn_output + + def _basic_attention_forward(self, q, k, v, batch_size, q_len): + q = q.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + k = k.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + v = v.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + + k_v_seq_len = k.shape[-2] + attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale + + if attn_weights.size() != ( + batch_size, + self.num_heads, + q_len, + k_v_seq_len, + ): + raise ValueError( + "Attention weights should be of size " + f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}") + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to(q.dtype) + attn_weights = nn.functional.dropout(attn_weights, + p=self.dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, v) + + if attn_output.size() != ( + batch_size, + self.num_heads, + q_len, + self.head_dim, + ): + raise ValueError( + "`attn_output` should be of size " + f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + return attn_output + + +# NOTE: Not used - kept for later when we TP the ViT +# TODO(ChristopherCho): flash_attn_func is not working properly. +# It constantly throws a CUDA error. +class SiglipFlashAttention2(SiglipTPAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attn_fn = self._flash_attention_forward + + # Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449 + # and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133 + def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args, + **kwargs): + """Implements the multihead softmax attention. + Arguments + --------- + q, k, v: The tensor containing the + query, key, and value. (B, S, H, D) + """ + + q = q.view(batch_size, q_len, self.num_heads, self.head_dim) + k = k.view(batch_size, q_len, self.num_heads, self.head_dim) + v = v.view(batch_size, q_len, self.num_heads, self.head_dim) + + attn_output = flash_attn_func( + q, + k, + v, + dropout_p=self.dropout, + causal=False, + ) + + attn_output = attn_output.reshape(batch_size, q_len, + self.embed_dim).contiguous() + + return attn_output + + +# NOTE: Not used - kept for later when we TP the ViT +class SiglipSdpaAttention(SiglipTPAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + self.attn_fn = self._sdpa_attention_forward + + def _sdpa_attention_forward(self, q, k, v, batch_size, q_len): + q = q.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + k = k.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + v = v.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, q_len, self.embed_dim) + + return attn_output + + +# NOTE: Not used - kept for later when we TP the ViT +class SiglipxFormersAttention(SiglipTPAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attn_fn = self._xformers_attention_forward + + def _xformers_attention_forward(self, q, k, v, batch_size, q_len): + q = q.view(batch_size, q_len, self.num_heads, self.head_dim) + k = k.view(batch_size, q_len, self.num_heads, self.head_dim) + v = v.view(batch_size, q_len, self.num_heads, self.head_dim) + + attn_output = memory_efficient_attention(q, + k, + v, + p=0.0, + scale=self.scale) + attn_output = attn_output.reshape(batch_size, q_len, + self.embed_dim).contiguous() + + return attn_output + + +# NOTE: Not used - kept for later when we TP the ViT +SIGLIP_ATTENTION_CLASSES = { + "eager": SiglipTPAttention, + "flash_attention_2": SiglipFlashAttention2, + "sdpa": SiglipSdpaAttention, + "xformers": SiglipxFormersAttention, +} + + +class SiglipMLP(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + + # For quantization, we require the hidden size to be a multiple of 64 + quantizable = (config.hidden_size % 64 == 0 + and config.intermediate_size % 64 == 0) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config if quantizable else None, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config if quantizable else None, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(nn.Module): + + def __init__( + self, + config: SiglipConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = config.hidden_size + + # TODO(ChristopherCho): use TP'ed Attention block + self.self_attn = SiglipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP( + config, + quant_config=quant_config, + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> Tuple[torch.Tensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, None + + +class SiglipEncoder(nn.Module): + + def __init__( + self, + config: SiglipConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + SiglipEncoderLayer( + config, + quant_config=quant_config, + ) for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> Tuple: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states, _ = encoder_layer(hidden_states) + + return hidden_states + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention + self.attention = torch.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config=config, quant_config=quant_config) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class SiglipVisionTransformer(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder( + config, + quant_config=quant_config, + ) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + self.use_head = (True if not hasattr(config, "vision_use_head") else + config.vision_use_head) + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead( + config=config, quant_config=quant_config) + + def forward( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = True, + ) -> torch.Tensor: + hidden_states = self.embeddings( + pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + + last_hidden_state = self.post_layernorm(encoder_outputs) + + # TODO: add this back when pooled_output is used in inference + # if self.use_head: + # pooled_output = self.head(last_hidden_state) + + return last_hidden_state + + +class SiglipVisionModel(nn.Module): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.vision_model = SiglipVisionTransformer( + config, + quant_config, + ) + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + return self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) From 82a1b1a82b1fbb454c82a9ef95730b929c9b270c Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Mon, 5 Aug 2024 01:46:44 -0700 Subject: [PATCH 0062/3246] [Speculative decoding] Add periodic log with time spent in proposal/scoring/verification (#6963) --- tests/spec_decode/test_spec_decode_worker.py | 68 ++++++++++++++------ vllm/config.py | 8 ++- vllm/engine/arg_utils.py | 1 + vllm/spec_decode/spec_decode_worker.py | 68 ++++++++++++++++---- vllm/spec_decode/util.py | 15 +++++ 5 files changed, 125 insertions(+), 35 deletions(-) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 671c9bef294f..9ae1b4bc40f0 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int, target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) + draft_worker, + target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector) exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) @@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int, set_random_seed(1) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) + draft_worker, + target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector) worker.init_device() vocab_size = 32_000 @@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - metrics_collector) + worker = SpecDecodeWorker(draft_worker, + target_worker, + spec_decode_sampler, + disable_logprobs=False, + metrics_collector=metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int, set_random_seed(1) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) - worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - metrics_collector) + worker = SpecDecodeWorker(draft_worker, + target_worker, + spec_decode_sampler, + disable_logprobs=False, + metrics_collector=metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int, set_random_seed(1) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), False, - metrics_collector) + proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=mock_spec_decode_sampler( + acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector, + ) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int, set_random_seed(1) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), False, - metrics_collector) + proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=mock_spec_decode_sampler( + acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector, + ) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str): spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - False, metrics_collector) + worker = SpecDecodeWorker( + proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=spec_decode_sampler, + disable_logprobs=False, + metrics_collector=metrics_collector, + ) worker.init_device() draft_worker.init_device.assert_called_once() @@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method): target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) + worker = SpecDecodeWorker(proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=mock_spec_decode_sampler( + acceptance_sampler_method), + metrics_collector=metrics_collector) kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} worker.initialize_cache(**kwargs) @@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens(): seq_group_metadata_list=seq_group_metadata_list, accepted_token_ids=accepted_token_ids, target_logprobs=target_token_logprobs, - k=k) + k=k, + stage_times=(0, 0, 0)) # Verify that _seq_with_bonus_token_in_last_step contains the following: # 1. Sequence IDs that were already present in # _seq_with_bonus_token_in_last_step but were not part of the current diff --git a/vllm/config.py b/vllm/config.py index 35945e34452d..bec0b63197ef 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -907,6 +907,7 @@ def maybe_create_spec_config( speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, use_v2_block_manager: bool, + disable_log_stats: bool, speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], @@ -1095,7 +1096,8 @@ def maybe_create_spec_config( typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=\ typical_acceptance_sampler_posterior_alpha, - disable_logprobs=disable_logprobs + disable_logprobs=disable_logprobs, + disable_log_stats=disable_log_stats, ) @staticmethod @@ -1189,6 +1191,7 @@ def __init__( typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_alpha: float, disable_logprobs: bool, + disable_log_stats: bool, ): """Create a SpeculativeConfig object. @@ -1221,6 +1224,8 @@ def __init__( sampling, target sampling, and after accepted tokens are determined. If set to False, log probabilities will be returned. + disable_log_stats: Whether to disable periodic printing of stage + times in speculative decoding. """ self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config @@ -1235,6 +1240,7 @@ def __init__( self.typical_acceptance_sampler_posterior_alpha = \ typical_acceptance_sampler_posterior_alpha self.disable_logprobs = disable_logprobs + self.disable_log_stats = disable_log_stats self._verify_args() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2737b50927f6..acc0551af015 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -792,6 +792,7 @@ def create_engine_config(self, ) -> EngineConfig: speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, use_v2_block_manager=self.use_v2_block_manager, + disable_log_stats=self.disable_log_stats, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, draft_token_acceptance_method=\ diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index ad8c0cee0b5b..690aad505e21 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -27,7 +27,7 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker from vllm.spec_decode.target_model_runner import TargetModelRunner -from vllm.spec_decode.util import (create_sequence_group_output, +from vllm.spec_decode.util import (Timer, create_sequence_group_output, get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) @@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=speculative_config. typical_acceptance_sampler_posterior_alpha, - disable_logprobs=speculative_config.disable_logprobs) + disable_logprobs=speculative_config.disable_logprobs, + disable_log_stats=speculative_config.disable_log_stats, + ) return spec_decode_worker @@ -116,6 +118,7 @@ def create_worker( typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_alpha: float, disable_logprobs: bool, + disable_log_stats: bool, ) -> "SpecDecodeWorker": allow_zero_draft_token_step = True @@ -171,6 +174,7 @@ def create_worker( proposer_worker, scorer_worker, disable_logprobs=disable_logprobs, + disable_log_stats=disable_log_stats, disable_by_batch_size=disable_by_batch_size, spec_decode_sampler=spec_decode_sampler, allow_zero_draft_token_step=allow_zero_draft_token_step) @@ -180,7 +184,8 @@ def __init__( proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, spec_decode_sampler: SpecDecodeBaseSampler, - disable_logprobs: bool, + disable_logprobs: bool = False, + disable_log_stats: bool = False, metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, allow_zero_draft_token_step: Optional[bool] = True, @@ -203,6 +208,8 @@ def __init__( disable_logprobs: If set to True, token log probabilities will not be output in both the draft worker and the target worker. If set to False, log probabilities will be output by both. + disable_log_stats: If set to True, disable periodic printing of + speculative stage times. disable_by_batch_size: If the batch size is larger than this, disable speculative decoding for new incoming requests. metrics_collector: Helper class for collecting metrics; can be set @@ -240,6 +247,7 @@ def __init__( # in the subsequent step. self.previous_hidden_states: Optional[HiddenStates] = None self._disable_logprobs = disable_logprobs + self._disable_log_stats = disable_log_stats def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -525,28 +533,37 @@ def _run_speculative_decoding_step( execute_model_req.previous_hidden_states = self.previous_hidden_states self.previous_hidden_states = None - # Generate proposals using draft worker. - proposals = self.proposer_worker.get_spec_proposals( - execute_model_req, self._seq_with_bonus_token_in_last_step) + with Timer() as proposal_timer: + # Generate proposals using draft worker. + proposals = self.proposer_worker.get_spec_proposals( + execute_model_req, self._seq_with_bonus_token_in_last_step) if not self._allow_zero_draft_token_step and proposals.no_proposals: #TODO: Fix it #5814 raise RuntimeError("Cannot handle cases where distributed draft " "workers generate no tokens") - proposal_scores = self.scorer.score_proposals( - execute_model_req, - proposals, - ) - accepted_token_ids, target_logprobs = self._verify_tokens( - execute_model_req.seq_group_metadata_list, proposal_scores, - proposals, execute_model_req.num_lookahead_slots) + with Timer() as scoring_timer: + proposal_scores = self.scorer.score_proposals( + execute_model_req, + proposals, + ) + + with Timer() as verification_timer: + accepted_token_ids, target_logprobs = self._verify_tokens( + execute_model_req.seq_group_metadata_list, proposal_scores, + proposals, execute_model_req.num_lookahead_slots) + + stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots, + scoring_timer.elapsed_time_ms, + verification_timer.elapsed_time_ms) return self._create_output_sampler_list( execute_model_req.seq_group_metadata_list, accepted_token_ids, target_logprobs=target_logprobs, - k=execute_model_req.num_lookahead_slots) + k=execute_model_req.num_lookahead_slots, + stage_times=stage_times) @nvtx_range("spec_decode_worker._verify_tokens") def _verify_tokens( @@ -645,6 +662,7 @@ def _create_output_sampler_list( accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] k: int, + stage_times: Tuple[float, float, float], ) -> List[SamplerOutput]: """Given the accepted token ids, create a list of SamplerOutput. @@ -722,8 +740,30 @@ def _create_output_sampler_list( if maybe_rejsample_metrics is not None: sampler_output_list[ 0].spec_decode_worker_metrics = maybe_rejsample_metrics + + # Log time spent in each stage periodically. + # This is periodic because the rejection sampler emits metrics + # periodically. + self._maybe_log_stage_times(*stage_times) + return sampler_output_list + def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float, + scoring_time_ms: float, + verification_time_ms: float) -> None: + """Log the speculative stage times. If stat logging is disabled, do + nothing. + """ + if self._disable_log_stats: + return + + logger.info( + "SpecDecodeWorker stage times: " + "average_time_per_proposal_tok_ms=%.02f " + "scoring_time_ms=%.02f verification_time_ms=%.02f", + average_time_per_proposal_tok_ms, scoring_time_ms, + verification_time_ms) + def _create_dummy_logprob_lists( self, batch_size: int, diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index ade546eef264..c6223a97dba1 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -1,3 +1,4 @@ +import time from contextlib import contextmanager from typing import Dict, List, Optional, Tuple @@ -214,3 +215,17 @@ def nvtx_range(msg, *args, **kwargs): yield finally: torch.cuda.nvtx.range_pop() + + +class Timer: + """Basic timer context manager for measuring CPU time. + """ + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.end_time = time.time() + self.elapsed_time_s = self.end_time - self.start_time + self.elapsed_time_ms = self.elapsed_time_s * 1000 From e9630458c7b11732e147c120817c53420280d471 Mon Sep 17 00:00:00 2001 From: Bongwon Jang <152451401+bong-furiosa@users.noreply.github.com> Date: Tue, 6 Aug 2024 00:05:05 +0900 Subject: [PATCH 0063/3246] [SpecDecode] Support FlashInfer in DraftModelRunner (#6926) --- vllm/spec_decode/draft_model_runner.py | 47 ++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 0b755600ae82..b76a1ab4cf24 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -11,6 +11,17 @@ from vllm.attention.backends.rocm_flash_attn import ( ROCmFlashAttentionMetadata as FlashAttentionMetadata) +try: + from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper + from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +except ImportError: + BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None + BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -79,6 +90,11 @@ def __init__( return_hidden_states=return_hidden_states, ) + self.flashinfer_decode_workspace_buffer = None + self.flashinfer_decode_wrapper = None + self.flashinfer_prefill_workspace_buffer = None + self.flashinfer_prefill_wrapper = None + def _update_flash_attn_metadata(self, attn_metadata, num_seqs, num_queries): assert isinstance(attn_metadata, FlashAttentionMetadata) @@ -286,6 +302,37 @@ def execute_model( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) + if self.attn_backend.get_name() == "flashinfer": + assert model_input.attn_metadata is not None + assert model_input.input_tokens is not None + if self.flashinfer_decode_workspace_buffer is None: + self.flashinfer_decode_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_decode_wrapper = \ + BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_decode_workspace_buffer, "NHD") + self.flashinfer_prefill_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_prefill_wrapper = \ + BatchPrefillWithPagedKVCacheWrapper( + self.flashinfer_prefill_workspace_buffer, "NHD") + + model_input.attn_metadata.prefill_wrapper = \ + self.flashinfer_prefill_wrapper + if model_input.attn_metadata.use_cuda_graph: + batch_size = model_input.input_tokens.shape[0] + model_input.attn_metadata.decode_wrapper = \ + self.graph_runners[model_input. + virtual_engine][batch_size].flashinfer_decode_wrapper + else: + model_input.attn_metadata.decode_wrapper = \ + self.flashinfer_decode_wrapper + model_input.attn_metadata.begin_forward() + # Detect exec mode assert model_input.attn_metadata is not None use_cuda_graph = False From 003f8ee1287f90a7e8aa9b9e7d6246ac74ebefbe Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 5 Aug 2024 08:41:03 -0700 Subject: [PATCH 0064/3246] [BugFix] Use IP4 localhost form for zmq bind (#7163) --- vllm/entrypoints/openai/rpc/server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 7a72a6f732c9..60bb23b9bde0 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -30,7 +30,9 @@ def __init__(self, async_engine_args: AsyncEngineArgs, # Init socket for readiness state. self.socket = self.context.socket(zmq.constants.ROUTER) - self.socket.bind(f"tcp://localhost:{port}") + # Note numeric form of localhost should be used for zmq bind(), + # see https://stackoverflow.com/a/8958414 + self.socket.bind(f"tcp://127.0.0.1:{port}") def cleanup(self): """Cleanup all resources.""" From 57f560aa23077ed9def5952ab81a65bc080ae234 Mon Sep 17 00:00:00 2001 From: Aditya Paliwal Date: Mon, 5 Aug 2024 09:26:14 -0700 Subject: [PATCH 0065/3246] [BugFix] Use args.trust_remote_code (#7121) --- vllm/entrypoints/openai/api_server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a0190f3d66b1..88f0bd4ee4db 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -60,11 +60,11 @@ _running_tasks: Set[asyncio.Task] = set() -def model_is_embedding(model_name: str) -> bool: +def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool: return ModelConfig(model=model_name, tokenizer=model_name, tokenizer_mode="auto", - trust_remote_code=False, + trust_remote_code=trust_remote_code, seed=0, dtype="float16").embedding_mode @@ -97,7 +97,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: # If manually triggered or embedding model, use AsyncLLMEngine in process. # TODO: support embedding model via RPC. - if (model_is_embedding(args.model) + if (model_is_embedding(args.model, args.trust_remote_code) or args.disable_frontend_multiprocessing): async_engine_client = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) From 997cf78308d292b03c8a1e68d8d1a1f599551937 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Mon, 5 Aug 2024 11:10:16 -0700 Subject: [PATCH 0066/3246] [Misc] Fix typo in GroupCoordinator.recv() (#7167) Signed-off-by: Rui Qiao --- vllm/distributed/parallel_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d7ca8fd82e1a..a20b92de81cd 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -713,8 +713,8 @@ def recv(self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None) -> torch.Tensor: - """Receives a tensor from the src rank.""" - """NOTE: `src` is the local rank of the destination rank.""" + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" if src is None: src = (self.rank_in_group - 1) % self.world_size From 8571ac4672c8b599338cb95e23dfd624016aab36 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 5 Aug 2024 15:13:43 -0400 Subject: [PATCH 0067/3246] [Kernel] Update CUTLASS to 3.5.1 (#7085) --- CMakeLists.txt | 6 +- .../broadcast_load_epilogue_c3x.hpp | 192 ++++++++++-------- .../cutlass_w8a8/scaled_mm_c2x.cuh | 8 +- .../cutlass_w8a8/scaled_mm_c3x.cu | 30 +-- 4 files changed, 129 insertions(+), 107 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 922613ec5dda..e5ac5516c2e4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -193,8 +193,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - # CUTLASS 3.5.0 - GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc + # CUTLASS 3.5.1 + GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9 # Shallow clone with depth 1 GIT_SHALLOW TRUE GIT_PROGRESS TRUE @@ -237,7 +237,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp index e4bc9752ed7d..58b1e8ff159f 100644 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp +++ b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp @@ -64,8 +64,6 @@ using namespace detail; // Row vector broadcast template< - // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least - // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races int Stages, class CtaTileShapeMNK, class Element, @@ -73,14 +71,12 @@ template< int Alignment = 128 / sizeof_bits_v > struct Sm90RowOrScalarBroadcast { - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert( - (cute::is_same_v>) || // row vector broadcast, e.g. per-col alpha/bias - (cute::is_same_v>)); // batched row vector broadcast + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); - // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem - struct SharedStorage { - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_row; + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; }; // This struct has been modified to have a bool indicating that ptr_row is a @@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast { return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast { CUTLASS_HOST_DEVICE Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params), - smem_row(const_cast(shared_storage.smem_row.data())) { } + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } Params params; - Element* smem_row; + Element *smem = nullptr; CUTLASS_DEVICE bool is_producer_load_needed() const { - return true; + return false; } CUTLASS_DEVICE bool @@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast { return (!params.row_broadcast && *(params.ptr_row) == Element(0)); } - template - struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { - CUTLASS_DEVICE - ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params) - : gRow(cute::forward(gRow)), - sRow(cute::forward(sRow)), - params(params) {} - - GTensor gRow; // (CTA_M,CTA_N) - STensor sRow; // (CTA_M,CTA_N,PIPE) - Params const& params; - - CUTLASS_DEVICE void - begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { - if (!params.row_broadcast) { - return; - } - - if (issue_tma_load) { - // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size - constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v / 8; - cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); - // Issue the TMA bulk copy - auto bulk_copy = Copy_Atom{}.with(*full_mbarrier_ptr); - // Filter so we don't issue redundant copies over stride-0 modes - int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; - copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index))); - } - } - }; - template CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); - Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ProducerLoadCallbacks( - cute::move(gRow), cute::move(sRow), params); + return EmptyProducerLoadCallbacks{}; } - template + template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) - : tCrRow(cute::forward(tCrRow)), - tCsRow(cute::forward(tCsRow)), - params(params) {} - - RTensor tCrRow; // (CPY,CPY_M,CPY_N) - STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; Params const& params; CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + begin() { if (!params.row_broadcast) { - fill(tCrRow, *(params.ptr_row)); + fill(tSR_rRow, *(params.ptr_row)); return; } + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { if (epi_m == 0) { // Assumes M-major subtile loop - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; - copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); } } @@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { - frg_row[i] = tCrRow(epi_v * FragmentSize + i); + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); } return frg_row; @@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); - Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor tCsRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N) - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ConsumerStoreCallbacks( - cute::move(tCrRow), cute::move(tCsRow), params); + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + params); } }; @@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast { return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index ba620e85117b..be8a5c0e54e8 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -10,8 +10,6 @@ #include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" -#include "cutlass/util/device_memory.h" - #include "cutlass/cutlass.h" #include "cutlass/gemm_coord.h" #include "cutlass/arch/mma_sm75.h" @@ -301,12 +299,14 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, // Launch the CUTLASS GEMM kernel. typename Gemm::Op gemm_op; size_t workspace_size = gemm_op.get_workspace_size(args); - cutlass::device_memory::allocation workspace(workspace_size); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); CUTLASS_CHECK(gemm_op.can_implement(args)); - cutlass::Status status = gemm_op(args, workspace.get(), stream); + cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index b3f5b6208660..088185188770 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -18,8 +18,6 @@ #include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" -#include "cutlass/util/device_memory.h" - #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" @@ -72,13 +70,9 @@ struct ScaledEpilogueBase { 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, Stride, Int<0>, Int<0>>>; - using ScaleBDescriptor = - cutlass::epilogue::collective::detail::RowBroadcastDescriptor< - EpilogueDescriptor, float>; - using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< - ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, - typename ScaleBDescriptor::Element, Stride, Int<1>, Int<0>>>; + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, + Stride, Int<1>, Int<0>>>; }; /* @@ -154,12 +148,8 @@ struct ScaledEpilogueBias cutlass::multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; - using BiasDescriptor = - cutlass::epilogue::collective::detail::RowBroadcastDescriptor< - EpilogueDescriptor, ElementD>; - using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< - BiasDescriptor::Stages, typename EpilogueDescriptor::TileShape, ElementD, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, false>; public: @@ -251,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, int64_t ldb = b.stride(1); int64_t ldc = out.stride(0); - using StrideA = Stride, Int<0>>; - using StrideB = Stride, Int<0>>; + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; using StrideC = typename Gemm::StrideC; - StrideA a_stride{lda, Int<1>{}, Int<0>{}}; - StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; using GemmKernel = typename Gemm::GemmKernel; @@ -282,11 +272,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); - cutlass::device_memory::allocation workspace(workspace_size); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - cutlass::Status status = gemm_op.run(args, workspace.get(), stream); + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } From 6e4852ce28ad57dc440067778464ac61e0621899 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 5 Aug 2024 16:00:01 -0400 Subject: [PATCH 0068/3246] [CI/Build] Suppress divide-by-zero and missing return statement warnings (#7001) --- csrc/attention/dtype_bfloat16.cuh | 8 ++++++++ csrc/quantization/awq/dequantize.cuh | 1 + csrc/quantization/fp8/nvidia/quant_utils.cuh | 5 +++-- csrc/quantization/gptq_marlin/gptq_marlin.cu | 18 ++++++++++++------ 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 3cdcb95e0809..97a25baa1fc0 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -94,6 +94,7 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { #else return __bfloat1622float2(val); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { @@ -102,6 +103,7 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { #else return __bfloat162bfloat162(val); #endif + __builtin_unreachable(); // Suppress missing return statement warning } // Vector addition. @@ -115,6 +117,7 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { return __hadd(a, b); #endif #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { @@ -123,6 +126,7 @@ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { #else return __hadd2(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { @@ -170,6 +174,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #else return __hmul(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } template <> @@ -179,6 +184,7 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #else return __hmul2(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } template <> @@ -289,6 +295,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, #else return __hfma2(a, b, c); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, @@ -298,6 +305,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, #else return __hfma2(bf162bf162(a), b, c); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/quantization/awq/dequantize.cuh index 813ec6716cf5..5fa4b5f64027 100644 --- a/csrc/quantization/awq/dequantize.cuh +++ b/csrc/quantization/awq/dequantize.cuh @@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { return result; #endif + __builtin_unreachable(); // Suppress missing return statement warning } } // namespace awq diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh index e32684eaed24..f8cd1dcba4ab 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -475,6 +475,7 @@ __inline__ __device__ uint8_t scaled_vec_conversion( __NV_SATFINITE, fp8_type); return (uint8_t)res; #endif + __builtin_unreachable(); // Suppress missing return statement warning } // float -> fp8 @@ -508,7 +509,7 @@ __inline__ __device__ Tout convert(const Tin& x) { } #endif assert(false); - return {}; // Squash missing return statement warning + __builtin_unreachable(); // Suppress missing return statement warning } template @@ -521,7 +522,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { } #endif assert(false); - return {}; // Squash missing return statement warning + __builtin_unreachable(); // Suppress missing return statement warning } // The following macro is used to dispatch the conversion function based on diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index edf19365c809..e2b0f2b05816 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1130,12 +1130,12 @@ __global__ void Marlin( }; auto fetch_zp_to_registers = [&](int k, int full_pipe) { - if constexpr (has_zp) { - // This code does not handle group_blocks == 0, - // which signifies act_order. - // has_zp implies AWQ, which doesn't have act_order, - static_assert(group_blocks != 0); + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + if constexpr (has_zp) { int pipe = full_pipe % stages; if constexpr (group_blocks == -1) { @@ -1161,7 +1161,13 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; From 4cf1dc39be80d81ddda9e7e55f4742a6bd57920c Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 5 Aug 2024 17:22:57 -0400 Subject: [PATCH 0069/3246] [Bugfix][CI/Build] Fix CUTLASS FetchContent (#7171) --- CMakeLists.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e5ac5516c2e4..8de0c034a7cb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -195,8 +195,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") GIT_REPOSITORY https://github.com/nvidia/cutlass.git # CUTLASS 3.5.1 GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9 - # Shallow clone with depth 1 - GIT_SHALLOW TRUE GIT_PROGRESS TRUE ) FetchContent_MakeAvailable(cutlass) From 4db5176d9758b720b05460c50ace3c01026eb158 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 5 Aug 2024 14:39:48 -0700 Subject: [PATCH 0070/3246] bump version to v0.5.4 (#7139) --- docs/source/getting_started/installation.rst | 2 +- vllm/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 57ad8bacedfc..5294003aa926 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -48,7 +48,7 @@ You can install vLLM using pip: .. code-block:: console - $ export VLLM_VERSION=0.5.2 # vLLM's main branch version is currently set to latest released tag + $ export VLLM_VERSION=0.5.4 # vLLM's main branch version is currently set to latest released tag $ pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-${VLLM_VERSION}-cp38-abi3-manylinux1_x86_64.whl $ # You can also access a specific commit $ # export VLLM_COMMIT=... diff --git a/vllm/version.py b/vllm/version.py index 693065471063..247036f1d621 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -9,4 +9,4 @@ stacklevel=2) __commit__ = "COMMIT_HASH_PLACEHOLDER" -__version__ = "0.5.3.post1" +__version__ = "0.5.4" From dfb1a15dcb4c24bf7ff0ba7ddfc5d623ad519d7d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 5 Aug 2024 15:59:22 -0700 Subject: [PATCH 0071/3246] [ci][frontend] deduplicate tests (#7101) --- tests/entrypoints/openai/test_completion.py | 14 +- tests/entrypoints/openai/test_disable_mp.py | 715 -------------------- 2 files changed, 6 insertions(+), 723 deletions(-) delete mode 100644 tests/entrypoints/openai/test_disable_mp.py diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 50add84087a9..05f667231738 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -87,15 +87,13 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, ] -@pytest.fixture(scope="module") -def server(default_server_args): +@pytest.fixture(scope="module", + params=["", "--disable-frontend-multiprocessing"]) +def client(default_server_args, request): + if request.param: + default_server_args.append(request.param) with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: - yield remote_server - - -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() + yield remote_server.get_async_client() @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_disable_mp.py b/tests/entrypoints/openai/test_disable_mp.py deleted file mode 100644 index 12c805413311..000000000000 --- a/tests/entrypoints/openai/test_disable_mp.py +++ /dev/null @@ -1,715 +0,0 @@ -""" -Repeat of tests in test_completion.py with the non-mp backend. -""" - -# imports for guided decoding tests -import json -import re -import shutil -from tempfile import TemporaryDirectory -from typing import List - -import jsonschema -import openai # use the official client for correctness check -import pytest -# downloading lora to test lora requests -from huggingface_hub import snapshot_download -from openai import BadRequestError -from transformers import AutoTokenizer - -from vllm.transformers_utils.tokenizer import get_tokenizer - -from ...utils import RemoteOpenAIServer - -# any model with a chat template should work here -MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" -# technically these adapters use a different base model, -# but we're not testing generation quality here -LORA_NAME = "typeof/zephyr-7b-beta-lora" -PA_NAME = "swapnilbp/llama_tweet_ptune" -# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also -# need to change to match the prompt adapter -PA_NUM_VIRTUAL_TOKENS = 8 - - -@pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) - - -@pytest.fixture(scope="module") -def zephyr_lora_added_tokens_files(zephyr_lora_files): - tmp_dir = TemporaryDirectory() - tmp_model_dir = f"{tmp_dir.name}/zephyr" - shutil.copytree(zephyr_lora_files, tmp_model_dir) - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - # Copy tokenizer to adapter and add some unique tokens - # 32000, 32001, 32002 - added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], - special_tokens=True) - assert added == 3 - tokenizer.save_pretrained(tmp_model_dir) - yield tmp_model_dir - tmp_dir.cleanup() - - -@pytest.fixture(scope="module") -def zephyr_pa_files(): - return snapshot_download(repo_id=PA_NAME) - - -@pytest.fixture(scope="module") -def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, - zephyr_pa_files): - return [ - # use half precision for speed and memory savings in CI environment - "--dtype", - "bfloat16", - "--max-model-len", - "8192", - "--max-num-seqs", - "128", - "--enforce-eager", - # lora config - "--enable-lora", - "--lora-modules", - f"zephyr-lora={zephyr_lora_files}", - f"zephyr-lora2={zephyr_lora_added_tokens_files}", - "--max-lora-rank", - "64", - "--max-cpu-loras", - "2", - # pa config - "--enable-prompt-adapter", - "--prompt-adapters", - f"zephyr-pa={zephyr_pa_files}", - f"zephyr-pa2={zephyr_pa_files}", - "--max-prompt-adapters", - "2", - "--max-prompt-adapter-token", - "128", - "--disable-frontend-multiprocessing" - ] - - -@pytest.fixture(scope="module") -def server(default_server_args): - with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: - yield remote_server - - -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - # first test base model, then test loras, then test prompt adapters - "model_name,num_virtual_tokens", - [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), - ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS), - ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)], -) -async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, - num_virtual_tokens: int): - completion = await client.completions.create(model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) - - assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - - choice = completion.choices[0] - assert len(choice.text) >= 5 - assert choice.finish_reason == "length" - assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, - prompt_tokens=6 + num_virtual_tokens, - total_tokens=11 + num_virtual_tokens) - - # test using token IDs - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert len(completion.choices[0].text) >= 1 - - -@pytest.mark.asyncio -async def test_added_lora_tokens(client: openai.AsyncOpenAI): - # test using token IDs - completion = await client.completions.create( - model="zephyr-lora2", - prompt=[0, 0, 32000, 32001, 32002], - echo=True, - max_tokens=5, - temperature=0.0, - ) - # Added tokens should appear in tokenized prompt - assert completion.choices[0].text.startswith("vllm1vllm2vllm3") - - -@pytest.mark.asyncio -async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): - # test using token IDs - completion = await client.completions.create( - model=MODEL_NAME, - prompt=[0, 0, 32000, 32001, 32002], - echo=True, - max_tokens=5, - temperature=0.0, - ) - # Added tokens should not appear in tokenized prompt - assert "vllm" not in completion.choices[0].text - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - # first test base model, then test loras, then test prompt adapters - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"], -) -async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): - # test using token IDs - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - logprobs=None, - ) - choice = completion.choices[0] - assert choice.logprobs is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - # just test 1 lora and 1 pa hereafter - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], -) -async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): - # test using token IDs - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - logprobs=0, - ) - choice = completion.choices[0] - assert choice.logprobs is not None - assert choice.logprobs.token_logprobs is not None - assert choice.logprobs.top_logprobs is not None - assert len(choice.logprobs.top_logprobs[0]) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], -) -async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): - # test using token IDs - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - logprobs=5, - ) - choice = completion.choices[0] - assert choice.logprobs is not None - assert choice.logprobs.token_logprobs is not None - assert choice.logprobs.top_logprobs is not None - assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], -) -async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, - model_name: str): - - with pytest.raises( - (openai.BadRequestError, openai.APIError)): # test using token IDs - await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - # vLLM has higher default max_logprobs (20 instead of 5) to support - # both Completion API and Chat Completion API - logprobs=21, - ) - ... - with pytest.raises( - (openai.BadRequestError, openai.APIError)): # test using token IDs - stream = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - # vLLM has higher default max_logprobs (20 instead of 5) to support - # both Completion API and Chat Completion API - logprobs=30, - stream=True, - ) - async for chunk in stream: - ... - - # the server should still work afterwards - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert len(completion.choices[0].text) >= 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], -) -async def test_completion_streaming(client: openai.AsyncOpenAI, - model_name: str): - prompt = "What is an LLM?" - - single_completion = await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - ) - single_output = single_completion.choices[0].text - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) - chunks: List[str] = [] - finish_reason_count = 0 - async for chunk in stream: - chunks.append(chunk.choices[0].text) - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - # finish reason should only return in last block - assert finish_reason_count == 1 - assert chunk.choices[0].finish_reason == "length" - assert chunk.choices[0].text - assert "".join(chunks) == single_output - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], -) -async def test_completion_stream_options(client: openai.AsyncOpenAI, - model_name: str): - prompt = "What is the capital of France?" - - # Test stream=True, stream_options= - # {"include_usage": False, "continuous_usage_stats": False} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": False, - "continuous_usage_stats": - False, - }) - - async for chunk in stream: - assert chunk.usage is None - - # Test stream=True, stream_options= - # {"include_usage": False, "continuous_usage_stats": True} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": False, - "continuous_usage_stats": - True, - }) - async for chunk in stream: - assert chunk.usage is None - - # Test stream=True, stream_options= - # {"include_usage": True, "continuous_usage_stats": False} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": True, - "continuous_usage_stats": - False, - }) - async for chunk in stream: - if chunk.choices[0].finish_reason is None: - assert chunk.usage is None - else: - assert chunk.usage is None - final_chunk = await stream.__anext__() - assert final_chunk.usage is not None - assert final_chunk.usage.prompt_tokens > 0 - assert final_chunk.usage.completion_tokens > 0 - assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) - assert final_chunk.choices == [] - - # Test stream=True, stream_options= - # {"include_usage": True, "continuous_usage_stats": True} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": True, - "continuous_usage_stats": - True, - }) - async for chunk in stream: - assert chunk.usage is not None - assert chunk.usage.prompt_tokens > 0 - assert chunk.usage.completion_tokens > 0 - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) - if chunk.choices[0].finish_reason is not None: - final_chunk = await stream.__anext__() - assert final_chunk.usage is not None - assert final_chunk.usage.prompt_tokens > 0 - assert final_chunk.usage.completion_tokens > 0 - assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) - assert final_chunk.choices == [] - - # Test stream=False, stream_options= - # {"include_usage": None} - with pytest.raises(BadRequestError): - await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": None}) - - # Test stream=False, stream_options= - # {"include_usage": True} - with pytest.raises(BadRequestError): - await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": True}) - - # Test stream=False, stream_options= - # {"continuous_usage_stats": None} - with pytest.raises(BadRequestError): - await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"continuous_usage_stats": None}) - - # Test stream=False, stream_options= - # {"continuous_usage_stats": True} - with pytest.raises(BadRequestError): - await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"continuous_usage_stats": True}) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], -) -async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): - # test both text and token IDs - for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2): - # test simple list - batch = await client.completions.create( - model=model_name, - prompt=prompts, - max_tokens=5, - temperature=0.0, - ) - assert len(batch.choices) == 2 - assert batch.choices[0].text == batch.choices[1].text - - # test n = 2 - batch = await client.completions.create( - model=model_name, - prompt=prompts, - n=2, - max_tokens=5, - temperature=0.0, - extra_body=dict( - # NOTE: this has to be true for n > 1 in vLLM, but not necessary - # for official client. - use_beam_search=True), - ) - assert len(batch.choices) == 4 - assert batch.choices[0].text != batch.choices[ - 1].text, "beam search should be different" - assert batch.choices[0].text == batch.choices[ - 2].text, "two copies of the same prompt should be the same" - assert batch.choices[1].text == batch.choices[ - 3].text, "two copies of the same prompt should be the same" - - # test streaming - batch = await client.completions.create( - model=model_name, - prompt=prompts, - max_tokens=5, - temperature=0.0, - stream=True, - ) - texts = [""] * 2 - async for chunk in batch: - assert len(chunk.choices) == 1 - choice = chunk.choices[0] - texts[choice.index] += choice.text - assert texts[0] == texts[1] - - -@pytest.mark.asyncio -async def test_logits_bias(client: openai.AsyncOpenAI): - prompt = "Hello, my name is" - max_tokens = 5 - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) - - # Test exclusive selection - token_id = 1000 - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - logit_bias={str(token_id): 100}, - seed=42, - ) - assert len(completion.choices[0].text) >= 5 - response_tokens = tokenizer(completion.choices[0].text, - add_special_tokens=False)["input_ids"] - expected_tokens = tokenizer(tokenizer.decode([token_id] * 5), - add_special_tokens=False)["input_ids"] - assert all([ - response == expected - for response, expected in zip(response_tokens, expected_tokens) - ]) - - # Test ban - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - ) - response_tokens = tokenizer(completion.choices[0].text, - add_special_tokens=False)["input_ids"] - first_response = completion.choices[0].text - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - logit_bias={str(token): -100 - for token in response_tokens}, - ) - assert first_response != completion.choices[0].text - - -@pytest.mark.asyncio -async def test_allowed_token_ids(client: openai.AsyncOpenAI): - prompt = "Hello, my name is" - max_tokens = 1 - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) - - # Test exclusive selection - allowed_ids = [21555, 21557, 21558] - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - seed=42, - extra_body=dict(allowed_token_ids=allowed_ids), - logprobs=1, - ) - response_tokens = completion.choices[0].logprobs.tokens - assert len(response_tokens) == 1 - assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) -async def test_guided_json_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_json_schema): - completion = await client.completions.create( - model=MODEL_NAME, - prompt=f"Give an example JSON for an employee profile " - f"that fits this schema: {sample_json_schema}", - n=3, - temperature=1.0, - max_tokens=500, - extra_body=dict(guided_json=sample_json_schema, - guided_decoding_backend=guided_decoding_backend)) - - assert completion.id is not None - assert len(completion.choices) == 3 - for i in range(3): - output_json = json.loads(completion.choices[i].text) - jsonschema.validate(instance=output_json, schema=sample_json_schema) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) -async def test_guided_regex_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_regex): - completion = await client.completions.create( - model=MODEL_NAME, - prompt=f"Give an example IPv4 address with this regex: {sample_regex}", - n=3, - temperature=1.0, - max_tokens=20, - extra_body=dict(guided_regex=sample_regex, - guided_decoding_backend=guided_decoding_backend)) - - assert completion.id is not None - assert len(completion.choices) == 3 - for i in range(3): - assert re.fullmatch(sample_regex, - completion.choices[i].text) is not None - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) -async def test_guided_choice_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_guided_choice): - completion = await client.completions.create( - model=MODEL_NAME, - prompt="The best language for type-safe systems programming is ", - n=2, - temperature=1.0, - max_tokens=10, - extra_body=dict(guided_choice=sample_guided_choice, - guided_decoding_backend=guided_decoding_backend)) - - assert completion.id is not None - assert len(completion.choices) == 2 - for i in range(2): - assert completion.choices[i].text in sample_guided_choice - - -@pytest.mark.asyncio -async def test_guided_grammar(client: openai.AsyncOpenAI, - sample_sql_statements): - - completion = await client.completions.create( - model=MODEL_NAME, - prompt=("Generate a sql state that select col_1 from " - "table_1 where it is equals to 1"), - temperature=1.0, - max_tokens=500, - extra_body=dict(guided_grammar=sample_sql_statements)) - - content = completion.choices[0].text - - # use Lark to parse the output, and make sure it's a valid parse tree - from lark import Lark - parser = Lark(sample_sql_statements) - parser.parse(content) - - # remove spaces for comparison b/c we removed them in the grammar - ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") - - assert content.strip() == ground_truth - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - # first test base model, then test loras - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], -) -@pytest.mark.parametrize("logprobs_arg", [1, 0]) -async def test_echo_logprob_completion(client: openai.AsyncOpenAI, - model_name: str, logprobs_arg: int): - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) - # test using text and token IDs - for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): - completion = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - echo=True, - logprobs=logprobs_arg) - - prompt_text = tokenizer.decode(prompt) if isinstance(prompt, - list) else prompt - assert re.search(r"^" + prompt_text, completion.choices[0].text) - logprobs = completion.choices[0].logprobs - assert logprobs is not None - assert len(logprobs.text_offset) > 5 - assert (len(logprobs.token_logprobs) > 5 - and logprobs.token_logprobs[0] is None) - assert (len(logprobs.top_logprobs) > 5 - and logprobs.top_logprobs[0] is None) - for top_logprobs in logprobs.top_logprobs[1:]: - assert max(logprobs_arg, - 1) <= len(top_logprobs) <= logprobs_arg + 1 - assert len(logprobs.tokens) > 5 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) -async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_json_schema, sample_regex): - with pytest.raises(openai.BadRequestError): - _ = await client.completions.create( - model=MODEL_NAME, - prompt="Give an example JSON that fits this schema: 42", - extra_body=dict(guided_json=42, - guided_decoding_backend=guided_decoding_backend)) - - with pytest.raises(openai.BadRequestError): - _ = await client.completions.create( - model=MODEL_NAME, - prompt="Give an example string that fits this regex", - extra_body=dict(guided_regex=sample_regex, - guided_json=sample_json_schema)) From 789937af2edb6c1ff847c3cbf0c773fb06602a5f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 6 Aug 2024 01:29:43 +0200 Subject: [PATCH 0072/3246] [Doc] [SpecDecode] Update MLPSpeculator documentation (#7100) Signed-off-by: Thomas Parnell --- docs/source/models/spec_decode.rst | 49 ++++++++++++++++++++ vllm/model_executor/models/mlp_speculator.py | 9 ++++ 2 files changed, 58 insertions(+) diff --git a/docs/source/models/spec_decode.rst b/docs/source/models/spec_decode.rst index 87a52360c084..be901fa881b1 100644 --- a/docs/source/models/spec_decode.rst +++ b/docs/source/models/spec_decode.rst @@ -69,6 +69,55 @@ matching n-grams in the prompt. For more information read `this thread. `_ or +`this technical report `_. + +.. code-block:: python + + from vllm import LLM, SamplingParams + + prompts = [ + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + llm = LLM( + model="meta-llama/Meta-Llama-3.1-70B-Instruct", + tensor_parallel_size=4, + speculative_model="ibm-fms/llama3-70b-accelerator", + speculative_draft_tensor_parallel_size=1, + use_v2_block_manager=True, + ) + outputs = llm.generate(prompts, sampling_params) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +Note that these speculative models currently need to be run without tensor parallelism, although +it is possible to run the main model using tensor parallelism (see example above). Since the +speculative models are relatively small, we still see significant speedups. However, this +limitation will be fixed in a future release. + +A variety of speculative models of this type are available on HF hub: + +* `llama-13b-accelerator `_ +* `llama3-8b-accelerator `_ +* `codellama-34b-accelerator `_ +* `llama2-70b-accelerator `_ +* `llama3-70b-accelerator `_ +* `granite-3b-code-instruct-accelerator `_ +* `granite-8b-code-instruct-accelerator `_ +* `granite-7b-instruct-accelerator `_ +* `granite-20b-code-instruct-accelerator `_ + + Resources for vLLM contributors ------------------------------- * `A Hacker's Guide to Speculative Decoding in vLLM `_ diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index d3aec06a92fd..95a655fbbf37 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -56,6 +56,15 @@ def forward(self, x): class MLPSpeculator(nn.Module): + """ + An implementation of the speculative models introduced in + "Accelerating Production LLMs with Combined Token/Embedding + Speculators" + https://arxiv.org/pdf/2404.19124 + + Trained speculators of this type are available on HF hub at: + https://huggingface.co/ibm-fms and https://huggingface.co/ibm-granite + """ def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None: super().__init__() From 89b8db6bb2ce2948073c21231f103c76456844da Mon Sep 17 00:00:00 2001 From: Jacob Schein Date: Mon, 5 Aug 2024 16:35:47 -0700 Subject: [PATCH 0073/3246] [Bugfix] Specify device when loading LoRA and embedding tensors (#7129) Co-authored-by: Jacob Schein --- vllm/lora/models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 017a1002bb9a..279477562a94 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -248,7 +248,7 @@ def from_local_checkpoint( f" target modules in {expected_lora_modules}" f" but received {unexpected_modules}." f" Please verify that the loaded LoRA module is correct") - tensors = torch.load(lora_bin_file_path) + tensors = torch.load(lora_bin_file_path, map_location=device) else: raise ValueError(f"{lora_dir} doesn't contain tensors") @@ -257,7 +257,8 @@ def from_local_checkpoint( embeddings = safetensors.torch.load_file( new_embeddings_tensor_path) elif os.path.isfile(new_embeddings_bin_file_path): - embeddings = torch.load(new_embeddings_bin_file_path) + embeddings = torch.load(new_embeddings_bin_file_path, + map_location=device) rank = config["r"] lora_alpha = config["lora_alpha"] From ef527be06c4064f3a2753a3b2c7ede862fe459e8 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 5 Aug 2024 16:41:27 -0700 Subject: [PATCH 0074/3246] [MISC] Use non-blocking transfer in prepare_input (#7172) --- vllm/attention/backends/flash_attn.py | 27 ++++++++++++--------------- vllm/attention/backends/flashinfer.py | 23 +++++++++++------------ vllm/attention/backends/utils.py | 27 ++++++++++++--------------- vllm/worker/model_runner.py | 15 ++++++++------- 4 files changed, 43 insertions(+), 49 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 26b3159682b3..8a895bbdc2dd 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -13,7 +13,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) -from vllm.utils import make_tensor_with_pad +from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder @@ -310,7 +310,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], for i, block_table in enumerate(self.block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device=device) + block_tables = torch.from_numpy(input_block_tables).to( + device=device, non_blocking=True) else: block_tables = make_tensor_with_pad( self.block_tables, @@ -320,15 +321,15 @@ def build(self, seq_lens: List[int], query_lens: List[int], ) assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - context_lens_tensor = torch.tensor(self.context_lens, - dtype=torch.int, - device=device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) @@ -344,10 +345,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=query_start_loc.dtype, out=query_start_loc[1:]) - slot_mapping_tensor = torch.tensor(self.slot_mapping, - dtype=torch.long, - device=device) - return FlashAttentionMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 91abaab78dcb..03188164a963 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -21,7 +21,8 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.paged_attn import PagedAttention -from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad +from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, + make_tensor_with_pad) if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder @@ -356,7 +357,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], for i, block_table in enumerate(self.block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device=device) + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) last_paged_kv_indptr = self.paged_kv_indptr[-1] self.paged_kv_indptr.extend([last_paged_kv_indptr] * @@ -371,12 +373,13 @@ def build(self, seq_lens: List[int], query_lens: List[int], ) assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) + assert device is not None + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) @@ -392,10 +395,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=query_start_loc.dtype, out=query_start_loc[1:]) - slot_mapping_tensor = torch.tensor(self.slot_mapping, - dtype=torch.long, - device=device) - if len(self.paged_kv_indptr) > 0: paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, device="cpu", diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index bca1370343b7..f7cb2ee99650 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -4,7 +4,7 @@ import torch from vllm.attention import AttentionMetadata, AttentionMetadataBuilder -from vllm.utils import make_tensor_with_pad +from vllm.utils import async_tensor_h2d, make_tensor_with_pad # Error string(s) for encoder/decoder # unsupported attention scenarios @@ -181,7 +181,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], for i, block_table in enumerate(self.block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device=device) + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) else: block_tables = make_tensor_with_pad( self.block_tables, @@ -191,15 +192,15 @@ def build(self, seq_lens: List[int], query_lens: List[int], ) assert max_query_len > 0, "query_lens: {}".format(query_lens) - context_lens_tensor = torch.tensor(self.context_lens, - dtype=torch.int, - device=device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=device) + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) @@ -215,10 +216,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=query_start_loc.dtype, out=query_start_loc[1:]) - slot_mapping_tensor = torch.tensor(self.slot_mapping, - dtype=torch.long, - device=device) - return self._metadata_cls( # type: ignore num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f9c26e0c318b..8b744a438e81 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -50,7 +50,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, flatten_2d_lists, +from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, flatten_2d_lists, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available) from vllm.worker.model_runner_base import ( @@ -549,12 +549,13 @@ def build(self) -> ModelInputForGPU: # Tokens and positions. input_tokens.extend([0] * cuda_graph_pad_size) input_positions.extend([0] * cuda_graph_pad_size) - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.runner.device) - input_positions_tensor = torch.tensor(input_positions, - dtype=torch.long, - device=self.runner.device) + assert self.runner.device is not None + input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, + self.runner.device, + self.runner.pin_memory) + input_positions_tensor = async_tensor_h2d(input_positions, torch.long, + self.runner.device, + self.runner.pin_memory) # Sequence and query lengths. seq_lens.extend([1] * cuda_graph_pad_size) From 360bd67cf0ea4a79a59c1aae736cc495a5a63ec5 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 6 Aug 2024 07:54:23 +0800 Subject: [PATCH 0075/3246] [Core] Support loading GGUF model (#5191) Co-authored-by: Michael Goin --- .github/workflows/clang-format.yml | 5 + CMakeLists.txt | 1 + csrc/ops.h | 9 + csrc/quantization/gguf/dequantize.cuh | 531 +++++ csrc/quantization/gguf/ggml-common.h | 969 +++++++++ csrc/quantization/gguf/gguf_kernel.cu | 242 +++ csrc/quantization/gguf/mmq.cuh | 600 ++++++ csrc/quantization/gguf/mmvq.cuh | 182 ++ csrc/quantization/gguf/vecdotq.cuh | 1745 +++++++++++++++++ csrc/torch_bindings.cpp | 12 + examples/gguf_inference.py | 38 + format.sh | 5 + requirements-common.txt | 1 + tests/models/test_gguf.py | 76 + tests/quantization/test_lm_head.py | 6 +- vllm/_custom_ops.py | 32 + vllm/config.py | 1 + vllm/engine/arg_utils.py | 3 + vllm/model_executor/layers/linear.py | 82 +- .../layers/quantization/__init__.py | 2 + .../layers/quantization/base_config.py | 26 +- .../layers/quantization/gguf.py | 165 ++ .../layers/vocab_parallel_embedding.py | 59 +- vllm/model_executor/model_loader/loader.py | 94 +- .../model_loader/weight_utils.py | 47 +- vllm/model_executor/models/llama.py | 7 + vllm/model_executor/models/qwen2.py | 1 + vllm/transformers_utils/config.py | 40 +- vllm/transformers_utils/tokenizer.py | 10 +- 29 files changed, 4970 insertions(+), 21 deletions(-) create mode 100644 csrc/quantization/gguf/dequantize.cuh create mode 100644 csrc/quantization/gguf/ggml-common.h create mode 100644 csrc/quantization/gguf/gguf_kernel.cu create mode 100644 csrc/quantization/gguf/mmq.cuh create mode 100644 csrc/quantization/gguf/mmvq.cuh create mode 100644 csrc/quantization/gguf/vecdotq.cuh create mode 100644 examples/gguf_inference.py create mode 100644 tests/models/test_gguf.py create mode 100644 vllm/model_executor/layers/quantization/gguf.py diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml index 79b85d8cad0d..d5f37396e69d 100644 --- a/.github/workflows/clang-format.yml +++ b/.github/workflows/clang-format.yml @@ -30,6 +30,11 @@ jobs: run: | EXCLUDES=( 'csrc/moe/topk_softmax_kernels.cu' + 'csrc/quantization/gguf/ggml-common.h' + 'csrc/quantization/gguf/dequantize.cuh' + 'csrc/quantization/gguf/vecdotq.cuh' + 'csrc/quantization/gguf/mmq.cuh' + 'csrc/quantization/gguf/mmvq.cuh' ) find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ diff --git a/CMakeLists.txt b/CMakeLists.txt index 8de0c034a7cb..784fea05ea73 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -208,6 +208,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu" + "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/fp8/fp8_marlin.cu" "csrc/custom_all_reduce.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" diff --git a/csrc/ops.h b/csrc/ops.h index 3bd4a9eda5ee..e9e5f79a4a6f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -107,6 +107,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits); +torch::Tensor ggml_dequantize(torch::Tensor W, int8_t type, int64_t m, + int64_t n); + +torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int8_t type, + int64_t row); + +torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int8_t type, + int64_t row); + torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, diff --git a/csrc/quantization/gguf/dequantize.cuh b/csrc/quantization/gguf/dequantize.cuh new file mode 100644 index 000000000000..03c080f645f0 --- /dev/null +++ b/csrc/quantization/gguf/dequantize.cuh @@ -0,0 +1,531 @@ +// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/convert.cu +// Dequant functions +static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q4_0 * x = (const block_q4_0 *) vx; + + const dfloat d = x[ib].d; + + const int vui = x[ib].qs[iqs]; + + v.x = __int2half_rn(vui & 0xF); + v.y = __int2half_rn(vui >> 4); + + v = __hsub2(v, __floats2half2_rn(8.0f, 8.0f)); + v = __hmul2(v, {d, d}); +} + +static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q4_1 * x = (const block_q4_1 *) vx; + + const dfloat d = __low2half(x[ib].dm); + const dfloat m = __high2half(x[ib].dm); + + const int vui = x[ib].qs[iqs]; + + v.x = __int2half_rn(vui & 0xF); + v.y = __int2half_rn(vui >> 4); + + v = __hmul2(v, {d, d}); + v = __hadd2(v, {m, m}); +} + +static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q5_0 * x = (const block_q5_0 *) vx; + + const dfloat d = x[ib].d; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x = __int2half_rn((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = __int2half_rn((x[ib].qs[iqs] >> 4) | xh_1); + + v = __hsub2(v, __floats2half2_rn(16.0f, 16.0f)); + v = __hmul2(v, {d, d}); +} + +static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q5_1 * x = (const block_q5_1 *) vx; + + const dfloat d = __low2half(x[ib].dm); + const dfloat m = __high2half(x[ib].dm); + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x = __int2half_rn((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = __int2half_rn((x[ib].qs[iqs] >> 4) | xh_1); + + v = __hmul2(v, {d, d}); + v = __hadd2(v, {m, m}); +} + +static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q8_0 * x = (const block_q8_0 *) vx; + + const dfloat d = x[ib].d; + + v.x = __int2half_rn(x[ib].qs[iqs + 0]); + v.y = __int2half_rn(x[ib].qs[iqs + 1]); + + v = __hmul2(v, {d, d}); +} + +template +static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { + const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x); + + if (i >= k) { + return; + } + + const int ib = i/qk; // block index + const int iqs = (i%qk)/qr; // quant index + const int iybs = i - i%qk; // y block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + dfloat2 v; + dequantize_kernel(vx, ib, iqs, v); + + y[iybs + iqs + 0] = v.x; + y[iybs + iqs + y_offset] = v.y; +} + +template +static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_q2_K * x = (const block_q2_K *) vx; + + const int tid = threadIdx.x; + const int n = tid/32; + const int l = tid - 32*n; + const int is = 8*n + l/16; + + const uint8_t q = x[i].qs[32*n + l]; + dst_t * y = yy + i*QK_K + 128*n; + + half dall = __low2half(x[i].dm); + half dmin = __high2half(x[i].dm); + y[l+ 0] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4))); + y[l+32] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4))); + y[l+64] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4))); + y[l+96] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4))); +} + +template +static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_q3_K * x = (const block_q3_K *) vx; + + const int r = threadIdx.x/4; + const int tid = r/2; + const int is0 = r%2; + const int l0 = 16*is0 + 4*(threadIdx.x%4); + const int n = tid / 4; + const int j = tid - 4*n; + + uint8_t m = 1 << (4*n + j); + int is = 8*n + 2*j + is0; + int shift = 2*j; + + int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : + (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); + half d_all = x[i].d; + half dl = __hmul(d_all, __int2half_rn(us - 32)); + + dst_t * y = yy + i*QK_K + 128*n + 32*j; + const uint8_t * q = x[i].qs + 32*n; + const uint8_t * hm = x[i].hmask; + + for (int l = l0; l < l0+4; ++l) y[l] = __hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4))); +} + +static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { + if (j < 4) { + d = q[j] & 63; m = q[j + 4] & 63; + } else { + d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} + +template +static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q4_K * x = (const block_q4_K *) vx; + + const int i = blockIdx.x; + + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int is = 2*il; + const int n = 4; + + dst_t * y = yy + i*QK_K + 64*il + n*ir; + + const half dall = __low2half(x[i].dm); + const half dmin = __high2half(x[i].dm); + + const uint8_t * q = x[i].qs + 32*il + n*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const half d1 = __hmul(dall, __int2half_rn(sc)); + const half m1 = __hmul(dmin, __int2half_rn(m)); + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const half d2 = __hmul(dall, __int2half_rn(sc)); + const half m2 = __hmul(dmin, __int2half_rn(m)); + for (int l = 0; l < n; ++l) { + y[l + 0] = __hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1); + y[l +32] = __hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2); + } +} + +template +static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q5_K * x = (const block_q5_K *) vx; + + const int i = blockIdx.x; + + // assume 64 threads - this is very slightly better than the one below + const int tid = threadIdx.x; + const int il = tid/16; // il is in 0...3 + const int ir = tid%16; // ir is in 0...15 + const int is = 2*il; // is is in 0...6 + + dst_t * y = yy + i*QK_K + 64*il + 2*ir; + + const half dall = __low2half(x[i].dm); + const half dmin = __high2half(x[i].dm); + + const uint8_t * ql = x[i].qs + 32*il + 2*ir; + const uint8_t * qh = x[i].qh + 2*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const half d1 = __hmul(dall, __int2half_rn(sc)); const half m1 = __hmul(dmin, __int2half_rn(m)); + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m)); + + uint8_t hm = 1 << (2*il); + y[ 0] = __hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1); + y[ 1] = __hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1); + hm <<= 1; + y[32] = __hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2); + y[33] = __hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2); +} + +template +static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q6_K * x = (const block_q6_K *) vx; + + const int i = blockIdx.x; + + // assume 64 threads - this is very slightly better than the one below + const int tid = threadIdx.x; + const int ip = tid/32; // ip is 0 or 1 + const int il = tid - 32*ip; // 0...32 + const int is = 8*ip + il/16; + + dst_t * y = yy + i*QK_K + 128*ip + il; + + const half d = x[i].d; + + const uint8_t * ql = x[i].ql + 64*ip + il; + const uint8_t qh = x[i].qh[32*ip + il]; + const int8_t * sc = x[i].scales + is; + + y[ 0] = __hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); + y[32] = __hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); + y[64] = __hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); + y[96] = __hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); +} + +template +static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq2_xxs * x = (const block_iq2_xxs *) vx; + + const int tid = threadIdx.x; + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint16_t * q2 = x[i].qs + 4*ib; + const uint8_t * aux8 = (const uint8_t *)q2; + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]); + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; + for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f)); +} + +template +static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq2_xs * x = (const block_iq2_xs *) vx; + + const int tid = threadIdx.x; + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint16_t * q2 = x[i].qs + 4*ib; + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511)); + const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[q2[il] >> 9]; + for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f)); + +} + +template +static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq2_s * x = (const block_iq2_s *) vx; + + const int tid = threadIdx.x; + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300))); + const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; + const uint8_t signs = x[i].qs[QK_K/8+4*ib+il]; + for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f)); +} + +template +static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq3_xxs * x = (const block_iq3_xxs *) vx; + + const int tid = threadIdx.x; + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint8_t * q3 = x[i].qs + 8*ib; + const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib; + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]); + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.5f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; + for (int j = 0; j < 4; ++j) { + y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f)); + y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f)); + } +} + +template +static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq3_s * x = (const block_iq3_s *) vx; + + const int tid = threadIdx.x; + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint8_t * qs = x[i].qs + 8*ib; + const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256))); + const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f; + const uint8_t signs = x[i].signs[4*ib + il]; + for (int j = 0; j < 4; ++j) { + y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f)); + y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f)); + } +} + +template +static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq1_s * x = (const block_iq1_s *) vx; + + const int tid = threadIdx.x; + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const int i8 = 4*ib+il; + uint8_t h = x[i].scales[i8/2] >> 4*(i8%2); + const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5))); + const float d = __half2float(x[i].d) * (2*(h & 7) + 1); + for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j]); +} + +template +static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL); + + const int tid = threadIdx.x; + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[ib].qs + 4*il; + const float d = __half2float(x[ib].d); + for (int j = 0; j < 4; ++j) { + y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]); + y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]); + } + +} + +template +static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const int i = blockIdx.x; + const block_iq4_xs * x = (const block_iq4_xs *)vx; + + const int tid = threadIdx.x; + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[i].qs + 16*ib + 4*il; + const float d = __half2float(x[i].d) * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32); + for (int j = 0; j < 4; ++j) { + y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]); + y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]); + } +} + +template +static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { + const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE); + dequantize_block<<>>(vx, y, k); +} + +template +static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q2_K<<>>(vx, y); +} + +template +static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q3_K<<>>(vx, y); +} + +template +static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q4_K<<>>(vx, y); +} + +template +static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q5_K<<>>(vx, y); +} + +template +static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q6_K<<>>(vx, y); +} + +template +static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq2_xxs<<>>(vx, y); +} + +template +static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq2_xs<<>>(vx, y); +} + +template +static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq2_s<<>>(vx, y); +} + +template +static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq3_xxs<<>>(vx, y); +} + +template +static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq3_s<<>>(vx, y); +} + +template +static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq1_s<<>>(vx, y); +} + +template +static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq4_nl<<>>(vx, y); +} + +template +static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq4_xs<<>>(vx, y); +} + +static to_fp16_cuda_t ggml_get_to_fp16_cuda(int type) { + switch (type) { + case 2: + return dequantize_block_cuda; + case 3: + return dequantize_block_cuda; + case 6: + return dequantize_block_cuda; + case 7: + return dequantize_block_cuda; + case 8: + return dequantize_block_cuda; + case 10: + return dequantize_row_q2_K_cuda; + case 11: + return dequantize_row_q3_K_cuda; + case 12: + return dequantize_row_q4_K_cuda; + case 13: + return dequantize_row_q5_K_cuda; + case 14: + return dequantize_row_q6_K_cuda; + case 16: + return dequantize_row_iq2_xxs_cuda; + case 17: + return dequantize_row_iq2_xs_cuda; + case 18: + return dequantize_row_iq3_xxs_cuda; + case 19: + return dequantize_row_iq1_s_cuda; + case 20: + return dequantize_row_iq4_nl_cuda; + case 21: + return dequantize_row_iq3_s_cuda; + case 22: + return dequantize_row_iq2_s_cuda; + case 23: + return dequantize_row_iq4_xs_cuda; + default: + return nullptr; + } +} \ No newline at end of file diff --git a/csrc/quantization/gguf/ggml-common.h b/csrc/quantization/gguf/ggml-common.h new file mode 100644 index 000000000000..d7989d84bf68 --- /dev/null +++ b/csrc/quantization/gguf/ggml-common.h @@ -0,0 +1,969 @@ +// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 +#define WARP_SIZE 32 +#define K_SCALE_SIZE 12 +#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 +#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#define GGML_CUDA_DMMV_X 32 +#define GGML_CUDA_MMV_Y 1 + + +// Data Structures +// QK = number of values after dequantization +// QR = QK / number of values before dequantization +// QI = number of 32 bit integers before dequantization + +#define QK4_0 32 +#define QR4_0 2 +#define QI4_0 (QK4_0 / (4 * QR4_0)) +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; + +#define QK4_1 32 +#define QR4_1 2 +#define QI4_1 (QK4_1 / (4 * QR4_1)) +typedef struct { + half2 dm; // dm.x = delta, dm.y = min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; + +#define QK5_0 32 +#define QR5_0 2 +#define QI5_0 (QK5_0 / (4 * QR5_0)) +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; + +#define QK5_1 32 +#define QR5_1 2 +#define QI5_1 (QK5_1 / (4 * QR5_1)) +typedef struct { + half2 dm; // dm.x = delta, dm.y = min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; + +#define QK8_0 32 +#define QR8_0 1 +#define QI8_0 (QK8_0 / (4 * QR8_0)) +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; + +#define QK8_1 32 +#define QR8_1 1 +#define QI8_1 (QK8_1 / (4 * QR8_1)) +typedef struct { + half2 ds; // ds.x = delta, ds.y = sum + int8_t qs[QK8_0]; // quants +} block_q8_1; + +#define QR2_K 4 +#define QI2_K (QK_K / (4*QR2_K)) +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + half2 dm; // super-block scale for quantized scales/mins +} block_q2_K; + +#define QR3_K 4 +#define QI3_K (QK_K / (4*QR3_K)) +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits + half d; // super-block scale +} block_q3_K; + +#define QR4_K 2 +#define QI4_K (QK_K / (4*QR4_K)) +typedef struct { + half2 dm; // super-block scale for quantized scales/mins + uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; + +#define QR5_K 2 +#define QI5_K (QK_K / (4*QR5_K)) +typedef struct { + half2 dm; // super-block scale for quantized scales/mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; + +#define QR6_K 2 +#define QI6_K (QK_K / (4*QR6_K)) +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales + half d; // delta +} block_q6_K; + +#define QR2_XXS 8 +#define QI2_XXS (QK_K / (4*QR2_XXS)) +typedef struct { + half d; + uint16_t qs[QK_K/8]; +} block_iq2_xxs; + +#define QR2_XS 8 +#define QI2_XS (QK_K / (4*QR2_XS)) +typedef struct { + half d; + uint16_t qs[QK_K/8]; + uint8_t scales[QK_K/32]; +} block_iq2_xs; + +#define QR2_S 8 +#define QI2_S (QK_K / (4*QR2_S)) +typedef struct { + half d; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/32]; + uint8_t scales[QK_K/32]; +} block_iq2_s; + +#define QR3_XXS 8 +#define QI3_XXS (QK_K / (4*QR3_XXS)) +typedef struct { + half d; + uint8_t qs[3*(QK_K/8)]; +} block_iq3_xxs; + +#define QR3_XS 8 +#define QI3_XS (QK_K / (4*QR3_XS)) +#define IQ3S_N_SCALE QK_K/64 +typedef struct { + half d; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/32]; + uint8_t signs[QK_K/8]; + uint8_t scales[IQ3S_N_SCALE]; +} block_iq3_s; + +#define QR1_S 8 +#define QI1_S (QK_K / (4*QR1_S)) +typedef struct { + half d; + uint8_t qs[QK_K/8]; + uint8_t scales[QK_K/16]; +} block_iq1_s; + +#define QK4_NL 32 +#define QR4_NL 2 +#define QI4_NL (QK4_NL / (4*QR4_NL)) +typedef struct { + half d; + uint8_t qs[QK4_NL/2]; +} block_iq4_nl; + +#define QR4_XS 8 +#define QI4_XS (QK_K / (4*QR4_XS)) +typedef struct { + half d; + uint16_t scales_h; + uint8_t scales_l[QK_K/64]; + uint8_t qs[QK_K/2]; +} block_iq4_xs; + +static const __device__ uint64_t iq2xxs_grid[256] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, + 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, + 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819, + 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b, + 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, + 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08, + 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b, + 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819, + 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08, + 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, + 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, + 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808, + 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808, + 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919, + 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08, + 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908, + 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819, + 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808, + 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, + 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908, + 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808, + 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08, + 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908, + 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19, + 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819, + 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b, + 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808, + 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908, + 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08, + 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08, + 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, + 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819, + 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808, + 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808, + 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19, + 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819, + 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, + 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b, + 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08, + 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808, + 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908, + 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b, + 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819, + 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08, + 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08, + 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, + 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b, + 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b, + 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908, + 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819, + 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808, + 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, + 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b, + 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808, + 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b, + 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b, + 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808, + 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19, + 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908, +}; + +static const __device__ uint64_t iq2xs_grid[512] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, + 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819, + 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819, + 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, + 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b, + 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b, + 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908, + 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908, + 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919, + 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808, + 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908, + 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, + 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, + 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08, + 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808, + 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808, + 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819, + 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908, + 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808, + 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819, + 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819, + 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808, + 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908, + 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19, + 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b, + 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b, + 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919, + 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808, + 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819, + 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819, + 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b, + 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908, + 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808, + 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819, + 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808, + 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, + 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808, + 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808, + 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908, + 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908, + 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808, + 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b, + 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819, + 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, + 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908, + 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808, + 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908, + 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919, + 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08, + 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19, + 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b, + 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b, + 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808, + 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08, + 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b, + 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908, + 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b, + 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, + 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, + 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808, + 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808, + 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08, + 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819, + 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919, + 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808, + 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808, + 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819, + 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819, + 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908, + 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908, + 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b, + 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908, + 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908, + 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908, + 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808, + 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, + 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819, + 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819, + 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808, + 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b, + 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819, + 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819, + 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08, + 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808, + 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19, + 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919, + 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, + 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19, + 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b, + 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808, + 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b, + 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b, + 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, + 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808, + 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819, + 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808, + 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808, + 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08, + 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19, + 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08, + 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919, + 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08, + 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08, + 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908, + 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908, + 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b, + 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908, + 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808, + 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b, + 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808, + 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808, + 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19, + 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08, + 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808, + 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b, + 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808, + 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b, + 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, +}; + +static const __device__ uint64_t iq2s_grid[1024] = { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, + 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b, + 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b, + 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919, + 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b, + 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919, + 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808, + 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908, + 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b, + 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908, + 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08, + 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19, + 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819, + 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919, + 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b, + 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, + 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908, + 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919, + 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908, + 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b, + 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919, + 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b, + 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, + 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908, + 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b, + 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b, + 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08, + 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, + 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819, + 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808, + 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908, + 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b, + 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908, + 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08, + 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819, + 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808, + 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08, + 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819, + 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908, + 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919, + 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b, + 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919, + 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808, + 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819, + 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919, + 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919, + 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808, + 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819, + 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b, + 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908, + 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, + 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, + 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919, + 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b, + 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919, + 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b, + 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819, + 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919, + 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908, + 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b, + 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908, + 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b, + 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908, + 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08, + 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908, + 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819, + 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819, + 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808, + 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08, + 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19, + 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819, + 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808, + 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819, + 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919, + 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808, + 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19, + 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08, + 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b, + 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908, + 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808, + 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819, + 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908, + 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819, + 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808, + 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808, + 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819, + 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908, + 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08, + 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819, + 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b, + 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08, + 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19, + 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819, + 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919, + 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908, + 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808, + 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808, + 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908, + 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808, + 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08, + 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08, + 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908, + 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919, + 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808, + 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819, + 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908, + 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08, + 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819, + 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808, + 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808, + 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819, + 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808, + 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908, + 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b, + 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, + 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, + 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b, + 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808, + 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b, + 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19, + 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819, + 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08, + 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b, + 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908, + 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b, + 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b, + 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919, + 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808, + 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819, + 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908, + 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08, + 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08, + 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819, + 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919, + 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908, + 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b, + 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908, + 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b, + 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908, + 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08, + 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819, + 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808, + 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819, + 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919, + 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808, + 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808, + 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08, + 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819, + 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919, + 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808, + 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819, + 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919, + 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808, + 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b, + 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908, + 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808, + 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908, + 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b, + 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908, + 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b, + 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908, + 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b, + 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908, + 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08, + 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908, + 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b, + 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908, + 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08, + 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819, + 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919, + 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808, + 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19, + 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b, + 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919, + 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808, + 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819, + 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908, + 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919, + 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808, + 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808, + 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b, + 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919, + 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808, + 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b, + 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808, + 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919, + 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b, + 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08, + 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919, + 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808, + 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b, + 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908, + 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808, + 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808, + 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808, + 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908, + 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808, + 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808, + 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b, + 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908, + 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808, + 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808, + 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819, + 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b, + 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808, + 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819, + 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b, + 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908, + 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08, + 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908, + 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919, + 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819, + 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908, + 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808, + 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819, + 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908, + 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919, + 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808, + 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808, + 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808, + 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919, + 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908, + 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908, + 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08, + 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819, + 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b, + 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808, + 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819, + 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908, + 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819, + 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808, + 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808, + 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b, + 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908, + 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808, + 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908, + 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819, + 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819, + 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808, + 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b, + 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b, + 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819, + 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b, + 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b, + 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b, + 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819, + 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19, + 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819, + 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908, + 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808, + 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b, +}; + +static const __device__ uint32_t iq3xxs_grid[256] = { + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, +}; + +static const __device__ uint32_t iq3xs_grid[512] = { + 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14, + 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414, + 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24, + 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c, + 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c, + 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34, + 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c, + 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414, + 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c, + 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404, + 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434, + 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c, + 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404, + 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414, + 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414, + 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404, + 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c, + 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c, + 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404, + 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e, + 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14, + 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c, + 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424, + 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c, + 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c, + 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e, + 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e, + 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e, + 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424, + 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e, + 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424, + 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404, + 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c, + 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e, + 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c, + 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c, + 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c, + 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404, + 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04, + 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c, + 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414, + 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c, + 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c, + 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424, + 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c, + 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c, + 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414, + 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c, + 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e, + 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04, + 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424, + 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14, + 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34, + 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c, + 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434, + 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c, + 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424, + 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24, + 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24, + 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e, + 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c, + 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c, + 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c, + 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, +}; + +static const __device__ uint64_t iq1s_grid[512] = { + 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, + 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01, + 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100, + 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00, + 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101, + 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100, + 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00, + 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff, + 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000, + 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000, + 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001, + 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff, + 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01, + 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001, + 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00, + 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001, + 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100, + 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000, + 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000, + 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000, + 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff, + 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff, + 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01, + 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100, + 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff, + 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000, + 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101, + 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff, + 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff, + 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001, + 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01, + 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101, + 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100, + 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00, + 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001, + 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff, + 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000, + 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000, + 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100, + 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100, + 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01, + 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff, + 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101, + 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000, + 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff, + 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000, + 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff, + 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00, + 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101, + 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000, + 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000, + 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000, + 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100, + 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000, + 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001, + 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff, + 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000, + 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000, + 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000, + 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000, + 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff, + 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000, + 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, + 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01, + 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100, + 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000, + 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00, + 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100, + 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000, + 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, + 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00, + 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff, + 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100, + 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff, + 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000, + 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff, + 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff, + 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00, + 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001, + 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001, + 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01, + 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000, + 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101, + 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00, + 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100, + 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101, + 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101, + 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000, + 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff, + 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff, + 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101, + 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, + 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101, + 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001, + 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff, + 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff, + 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01, + 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff, + 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100, + 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001, + 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00, + 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff, + 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff, + 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000, + 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000, + 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101, + 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001, + 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000, + 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101, + 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000, + 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, + 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000, + 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100, + 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000, + 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000, + 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100, + 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff, + 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff, + 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00, + 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101, + 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000, + 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00, + 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000, + 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff, + 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101, + 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff, + 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00, + 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff, +}; + +static const __device__ uint8_t ksigns_iq2xs[128] = { + 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, + 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159, + 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175, + 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, + 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, + 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95, + 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111, + 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, +}; + +static const __device__ uint64_t ksigns64[128] = { + 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff, + 0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff, + 0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff, + 0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff, + 0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff, + 0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff, + 0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff, + 0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff, + 0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff, + 0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff, + 0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff, + 0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff, + 0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff, + 0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff, + 0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff, + 0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff, + 0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff, + 0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff, + 0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff, + 0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff, + 0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff, + 0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff, + 0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff, + 0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff, + 0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff, + 0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff, + 0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff, + 0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff, + 0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff, + 0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff, + 0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff, + 0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff, +}; + +static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128}; +static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + + +typedef half dfloat; // dequantize float +typedef half2 dfloat2; +typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); +typedef void (*to_fp16_cuda_t)(const void * __restrict__ x, dfloat * __restrict__ y, int k, cudaStream_t stream); +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); +typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc); +typedef void (*load_tiles_cuda_t)( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row); +typedef float (*vec_dot_q_mul_mat_cuda_t)( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k); + +// Utility function + +#if defined(USE_ROCM) + +#ifndef __has_builtin + #define __has_builtin(x) 0 +#endif + +typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); +static __device__ __forceinline__ int __vsubss4(const int a, const int b) { + const int8x4_t va = reinterpret_cast(a); + const int8x4_t vb = reinterpret_cast(b); +#if __has_builtin(__builtin_elementwise_sub_sat) + const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); + return reinterpret_cast(c); +#else + int8x4_t c; + int16_t tmp; +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp = va[i] - vb[i]; + if(tmp > std::numeric_limits::max()) tmp = std::numeric_limits::max(); + if(tmp < std::numeric_limits::min()) tmp = std::numeric_limits::min(); + c[i] = tmp; + } + return reinterpret_cast(c); +#endif // __has_builtin(__builtin_elementwise_sub_sat) +} + +static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { +#if __has_builtin(__builtin_amdgcn_sdot4) + c = __builtin_amdgcn_sdot4(a, b, c, false); +#else + const int8x4_t va = reinterpret_cast(a); + const int8x4_t vb = reinterpret_cast(b); + c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3]; +#endif + return c; +} +#endif // defined(USE_ROCM) diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu new file mode 100644 index 000000000000..9beae1bec403 --- /dev/null +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -0,0 +1,242 @@ +#include +#include + +#include +#include + +#include "ggml-common.h" +#include "vecdotq.cuh" +#include "dequantize.cuh" +#include "mmvq.cuh" +#include "mmq.cuh" + +// Q8 gemv +static __global__ void quantize_q8_1(const half* __restrict__ x, + void* __restrict__ vy, const int kx, + const int kx_padded) { + const int ix = blockDim.x * blockIdx.x + threadIdx.x; + if (ix >= kx_padded) { + return; + } + const int iy = blockDim.y * blockIdx.y + threadIdx.y; + const int i_padded = iy * kx_padded + ix; + + block_q8_1* y = (block_q8_1*)vy; + + const int ib = i_padded / QK8_1; // block index + const int iqs = i_padded % QK8_1; // quant index + + const float xi = ix < kx ? __half2float(x[iy * kx + ix]) : 0.0f; + float amax = fabsf(xi); + float sum = xi; + +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32)); + sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); + } + + const float d = amax / 127; + const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + y[ib].qs[iqs] = q; + + if (iqs > 0) { + return; + } + + y[ib].ds.x = __float2half(d); + y[ib].ds.y = __float2half(sum); +} + +static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx, + const int ky, cudaStream_t stream) { + const int64_t kx_padded = (kx + 512 - 1) / 512 * 512; + const int block_num_x = + (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const dim3 num_blocks(block_num_x, ky, 1); + const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1); + quantize_q8_1<<>>(x, vy, kx, kx_padded); +} + +torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight + int8_t type, int64_t m, int64_t n) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(W)); + auto options = + torch::TensorOptions().dtype(torch::kFloat16).device(W.device()); + at::Tensor DW = torch::empty({m, n}, options); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(type); + to_fp16_cuda((void*)W.data_ptr(), (half*)DW.data_ptr(), m * n, stream); + return DW; +} + +torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight + torch::Tensor X, // input + int8_t type, int64_t row) { + int col = X.sizes()[1]; + const int padded = (col + 512 - 1) / 512 * 512; + const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); + auto options = + torch::TensorOptions().dtype(torch::kFloat16).device(W.device()); + at::Tensor Y = torch::empty({1, row}, options); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); + at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options); + quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, 1, + stream); + switch (type) { + case 2: + mul_mat_vec_q4_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 3: + mul_mat_vec_q4_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 6: + mul_mat_vec_q5_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 7: + mul_mat_vec_q5_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 8: + mul_mat_vec_q8_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 10: + mul_mat_vec_q2_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 11: + mul_mat_vec_q3_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 12: + mul_mat_vec_q4_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 13: + mul_mat_vec_q5_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 14: + mul_mat_vec_q6_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 16: + mul_mat_vec_iq2_xxs_q8_1_cuda((void*)W.data_ptr(), + (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 17: + mul_mat_vec_iq2_xs_q8_1_cuda((void*)W.data_ptr(), + (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 18: + mul_mat_vec_iq3_xxs_q8_1_cuda((void*)W.data_ptr(), + (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 19: + mul_mat_vec_iq1_s_q8_1_cuda((void*)W.data_ptr(), + (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 20: + mul_mat_vec_iq4_nl_q8_1_cuda((void*)W.data_ptr(), + (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 21: + mul_mat_vec_iq3_s_q8_1_cuda((void*)W.data_ptr(), + (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 22: + mul_mat_vec_iq2_s_q8_1_cuda((void*)W.data_ptr(), + (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + case 23: + mul_mat_vec_iq4_xs_q8_1_cuda((void*)W.data_ptr(), + (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; + } + return Y; +} + +torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight + torch::Tensor X, // input + int8_t type, int64_t row) { + int col = X.sizes()[1]; + int padded = (col + 512 - 1) / 512 * 512; + int batch = X.sizes()[0]; + const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); + auto options = + torch::TensorOptions().dtype(torch::kFloat16).device(W.device()); + at::Tensor Y = torch::empty({batch, row}, options); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); + at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options); + quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, + batch, stream); + + switch (type) { + case 2: + ggml_mul_mat_q4_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), + col, row, batch, padded, row, stream); + break; + case 3: + ggml_mul_mat_q4_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), + col, row, batch, padded, row, stream); + break; + case 6: + ggml_mul_mat_q5_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), + col, row, batch, padded, row, stream); + break; + case 7: + ggml_mul_mat_q5_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), + col, row, batch, padded, row, stream); + break; + case 8: + ggml_mul_mat_q8_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), + col, row, batch, padded, row, stream); + break; + case 10: + ggml_mul_mat_q2_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), + col, row, batch, padded, row, stream); + break; + case 11: + ggml_mul_mat_q3_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), + col, row, batch, padded, row, stream); + break; + case 12: + ggml_mul_mat_q4_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), + col, row, batch, padded, row, stream); + break; + case 13: + ggml_mul_mat_q5_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), + col, row, batch, padded, row, stream); + break; + case 14: + ggml_mul_mat_q6_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), + col, row, batch, padded, row, stream); + break; + } + return Y; +} \ No newline at end of file diff --git a/csrc/quantization/gguf/mmq.cuh b/csrc/quantization/gguf/mmq.cuh new file mode 100644 index 000000000000..d13efd596531 --- /dev/null +++ b/csrc/quantization/gguf/mmq.cuh @@ -0,0 +1,600 @@ +// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu +template +static __device__ __forceinline__ void mul_mat_q( + const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_col_y = nrows_y / QK8_1; + const int blocks_per_warp = WARP_SIZE / qi; + + const int & ncols_dst = ncols_y; + + const int row_dst_0 = blockIdx.x*mmq_y; + const int & row_x_0 = row_dst_0; + + const int col_dst_0 = blockIdx.y*mmq_x; + const int & col_y_0 = col_dst_0; + + int * tile_x_ql = nullptr; + half2 * tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + + allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); + + __shared__ int tile_y_qs[mmq_x * WARP_SIZE]; + __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1]; + + float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}}; + + for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { + + load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, + threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x); + +#pragma unroll + for (int ir = 0; ir < qr; ++ir) { + const int kqs = ir*WARP_SIZE + threadIdx.x; + const int kbxd = kqs / QI8_1; + +#pragma unroll + for (int i = 0; i < mmq_x; i += nwarps) { + const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses + const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd]; + const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE; + tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); + } + +#pragma unroll + for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { + const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; + const int kby = threadIdx.x % (WARP_SIZE/QI8_1); + const int col_y_eff = min(col_y_0 + ids, ncols_y-1); + + // if the sum is not needed it's faster to transform the scale to f32 ahead of time + const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds; + half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; + if (need_sum) { + *dsi_dst = *dsi_src; + } else { + float * dfi_dst = (float *) dsi_dst; + *dfi_dst = __low2float(*dsi_src); + } + } + + __syncthreads(); + +// #pragma unroll // unrolling this loop causes too much register pressure + for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) { +#pragma unroll + for (int j = 0; j < mmq_x; j += nwarps) { +#pragma unroll + for (int i = 0; i < mmq_y; i += WARP_SIZE) { + sum[i/WARP_SIZE][j/nwarps] += vec_dot( + tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, + threadIdx.x + i, threadIdx.y + j, k); + } + } + } + __syncthreads(); + } + } + +#pragma unroll + for (int j = 0; j < mmq_x; j += nwarps) { + const int col_dst = col_dst_0 + j + threadIdx.y; + if (col_dst >= ncols_dst) { + return; + } + +#pragma unroll + for (int i = 0; i < mmq_y; i += WARP_SIZE) { + const int row_dst = row_dst_0 + threadIdx.x + i; + if (row_dst >= nrows_dst) { + continue; + } + dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE][j/nwarps]); + } + } +} + +#if defined(USE_ROCM) +#define MMQ_X_Q4_0 64 +#define MMQ_Y_Q4_0 128 +#define NWARPS_Q4_0 8 +#else +#define MMQ_X_Q4_0 4 +#define MMQ_Y_Q4_0 32 +#define NWARPS_Q4_0 4 +#endif + +template static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE*NWARPS_Q4_0, 2) +#endif +mul_mat_q4_0( + const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q4_0; + const int mmq_y = MMQ_Y_Q4_0; + const int nwarps = NWARPS_Q4_0; + + mul_mat_q, + load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q4_0_q8_1_cuda( + const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int mmq_x = MMQ_X_Q4_0; + int mmq_y = MMQ_Y_Q4_0; + int nwarps = NWARPS_Q4_0; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q4_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q4_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +#if defined(USE_ROCM) +#define MMQ_X_Q4_1 64 +#define MMQ_Y_Q4_1 128 +#define NWARPS_Q4_1 8 +#else +#define MMQ_X_Q4_1 4 +#define MMQ_Y_Q4_1 32 +#define NWARPS_Q4_1 4 +#endif + +template static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE*NWARPS_Q4_1, 2) +#endif +mul_mat_q4_1( + const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q4_1; + const int mmq_y = MMQ_Y_Q4_1; + const int nwarps = NWARPS_Q4_1; + + mul_mat_q, + load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q4_1_q8_1_cuda( + const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int mmq_x = MMQ_X_Q4_1; + int mmq_y = MMQ_Y_Q4_1; + int nwarps = NWARPS_Q4_1; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q4_1<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q4_1<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +#if defined(USE_ROCM) +#define MMQ_X_Q5_0 64 +#define MMQ_Y_Q5_0 128 +#define NWARPS_Q5_0 8 +#else +#define MMQ_X_Q5_0 4 +#define MMQ_Y_Q5_0 32 +#define NWARPS_Q5_0 4 +#endif + +template static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE*NWARPS_Q5_0, 2) +#endif +mul_mat_q5_0( + const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q5_0; + const int mmq_y = MMQ_Y_Q5_0; + const int nwarps = NWARPS_Q5_0; + + mul_mat_q, + load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q5_0_q8_1_cuda( + const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int mmq_x = MMQ_X_Q5_0; + const int mmq_y = MMQ_Y_Q5_0; + const int nwarps = NWARPS_Q5_0; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q5_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q5_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +#if defined(USE_ROCM) +#define MMQ_X_Q5_1 64 +#define MMQ_Y_Q5_1 128 +#define NWARPS_Q5_1 8 +#else +#define MMQ_X_Q5_1 4 +#define MMQ_Y_Q5_1 32 +#define NWARPS_Q5_1 4 +#endif + +template static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE*NWARPS_Q5_1, 2) +#endif +mul_mat_q5_1( + const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q5_1; + const int mmq_y = MMQ_Y_Q5_1; + const int nwarps = NWARPS_Q5_1; + + mul_mat_q, + load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q5_1_q8_1_cuda( + const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int mmq_x = MMQ_X_Q5_1; + const int mmq_y = MMQ_Y_Q5_1; + const int nwarps = NWARPS_Q5_1; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q5_1<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q5_1<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +#if defined(USE_ROCM) +#define MMQ_X_Q8_0 64 +#define MMQ_Y_Q8_0 128 +#define NWARPS_Q8_0 8 +#else +#define MMQ_X_Q8_0 4 +#define MMQ_Y_Q8_0 32 +#define NWARPS_Q8_0 4 +#endif + +template static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE*NWARPS_Q8_0, 2) +#endif +mul_mat_q8_0( + const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q8_0; + const int mmq_y = MMQ_Y_Q8_0; + const int nwarps = NWARPS_Q8_0; + + mul_mat_q, + load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q8_0_q8_1_cuda( + const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int mmq_x = MMQ_X_Q8_0; + const int mmq_y = MMQ_Y_Q8_0; + const int nwarps = NWARPS_Q8_0; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q8_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q8_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +#if defined(USE_ROCM) +#define MMQ_X_Q2_K 64 +#define MMQ_Y_Q2_K 128 +#define NWARPS_Q2_K 8 +#else +#define MMQ_X_Q2_K 4 +#define MMQ_Y_Q2_K 32 +#define NWARPS_Q2_K 4 +#endif + +template static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE*NWARPS_Q2_K, 2) +#endif +mul_mat_q2_K( + const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q2_K; + const int mmq_y = MMQ_Y_Q2_K; + const int nwarps = NWARPS_Q2_K; + + mul_mat_q, + load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q2_K_q8_1_cuda( + const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int mmq_x = MMQ_X_Q2_K; + const int mmq_y = MMQ_Y_Q2_K; + const int nwarps = NWARPS_Q2_K; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q2_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q2_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +#if defined(USE_ROCM) +#define MMQ_X_Q3_K 64 +#define MMQ_Y_Q3_K 128 +#define NWARPS_Q3_K 8 +#else +#define MMQ_X_Q3_K 4 +#define MMQ_Y_Q3_K 32 +#define NWARPS_Q3_K 4 +#endif + +template static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE*NWARPS_Q3_K, 2) +#endif +mul_mat_q3_K( + const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + + const int mmq_x = MMQ_X_Q3_K; + const int mmq_y = MMQ_Y_Q3_K; + const int nwarps = NWARPS_Q3_K; + + mul_mat_q, + load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q3_K_q8_1_cuda( + const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int mmq_x = MMQ_X_Q3_K; + const int mmq_y = MMQ_Y_Q3_K; + const int nwarps = NWARPS_Q3_K; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q3_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q3_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +#if defined(USE_ROCM) +#define MMQ_X_Q4_K 64 +#define MMQ_Y_Q4_K 128 +#define NWARPS_Q4_K 8 +#else +#define MMQ_X_Q4_K 4 +#define MMQ_Y_Q4_K 32 +#define NWARPS_Q4_K 4 +#endif + +template static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE*NWARPS_Q4_K, 2) +#endif +mul_mat_q4_K( + const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q4_K; + const int mmq_y = MMQ_Y_Q4_K; + const int nwarps = NWARPS_Q4_K; + + mul_mat_q, + load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q4_K_q8_1_cuda( + const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int mmq_x = MMQ_X_Q4_K; + const int mmq_y = MMQ_Y_Q4_K; + const int nwarps = NWARPS_Q4_K; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q4_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q4_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +#if defined(USE_ROCM) +#define MMQ_X_Q5_K 64 +#define MMQ_Y_Q5_K 128 +#define NWARPS_Q5_K 8 +#else +#define MMQ_X_Q5_K 4 +#define MMQ_Y_Q5_K 32 +#define NWARPS_Q5_K 4 +#endif + +template static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE*NWARPS_Q5_K, 2) +#endif +mul_mat_q5_K( + const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q5_K; + const int mmq_y = MMQ_Y_Q5_K; + const int nwarps = NWARPS_Q5_K; + + mul_mat_q, + load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q5_K_q8_1_cuda( + const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + const int mmq_x = MMQ_X_Q5_K; + const int mmq_y = MMQ_Y_Q5_K; + const int nwarps = NWARPS_Q5_K; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q5_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q5_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +#if defined(USE_ROCM) +#define MMQ_X_Q6_K 64 +#define MMQ_Y_Q6_K 128 +#define NWARPS_Q6_K 8 +#else +#define MMQ_X_Q6_K 4 +#define MMQ_Y_Q6_K 32 +#define NWARPS_Q6_K 4 +#endif + +template static __global__ void +#if defined(USE_ROCM) +__launch_bounds__(WARP_SIZE*NWARPS_Q6_K, 2) +#endif +mul_mat_q6_K( + const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q6_K; + const int mmq_y = MMQ_Y_Q6_K; + const int nwarps = NWARPS_Q6_K; + + mul_mat_q, + load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static void ggml_mul_mat_q6_K_q8_1_cuda( + const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int mmq_x = MMQ_X_Q6_K; + const int mmq_y = MMQ_Y_Q6_K; + const int nwarps = NWARPS_Q6_K; + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q6_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q6_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} diff --git a/csrc/quantization/gguf/mmvq.cuh b/csrc/quantization/gguf/mmvq.cuh new file mode 100644 index 000000000000..ef2ea072392d --- /dev/null +++ b/csrc/quantization/gguf/mmvq.cuh @@ -0,0 +1,182 @@ +// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu +template +static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, const int ncols, const int nrows) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = __float2half(tmp); + } +} + +static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} diff --git a/csrc/quantization/gguf/vecdotq.cuh b/csrc/quantization/gguf/vecdotq.cuh new file mode 100644 index 000000000000..78c749d3f3bc --- /dev/null +++ b/csrc/quantization/gguf/vecdotq.cuh @@ -0,0 +1,1745 @@ +// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/vecdotq.cuh +// and https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu +static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { + const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + return x32; +} + +static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) { + const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + return x32; +} + +static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) { + return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + +static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) { + return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + + +#define VDR_Q4_0_Q8_1_MMVQ 2 +#define VDR_Q4_0_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( + const int * v, const int * u, const float & d4, const half2 & ds8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = __dp4a(vi0, u[2*i+0], sumi); + sumi = __dp4a(vi1, u[2*i+1], sumi); + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 8 from each quant value + return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y); +#endif +} + +#define VDR_Q4_1_Q8_1_MMVQ 2 +#define VDR_Q4_1_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( + const int * v, const int * u, const half2 & dm4, const half2 & ds8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = __dp4a(vi0, u[2*i+0], sumi); + sumi = __dp4a(vi1, u[2*i+1], sumi); + } + + const float2 tmp = __half22float2(__hmul2(dm4, ds8)); + const float d4d8 = tmp.x; + const float m4s8 = tmp.y; + + // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it + return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1)); +#endif +} + +#define VDR_Q5_0_Q8_1_MMVQ 2 +#define VDR_Q5_0_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( + const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 16 from each quant value + return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y); +#endif +} + + +#define VDR_Q5_1_Q8_1_MMVQ 2 +#define VDR_Q5_1_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( + const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + } + + const float2 tmp = __half22float2(__hmul2(dm5, ds8)); + const float d5d8 = tmp.x; + const float m5s8 = tmp.y; + + // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it + return sumi*d5d8 + m5s8 / (QI5_1 / vdr); +#endif +} + +#define VDR_Q8_0_Q8_1_MMVQ 2 +#define VDR_Q8_0_Q8_1_MMQ 8 + +template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( + const int * v, const int * u, const float & d8_0, const float & d8_1) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = __dp4a(v[i], u[i], sumi); + } + return d8_0*d8_1 * sumi; +#endif +} + +template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl( + const int * v, const int * u, const half2 & dm8, const half2 & ds8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = __dp4a(v[i], u[i], sumi); + } + + const float2 tmp = __half22float2(__hmul2(dm8, ds8)); + const float d8d8 = tmp.x; + const float m8s8 = tmp.y; + + // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it + return sumi*d8d8 + m8s8 / (QI8_1 / vdr); +#endif +} + +#define VDR_Q2_K_Q8_1_MMVQ 1 +#define VDR_Q2_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( + const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float * __restrict__ d8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR2_K; ++i) { + const int sc = scales[2*i]; + + const int vi = (v >> (2*i)) & 0x03030303; + + sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values + } + + const float2 dm2f = __half22float2(dm2); + + return dm2f.x*sumf_d - dm2f.y*sumf_m; +#endif +} + +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float & d8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + int sumi_d = 0; + int sumi_m = 0; + +#pragma unroll + for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { + int sumi_d_sc = 0; + + const int sc = scales[i0 / (QI8_1/2)]; + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + +#pragma unroll + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product + sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m + } + + sumi_d += sumi_d_sc * (sc & 0xF); + } + + const float2 dm2f = __half22float2(dm2); + + return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m); +#endif +} + +#define VDR_Q3_K_Q8_1_MMVQ 1 +#define VDR_Q3_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const int & scale_offset, const float & d3, const float * __restrict__ d8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + const int isc = scale_offset + 2*i; + + const int isc_low = isc % (QK_K/32); + const int sc_shift_low = 4 * (isc / (QK_K/32)); + const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF; + + const int isc_high = isc % (QK_K/64); + const int sc_shift_high = 2 * (isc / (QK_K/64)); + const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; + + const int sc = (sc_low | sc_high) - 32; + + const int vil = (vl >> (2*i)) & 0x03030303; + + const int vih = ((vh >> i) << 2) & 0x04040404; + + const int vi = __vsubss4(vil, vih); + + sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d3 * sumf; +#endif +} + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d3, const float & d8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + int sumi = 0; + +#pragma unroll + for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) { + int sumi_sc = 0; + + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product + } + + sumi += sumi_sc * scales[i0 / (QI8_1/2)]; + } + + return d3*d8 * sumi; +#endif +} + +#define VDR_Q4_K_Q8_1_MMVQ 2 +#define VDR_Q4_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K; ++i) { + const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; + const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; + + const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values + } + + const float2 dm4f = __half22float2(dm4); + return dm4f.x*sumf_d - dm4f.y*sumf_m; +#endif +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i]); + + sumf_d += ds8f.x * (sc[i] * sumi_d); + sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; +#endif +} + +#define VDR_Q5_K_Q8_1_MMVQ 2 +#define VDR_Q5_K_Q8_1_MMQ 8 + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( + const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F; + const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F; + + const int vh0i = ((vh[0] >> i) << 4) & 0x10101010; + const int vh1i = ((vh[1] >> i) << 4) & 0x10101010; + + const int v0i = vl0i | vh0i; + const int v1i = vl1i | vh1i; + + const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); + } + + const float2 dm5f = __half22float2(dm5); + return dm5f.x*sumf_d - dm5f.y*sumf_m; +#endif +} + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i]); + + sumf_d += ds8f.x * (sc[i] * sumi_d); + sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; +#endif +} + +#define VDR_Q6_K_Q8_1_MMVQ 1 +#define VDR_Q6_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d, const float * __restrict__ d8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + const int sc = scales[4*i]; + const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + const int vih = ((vh >> (4*i)) << 4) & 0x30303030; + const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 + + sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d*sumf; +#endif +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, + const float & d6, const float * __restrict__ d8) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + float sumf_d = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { + int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale + +#pragma unroll + for (int i = i0; i < i0 + 2; ++i) { + sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product + sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product + + sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product + sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product + } + + sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); + } + + return d6 * sumf_d; +#endif +} + +static __device__ __forceinline__ float vec_dot_q4_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; + + int v[VDR_Q4_0_Q8_1_MMVQ]; + int u[2*VDR_Q4_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); + } + + return vec_dot_q4_0_q8_1_impl(v, u, __half2float(bq4_0->d), bq8_1->ds); +} + +template static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0]; + *x_ql = tile_x_qs; + *x_dm = (half2 *) tile_x_d; +} + +template static __device__ __forceinline__ void load_tiles_q4_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + const int kbx = k / QI4_0; + const int kqsx = k % QI4_0; + + const block_q4_0 * bx0 = (const block_q4_0 *) vx; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + if (need_check) { + i = min(i, i_max); + } + const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx; + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { + int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row; + if (need_check) { + i = min(i, i_max); + } + const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd; + x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = __half2float(bxi->d); + } +} + +static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + (void)x_qh; (void)x_sc; + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const float * x_dmf = (const float *) x_dm; + + int u[2*VDR_Q4_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + } + + return vec_dot_q4_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +static __device__ __forceinline__ float vec_dot_q4_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; + + int v[VDR_Q4_1_Q8_1_MMVQ]; + int u[2*VDR_Q4_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1); + } + + return vec_dot_q4_1_q8_1_impl(v, u, bq4_1->dm, bq8_1->ds); +} + +template static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1]; + *x_ql = tile_x_qs; + *x_dm = tile_x_dm; +} + +template static __device__ __forceinline__ void load_tiles_q4_1( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + const int kbx = k / QI4_1; + const int kqsx = k % QI4_1; + + const block_q4_1 * bx0 = (const block_q4_1 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + if (need_check) { + i = min(i, i_max); + } + const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx; + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { + int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row; + if (need_check) { + i = min(i, i_max); + } + const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd; + x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; + } +} + +static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + + int u[2*VDR_Q4_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + } + + return vec_dot_q4_1_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; + + int vl[VDR_Q5_0_Q8_1_MMVQ]; + int vh[VDR_Q5_0_Q8_1_MMVQ]; + int u[2*VDR_Q5_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i); + vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0); + } + + return vec_dot_q5_0_q8_1_impl(vl, vh, u, __half2float(bq5_0->d), bq8_1->ds); +} + +template static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0]; + + *x_ql = tile_x_ql; + *x_dm = (half2 *) tile_x_d; +} + +template static __device__ __forceinline__ void load_tiles_q5_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + const int kbx = k / QI5_0; + const int kqsx = k % QI5_0; + + const block_q5_0 * bx0 = (const block_q5_0 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx; + const int ql = get_int_from_uint8(bxi->qs, kqsx); + const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0)); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; + const int kbxd = k % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { + int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd; + x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = __half2float(bxi->d); + } +} + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + int u[2*VDR_Q5_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; + } + + return vec_dot_q8_0_q8_1_impl + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +static __device__ __forceinline__ float vec_dot_q5_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; + + int vl[VDR_Q5_1_Q8_1_MMVQ]; + int vh[VDR_Q5_1_Q8_1_MMVQ]; + int u[2*VDR_Q5_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i); + vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1); + } + + return vec_dot_q5_1_q8_1_impl(vl, vh, u, bq5_1->dm, bq8_1->ds); +} + +template static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; +} + +template static __device__ __forceinline__ void load_tiles_q5_1( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + const int kbx = k / QI5_1; + const int kqsx = k % QI5_1; + + const block_q5_1 * bx0 = (const block_q5_1 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx; + + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1)); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { + int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; + } +} + +static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; + + int u[2*VDR_Q5_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; + } + + return vec_dot_q8_1_q8_1_impl + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; + + int v[VDR_Q8_0_Q8_1_MMVQ]; + int u[VDR_Q8_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_int8(bq8_0->qs, iqs + i); + u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + } + + return vec_dot_q8_0_q8_1_impl(v, u, __half2float(bq8_0->d), __low2float(bq8_1->ds)); +} + +template static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0]; + + *x_ql = tile_x_qs; + *x_dm = (half2 *) tile_x_d; +} + +template static __device__ __forceinline__ void load_tiles_q8_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + const int kbx = k / QI8_0; + const int kqsx = k % QI8_0; + float * x_dmf = (float *) x_dm; + + const block_q8_0 * bx0 = (const block_q8_0 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx; + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { + int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd; + x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = __half2float(bxi->d); + } +} + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + return vec_dot_q8_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], + y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q2_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q2_K * bq2_K = (const block_q2_K *) vbq; + + const int bq8_offset = QR2_K * (iqs / QI8_1); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const uint8_t * scales = bq2_K->scales + scale_offset; + + const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs); + int u[QR2_K]; + float d8[QR2_K]; + +#pragma unroll + for (int i = 0; i < QR2_K; ++ i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + i].ds); + } + + return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); +} + +template static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q2_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + const int kbx = k / QI2_K; + const int kqsx = k % QI2_K; + + const block_q2_K * bx0 = (const block_q2_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx; + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) { + int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd; + x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4); + x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); + } +} + +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + const int kbx = k / QI2_K; + const int ky = (k % QI2_K) * QR2_K; + const float * y_df = (const float *) y_ds; + + int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; + + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); + const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); + +#pragma unroll + for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) { + v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; + } + + const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; + + const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE; + return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q3_K * bq3_K = (const block_q3_K *) vbq; + + const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const float d = __half2float(bq3_K->d); + + const int vl = get_int_from_uint8(bq3_K->qs, iqs); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + i].ds); + } + + return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); +} + +template static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K]; + __shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_qh = tile_x_qh; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q3_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + const int kbx = k / QI3_K; + const int kqsx = k % QI3_K; + + const block_q3_K * bx0 = (const block_q3_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + if (need_check) { + i = min(i, i_max); + } + const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx; + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; + const int kbxd = k % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { + int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y; + if (need_check) { + i = min(i, i_max); + } + const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd; + x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = __half2float(bxi->d); + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { + int i = i0 + i_offset * 2 + k / (WARP_SIZE/2); + if (need_check) { + i = min(i, i_max); + } + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2); + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + if (need_check) { + i = min(i, i_max); + } + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4); + + const int ksc = k % (QI3_K/4); + + const int ksc_low = ksc % (QI3_K/8); + const int shift_low = 4 * (ksc / (QI3_K/8)); + const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; + + const int ksc_high = QI3_K/8; + const int shift_high = 2 * ksc; + const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; + + const int sc = __vsubss4(sc_low | sc_high, 0x20202020); + + x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc; + } +} + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kbx = k / QI3_K; + const int ky = (k % QI3_K) * QR3_K; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; + + int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); + const int shift = 2 * ((ky % 32) / 8); + const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; + + const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); + const int vlh = (vh << 2) & 0x04040404; + + v[l] = __vsubss4(vll, vlh); + } + + const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE; + return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + int v[2]; + int u[2*QR4_K]; + float d8[QR4_K]; + + // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6 + const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2)); + + // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 + // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 + // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 + // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 + + const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + v[0] = q4[0]; + v[1] = q4[4]; + + const uint16_t * scales = (const uint16_t *)bq4_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + + for (int i = 0; i < QR4_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2float(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); +} + +template static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q4_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + const int kbx = k / QI4_K; // == 0 if QK_K == 256 + const int kqsx = k % QI4_K; // == k if QK_K == 256 + + const block_q4_K * bx0 = (const block_q4_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx; + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { + int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y; + if (need_check) { + i = min(i, i_max); + } + const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd; + x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = k % (WARP_SIZE/8); + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + } +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + (void)x_qh; + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); + + const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE; + return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8, + x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + + int vl[2]; + int vh[2]; + int u[2*QR5_K]; + float d8[QR5_K]; + + const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2)); + const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4)); + + vl[0] = ql[0]; + vl[1] = ql[4]; + + vh[0] = qh[0] >> bq8_offset; + vh[1] = qh[4] >> bq8_offset; + + const uint16_t * scales = (const uint16_t *)bq5_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2float(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8); +} + +template static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q5_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + const int kbx = k / QI5_K; // == 0 if QK_K == 256 + const int kqsx = k % QI5_K; // == k if QK_K == 256 + + const block_q5_K * bx0 = (const block_q5_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx; + const int ky = QR5_K*kqsx; + + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; + + const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; + x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { + int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd; + x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = k % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + } +} + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); + + const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k; + const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE; + return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, + x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q6_K * bq6_K = (const block_q6_K *) vbq; + + const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); + const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); + const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4)); + + const int vl = get_int_from_uint8(bq6_K->ql, iqs); + const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift; + + const int8_t * scales = bq6_K->scales + scale_offset; + + int u[QR6_K]; + float d8[QR6_K]; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds); + } + + return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, __half2float(bq6_K->d), d8); +} + +template static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q6_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + const int kbx = k / QI6_K; // == 0 if QK_K == 256 + const int kqsx = k % QI6_K; // == k if QK_K == 256 + + const block_q6_K * bx0 = (const block_q6_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx; + const int ky = QR6_K*kqsx; + + const int ql = get_int_from_uint8(bxi->ql, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); + const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; + const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; + + const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0; + const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { + int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = __half2float(bxi->d); + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4; + + x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); + } +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]); + + const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k; + const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE; + return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq; + + const int ib32 = iqs; + const uint16_t * q2 = bq2->qs + 4*ib32; + const uint8_t * aux8 = (const uint8_t *)q2; + const int8_t * q8 = bq8_1[ib32].qs; + uint32_t aux32 = q2[2] | (q2[3] << 16); + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[aux32 & 127]; + for (int j = 0; j < 8; ++j) { + sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + aux32 >>= 7; + } + const float d = __half2float(bq2->d) * (0.5f + aux32) * __half2float(bq8_1[ib32].ds.x) * 0.25f; + return d * sumi; +} + +static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq; + + const int ib32 = iqs; + const uint16_t * q2 = bq2->qs + 4*ib32; + const int8_t * q8 = bq8_1[ib32].qs; + const uint8_t ls1 = bq2->scales[ib32] & 0xf; + const uint8_t ls2 = bq2->scales[ib32] >> 4; + int sumi1 = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + int sumi2 = 0; + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi2 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + const float d = __half2float(bq2->d) * __half2float(bq8_1[ib32].ds.x) * 0.25f; + return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2); +} + +static __device__ __forceinline__ float vec_dot_iq2_s_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + const block_iq2_s * bq2 = (const block_iq2_s *) vbq; + + const int ib32 = iqs; + const int8_t * q8 = bq8_1[ib32].qs; + const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32; + const uint8_t ls1 = bq2->scales[ib32] & 0xf; + const uint8_t ls2 = bq2->scales[ib32] >> 4; + int sumi1 = 0; + for (int l = 0; l < 2; ++l) { + const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300))); + const uint32_t signs0 = __vcmpeq4(((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); + const uint32_t signs1 = __vcmpeq4(((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); + const int grid_l = __vsub4(grid[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid[1] ^ signs1, signs1); + sumi1 = __dp4a(grid_l, *((const int *)q8 + 0), sumi1); + sumi1 = __dp4a(grid_h, *((const int *)q8 + 1), sumi1); + q8 += 8; + } + int sumi2 = 0; + for (int l = 2; l < 4; ++l) { + const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300))); + const uint32_t signs0 = __vcmpeq4(((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); + const uint32_t signs1 = __vcmpeq4(((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); + const int grid_l = __vsub4(grid[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid[1] ^ signs1, signs1); + sumi2 = __dp4a(grid_l, *((const int *)q8 + 0), sumi2); + sumi2 = __dp4a(grid_h, *((const int *)q8 + 1), sumi2); + q8 += 8; + } + const float d = __half2float(bq2->d) * __low2float(bq8_1[ib32].ds) * 0.25f; + return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2); +#endif +} + +static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq; + + const int ib32 = iqs; + const uint8_t * q3 = bq2->qs + 8*ib32; + const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32; + const int8_t * q8 = bq8_1[ib32].qs; + uint32_t aux32 = gas[0] | (gas[1] << 16); + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0]; + const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1]; + const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127)); + const int grid_l = __vsub4(grid1[0] ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid2[0] ^ signs[1], signs[1]); + sumi = __dp4a(grid_l, *((int *)q8+0), sumi); + sumi = __dp4a(grid_h, *((int *)q8+1), sumi); + q8 += 8; + aux32 >>= 7; + } + const float d = __half2float(bq2->d) * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.5f; + return d * sumi; +#endif +} + +static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + const block_iq3_s * bq2 = (const block_iq3_s *) vbq; + + const int ib32 = iqs; + const uint8_t * qs = bq2->qs + 8*ib32; + const int8_t * q8 = bq8_1[ib32].qs; + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); + const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); + uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); + uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); + const int grid_l = __vsub4(grid1[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid2[0] ^ signs1, signs1); + sumi = __dp4a(grid_l, *((int *)q8+0), sumi); + sumi = __dp4a(grid_h, *((int *)q8+1), sumi); + q8 += 8; + } + const float d = __half2float(bq2->d) * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f; + return d * sumi; +#endif +} + +static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + const block_iq1_s * bq1 = (const block_iq1_s *) vbq; + + const int ib32 = iqs; + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + const uint8_t h1 = bq1->scales[2*ib32+0]; + const uint8_t h2 = bq1->scales[2*ib32+1]; + const int * q8 = (const int *)bq8_1[ib32].qs; + const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5))); + const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1))); + const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5))); + const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1))); + for (int j = 0; j < 2; ++j) { + sumi1 = __dp4a(q8[j+0], grid1[j], sumi1); + sumi2 = __dp4a(q8[j+2], grid2[j], sumi2); + sumi3 = __dp4a(q8[j+4], grid3[j], sumi3); + sumi4 = __dp4a(q8[j+6], grid4[j], sumi4); + } + const float d = __half2float(bq1->d) * __low2float(bq8_1[ib32].ds); + return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) + + sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1)); +#endif +} + +static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values, + int & val1, int & val2) { + + uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32; + aux32 = q4 & 0x0f0f0f0f; + uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8); + uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8); + val1 = v1 | (v2 << 16); + aux32 = (q4 >> 4) & 0x0f0f0f0f; + v1 = values[q8[0]] | (values[q8[1]] << 8); + v2 = values[q8[2]] | (values[q8[3]] << 8); + val2 = v1 | (v2 << 16); +} + +static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + + const block_iq4_nl * bq = (const block_iq4_nl *) vbq; + + const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs; + const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs; + + const uint8_t * values = (const uint8_t *)kvalues_iq4nl; + + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) { + const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16); + get_int_from_table_16(aux, values, v1, v2); + sumi1 = __dp4a(v1, q8[l+0], sumi1); + sumi2 = __dp4a(v2, q8[l+4], sumi2); + } + const float d = __half2float(bq->d) * __low2float(bq8_1->ds); + return d * (sumi1 + sumi2); +#endif +} + + +static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq; + const uint8_t * values = (const uint8_t *)kvalues_iq4nl; + + // iqs is 0...7 + const int ib32 = iqs; + const int32_t * q8 = (const int *)bq8_1[ib32].qs; + const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32; + const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4); + const float d = __half2float(bq4->d) * (ls - 32) * __low2float(bq8_1[ib32].ds); + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 4; ++j) { + get_int_from_table_16(q4[j], values, v1, v2); + sumi1 = __dp4a(v1, q8[j+0], sumi1); + sumi2 = __dp4a(v2, q8[j+4], sumi2); + } + return d * (sumi1 + sumi2); +#endif +} \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7c0d617fc8b3..b35fd471ed4f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -145,6 +145,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("awq_marlin_repack", &awq_marlin_repack); ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack); + // Dequantization for GGML. + ops.def("ggml_dequantize", &ggml_dequantize); + ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize); + + // mmvq kernel for GGML. + ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8); + ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8); + + // mmq kernel for GGML. + ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8); + ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8); + // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. ops.def("fp8_marlin_gemm", &fp8_marlin_gemm); ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm); diff --git a/examples/gguf_inference.py b/examples/gguf_inference.py new file mode 100644 index 000000000000..09a5fcc22e55 --- /dev/null +++ b/examples/gguf_inference.py @@ -0,0 +1,38 @@ +from huggingface_hub import hf_hub_download + +from vllm import LLM, SamplingParams + + +def run_gguf_inference(model_path): + PROMPT_TEMPLATE = "<|system|>\n{system_message}\n<|user|>\n{prompt}\n<|assistant|>\n" # noqa: E501 + system_message = "You are a friendly chatbot who always responds in the style of a pirate." # noqa: E501 + # Sample prompts. + prompts = [ + "How many helicopters can a human eat in one sitting?", + "What's the future of AI?", + ] + prompts = [ + PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt) + for prompt in prompts + ] + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0, max_tokens=128) + + # Create an LLM. + llm = LLM(model=model_path, + tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + gpu_memory_utilization=0.95) + + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == "__main__": + repo_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" + filename = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" + model = hf_hub_download(repo_id, filename=filename) + run_gguf_inference(model) diff --git a/format.sh b/format.sh index baaebc811d40..a8fd95a1ea44 100755 --- a/format.sh +++ b/format.sh @@ -242,6 +242,11 @@ echo 'vLLM isort: Done' # NOTE: Keep up to date with .github/workflows/clang-format.yml CLANG_FORMAT_EXCLUDES=( 'csrc/moe/topk_softmax_kernels.cu' + 'csrc/quantization/gguf/ggml-common.h' + 'csrc/quantization/gguf/dequantize.cuh' + 'csrc/quantization/gguf/vecdotq.cuh' + 'csrc/quantization/gguf/mmq.cuh' + 'csrc/quantization/gguf/mmvq.cuh' ) # Format specified files with clang-format diff --git a/requirements-common.txt b/requirements-common.txt index 3b8d473c1fe7..d8c95bf77240 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -22,3 +22,4 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 pyzmq +gguf == 0.9.1 diff --git a/tests/models/test_gguf.py b/tests/models/test_gguf.py new file mode 100644 index 000000000000..5971179f0121 --- /dev/null +++ b/tests/models/test_gguf.py @@ -0,0 +1,76 @@ +""" +Tests gguf models against unquantized models generations +Note: To pass the test, quantization higher than Q4 should be used +""" + +import os + +import pytest +from huggingface_hub import hf_hub_download + +from tests.quantization.utils import is_quant_method_supported + +from .utils import check_logprobs_close + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +MAX_MODEL_LEN = 1024 + +# FIXME: Move this to confest +MODELS = [ + ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", + hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", + filename="tinyllama-1.1b-chat-v1.0.Q4_0.gguf")), + ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", + hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF", + filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")), + ("Qwen/Qwen2-1.5B-Instruct", + hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF", + filename="qwen2-1_5b-instruct-q4_k_m.gguf")), + ("Qwen/Qwen2-1.5B-Instruct", + hf_hub_download("legraphista/Qwen2-1.5B-Instruct-IMat-GGUF", + filename="Qwen2-1.5B-Instruct.IQ4_XS.gguf")), +] + + +@pytest.mark.skipif(not is_quant_method_supported("gguf"), + reason="gguf is not supported on this GPU type.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + original_model, gguf_model = model + + # Run unquantized model. + with vllm_runner(model_name=original_model, + dtype=dtype, + max_model_len=MAX_MODEL_LEN, + enforce_eager=True, + tensor_parallel_size=1) as original_model: + + original_outputs = original_model.generate_greedy_logprobs( + example_prompts[:-1], max_tokens, num_logprobs) + + # Run gguf model. + with vllm_runner(model_name=gguf_model, + dtype=dtype, + max_model_len=MAX_MODEL_LEN, + enforce_eager=True, + tensor_parallel_size=1) as gguf_model: + gguf_outputs = gguf_model.generate_greedy_logprobs( + example_prompts[:-1], max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=original_outputs, + outputs_1_lst=gguf_outputs, + name_0="original", + name_1="gguf", + ) diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py index dd9a016807df..ad526a406510 100644 --- a/tests/quantization/test_lm_head.py +++ b/tests/quantization/test_lm_head.py @@ -7,11 +7,12 @@ import pytest import torch -from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinLinearMethod) from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod +from vllm.model_executor.layers.vocab_parallel_embedding import ( + UnquantizedEmbeddingMethod) PROMPT = "On the surface of Mars, we found" @@ -37,7 +38,8 @@ def test_lm_head( lm_head_layer.linear_method, (GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod)) else: - assert isinstance(lm_head_layer.linear_method, UnquantizedLinearMethod) + assert isinstance(lm_head_layer.linear_method, + UnquantizedEmbeddingMethod) print( vllm_model.generate_greedy(prompts=["Hello my name is"], diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ad7e5bd19933..e3e2c5536a2b 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -404,6 +404,38 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, workspace, size_m, size_n, size_k) +# gguf +def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int): + return torch.ops._C.ggml_dequantize(W, quant_type, m, n) + + +def ggml_mul_mat_vec( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, +): + return torch.ops._C.ggml_mul_mat_vec(W, X, quant_type, row) + + +def ggml_mul_mat_vec_a8( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, +): + return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row) + + +def ggml_mul_mat_a8( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, +): + return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) + + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, diff --git a/vllm/config.py b/vllm/config.py index bec0b63197ef..4b968f549d90 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -582,6 +582,7 @@ class LoadFormat(str, enum.Enum): DUMMY = "dummy" TENSORIZER = "tensorizer" SHARDED_STATE = "sharded_state" + GGUF = "gguf" BITSANDBYTES = "bitsandbytes" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index acc0551af015..935a509cdb7c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -672,6 +672,9 @@ def from_cli_args(cls, args: argparse.Namespace): return engine_args def create_engine_config(self, ) -> EngineConfig: + # gguf file needs a specific model loader and doesn't use hf_repo + if self.model.endswith(".gguf"): + self.quantization = self.load_format = "gguf" # bitsandbytes quantization needs a specific model loader # so we make sure the quant method and the load format are consistent diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b6e280ae6504..cd53c2b91621 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F -from torch.nn.parameter import Parameter +from torch.nn.parameter import Parameter, UninitializedParameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -311,6 +311,17 @@ def __init__(self, def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + param_data = param.data if output_dim is not None: shard_size = param_data.shape[output_dim] @@ -398,6 +409,27 @@ def weight_loader(self, loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None): + # Special case for GGUF + # initialize GGUF param after we know the quantize type + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.data[loaded_shard_id].copy_(loaded_weight) + param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + return + + if is_gguf_weight and isinstance(param, UninitializedParameter): + from gguf.constants import GGML_QUANT_SIZES + + ori_shape = param.tensor_shape + weight_types = self.qweight_type.shard_weight_type.values() + row_size = [] + for weight_type in weight_types: + block_size, type_size = GGML_QUANT_SIZES[weight_type] + row_size.append(ori_shape[1] // block_size * type_size) + q_shape = (ori_shape[0], max(row_size)) + param.materialize(q_shape, dtype=loaded_weight.dtype) + param_data = param.data output_dim = getattr(param, "output_dim", None) # Special case for AQLM codebooks. @@ -460,6 +492,13 @@ def weight_loader(self, shard_offset = loaded_weight.shape[output_dim] * \ loaded_shard_id + if is_gguf_weight: + shard_size = loaded_weight.shape[output_dim] + shard_offset = loaded_weight.shape[output_dim] * \ + loaded_shard_id + param.shard_id.append(loaded_shard_id) + param.shard_size[loaded_shard_id] = loaded_weight.shape + param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = tp_rank * shard_size @@ -563,6 +602,29 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[str] = None): + + # Special case for GGUF + # initialize GGUF param after we know the quantize type + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type and loaded_shard_id is not None: + idx_map = {"q": 0, "k": 1, "v": 2} + param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) + param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + return + + if is_gguf_weight and isinstance(param, UninitializedParameter): + from gguf.constants import GGML_QUANT_SIZES + + ori_shape = param.tensor_shape + weight_types = self.qweight_type.shard_weight_type.values() + row_size = [] + for weight_type in weight_types: + block_size, type_size = GGML_QUANT_SIZES[weight_type] + row_size.append(ori_shape[1] // block_size * type_size) + q_shape = (ori_shape[0], max(row_size)) + param.materialize(q_shape, dtype=loaded_weight.dtype) + param_data = param.data output_dim = getattr(param, "output_dim", None) # Special case for AQLM codebooks. @@ -650,6 +712,13 @@ def weight_loader(self, shard_size, shard_offset = adjust_bitsandbytes_shard( param, orig_qkv_offsets, loaded_shard_id) + if is_gguf_weight: + param.shard_id.append(loaded_shard_id) + param.shard_size[loaded_shard_id] = loaded_weight.shape + input_dim = getattr(param, "input_dim", None) + input_size = loaded_weight.shape[input_dim] + param_data = param_data.narrow(input_dim, 0, input_size) + param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": @@ -755,6 +824,17 @@ def __init__(self, def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + param_data = param.data if input_dim is not None: shard_size = param_data.shape[input_dim] diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 13da6376ec29..db2a24556169 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -13,6 +13,7 @@ DeepSpeedFPConfig) from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.gguf import GGUFConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) @@ -31,6 +32,7 @@ # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) "marlin": MarlinConfig, + "gguf": GGUFConfig, "gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin": GPTQMarlinConfig, "awq_marlin": AWQMarlinConfig, diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index f5ff27b9f14b..75fa8249cd3c 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -1,5 +1,6 @@ +import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type import torch from torch import nn @@ -23,6 +24,14 @@ def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: Expects create_weights to have been called before on the layer.""" raise NotImplementedError + # Not required functions + def embedding(self, layer: torch.nn.Module, *args, + **kwargs) -> torch.Tensor: + """Gather embeddings in the layer based on indices in the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + def process_weights_after_loading(self, layer: nn.Module) -> None: """Process the weight after loading. @@ -31,6 +40,21 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: return +def method_has_implemented_embedding( + method_class: Type[QuantizeMethodBase]) -> bool: + """ + Not all quant methods have embedding implemented, so we need to check that + it exists for our given method. We check this by making sure the function + has been changed from the base implementation. + """ + base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", + None) + class_embedding = inspect.getattr_static(method_class, "embedding", None) + + return (class_embedding is not None + and class_embedding is not base_embedding) + + class QuantizationConfig(ABC): """Base class for quantization configs.""" diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py new file mode 100644 index 000000000000..a4e0a4d50960 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -0,0 +1,165 @@ +from typing import Any, Dict, List, Optional + +import gguf +import torch +from torch.nn.parameter import Parameter, UninitializedParameter + +from vllm import _custom_ops as ops +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.utils import set_weight_attrs + + +class GGUFConfig(QuantizationConfig): + """Config class for GGUF.""" + + def __init__(self, ) -> None: + pass + + def __repr__(self) -> str: + return ("GGUFConfig()") + + def get_name(self) -> str: + return "gguf" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] # no extra configs. + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig": + if get_tensor_model_parallel_world_size() > 1: + raise ValueError( + "GGUF quantization hasn't supported tensor parallelism yet.") + return cls() + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return GGUFLinearMethod(self) + elif isinstance(layer, VocabParallelEmbedding): + return GGUFEmbeddingMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, + qweight_type: int) -> torch.Tensor: + # use dequantize mulmat for IQmatrix, mmq for k-quants + if qweight_type >= 16: + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) + weight = ops.ggml_dequantize(qweight, qweight_type, *shape) + y = x @ weight.T + else: + y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) + return y + + +class GGUFLinearMethod(LinearMethodBase): + """Linear method for GGUF. + + Args: + quant_config: The GGUF quantization config. + """ + + def __init__(self, quant_config: GGUFConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + output_size_per_partition = sum(output_partition_sizes) + + tensor_shape = (output_size_per_partition, input_size_per_partition) + qweight = UninitializedParameter(requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 1, + "output_dim": 0, + "tensor_shape": tensor_shape, + "is_gguf_weight": True, + "shard_size": {}, + "shard_id": [], + }) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("qweight", qweight) + + qweight_type = Parameter(torch.empty(len(output_partition_sizes), + dtype=torch.uint8), + requires_grad=False) + set_weight_attrs( + qweight_type, { + "is_gguf_weight_type": True, + "weight_type": 0, + "shard_weight_type": {}, + "ignore_warning": True + }) + set_weight_attrs(qweight_type, extra_weight_attrs) + layer.register_parameter("qweight_type", qweight_type) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + shard_size = getattr(layer.qweight, "shard_size", None) + shard_id = getattr(layer.qweight, "shard_id", None) + + if shard_id and shard_size: + result = [] + offset = 0 + # dequantize shard weights respectively + shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id + for id in shard_id: + shard_weight = layer.qweight[ + offset:offset + + shard_size[id][0], :shard_size[id][1]].contiguous() + qweight_type = layer.qweight_type.shard_weight_type[id] + result.append(_fuse_mul_mat(x, shard_weight, qweight_type)) + offset += shard_size[id][0] + out = torch.cat(result, axis=1) + else: + qweight = layer.qweight + qweight_type = layer.qweight_type.weight_type + out = _fuse_mul_mat(x, qweight, qweight_type) + if bias is not None: + out.add_(bias) + return out + + +class GGUFEmbeddingMethod(GGUFLinearMethod): + """Embedding method for GGUF. + + Args: + quant_config: The GGUF quantization config. + """ + + def embedding(self, layer: torch.nn.Module, + x: torch.Tensor) -> torch.Tensor: + qweight = layer.qweight + qweight_type = layer.qweight_type.weight_type + + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + hidden_size = qweight.shape[1] // type_size * block_size + if qweight_type < 2: + return torch.embedding(qweight, x) + x_flat = x.flatten() + quant = torch.index_select(qweight, dim=0, index=x_flat) + dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, + x_flat.shape[0]) + return dequant.view(*x.shape, hidden_size) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 74aeb964274b..3ba15573c217 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -3,19 +3,46 @@ import torch import torch.nn.functional as F -from torch.nn.parameter import Parameter +from torch.nn.parameter import Parameter, UninitializedParameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) from vllm.model_executor.utils import set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 +class UnquantizedEmbeddingMethod(QuantizeMethodBase): + """Unquantized method for embeddings.""" + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """Create weights for embedding layer.""" + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return F.linear(x, layer.weight, bias) + + def embedding(self, layer: torch.nn.Module, + input_: torch.Tensor) -> torch.Tensor: + return F.embedding(input_, layer.weight) + + def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: """Pad the vocab size to the given value.""" @@ -199,7 +226,19 @@ def __init__(self, if quant_config is not None: linear_method = quant_config.get_quant_method(self, prefix=prefix) if linear_method is None: - linear_method = UnquantizedLinearMethod() + linear_method = UnquantizedEmbeddingMethod() + + # If we are making an embedding layer, then our quantization linear + # method must implement the embedding operation. If we are another + # layer type like ParallelLMHead, this is not important. + is_embedding_layer = type(self.__class__) is VocabParallelEmbedding + linear_method_implements_embedding = method_has_implemented_embedding( + type(linear_method)) + if is_embedding_layer and not linear_method_implements_embedding: + raise NotImplementedError( + f"The class {type(linear_method).__name__} must implement " + "the 'embedding' method, see UnquantizedEmbeddingMethod.") + self.linear_method: QuantizeMethodBase = linear_method if params_dtype is None: @@ -306,6 +345,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): output_dim = getattr(param, "output_dim", None) packed_dim = getattr(param, "packed_dim", None) + # If the parameter is a gguf weight, then load it directly. + if getattr(param, "is_gguf_weight_type", None): + param.data.copy_(loaded_weight) + param.weight_type = loaded_weight.item() + return + elif isinstance(param, UninitializedParameter): + param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + # If parameter does not have output dim, then it should # be copied onto all gpus (e.g. g_idx for act_order gptq). if output_dim is None: @@ -344,7 +391,8 @@ def forward(self, input_): else: masked_input = input_ # Get the embeddings. - output_parallel = F.embedding(masked_input.long(), self.weight) + output_parallel = self.linear_method.embedding(self, + masked_input.long()) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) @@ -389,6 +437,7 @@ def __init__(self, super().__init__(num_embeddings, embedding_dim, params_dtype, org_num_embeddings, padding_size, quant_config, prefix) + if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index f72515e01482..a5c5cb87bc46 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -10,11 +10,13 @@ from contextlib import contextmanager from typing import Any, Dict, Generator, List, Optional, Tuple, Type +import gguf import huggingface_hub import numpy as np import torch from huggingface_hub import HfApi, hf_hub_download from torch import nn +from transformers import AutoModelForCausalLM from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, MultiModalConfig, @@ -31,8 +33,9 @@ from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, - get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, - pt_weights_iterator, safetensors_weights_iterator) + get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator, + initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, + safetensors_weights_iterator) from vllm.model_executor.models.interfaces import (has_inner_state, supports_lora, supports_vision) @@ -948,6 +951,90 @@ def load_model(self, *, model_config: ModelConfig, return model.eval() +class GGUFModelLoader(BaseModelLoader): + """ + Model loader that can load GGUF files. This is useful for loading models + that are quantized with GGUF and saved in the GGUF format. This loader + supports loading both full models and sharded models. + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _prepare_weights(self, model_name_or_path: str): + if os.path.isfile(model_name_or_path): + return model_name_or_path + else: + raise ValueError(f"{model_name_or_path} is not a file.") + + def _get_gguf_weights_map(self, model_config: ModelConfig): + """ + GGUF uses this naming convention for their tensors from HF checkpoint: + `blk.N.BB.weight` and `blk.N.BB.bias` + where N signifies the block number of a layer, and BB signifies the + attention/mlp layer components. + See "Standardized tensor names" in + https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. + """ + config = model_config.hf_config + model_type = config.model_type + # hack: ggufs have a different name than transformers + if model_type == "cohere": + model_type = "command-r" + arch = None + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + arch = key + break + if arch is None: + raise RuntimeError(f"Unknown gguf model_type: {model_type}") + num_layers = config.num_hidden_layers + name_map = gguf.get_tensor_name_map(arch, num_layers) + with torch.device("meta"): + dummy_model = AutoModelForCausalLM.from_config(config) + state_dict = dummy_model.state_dict() + + gguf_to_hf_name_map = {} + for hf_name in state_dict: + name, suffix = hf_name.rsplit(".", 1) + gguf_name = name_map.get_name(name) + gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name + return gguf_to_hf_name_map + + def _get_weights_iterator( + self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str] + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + return gguf_quant_weights_iterator(model_name_or_path, + gguf_to_hf_name_map) + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + + local_model_path = self._prepare_weights(model_config.model) + gguf_weights_map = self._get_gguf_weights_map(model_config) + # we can only know if tie word embeddings after mapping weights + if "lm_head.weight" in get_gguf_extra_tensor_names( + local_model_path, gguf_weights_map): + model_config.hf_config.update({"tie_word_embeddings": True}) + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, multimodal_config, + cache_config) + model.load_weights( + self._get_weights_iterator(local_model_path, gguf_weights_map)) + return model + + def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" @@ -966,4 +1053,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.BITSANDBYTES: return BitsAndBytesModelLoader(load_config) + if load_config.load_format == LoadFormat.GGUF: + return GGUFModelLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 5e142e8cb8b8..250561654b14 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -6,9 +6,10 @@ import os import tempfile from collections import defaultdict -from typing import Any, Generator, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union import filelock +import gguf import huggingface_hub.constants import numpy as np import torch @@ -121,6 +122,11 @@ def get_quant_config(model_config: ModelConfig, load_config: LoadConfig) -> QuantizationConfig: quant_cls = get_quantization_config(model_config.quantization) + + # GGUF doesn't have config file + if model_config.quantization == "gguf": + return quant_cls.from_config({}) + # Read the quantization config from the HF model config, if available. hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) @@ -409,6 +415,45 @@ def pt_weights_iterator( torch.cuda.empty_cache() +def get_gguf_extra_tensor_names( + gguf_file: str, gguf_to_hf_name_map: Dict[str, str]) -> List[str]: + reader = gguf.GGUFReader(gguf_file) + expected_gguf_keys = set(gguf_to_hf_name_map.keys()) + exact_gguf_keys = set([tensor.name for tensor in reader.tensors]) + extra_keys = expected_gguf_keys - exact_gguf_keys + return [gguf_to_hf_name_map[key] for key in extra_keys] + + +def gguf_quant_weights_iterator( + gguf_file: str, gguf_to_hf_name_map: Dict[str, str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """ + Iterate over the quant weights in the model gguf files and convert + them to torch tensors + """ + + reader = gguf.GGUFReader(gguf_file) + + for tensor in reader.tensors: + weight_type = tensor.tensor_type + name = gguf_to_hf_name_map[tensor.name] + + if weight_type.name != "F32": + weight_type_name = name.replace("weight", "qweight_type") + weight_type = torch.tensor(weight_type) + yield weight_type_name, weight_type + + for tensor in reader.tensors: + weight = tensor.data + weight_type = tensor.tensor_type + name = gguf_to_hf_name_map[tensor.name] + + if weight_type.name != "F32": + name = name.replace("weight", "qweight") + param = torch.tensor(weight) + yield name, param + + def kv_cache_scales_loader( filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int, model_type: Optional[str]) -> Iterable[Tuple[int, float]]: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 048c292c79c8..023ae2a18d41 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -140,6 +140,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) + self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, output_size=hidden_size, @@ -148,12 +149,17 @@ def __init__( prefix=f"{prefix}.o_proj", ) + is_neox_style = True + if quant_config is not None and quant_config.get_name() == "gguf": + is_neox_style = False + self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, + is_neox_style=is_neox_style, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -279,6 +285,7 @@ def __init__( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, + quant_config=quant_config, ) else: self.embed_tokens = PPMissingLayer() diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 99fdd993943b..a66a1eee7c16 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -238,6 +238,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 3d13631b9b2b..5f04b39ef524 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,7 +1,10 @@ import contextlib -from typing import Dict, Optional, Type +from pathlib import Path +from typing import Dict, Optional, Type, Union from transformers import GenerationConfig, PretrainedConfig +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger @@ -36,18 +39,29 @@ AutoConfig.register(name, cls) -def get_config(model: str, - trust_remote_code: bool, - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[dict] = None, - rope_theta: Optional[float] = None) -> PretrainedConfig: +def get_config( + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + **kwargs, +) -> PretrainedConfig: + + # Separate model folder from file path for GGUF models + is_gguf = Path(model).is_file() and Path(model).suffix == ".gguf" + if is_gguf: + kwargs["gguf_file"] = Path(model).name + model = Path(model).parent + try: config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision, - code_revision=code_revision) + code_revision=code_revision, + **kwargs) except ValueError as e: if (not trust_remote_code and "requires you to execute the configuration file" in str(e)): @@ -64,12 +78,22 @@ def get_config(model: str, config = config_class.from_pretrained(model, revision=revision, code_revision=code_revision) + + # Special architecture mapping check for GGUF models + if is_gguf: + if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + raise RuntimeError( + f"Can't get gguf config for {config.model_type}.") + model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] + config.update({"architectures": [model_type]}) + for key, value in [("rope_scaling", rope_scaling), ("rope_theta", rope_theta)]: if value is not None: logger.info("Updating %s from %r to %r", key, getattr(config, key, None), value) config.update({key: value}) + return config diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index c515f46ecc29..bf26d889d138 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from typing import Optional, Union import huggingface_hub @@ -55,7 +56,7 @@ def __len__(self): def get_tokenizer( - tokenizer_name: str, + tokenizer_name: Union[str, Path], *args, tokenizer_mode: str = "auto", trust_remote_code: bool = False, @@ -91,6 +92,13 @@ def get_tokenizer( if "truncation_side" not in kwargs: kwargs["truncation_side"] = "left" + # Separate model folder from file path for GGUF models + is_gguf = Path(tokenizer_name).is_file() and Path( + tokenizer_name).suffix == ".gguf" + if is_gguf: + kwargs["gguf_file"] = Path(tokenizer_name).name + tokenizer_name = Path(tokenizer_name).parent + try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, From e3c664bfcb14a41e43ddb6078ed1464ae9b7852f Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 5 Aug 2024 17:39:22 -0700 Subject: [PATCH 0076/3246] [Build] Add initial conditional testing spec (#6841) --- .buildkite/test-pipeline.yaml | 390 ++++++++++++++++++++-------------- 1 file changed, 234 insertions(+), 156 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 93b3e3fe9166..6f38cd313f11 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -5,11 +5,47 @@ # https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2 # to generate the final pipeline yaml file. +# Documentation +# label(str): the name of the test. emoji allowed. +# fast_check(bool): whether to run this on each commit on fastcheck pipeline. +# fast_check_only(bool): run this test on fastcheck pipeline only +# command(str): the single command to run for tests. incompatible with commands. +# commands(list): the list of commands to run for test. incompatbile with command. +# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd] +# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100 +# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4. +# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host, +# in this case, commands must be specified. the first command runs on first host, the second +# command runs on the second host. +# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests +# source_file_dependencies(list): the list of prefix to opt-in the test for, if empty, the test will always run. + +# When adding a test +# - If the test belong to an existing group, add it there +# - If the test is short, add to any existing step +# - If the test takes more than 10min, then it is okay to create a new step. +# Note that all steps execute in parallel. steps: -- label: Async Engine, Inputs, Utils, Worker Test +##### fast check tests ##### + +- label: Documentation Build # 2min + working_dir: "/vllm-workspace/test_docs/docs" fast_check: true - fast_check_only: true + no_gpu: True + commands: + - pip install -r requirements-docs.txt + - SPHINXOPTS=\"-W\" make html + +- label: Async Engine, Inputs, Utils, Worker Test # 15min + fast_check: true + source_file_dependencies: + - vllm/ + - tests/async_engine + - tests/test_inputs + - tests/multimodal + - tests/test_utils + - tests/worker commands: - pytest -v -s async_engine # Async Engine - pytest -v -s test_inputs.py @@ -17,31 +53,12 @@ steps: - pytest -v -s test_utils.py # Utils - pytest -v -s worker # Worker -- label: Metrics, Tracing Test - fast_check: true - fast_check_only: true - commands: - - pytest -v -s metrics # Metrics - - "pip install \ - opentelemetry-sdk \ - opentelemetry-api \ - opentelemetry-exporter-otlp \ - opentelemetry-semantic-conventions-ai" # Tracing - - pytest -v -s tracing - -- label: Regression Test - mirror_hardwares: [amd] - fast_check: true - command: pytest -v -s test_regression.py - working_dir: "/vllm-workspace/tests" # optional - -- label: AsyncEngine Test - #mirror_hardwares: [amd] - command: pytest -v -s async_engine - -- label: Basic Correctness Test +- label: Basic Correctness Test # 30min mirror_hardwares: [amd] fast_check: true + source_file_dependencies: + - vllm/ + - tests/basic_correctness commands: # This flashinfer installation will fail on AMD ROCm, so it is set as optional. - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl || true @@ -50,215 +67,264 @@ steps: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - -- label: Core Test + +- label: Core Test # 10min mirror_hardwares: [amd] fast_check: true + source_file_dependencies: + - vllm/core + - vllm/distributed + - tests/core commands: - pytest -v -s core -- label: Distributed Comm Ops Test - #mirror_hardwares: [amd] - working_dir: "/vllm-workspace/tests" - num_gpus: 2 - commands: - - pytest -v -s distributed/test_comm_ops.py - - pytest -v -s distributed/test_shm_broadcast.py - -- label: 2 Node Tests (4 GPUs in total) - working_dir: "/vllm-workspace/tests" - num_gpus: 2 - num_nodes: 2 - commands: - - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py - - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py - -- label: Distributed Tests (2 GPUs) +- label: Entrypoints Test # 20min + fast_check: true mirror_hardwares: [amd] - working_dir: "/vllm-workspace/tests" - num_gpus: 2 + source_file_dependencies: + - vllm/entrypoints + - tests/entrypoints commands: - - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - - TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py - - pytest -v -s distributed/test_chunked_prefill_distributed.py - - pytest -v -s distributed/test_multimodal_broadcast.py - - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py + - pytest -v -s entrypoints/llm + - pytest -v -s entrypoints/openai -- label: Distributed Tests (4 GPUs) - #mirror_hardwares: [amd] +- label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" num_gpus: 4 fast_check: true + source_file_dependencies: + - vllm/ + - tests/distributed + - tests/spec_decode/e2e/test_integration_dist_tp4 commands: - pytest -v -s distributed/test_pynccl.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py -- label: Pipeline Parallelism Test - working_dir: "/vllm-workspace/tests" - num_gpus: 4 +##### fast check tests ##### +##### 1 GPU test ##### + +- label: Metrics, Tracing Test # 10min + source_file_dependencies: + - vllm/ + - tests/metrics + - tests/tracing commands: - - pytest -v -s distributed/test_pipeline_parallel.py + - pytest -v -s metrics + - "pip install \ + opentelemetry-sdk \ + opentelemetry-api \ + opentelemetry-exporter-otlp \ + opentelemetry-semantic-conventions-ai" + - pytest -v -s tracing -- label: Engine Test +- label: Regression Test # 5min mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/test_regression + command: pytest -v -s test_regression.py + working_dir: "/vllm-workspace/tests" # optional + +- label: Engine Test # 10min + mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/engine + - tests/tokenization commands: - pytest -v -s engine test_sequence.py test_config.py test_logger.py # OOM in the CI unless we run this separately - pytest -v -s tokenization -- label: Entrypoints Test - fast_check: true - mirror_hardwares: [amd] - - commands: - - pytest -v -s entrypoints/llm - - pytest -v -s entrypoints/openai - -- label: Examples Test +- label: Examples Test # 12min working_dir: "/vllm-workspace/examples" mirror_hardwares: [amd] + source_file_dependencies: + - vllm/entrypoints + - examples/ commands: - # install tensorizer for tensorize_vllm_model.py - - pip install awscli tensorizer + - pip install awscli tensorizer # for llava example and tensorizer test - python3 offline_inference.py - python3 cpu_offload.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - - python3 offline_inference_vision_language.py + - python3 llava_example.py - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors -- label: Inputs Test - #mirror_hardwares: [amd] - commands: - - pytest -v -s test_inputs.py - - pytest -v -s multimodal - -# - label: Kernels Test %N -# #mirror_hardwares: [amd] -# commands: -# - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl -# - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT -# parallelism: 4 - -- label: Models Test - #mirror_hardwares: [amd] +- label: Models Test # 1hr10min + source_file_dependencies: + - vllm/ + - tests/models commands: - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl - pytest -v -s models -m \"not vlm\" -- label: Vision Language Models Test +- label: Vision Language Models Test # 42min mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ commands: - pytest -v -s models -m vlm -- label: Prefix Caching Test +- label: Prefix Caching Test # 7min mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/prefix_caching commands: - pytest -v -s prefix_caching -- label: Samplers Test - #mirror_hardwares: [amd] +- label: Samplers Test # 18min + source_file_dependencies: + - vllm/model_executor/layers + - vllm/sampling_metadata.py + - tests/samplers command: pytest -v -s samplers -- label: LogitsProcessor Test +- label: LogitsProcessor Test # 5min mirror_hardwares: [amd] + source_file_dependencies: + - vllm/model_executor/layers + - tests/test_logits_processor command: pytest -v -s test_logits_processor.py -- label: Utils Test - commands: - - pytest -v -s test_utils.py - - pytest -v -s test_embedded_commit.py - -- label: Worker Test - mirror_hardwares: [amd] - command: pytest -v -s worker - -- label: Speculative decoding tests - #mirror_hardwares: [amd] +- label: Speculative decoding tests # 22min + source_file_dependencies: + - vllm/spec_decode + - tests/spec_decode commands: # See https://github.com/vllm-project/vllm/issues/5152 - export VLLM_ATTENTION_BACKEND=XFORMERS - pytest -v -s spec_decode -# - label: LoRA Test %N -# #mirror_hardwares: [amd] -# command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py -# parallelism: 4 - -# - label: LoRA Long Context (Distributed) -# #mirror_hardwares: [amd] -# num_gpus: 4 -# # This test runs llama 13B, so it is required to run on 4 GPUs. -# commands: -# # FIXIT: find out which code initialize cuda before running the test -# # before the fix, we need to use spawn to test it -# - export VLLM_WORKER_MULTIPROC_METHOD=spawn -# - pytest -v -s -x lora/test_long_context.py - -- label: Tensorizer Test - #mirror_hardwares: [amd] - fast_check: true +- label: LoRA Test %N # 30min each + source_file_dependencies: + - vllm/lora + - csrc/punica + - tests/lora + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py + parallelism: 4 + +- label: Kernels Test %N # 30min each + source_file_dependencies: + - csrc/ + - vllm/attention + - tests/kernels + commands: + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl + - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 4 + +- label: Tensorizer Test # 11min + soft_fail: true + source_file_dependencies: + - vllm/model_executor/model_loader + - tests/tensorizer_loader commands: - apt-get install -y curl libsodium23 - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s tensorizer_loader -- label: Metrics Test - mirror_hardwares: [amd] - command: pytest -v -s metrics - -- label: Quantization Test - #mirror_hardwares: [amd] - command: pytest -v -s quantization - -- label: Tracing Test - commands: - - "pip install \ - opentelemetry-sdk \ - opentelemetry-api \ - opentelemetry-exporter-otlp \ - opentelemetry-semantic-conventions-ai" - - pytest -v -s tracing - -- label: Benchmarks +- label: Benchmarks # 9min working_dir: "/vllm-workspace/.buildkite" mirror_hardwares: [amd] + source_file_dependencies: + - benchmarks/ commands: - pip install aiohttp - bash run-benchmarks.sh -- label: LM Eval Small Models +- label: Quantization Test # 15min + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + - tests/quantization + command: pytest -v -s quantization + +- label: LM Eval Small Models # 53min working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization commands: - pip install lm-eval - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 -- label: LM Eval Large Models - gpu: a100 + +##### 1 GPU test ##### +##### multi gpus test ##### + +- label: Distributed Comm Ops Test # 7min + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/distributed + - tests/distributed + commands: + - pytest -v -s distributed/test_comm_ops.py + - pytest -v -s distributed/test_shm_broadcast.py + +- label: 2 Node Tests (4 GPUs in total) # 16min + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + num_nodes: 2 + source_file_dependencies: + - vllm/ + - tests/distributed/test_same_node + commands: + - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py + - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py + - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py + +- label: Distributed Tests (2 GPUs) # 28min + mirror_hardwares: [amd] + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/ + - tests/distributed + commands: + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py + - TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py + - pytest -v -s distributed/test_chunked_prefill_distributed.py + - pytest -v -s distributed/test_multimodal_broadcast.py + - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py + +- label: Pipeline Parallelism Test # 23min + working_dir: "/vllm-workspace/tests" num_gpus: 4 - working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - vllm/ + - tests/distributed/test_pipeline_parallel commands: - - pip install lm-eval - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh -c configs/models-large.txt -t 4 + - pytest -v -s distributed/test_pipeline_parallel.py -- label: Documentation Build - working_dir: "/vllm-workspace/test_docs/docs" - fast_check: true - no_gpu: True +- label: LoRA Long Context (Distributed) # 11min + # This test runs llama 13B, so it is required to run on 4 GPUs. + num_gpus: 4 + source_file_dependencies: + - vllm/lora + - csrc/punica + - tests/lora/test_long_context commands: - - pip install -r requirements-docs.txt - - SPHINXOPTS=\"-W\" make html + # FIXIT: find out which code initialize cuda before running the test + # before the fix, we need to use spawn to test it + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s -x lora/test_long_context.py + +##### multi gpus test ##### +##### A100 test ##### -- label: Distributed Tests (A100) +- label: Distributed Tests (A100) # optional gpu: a100 num_gpus: 4 + source_file_dependencies: + - vllm/ commands: # NOTE: don't test llama model here, it seems hf implementation is buggy # see https://github.com/vllm-project/vllm/pull/5689 for details @@ -266,3 +332,15 @@ steps: - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl - TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s -x lora/test_mixtral.py + +- label: LM Eval Large Models # optional + gpu: a100 + num_gpus: 4 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - pip install lm-eval + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - bash ./run-tests.sh -c configs/models-large.txt -t 4 From 9118217f58c39040aa935b7c85250c7364ffa72d Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 6 Aug 2024 09:57:25 +0800 Subject: [PATCH 0077/3246] [LoRA] Relax LoRA condition (#7146) --- tests/lora/test_layers.py | 2 +- tests/lora/test_punica_variation.py | 2 +- vllm/config.py | 5 +++-- vllm/lora/layers.py | 6 +++--- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index d8cc68d5e959..ad86f7bdf610 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -420,7 +420,7 @@ def create_random_embedding_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, stage) -> None: diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 7e73ea67ee5f..5bf3f72e7d97 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -25,7 +25,7 @@ BATCHES = [1, 4, 16, 32] NUM_LORA = [1, 4, 8, 16, 32, 64, 128] DTYPES = [torch.float16, torch.bfloat16] -MAX_RANKS = [1, 4, 8, 16, 32, 64, 128] +MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256] SCALES = [0.5] SEED = [0] CUDA_DEVICES = [f"cuda:{0}"] diff --git a/vllm/config.py b/vllm/config.py index 4b968f549d90..3cc197f3d655 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1311,8 +1311,9 @@ class LoRAConfig: long_lora_scaling_factors: Optional[Tuple[float]] = None def __post_init__(self): - # TODO: Increase the range of rank - possible_max_ranks = (8, 16, 32, 64) + # Setting the maximum rank to 256 should be able to satisfy the vast + # majority of applications. + possible_max_ranks = (8, 16, 32, 64, 128, 256) possible_lora_extra_vocab_size = (0, 256, 512) if self.max_lora_rank not in possible_max_ranks: raise ValueError( diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index d3978ff6f4ff..e3316059dc6d 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1073,10 +1073,10 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: - # TODO: Verify if this condition can be relaxed - if 32000 < self.base_layer.vocab_size > 128512: + # TODO: Verify if this condition can be further relaxed + if 32000 < self.base_layer.vocab_size > 257024: raise ValueError("When using LoRA, vocab size must be " - "32000 >= vocab_size <= 128512") + "32000 >= vocab_size <= 257024") self.lora_a_stacked = torch.zeros( ( max_loras, From 1f26efbb3a5e6dad0b98421dd697167c42a50629 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 6 Aug 2024 16:55:31 +0800 Subject: [PATCH 0078/3246] [Model] Support SigLIP encoder and alternative decoders for LLaVA models (#7153) Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- requirements-test.txt | 3 + tests/models/test_llava.py | 36 +++- tests/models/test_llava_next.py | 9 +- tests/models/test_paligemma.py | 9 +- tests/models/test_registry.py | 2 +- vllm/model_executor/model_loader/loader.py | 38 +++- vllm/model_executor/model_loader/utils.py | 8 +- vllm/model_executor/models/__init__.py | 16 +- vllm/model_executor/models/clip.py | 24 ++- vllm/model_executor/models/internvl.py | 24 +-- vllm/model_executor/models/llava.py | 211 +++++++++--------- vllm/model_executor/models/llava_next.py | 239 ++++++++++++--------- vllm/model_executor/models/siglip.py | 49 ++++- vllm/model_executor/models/utils.py | 56 ++++- 14 files changed, 455 insertions(+), 269 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index 5f3fd15c7ee5..62d6cc49eade 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -20,6 +20,9 @@ sentence-transformers # required for embedding compressed-tensors==0.4.0 # required for compressed-tensors timm # required for internvl test +# TODO: Add this after fully implementing llava(mantis) +# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test + # Benchmarking aiohttp diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index 79ab58c364f6..749d3353717e 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -1,10 +1,11 @@ from typing import List, Optional, Tuple, Type import pytest -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from .utils import check_logprobs_close @@ -18,9 +19,11 @@ "USER: \nWhat is the season?\nASSISTANT:", }) -IMAGE_TOKEN_ID = 32000 - -models = ["llava-hf/llava-1.5-7b-hf"] +models = [ + "llava-hf/llava-1.5-7b-hf", + # TODO: Get this model to produce meaningful output in vLLM + # "TIGER-Lab/Mantis-8B-siglip-llama3", +] def vllm_to_hf_output(vllm_output: Tuple[List[int], str, @@ -29,12 +32,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, """Sanitize vllm output to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + tokenizer = AutoTokenizer.from_pretrained(model) eos_token_id = tokenizer.eos_token_id hf_output_ids = [ token_id for idx, token_id in enumerate(output_ids) - if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID + if token_id != image_token_id or output_ids[idx - 1] != image_token_id ] assert output_str[0] == " " @@ -67,6 +73,17 @@ def run_test( Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ + # NOTE: For local use; this isn't tested in CI yet (see TODO above) + if model.startswith("TIGER-Lab/Mantis"): + from mantis.models.mllava import MLlavaProcessor + + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] + mantis_processor = MLlavaProcessor.from_pretrained( + model, torch_dtype=torch_dtype) + assert isinstance(mantis_processor, MLlavaProcessor) + else: + mantis_processor = None + images = [asset.pil_image for asset in image_assets] inputs_per_image = [( @@ -94,6 +111,15 @@ def run_test( ] with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: + if mantis_processor is not None: + + def process(*args, **kwargs): + output = mantis_processor(*args, **kwargs) + output["pixel_values"] = output["pixel_values"].to(torch_dtype) + return output + + hf_model.processor = process + hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index b6d72dee5c5b..60c7fc33b72f 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -1,7 +1,7 @@ from typing import List, Optional, Tuple, Type, overload import pytest -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -23,8 +23,6 @@ f"{_PREFACE} USER: \nWhat is the season? ASSISTANT:", }) -IMAGE_TOKEN_ID = 32000 - models = ["llava-hf/llava-v1.6-vicuna-7b-hf"] @@ -34,12 +32,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, """Sanitize vllm output to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + tokenizer = AutoTokenizer.from_pretrained(model) eos_token_id = tokenizer.eos_token_id hf_output_ids = [ token_id for idx, token_id in enumerate(output_ids) - if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID + if token_id != image_token_id or output_ids[idx - 1] != image_token_id ] assert output_str[0] == " " diff --git a/tests/models/test_paligemma.py b/tests/models/test_paligemma.py index e1c39ee6fecb..f3f682b1c2cd 100644 --- a/tests/models/test_paligemma.py +++ b/tests/models/test_paligemma.py @@ -2,7 +2,7 @@ from typing import List, Optional, Tuple, Type import pytest -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -20,8 +20,6 @@ "What is in the picture?", }) -IMAGE_TOKEN_ID = 257152 - models = ["google/paligemma-3b-mix-224"] # ROCm Triton FA can run into compilation issues with these models due to, @@ -37,12 +35,15 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, """Sanitize vllm output to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + tokenizer = AutoTokenizer.from_pretrained(model) eos_token_id = tokenizer.eos_token_id hf_output_ids = [ token_id for idx, token_id in enumerate(output_ids) - if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID + if token_id != image_token_id or output_ids[idx - 1] != image_token_id ] hf_output_str = output_str diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 547ab10051f1..b058e2755c24 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -6,4 +6,4 @@ @pytest.mark.parametrize("model_cls", _MODELS) def test_registry_imports(model_cls): # Ensure all model classes can be imported successfully - ModelRegistry.load_model_cls(model_cls) + ModelRegistry.resolve_model_cls([model_cls]) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index a5c5cb87bc46..44c04c9ba8dd 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -16,7 +16,7 @@ import torch from huggingface_hub import HfApi, hf_hub_download from torch import nn -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, PretrainedConfig from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, MultiModalConfig, @@ -143,6 +143,22 @@ def _get_model_initialization_kwargs( return extra_kwargs +def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], *, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + scheduler_config: Optional[SchedulerConfig]) -> nn.Module: + extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config, + multimodal_config, + scheduler_config) + + return model_class(config=hf_config, + cache_config=cache_config, + quant_config=quant_config, + **extra_kwargs) + + def _initialize_model( model_config: ModelConfig, load_config: LoadConfig, @@ -151,15 +167,17 @@ def _initialize_model( cache_config: CacheConfig, scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module: """Initialize a model with the given configurations.""" - model_class = get_model_architecture(model_config)[0] - quant_config = _get_quantization_config(model_config, load_config) - - return model_class(config=model_config.hf_config, - cache_config=cache_config, - quant_config=quant_config, - **_get_model_initialization_kwargs( - model_class, lora_config, multimodal_config, - scheduler_config)) + model_class, _ = get_model_architecture(model_config) + + return build_model( + model_class, + model_config.hf_config, + quant_config=_get_quantization_config(model_config, load_config), + lora_config=lora_config, + multimodal_config=multimodal_config, + cache_config=cache_config, + scheduler_config=scheduler_config, + ) class BaseModelLoader(ABC): diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index f7e0f56c1a46..331b859d2ade 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -28,13 +28,7 @@ def get_model_architecture( and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) - if model_cls is not None: - return (model_cls, arch) - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") + return ModelRegistry.resolve_model_cls(architectures) def get_architecture_class_name(model_config: ModelConfig) -> str: diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 94c3cea98be7..ebb77a802d5c 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,6 +1,6 @@ import functools import importlib -from typing import Dict, List, Optional, Type +from typing import Dict, List, Optional, Tuple, Type import torch.nn as nn @@ -126,7 +126,7 @@ def _get_model(model_arch: str): return getattr(module, model_cls_name, None) @staticmethod - def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: if model_arch in _OOT_MODELS: return _OOT_MODELS[model_arch] if model_arch not in _MODELS: @@ -143,6 +143,18 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: return ModelRegistry._get_model(model_arch) + @staticmethod + def resolve_model_cls( + architectures: List[str]) -> Tuple[Type[nn.Module], str]: + for arch in architectures: + model_cls = ModelRegistry._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + @staticmethod def get_supported_archs() -> List[str]: return list(_MODELS.keys()) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index b4f628061f19..805ade39389d 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,6 +1,6 @@ """Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" -from typing import Optional +from typing import Iterable, Optional, Tuple import torch import torch.nn as nn @@ -14,6 +14,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.image import (cached_get_tokenizer, repeat_and_pad_image_tokens) from vllm.sequence import SequenceData @@ -32,7 +33,7 @@ def get_clip_num_patches(*, image_size: int, patch_size: int) -> int: def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int: return get_clip_num_patches(image_size=hf_config.image_size, - patch_size=hf_config.patch_size) + patch_size=hf_config.patch_size) + 1 def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int: @@ -291,3 +292,22 @@ def forward(self, pixel_values: Optional[torch.Tensor] = None): @property def device(self): return next(self.parameters()).device + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + layer_count = len(self.vision_model.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is not needed in CLIPVisionModel + if "vision_model.post_layernorm" in name: + continue + # omit layers when num_hidden_layers_override is set + if "vision_model.encoder.layers." in name: + layer_idx = int(name.split(".")[3]) + if layer_idx >= layer_count: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 474925127148..8850fd7c6763 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -18,7 +18,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -29,7 +28,8 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) from .interfaces import SupportsVision -from .utils import merge_vision_embeddings +from .utils import (filter_weights, init_vllm_registered_model, + merge_vision_embeddings) IMG_START = '' IMG_END = '' @@ -283,10 +283,8 @@ def __init__(self, self.vision_model = InternVisionModel( config.vision_config, num_hidden_layers_override=num_hidden_layers) - llm_class = ModelRegistry.load_model_cls( - config.text_config.architectures[0]) - self.language_model = llm_class(config.text_config, cache_config, - quant_config) + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) vit_hidden_size = config.vision_config.hidden_size llm_hidden_size = config.text_config.hidden_size @@ -415,24 +413,16 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], - prefix: str): - for name, loaded_weight in weights: - name = name.split(".") - if prefix == name.pop(0): - name = ".".join(name) - yield name, loaded_weight - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) # load vision encoder - vit_weights = self._filter_weights(vit_weights, "vision_model") + vit_weights = filter_weights(vit_weights, "vision_model") self.vision_model.load_weights(vit_weights) # load mlp projector - mlp_weights = self._filter_weights(mlp_weights, "mlp1") + mlp_weights = filter_weights(mlp_weights, "mlp1") mlp_params_dict = dict(self.mlp1.named_parameters()) for name, loaded_weight in mlp_weights: param = mlp_params_dict[name] @@ -441,5 +431,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) # load llm backbone - llm_weights = self._filter_weights(llm_weights, "language_model") + llm_weights = filter_weights(llm_weights, "language_model") self.language_model.load_weights(llm_weights) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 4e7e6c47f0a0..9a11bcc4c54c 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,34 +1,30 @@ -from typing import Iterable, List, Literal, Optional, Tuple, TypedDict +import itertools +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch import torch.nn as nn -from transformers import CLIPVisionConfig, LlavaConfig +from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.clip import CLIPVisionModel -from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SamplerOutput -from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, - get_max_clip_image_tokens, input_processor_for_clip) +from .clip import (CLIPVisionModel, dummy_image_for_clip, + dummy_seq_data_for_clip, get_max_clip_image_tokens, + input_processor_for_clip) from .interfaces import SupportsVision -from .utils import merge_vision_embeddings - -_KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", -} +from .siglip import (SiglipVisionModel, dummy_image_for_siglip, + dummy_seq_data_for_siglip, get_max_siglip_image_tokens, + input_processor_for_siglip) +from .utils import (filter_weights, init_vllm_registered_model, + merge_vision_embeddings) # TODO(xwjiang): Run benchmark and decide if TP. @@ -67,25 +63,48 @@ def get_max_llava_image_tokens(ctx: InputContext): vision_config = hf_config.vision_config if isinstance(vision_config, CLIPVisionConfig): - return get_max_clip_image_tokens(vision_config) - - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + num_image_tokens = get_max_clip_image_tokens(vision_config) + elif isinstance(vision_config, SiglipVisionConfig): + num_image_tokens = get_max_siglip_image_tokens(vision_config) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + strategy = hf_config.vision_feature_select_strategy + if strategy == "default": + return num_image_tokens - 1 + elif strategy == "full": + return num_image_tokens + else: + raise ValueError(f"Unexpected select feature strategy: {strategy}") def dummy_data_for_llava(ctx: InputContext, seq_len: int): hf_config = ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config + image_feature_size = get_max_llava_image_tokens(ctx) + if isinstance(vision_config, CLIPVisionConfig): seq_data = dummy_seq_data_for_clip( vision_config, seq_len, image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, ) mm_data = dummy_image_for_clip(vision_config) return seq_data, mm_data + elif isinstance(vision_config, SiglipVisionConfig): + seq_data = dummy_seq_data_for_siglip( + vision_config, + seq_len, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + mm_data = dummy_image_for_siglip(vision_config) + return seq_data, mm_data msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -100,12 +119,49 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): hf_config = ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config + image_feature_size = get_max_llava_image_tokens(ctx) + if isinstance(vision_config, CLIPVisionConfig): return input_processor_for_clip( model_config, vision_config, llm_inputs, image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return input_processor_for_siglip( + model_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def _init_vision_tower(hf_config: LlavaConfig): + vision_config = hf_config.vision_config + + # Initialize the vision tower only up to the required feature layer + vision_feature_layer = hf_config.vision_feature_layer + if vision_feature_layer < 0: + num_hidden_layers = hf_config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + + if isinstance(vision_config, CLIPVisionConfig): + return CLIPVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return SiglipVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, ) msg = f"Unsupported vision config: {type(vision_config)}" @@ -128,36 +184,15 @@ def __init__(self, self.config = config self.multimodal_config = multimodal_config - # Initialize the vision tower only up to the required feature layer - vision_feature_layer = config.vision_feature_layer - if vision_feature_layer < 0: - num_hidden_layers = config.vision_config.num_hidden_layers \ - + vision_feature_layer + 1 - else: - num_hidden_layers = vision_feature_layer + 1 - # TODO: Optionally initializes this for supporting embeddings. - self.vision_tower = CLIPVisionModel( - config.vision_config, num_hidden_layers_override=num_hidden_layers) + self.vision_tower = _init_vision_tower(config) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) - self.quant_config = quant_config - self.language_model = LlamaModel(config.text_config, cache_config, - quant_config) - self.unpadded_vocab_size = config.text_config.vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.text_config.hidden_size, - org_num_embeddings=self.language_model.org_vocab_size, - quant_config=quant_config) - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.text_config.vocab_size, - logit_scale) - self.sampler = Sampler() + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size @@ -198,8 +233,11 @@ def _select_image_features(self, image_features: torch.Tensor, *, raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, - pixel_values: torch.Tensor) -> torch.Tensor: + def _image_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + pixel_values: torch.Tensor, + ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower @@ -272,7 +310,8 @@ def forward( if image_input is not None: vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) inputs_embeds = merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, @@ -282,68 +321,44 @@ def forward( else: inputs_embeds = None - hidden_states = self.language_model(input_ids, - positions, - kv_caches, - attn_metadata, - None, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + None, + inputs_embeds=inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + return self.language_model.compute_logits(hidden_states, + sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # only doing this for language model part for now. - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - # post_layernorm is not needed in CLIPVisionModel - if "vision_model.post_layernorm" in name: - continue - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - use_default_weight_loading = False - if "vision" in name: - if self.vision_tower is not None: - # We only do sharding for language model and - # not vision model for now. - use_default_weight_loading = True - else: - for (param_name, weight_name, - shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - param = params_dict[name.replace(weight_name, param_name)] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - use_default_weight_loading = True - if use_default_weight_loading and name in params_dict: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + # prepare weight iterators for components + vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + + # load vision encoder + vit_weights = filter_weights(vit_weights, "vision_tower") + self.vision_tower.load_weights(vit_weights) + + # load mlp projector + mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") + mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) + for name, loaded_weight in mlp_weights: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + llm_weights = filter_weights(llm_weights, "language_model") + self.language_model.load_weights(llm_weights) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 4a67b9a583ea..9abc480f60de 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,9 +1,10 @@ +import itertools from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch import torch.nn as nn from PIL import Image -from transformers import CLIPVisionConfig, LlavaNextConfig +from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) from typing_extensions import NotRequired @@ -12,23 +13,23 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.clip import CLIPVisionModel -from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SamplerOutput -from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, +from .clip import (CLIPVisionModel, dummy_image_for_clip, + dummy_seq_data_for_clip, get_clip_image_feature_size, get_clip_patch_grid_length, input_processor_for_clip) from .interfaces import SupportsVision from .llava import LlavaMultiModalProjector -from .utils import merge_vision_embeddings +from .siglip import (SiglipVisionModel, dummy_image_for_siglip, + dummy_seq_data_for_siglip, get_siglip_image_feature_size, + get_siglip_patch_grid_length, input_processor_for_siglip) +from .utils import (filter_weights, init_vllm_registered_model, + merge_vision_embeddings) logger = init_logger(__name__) @@ -104,30 +105,42 @@ def get_llava_next_image_feature_size( image_size=vision_config.image_size, patch_size=vision_config.patch_size, ) - base_feature_size = num_patches * num_patches - - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_size=(input_height, input_width), - grid_pinpoints=hf_config.image_grid_pinpoints, - patch_size=vision_config.image_size, + base_feature_size = get_clip_image_feature_size(vision_config) + elif isinstance(vision_config, SiglipVisionConfig): + num_patches = get_siglip_patch_grid_length( + image_size=vision_config.image_size, + patch_size=vision_config.patch_size, ) + base_feature_size = get_siglip_image_feature_size(vision_config) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + strategy = hf_config.vision_feature_select_strategy + if strategy == "default": + base_feature_size -= 1 + elif strategy == "full": + pass + else: + raise ValueError(f"Unexpected select feature strategy: {strategy}") - ( - unpadded_feature_size, - newline_feature_size, - ) = _get_llava_next_num_unpadded_features(input_height, input_width, - num_patches, - num_patch_height, - num_patch_width) + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_size=(input_height, input_width), + grid_pinpoints=hf_config.image_grid_pinpoints, + patch_size=vision_config.image_size, + ) - return unpadded_feature_size + newline_feature_size + base_feature_size + ( + unpadded_feature_size, + newline_feature_size, + ) = _get_llava_next_num_unpadded_features(input_height, input_width, + num_patches, num_patch_height, + num_patch_width) - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + return unpadded_feature_size + newline_feature_size + base_feature_size def get_max_llava_next_image_tokens(ctx: InputContext): - return get_llava_next_image_feature_size( ctx.get_hf_config(LlavaNextConfig), input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, @@ -155,6 +168,21 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) + return seq_data, mm_data + elif isinstance(vision_config, SiglipVisionConfig): + seq_data = dummy_seq_data_for_siglip( + vision_config, + seq_len, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + mm_data = dummy_image_for_siglip( + vision_config, + image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, + image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + ) + return seq_data, mm_data msg = f"Unsupported vision config: {type(vision_config)}" @@ -194,6 +222,40 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) + elif isinstance(vision_config, SiglipVisionConfig): + return input_processor_for_siglip( + model_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def _init_vision_tower(hf_config: LlavaNextConfig): + vision_config = hf_config.vision_config + + # Initialize the vision tower only up to the required feature layer + vision_feature_layer = hf_config.vision_feature_layer + if vision_feature_layer < 0: + num_hidden_layers = hf_config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + + if isinstance(vision_config, CLIPVisionConfig): + return CLIPVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return SiglipVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, + ) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -215,36 +277,15 @@ def __init__(self, self.config = config self.multimodal_config = multimodal_config - # Initialize the vision tower only up to the required feature layer - vision_feature_layer = config.vision_feature_layer - if vision_feature_layer < 0: - num_hidden_layers = config.vision_config.num_hidden_layers \ - + vision_feature_layer + 1 - else: - num_hidden_layers = vision_feature_layer + 1 - # TODO: Optionally initializes this for supporting embeddings. - self.vision_tower = CLIPVisionModel( - config.vision_config, num_hidden_layers_override=num_hidden_layers) + self.vision_tower = _init_vision_tower(config) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) - self.quant_config = quant_config - self.language_model = LlamaModel(config.text_config, cache_config, - quant_config) - self.unpadded_vocab_size = config.text_config.vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.text_config.hidden_size, - org_num_embeddings=self.language_model.org_vocab_size, - quant_config=quant_config) - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.text_config.vocab_size, - logit_scale) - self.sampler = Sampler() + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) @@ -310,8 +351,11 @@ def _select_image_features(self, image_features: torch.Tensor, *, raise ValueError(f"Unexpected select feature strategy: {strategy}") - def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, - pixel_values: torch.Tensor) -> torch.Tensor: + def _image_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + pixel_values: torch.Tensor, + ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower @@ -496,7 +540,8 @@ def forward( if image_input is not None: vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) inputs_embeds = merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, @@ -506,68 +551,54 @@ def forward( else: inputs_embeds = None - hidden_states = self.language_model(input_ids, - positions, - kv_caches, - attn_metadata, - None, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + None, + inputs_embeds=inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + return self.language_model.compute_logits(hidden_states, + sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # only doing this for language model part for now. - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - # post_layernorm is not needed in CLIPVisionModel - if "vision_model.post_layernorm" in name: - continue - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - use_default_weight_loading = False - if "vision" in name: - if self.vision_tower is not None: - # We only do sharding for language model and - # not vision model for now. - use_default_weight_loading = True - else: - for (param_name, weight_name, - shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - param = params_dict[name.replace(weight_name, param_name)] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - use_default_weight_loading = True - if use_default_weight_loading and name in params_dict: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + # prepare weight iterators for components + vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( + weights, 4) + + # load vision encoder + vit_weights = filter_weights(vit_weights, "vision_tower") + self.vision_tower.load_weights(vit_weights) + + # load mlp projector + mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") + mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) + for name, loaded_weight in mlp_weights: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load newline + newline_weights = filter_weights(newline_weights, "image_newline") + for name, loaded_weight in newline_weights: + assert name == "" + param = self.image_newline + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + llm_weights = filter_weights(llm_weights, "language_model") + self.language_model.load_weights(llm_weights) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 6faef45c9a6d..5ba14f73394f 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -2,12 +2,12 @@ within a vision language model.""" import math -from typing import Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from PIL import Image from torch import nn -from transformers import SiglipConfig, SiglipVisionConfig +from transformers import SiglipVisionConfig from transformers.models.siglip.modeling_siglip import SiglipAttention from vllm_flash_attn import flash_attn_func from xformers.ops import memory_efficient_attention @@ -22,13 +22,15 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.image import (cached_get_tokenizer, repeat_and_pad_image_tokens) from vllm.sequence import SequenceData def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: - assert image_size % patch_size == 0 + # Since interpolation is applied, the image size need not be divisible + # assert image_size % patch_size == 0 return image_size // patch_size @@ -454,7 +456,7 @@ class SiglipEncoderLayer(nn.Module): def __init__( self, - config: SiglipConfig, + config: SiglipVisionConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -474,7 +476,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - ) -> Tuple[torch.Tensor]: + ) -> Tuple[torch.Tensor, None]: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) @@ -493,22 +495,27 @@ class SiglipEncoder(nn.Module): def __init__( self, - config: SiglipConfig, + config: SiglipVisionConfig, quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None, ): super().__init__() self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + self.layers = nn.ModuleList([ - SiglipEncoderLayer( - config, - quant_config=quant_config, - ) for _ in range(config.num_hidden_layers) + SiglipEncoderLayer(config, quant_config=quant_config) + for _ in range(num_hidden_layers) ]) def forward( self, inputs_embeds: torch.Tensor, - ) -> Tuple: + ) -> torch.Tensor: hidden_states = inputs_embeds for encoder_layer in self.layers: hidden_states, _ = encoder_layer(hidden_states) @@ -553,6 +560,7 @@ def __init__( self, config: SiglipVisionConfig, quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None, ): super().__init__() self.config = config @@ -562,6 +570,7 @@ def __init__( self.encoder = SiglipEncoder( config, quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -600,11 +609,13 @@ def __init__( self, config: SiglipVisionConfig, quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None, ): super().__init__() self.vision_model = SiglipVisionTransformer( config, quant_config, + num_hidden_layers_override=num_hidden_layers_override, ) def get_input_embeddings(self) -> nn.Module: @@ -619,3 +630,19 @@ def forward( pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + layer_count = len(self.vision_model.encoder.layers) + + for name, loaded_weight in weights: + # omit layers when num_hidden_layers_override is set + if "vision_model.encoder.layers." in name: + layer_idx = int(name.split(".")[3]) + if layer_idx >= layer_count: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 91b4a27814bf..d1bb030c6c90 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,22 +1,70 @@ -from typing import Dict, List, Protocol, Tuple +from typing import Dict, Iterable, List, Optional, Protocol, Tuple import torch +import torch.nn as nn from torch.func import functional_call +from transformers import PretrainedConfig +from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig, + SchedulerConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.loader import build_model +from vllm.model_executor.models import ModelRegistry from vllm.multimodal import BatchedTensors from vllm.utils import is_pin_memory_available +def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str): + """ + Helper function to load weights for inner vLLM models. + + See also: + :ref:`init_vllm_registered_model` + """ + for name, loaded_weight in weights: + name = name.split(".") + if prefix == name.pop(0): + name = ".".join(name) + yield name, loaded_weight + + +def init_vllm_registered_model( + hf_config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], + *, + lora_config: Optional[LoRAConfig] = None, + multimodal_config: Optional[MultiModalConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, +) -> nn.Module: + """ + Helper function to initialize an inner model registered to vLLM, + based on the arguments passed to the outer vLLM model. + """ + model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures) + + return build_model( + model_class, + hf_config, + cache_config, + quant_config, + lora_config=lora_config, + multimodal_config=multimodal_config, + scheduler_config=scheduler_config, + ) + + def merge_vision_embeddings(input_ids: torch.Tensor, inputs_embeds: torch.Tensor, vision_embeddings: BatchedTensors, image_token_id: int) -> torch.Tensor: """ - Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions - in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`. + Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the + positions in ``inputs_embeds`` corresponding to placeholder image tokens in + ``input_ids``. Note: - This updates `inputs_embeds` in place. + This updates ``inputs_embeds`` in place. """ mask = (input_ids == image_token_id) num_expected_tokens = mask.sum() From a3bbbfa1d8c2f30581d37c6f30429d648bbbf87c Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 6 Aug 2024 11:16:53 -0400 Subject: [PATCH 0079/3246] [BugFix] Fix DeepSeek remote code (#7178) --- .../lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml | 1 + .buildkite/lm-eval-harness/test_lm_eval_correctness.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml index 15268395ec68..d70ecb2a7e7b 100644 --- a/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml +++ b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml @@ -9,3 +9,4 @@ tasks: value: 0.664 limit: 1000 num_fewshot: 5 +trust_remote_code: True \ No newline at end of file diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index 7fdce7b53bd7..af3226f51f4f 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -23,9 +23,12 @@ def launch_lm_eval(eval_config): + trust_remote_code = eval_config.get('trust_remote_code', False) + model_args = f"pretrained={eval_config['model_name']}," \ f"tensor_parallel_size={TP_SIZE}," \ - f"add_bos_token=true" + f"add_bos_token=true," \ + f"trust_remote_code={trust_remote_code}" results = lm_eval.simple_evaluate( model="vllm", From 541c1852d37b9502fbc06253def70e901ca0c352 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 6 Aug 2024 12:26:26 -0400 Subject: [PATCH 0080/3246] [ BugFix ] Fix ZMQ when `VLLM_PORT` is set (#7205) --- vllm/envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index 089a39d8e029..81d2d80e65e4 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -145,7 +145,7 @@ def get_default_config_root(): # used when the frontend api server is running in multi-processing mode, # to communicate with the backend engine process over ZMQ. 'VLLM_RPC_PORT': - lambda: int(os.getenv('VLLM_PORT', '5570')), + lambda: int(os.getenv('VLLM_RPC_PORT', '5570')), # If true, will load models from ModelScope instead of Hugging Face Hub. # note that the value is true or false, not numbers From 00afc7859072bdcaba30611c6563f2f7ac7104a3 Mon Sep 17 00:00:00 2001 From: Katarzyna Papis Date: Tue, 6 Aug 2024 19:08:35 +0200 Subject: [PATCH 0081/3246] [Bugfix] add gguf dependency (#7198) Co-authored-by: katarzyna.papis --- requirements-openvino.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-openvino.txt b/requirements-openvino.txt index a86c6cb58048..2dd971d6400b 100644 --- a/requirements-openvino.txt +++ b/requirements-openvino.txt @@ -25,6 +25,7 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 pyzmq +gguf == 0.9.1 # OpenVINO dependencies torch >= 2.1.2 From 5c60c8c423197bcf20fdc5217d79b78532033f04 Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Tue, 6 Aug 2024 10:40:32 -0700 Subject: [PATCH 0082/3246] [SpecDecode] [Minor] Fix spec decode sampler tests (#7183) --- tests/samplers/test_rejection_sampler.py | 14 +++++++------- .../test_typical_acceptance_sampler.py | 18 +++++++++--------- .../layers/spec_decode_base_sampler.py | 9 ++++++--- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 8f6c292620c2..3ce4a5f65819 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -25,7 +25,7 @@ def mock_causal_accepted_tensor( accepted = (torch.arange(k).expand(batch_size, k) <= last_accepted_indices.unsqueeze(-1).broadcast_to( - batch_size, k)).to(device="cuda") + batch_size, k)) # Sprinkle accepted values after the contiguous initial accepted values. # This replicates the behavior of rejection sampling, which may "accept" @@ -33,7 +33,7 @@ def mock_causal_accepted_tensor( sprinkle_candidates = ( torch.arange(k).expand(batch_size, k) > last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1) - sprinkle = torch.rand(batch_size, k, device="cuda") > 0.5 + sprinkle = torch.rand(batch_size, k) > 0.5 accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates] return accepted @@ -86,7 +86,7 @@ def test_correct_output_format(which_tokens_accepted: str, rejection_sampler = RejectionSampler( disable_bonus_tokens=disable_bonus_tokens) - rejection_sampler.init_gpu_tensors(rank=0) + rejection_sampler.init_gpu_tensors(device=device) output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access accepted, recovered_token_ids, @@ -138,7 +138,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, device: str): torch.set_default_device(device) rejection_sampler = RejectionSampler() - rejection_sampler.init_gpu_tensors(rank=0) + rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) @@ -167,7 +167,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, device: str): torch.set_default_device(device) rejection_sampler = RejectionSampler() - rejection_sampler.init_gpu_tensors(rank=0) + rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) @@ -211,7 +211,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, torch.set_default_device(device) rejection_sampler = RejectionSampler(strict_mode=True) - rejection_sampler.init_gpu_tensors(rank=0) + rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) @@ -339,7 +339,7 @@ def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler): self.vocab_size = vocab_size self.vocab_range = (0, vocab_size) - self.rejection_sampler.init_gpu_tensors(rank=0) + self.rejection_sampler.init_gpu_tensors(device=0) # Keep test simple, use k=1 self.k = 1 diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 4f6290795b2c..aa3c1d29bdb3 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -78,7 +78,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, """ torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler() - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, @@ -111,7 +111,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, vocab_size = 30_000 torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, @@ -171,7 +171,7 @@ def test_uniform_target_distribution_accepts_all_tokens( torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) draft_token_ids = torch.randint(low=0, high=vocab_size, @@ -225,7 +225,7 @@ def test_temperature_zero_target_distribution(seed: int, typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) # Simulate temperature 0 probability distribution for target probabilities # and create target probabilities such that only 1 token id has # probability 1.0 @@ -285,7 +285,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) # For sequences 0 and 2 set the distribution to a temperature # zero distribution. For sequences 1 and 3 set it to a uniform # distribution. @@ -352,7 +352,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) # Create a temperature zero target probability distribution and ensure # all draft token ids correspond to the tokens with 1.0 probability. # Verify that all of them are accepted. @@ -414,7 +414,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) # Simulate temperature 0 probability distribution for target # probabilities and create target probabilities such that only 1 token # id has probability 1.0 and others have a very low probability of @@ -447,7 +447,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, disable_bonus_tokens=disable_bonus_tokens, posterior_threshold=0.0, posterior_alpha=0.0) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) output_token_ids = typical_acceptance_sampler( target_probs, bonus_token_ids, @@ -485,7 +485,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool, torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) expected_replacement_tokens = -torch.ones( (batch_size, k), dtype=torch.long) diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 3091e639727b..467c43c41550 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Dict, Optional +from typing import Dict, Optional, Union import torch import torch.jit @@ -36,9 +36,12 @@ def __init__(self, self.num_emitted_tokens: Optional[torch.Tensor] = None self.num_draft_tokens: int = 0 - def init_gpu_tensors(self, rank: int) -> None: + def init_gpu_tensors(self, device: Union[int, str]) -> None: assert self.num_accepted_tokens is None - device = f"cuda:{rank}" + if isinstance(device, int): + device = f"cuda:{device}" + elif not isinstance(device, str): + raise ValueError(f"Device must be int or str, get {type(device)}") self.num_accepted_tokens = torch.tensor(0, dtype=torch.long, device=device) From 8d59dbb00044a588cab96bcdc028006ed922eb06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 6 Aug 2024 14:17:08 -0400 Subject: [PATCH 0083/3246] [Kernel] Add per-tensor and per-token AZP epilogues (#5941) Co-authored-by: Tyler Michael Smith --- .../cutlass_benchmarks/w8a8_benchmarks.py | 185 +++++++------ csrc/ops.h | 8 + csrc/quantization/cutlass_w8a8/Epilogues.md | 147 ++++++++++ .../broadcast_load_epilogue_c2x.hpp | 152 ++++++++++- .../cutlass_w8a8/scaled_mm_c2x.cu | 57 ++++ .../cutlass_w8a8/scaled_mm_c2x.cuh | 253 ++++++++++++++--- .../cutlass_w8a8/scaled_mm_c3x.cu | 258 ++++++++++++++++-- .../cutlass_w8a8/scaled_mm_entry.cu | 111 +++++++- csrc/torch_bindings.cpp | 11 +- tests/kernels/test_cutlass.py | 120 +++++++- vllm/_custom_ops.py | 26 +- 11 files changed, 1175 insertions(+), 153 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/Epilogues.md diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 64011b2db239..63cf5d50cac7 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -32,7 +32,6 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor: def make_rand_tensors(dtype: torch.dtype, m: int, n: int, k: int) -> Tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 b = torch.randn((n, k), device='cuda').t() * 5 @@ -44,59 +43,18 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int, raise ValueError("unsupported dtype") -# impl - - -def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return torch.mm(a, b) - - -def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return torch._scaled_mm(a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype) - - -def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor, - scale_a: torch.Tensor, scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return torch._scaled_mm(a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - use_fast_accum=True) - - -def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype) - - # bench -def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, out_dtype: torch.dtype, label: str, - sub_label: str, fn: Callable, description: str) -> TMeasurement: - +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: min_run_time = 1 globals = { - "a": a, - "b": b, - "scale_a": scale_a, - "scale_b": scale_b, - "out_dtype": out_dtype, + "args": args, + "kwargs": kwargs, "fn": fn, } return TBenchmark.Timer( - stmt="fn(a, b, scale_a, scale_b, out_dtype)", + stmt="fn(*args, **kwargs)", globals=globals, label=label, sub_label=sub_label, @@ -110,26 +68,58 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, a, b = make_rand_tensors(torch.int8, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) + azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) timers = [] # pytorch impl - bfloat16 timers.append( - bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, - torch.bfloat16, label, sub_label, pytorch_mm_impl, - "pytorch_bf16_bf16_bf16_matmul-no-scales")) + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) # pytorch impl - float16 timers.append( - bench_fn(a.to(dtype=torch.float16, device="cuda"), - b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b, - torch.float16, label, sub_label, pytorch_mm_impl, - "pytorch_fp16_fp16_fp16_matmul-no-scales")) + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) # cutlass impl timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm")) + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass with azp per-tensor + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp", + ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, + torch.bfloat16, azp_adj)) + + # cutlass with azp per-tensor + bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias", + ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, + torch.bfloat16, azp_adj, None, bias)) + + # cutlass with azp per-token + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt", + ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, + torch.bfloat16, azp_adj, azp)) + + # cutlass with azp per-token + bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias", + ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, + torch.bfloat16, azp_adj, azp, bias)) return timers @@ -140,46 +130,88 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) timers = [] # pytorch impl w. bf16 timers.append( - bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, - torch.bfloat16, label, sub_label, pytorch_mm_impl, - "pytorch_bf16_bf16_bf16_matmul-no-scales")) + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"))) # pytorch impl: bf16 output, without fp8 fast accum timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm")) + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16)) # pytorch impl: bf16 output, with fp8 fast accum timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - pytorch_fp8_impl_fast_accum, - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum")) + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True)) # pytorch impl: fp16 output, without fp8 fast accum timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm")) + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16)) # pytorch impl: fp16 output, with fp8 fast accum timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - pytorch_fp8_impl_fast_accum, - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum")) + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True)) # cutlass impl: bf16 output timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm")) + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) # cutlass impl: fp16 output timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm")) + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16)) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16, + bias.to(dtype=torch.float16))) + return timers @@ -200,7 +232,6 @@ def print_timers(timers: Iterable[TMeasurement]): def run(dtype: torch.dtype, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: - results = [] for m, k, n in MKNs: timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", @@ -216,7 +247,6 @@ def make_output(data: Iterable[TMeasurement], MKNs: Iterable[Tuple[int, int, int]], base_description: str, timestamp=None): - print(f"== All Results {base_description} ====") print_timers(data) @@ -251,7 +281,6 @@ def run_range_bench(args): def run_model_bench(args): - print("Benchmarking models:") for i, model in enumerate(args.models): print(f"[{i}] {model}") diff --git a/csrc/ops.h b/csrc/ops.h index e9e5f79a4a6f..023455f8a153 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -128,6 +128,14 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); +void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& azp, + c10::optional const& bias); + torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, torch::Tensor const& b_q_weight, torch::Tensor const& s_tok, diff --git a/csrc/quantization/cutlass_w8a8/Epilogues.md b/csrc/quantization/cutlass_w8a8/Epilogues.md new file mode 100644 index 000000000000..aae04157b10d --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/Epilogues.md @@ -0,0 +1,147 @@ +# CUTLASS Epilogues + +## Introduction +This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs. + +Currently, we only support symmetric quantization for weights, +and symmetric and asymmetric quantization for activations. +Both can be quantized per-tensor or per-channel (weights) / per-token (activations). + +There are 4 epilogues: +1. ScaledEpilogue: symmetric quantization for activations, no bias. +1. ScaledEpilogueBias: symmetric quantization for activations, supports bias. +1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias. +1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias. + +We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size. +Instead, if no bias is passed, the epilogue will use 0 as the bias. +That induces a redundant addition operation (and runtime check), but the performance impact is minor. + +## Underlying Linear Algebra + +More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975). + +If $` \widehat X `$ is the quantized $` X `$, our matrices become the following + +```math +A = s_a (\widehat A - J_a z_a) +``` +```math +B = s_b \widehat B +``` +```math +D = A B + C +``` +```math +D = s_a s_b \widehat D + C +``` + +Here, D is the output of the GEMM, and C is the bias. +A is the activations and supports asymmetric quantization, +and B is the weights and only supports symmetric quantization. +$ s_a $ and $s_b$ are the scales for activations and weights, respectively. +$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A. +Additional epilogues would be required to support asymmetric quantization for weights. + +Expanding further, we can calculate $` \widehat D `$ as follows: + +```math +A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B +``` +```math +A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right) +``` +```math +\widehat D = \widehat A \widehat B - z_a J_a \widehat B +``` + +Note that $` \widehat A \widehat B `$ is the raw output of the GEMM, +and $` J_a \widehat B `$ is known ahead of time. +Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$. + +## Epilogues + +### ScaledEpilogue +This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B +``` +```math +D = s_a s_b \widehat D +``` +```math +D = s_a s_b \widehat A \widehat B +``` + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). + +### ScaledEpilogueBias +This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B +``` +```math +D = s_a s_b \widehat D + C +``` +```math +D = s_a s_b \widehat A \widehat B + C +``` + + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `bias` is the bias, is always per-channel (row-vector). + +### ScaledEpilogueAzp +This epilogue computes the asymmetric per-tensor quantization for activations with bias. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B - z_a J_a \widehat B +``` +```math +D = s_a s_b \widehat D + C +``` +```math +D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C +``` + +Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. +That is precomputed and stored in `azp_with_adj` as a row-vector. + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). + - Generally this will be per-tensor as the zero-points are per-tensor. +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector). +- `bias` is the bias, is always per-channel (row-vector). + +To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel. + +### ScaledEpilogueAzpPerToken +This epilogue computes the asymmetric per-token quantization for activations with bias. + +The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector. +That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$. + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). + - Generally this will be per-token as the zero-points are per-token. +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector). +- `azp` is the zero-point (`z_a`), is per-token (column-vector). +- `bias` is the bias, is always per-channel (row-vector). + +To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel. + +The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM): +``` +out = scale_a * scale_b * (Dq - azp_adj * azp) + bias +``` diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp index c4c6b18654ee..d407d66ab2aa 100644 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp +++ b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp @@ -207,6 +207,156 @@ struct VisitorRowOrScalarBroadcast { }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrZeroBroadcast { + + // This struct has been modified to remove null_default (because it's always 0) + struct Arguments { + Element const* ptr_row = nullptr; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->ptr_row != nullptr) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are broadcasting 0 + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = Element{0}; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// // Column vector broadcast @@ -217,7 +367,7 @@ template< > struct VisitorColOrScalarBroadcast { - // This struct has been modified to have a bool indicating that ptr_col is a + // This struct has been modified to have a bool indicating that ptr_col is a // scalar that must be broadcast. struct Arguments { Element const* ptr_col = nullptr; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 8d0dfee7bf23..ee801e16573d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -50,6 +50,25 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, } } +void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& azp, + c10::optional const& bias) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + if (azp) { + return cutlass_scaled_mm_sm75_epilogue( + out, a, b, a_scales, b_scales, azp_adj, *azp, bias); + } else { + return cutlass_scaled_mm_sm75_epilogue( + out, a, b, a_scales, b_scales, azp_adj, bias); + } +} + template