Skip to content

Commit 7e0c808

Browse files
committed
Merge branch 'main' into remove-unused-attn-args
2 parents 29cff77 + b2c3fc5 commit 7e0c808

40 files changed

+1707
-971
lines changed

.pre-commit-config.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repos:
2222
additional_dependencies: ['tomli']
2323
args: ['--toml', 'pyproject.toml']
2424
- repo: https://github.com/PyCQA/isort
25-
rev: 5.13.2
25+
rev: 0a0b7a830386ba6a31c2ec8316849ae4d1b8240d # 6.0.0
2626
hooks:
2727
- id: isort
2828
exclude: 'vllm/third_party/.*'
@@ -44,6 +44,13 @@ repos:
4444
hooks:
4545
- id: actionlint
4646
exclude: 'vllm/third_party/.*'
47+
repos:
48+
- repo: https://github.com/astral-sh/uv-pre-commit
49+
rev: 0.6.2
50+
hooks:
51+
- id: pip-compile
52+
args: [requirements-test.in, -o, requirements-test.txt]
53+
files: ^requirements-test\.(in|txt)$
4754
- repo: local
4855
hooks:
4956
- id: mypy-local

csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,18 @@ static inline cute::Shape<int, int, int, int> get_problem_shape(
3030
}
3131

3232
template <typename GemmKernel>
33-
void cutlass_gemm_caller(torch::Device device,
34-
cute::Shape<int, int, int, int> prob_shape,
35-
typename GemmKernel::MainloopArguments mainloop_args,
36-
typename GemmKernel::EpilogueArguments epilogue_args) {
33+
void cutlass_gemm_caller(
34+
torch::Device device, cute::Shape<int, int, int, int> prob_shape,
35+
typename GemmKernel::MainloopArguments mainloop_args,
36+
typename GemmKernel::EpilogueArguments epilogue_args,
37+
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
38+
cutlass::KernelHardwareInfo hw_info;
3739
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
38-
prob_shape, mainloop_args, epilogue_args};
40+
prob_shape,
41+
mainloop_args,
42+
epilogue_args,
43+
hw_info,
44+
scheduler};
3945

4046
// Launch the CUTLASS GEMM kernel.
4147
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ namespace vllm {
2222

2323
using namespace cute;
2424

25-
template <typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
26-
int TileSizeM_ = 128, class ClusterShape = Shape<_1, _2, _1>>
25+
template <typename SchedulerType, typename OutType, int GroupSizeM_,
26+
int GroupSizeN_, int GroupSizeK_, int TileSizeM_ = 128,
27+
class ClusterShape = Shape<_1, _2, _1>>
2728
struct cutlass_3x_gemm_fp8_blockwise {
2829
using GroupSizeM = Int<GroupSizeM_>;
2930
using GroupSizeN = Int<GroupSizeN_>;
@@ -84,7 +85,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
8485

8586
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
8687
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
87-
cutlass::gemm::PersistentScheduler>>;
88+
SchedulerType>>;
8889

8990
struct GemmKernel : public KernelType {};
9091

@@ -150,8 +151,24 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
150151
typename GemmKernel::EpilogueArguments epilogue_args{
151152
{}, c_ptr, c_stride, c_ptr, c_stride};
152153

154+
typename GemmKernel::TileSchedulerArguments scheduler;
155+
156+
static constexpr bool UsesStreamKScheduler =
157+
cute::is_same_v<typename GemmKernel::TileSchedulerTag,
158+
cutlass::gemm::StreamKScheduler>;
159+
160+
if constexpr (UsesStreamKScheduler) {
161+
using DecompositionMode = typename cutlass::gemm::kernel::detail::
162+
PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
163+
using ReductionMode = typename cutlass::gemm::kernel::detail::
164+
PersistentTileSchedulerSm90StreamKParams::ReductionMode;
165+
166+
scheduler.decomposition_mode = DecompositionMode::StreamK;
167+
scheduler.reduction_mode = ReductionMode::Nondeterministic;
168+
}
169+
153170
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
154-
epilogue_args);
171+
epilogue_args, scheduler);
155172
}
156173

157174
template <typename OutType>
@@ -160,9 +177,18 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
160177
torch::Tensor const& b,
161178
torch::Tensor const& a_scales,
162179
torch::Tensor const& b_scales) {
163-
cutlass_gemm_caller_blockwise<
164-
cutlass_3x_gemm_fp8_blockwise<OutType, 1, 128, 128>>(out, a, b, a_scales,
165-
b_scales);
180+
auto k = a.size(1);
181+
auto n = b.size(1);
182+
183+
if (k > 3 * n) {
184+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
185+
cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>(
186+
out, a, b, a_scales, b_scales);
187+
} else {
188+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
189+
cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(
190+
out, a, b, a_scales, b_scales);
191+
}
166192
}
167193

168194
} // namespace vllm

csrc/quantization/fp4/nvfp4_quant_kernels.cu

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -348,10 +348,7 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
348348
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
349349
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
350350
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
351-
auto stream = at::cuda::getStreamFromPool(false, input.get_device());
352-
if (stream == nullptr) {
353-
std::cerr << "Warning: Null CUDA stream" << std::endl;
354-
}
351+
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
355352

356353
// We don't support e8m0 scales at this moment.
357354
bool useUE8M0 = false;

csrc/quantization/gguf/vecdotq.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *
3737
return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
3838
}
3939

40+
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
41+
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
4042

4143
#define VDR_Q4_0_Q8_1_MMVQ 2
4244
#define VDR_Q4_0_Q8_1_MMQ 4

docs/source/deployment/integrations/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
kserve
77
kubeai
88
llamastack
9+
llmaz
910
:::
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
(deployment-llmaz)=
2+
3+
# llmaz
4+
5+
[llmaz](https://github.com/InftyAI/llmaz) is an easy-to-use and advanced inference platform for large language models on Kubernetes, aimed for production use. It uses vLLM as the default model serving backend.
6+
7+
Please refer to the [Quick Start](https://github.com/InftyAI/llmaz?tab=readme-ov-file#quick-start) for more details.

docs/source/features/structured_outputs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The following parameters are supported, which must be added as extra parameters:
1616
- `guided_json`: the output will follow the JSON schema.
1717
- `guided_grammar`: the output will follow the context free grammar.
1818
- `guided_whitespace_pattern`: used to override the default whitespace pattern for guided json decoding.
19-
- `guided_decoding_backend`: used to select the guided decoding backend to use.
19+
- `guided_decoding_backend`: used to select the guided decoding backend to use. Additional backend-specific options can be supplied in a comma separated list following a colon after the backend name. For example `"xgrammar:no-fallback"` will not allow vLLM to fallback to a different backend on error.
2020

2121
You can see the complete list of supported parameters on the [OpenAI-Compatible Server](#openai-compatible-server)page.
2222

docs/source/models/pooling_models.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ A code example can be found here: <gh-file:examples/offline_inference/basic/clas
108108
### `LLM.score`
109109

110110
The {class}`~vllm.LLM.score` method outputs similarity scores between sentence pairs.
111-
It is primarily designed for [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html).
112-
These types of models serve as rerankers between candidate query-document pairs in RAG systems.
111+
It is designed for embedding models and cross encoder models. Embedding models use cosine similarity, and [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html) serve as rerankers between candidate query-document pairs in RAG systems.
113112

114113
:::{note}
115114
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.

docs/source/serving/openai_compatible_server.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ In addition, we have the following custom APIs:
5151
- [Pooling API](#pooling-api) (`/pooling`)
5252
- Applicable to all [pooling models](../models/pooling_models.md).
5353
- [Score API](#score-api) (`/score`)
54-
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`).
54+
- Applicable to embedding models and [cross-encoder models](../models/pooling_models.md) (`--task score`).
5555
- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
5656
- Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/)
5757
- Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank)
@@ -333,10 +333,10 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py>
333333

334334
### Score API
335335

336-
Our Score API applies a cross-encoder model to predict scores for sentence pairs.
336+
Our Score API can apply a cross-encoder model or an embedding model to predict scores for sentence pairs. When using an embedding model the score corresponds to the cosine similarity between each embedding pair.
337337
Usually, the score for a sentence pair refers to the similarity between two sentences, on a scale of 0 to 1.
338338

339-
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
339+
You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
340340

341341
Code example: <gh-file:examples/online_serving/openai_cross_encoder_score.py>
342342

@@ -496,11 +496,11 @@ The following extra parameters are supported:
496496

497497
### Re-rank API
498498

499-
Our Re-rank API applies a cross-encoder model to predict relevant scores between a single query, and
499+
Our Re-rank API can apply an embedding model or a cross-encoder model to predict relevant scores between a single query, and
500500
each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on
501501
a scale of 0 to 1.
502502

503-
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
503+
You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
504504

505505
The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the
506506
`score` task. Additionally, `/rerank`, `/v1/rerank`, and `/v2/rerank`

0 commit comments

Comments
 (0)