From 9c57bd3cdb79a691760517102739d5db420ccea0 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 14 Aug 2025 23:40:28 -0500 Subject: [PATCH 1/6] fix: more robust prompt embeds loading Signed-off-by: Andrew Sansom --- requirements/test.in | 2 ++ requirements/test.txt | 5 +++ .../openai/test_prompt_validation.py | 34 +++++++++++++++++++ vllm/entrypoints/openai/serving_engine.py | 6 ++-- 4 files changed, 45 insertions(+), 2 deletions(-) diff --git a/requirements/test.in b/requirements/test.in index 6652bfdfe66c..4ab34bd848f4 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -37,6 +37,8 @@ mteb[bm25s]>=1.38.11, <2 # required for mteb test transformers==4.55.0 tokenizers==0.21.1 schemathesis>=3.39.15 # Required for openai schema test. +hypothesis>=6.131.0 # Required for prompt embeds tests and openai schema test. +hypothesis-torch>=1.1.0 # Required for prompt embeds tests # quantization bitsandbytes==0.46.1 buildkite-test-collector==0.1.9 diff --git a/requirements/test.txt b/requirements/test.txt index ff9886a31597..0c4c136cbc39 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -307,13 +307,17 @@ hydra-core==1.3.2 # lightning hypothesis==6.131.0 # via + # -r requirements/test.in # hypothesis-graphql # hypothesis-jsonschema + # hypothesis-torch # schemathesis hypothesis-graphql==0.11.1 # via schemathesis hypothesis-jsonschema==0.23.1 # via schemathesis +hypothesis-torch==1.1.0 + # via -r requirements/test.in idna==3.10 # via # anyio @@ -1074,6 +1078,7 @@ torch==2.7.1+cu128 # efficientnet-pytorch # encodec # fastsafetensors + # hypothesis-torch # kornia # lightly # lightning diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index e31a1d077608..10e3542032a9 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -1,10 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import io + +import hypothesis +import hypothesis_torch # imports for guided decoding tests import openai +import pybase64 import pytest import regex as re +import torch +from hypothesis import strategies as st + +from vllm.entrypoints.openai.serving_engine import OpenAIServing from ...utils import RemoteOpenAIServer @@ -42,3 +51,28 @@ async def test_out_of_vocab_token_ids(): prompt=[999999], max_tokens=5, temperature=0.0) + + +@hypothesis.settings(max_examples=10000) +@hypothesis.given(tensor=hypothesis_torch.tensor_strategy( + dtype=hypothesis_torch.dtype_strategy( + [torch.float32, torch.bfloat16, torch.float16]), + shape=st.tuples(st.integers(min_value=2, max_value=10), + st.integers(min_value=2, max_value=10)), + device=hypothesis_torch.device_strategy(), + layout=st.just(torch.sparse_coo))) +def test_load_prompt_embeds(tensor: torch.Tensor): + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + encoded_tensor = pybase64.b64encode(buffer.getvalue()) + assert tensor.layout == torch.sparse_coo + + loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor) + assert len(loaded_prompt_embeds) == 1 + loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"] + assert loaded_tensor.device.type == "cpu" + assert loaded_tensor.layout == torch.strided + torch.testing.assert_close(loaded_tensor, + tensor.to("cpu").to_dense(), + equal_nan=True) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index d6f92a63301e..0f4a7c0186b6 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1006,8 +1006,8 @@ async def _generate_with_builtin_tools( # OPTIMIZATION priority = orig_priority - 1 + @staticmethod def _load_prompt_embeds( - self, prompt_embeds: Optional[Union[bytes, list[bytes]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None ) -> list[EmbedsPrompt]: @@ -1015,12 +1015,14 @@ def _load_prompt_embeds( def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: tensor = torch.load(io.BytesIO( pybase64.b64decode(embed, validate=True)), - weights_only=True) + weights_only=True, + map_location=torch.device("cpu")) assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( torch.float32, torch.bfloat16, torch.float16, ) + tensor = tensor.to_dense() if tensor.dim() > 2: tensor = tensor.squeeze(0) assert tensor.dim() == 2 From 78e4c89bc973dea22d19429a5603db1983ef6fb7 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 14 Aug 2025 23:48:07 -0500 Subject: [PATCH 2/6] test: add all layout types Signed-off-by: Andrew Sansom --- tests/entrypoints/openai/test_prompt_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index 10e3542032a9..5af93c80971f 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -60,7 +60,7 @@ async def test_out_of_vocab_token_ids(): shape=st.tuples(st.integers(min_value=2, max_value=10), st.integers(min_value=2, max_value=10)), device=hypothesis_torch.device_strategy(), - layout=st.just(torch.sparse_coo))) + layout=hypothesis_torch.layout_strategy())) def test_load_prompt_embeds(tensor: torch.Tensor): buffer = io.BytesIO() torch.save(tensor, buffer) From 08143ff5ea63292b2c8962a2e5d742a0275dcaa3 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Thu, 14 Aug 2025 23:50:03 -0500 Subject: [PATCH 3/6] test: do not require so many examples Signed-off-by: Andrew Sansom --- tests/entrypoints/openai/test_prompt_validation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index 5af93c80971f..6e36eea616ac 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -53,7 +53,6 @@ async def test_out_of_vocab_token_ids(): temperature=0.0) -@hypothesis.settings(max_examples=10000) @hypothesis.given(tensor=hypothesis_torch.tensor_strategy( dtype=hypothesis_torch.dtype_strategy( [torch.float32, torch.bfloat16, torch.float16]), From 3f539e6af21146c4f8d683b8483bf9c6a27639da Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 15 Aug 2025 07:20:01 -0500 Subject: [PATCH 4/6] test: do not assert tensor is sparse Signed-off-by: Andrew Sansom --- tests/entrypoints/openai/test_prompt_validation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index 6e36eea616ac..46b54d6ca468 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -65,7 +65,6 @@ def test_load_prompt_embeds(tensor: torch.Tensor): torch.save(tensor, buffer) buffer.seek(0) encoded_tensor = pybase64.b64encode(buffer.getvalue()) - assert tensor.layout == torch.sparse_coo loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor) assert len(loaded_prompt_embeds) == 1 From 5275035771c5cdbd1a97fcc1c1a3fd4fdd1753f0 Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 15 Aug 2025 16:13:14 -0500 Subject: [PATCH 5/6] test: remove hypothesis-torch dependency Signed-off-by: Andrew Sansom --- requirements/test.in | 2 - requirements/test.txt | 5 --- .../openai/test_prompt_validation.py | 40 ++++++++++++++----- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/requirements/test.in b/requirements/test.in index 4ab34bd848f4..6652bfdfe66c 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -37,8 +37,6 @@ mteb[bm25s]>=1.38.11, <2 # required for mteb test transformers==4.55.0 tokenizers==0.21.1 schemathesis>=3.39.15 # Required for openai schema test. -hypothesis>=6.131.0 # Required for prompt embeds tests and openai schema test. -hypothesis-torch>=1.1.0 # Required for prompt embeds tests # quantization bitsandbytes==0.46.1 buildkite-test-collector==0.1.9 diff --git a/requirements/test.txt b/requirements/test.txt index 0c4c136cbc39..ff9886a31597 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -307,17 +307,13 @@ hydra-core==1.3.2 # lightning hypothesis==6.131.0 # via - # -r requirements/test.in # hypothesis-graphql # hypothesis-jsonschema - # hypothesis-torch # schemathesis hypothesis-graphql==0.11.1 # via schemathesis hypothesis-jsonschema==0.23.1 # via schemathesis -hypothesis-torch==1.1.0 - # via -r requirements/test.in idna==3.10 # via # anyio @@ -1078,7 +1074,6 @@ torch==2.7.1+cu128 # efficientnet-pytorch # encodec # fastsafetensors - # hypothesis-torch # kornia # lightly # lightning diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index 46b54d6ca468..b76e68128e32 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -3,15 +3,12 @@ import io -import hypothesis -import hypothesis_torch # imports for guided decoding tests import openai import pybase64 import pytest import regex as re import torch -from hypothesis import strategies as st from vllm.entrypoints.openai.serving_engine import OpenAIServing @@ -53,14 +50,35 @@ async def test_out_of_vocab_token_ids(): temperature=0.0) -@hypothesis.given(tensor=hypothesis_torch.tensor_strategy( - dtype=hypothesis_torch.dtype_strategy( - [torch.float32, torch.bfloat16, torch.float16]), - shape=st.tuples(st.integers(min_value=2, max_value=10), - st.integers(min_value=2, max_value=10)), - device=hypothesis_torch.device_strategy(), - layout=hypothesis_torch.layout_strategy())) -def test_load_prompt_embeds(tensor: torch.Tensor): +@pytest.mark.parametrize("dtype", + [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize( + "layout", + [torch.strided, torch.sparse_coo, torch.sparse_csc, torch.sparse_csr]) +@pytest.mark.parametrize("seq_len", [2, 10]) +@pytest.mark.parametrize("hidden_size", [2, 10]) +def test_load_prompt_embeds(device: torch.device, dtype: torch.dtype, + layout: torch.layout, seq_len: int, + hidden_size: int): + # construct arbitrary tensors of various dtypes, layouts, and sizes. + # We need to check against different layouts to make sure that if a user + # uses sparse tensors to reduce the transmission size of prompt embeddings, + # we must cast them to dense/strided before passing them into the engine. + # We don't use non-CPU tensors in this test to avoid preemptively + # initializing cuda and break other tests in the suite that fork processes. + # We also need to make sure that we only use devices that are actually + # available in the environment the test is running on. For simplicity, + # we just test against CPU. + tensor = torch.randn((seq_len, hidden_size), dtype=dtype) + if layout == torch.strided: + tensor = tensor.contiguous() + elif layout == torch.sparse_coo: + tensor = tensor.to_sparse_coo() + elif layout == torch.sparse_csc: + tensor = tensor.to_sparse_csc() + elif layout == torch.sparse_csr: + tensor = tensor.to_sparse_csr() + buffer = io.BytesIO() torch.save(tensor, buffer) buffer.seek(0) From db84d8639a9880e742c8b3af57fdd2b5331f831e Mon Sep 17 00:00:00 2001 From: Andrew Sansom Date: Fri, 15 Aug 2025 22:58:37 -0500 Subject: [PATCH 6/6] test: remvoe vestigial device argument Signed-off-by: Andrew Sansom --- tests/entrypoints/openai/test_prompt_validation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index b76e68128e32..4197583074df 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -57,9 +57,8 @@ async def test_out_of_vocab_token_ids(): [torch.strided, torch.sparse_coo, torch.sparse_csc, torch.sparse_csr]) @pytest.mark.parametrize("seq_len", [2, 10]) @pytest.mark.parametrize("hidden_size", [2, 10]) -def test_load_prompt_embeds(device: torch.device, dtype: torch.dtype, - layout: torch.layout, seq_len: int, - hidden_size: int): +def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout, + seq_len: int, hidden_size: int): # construct arbitrary tensors of various dtypes, layouts, and sizes. # We need to check against different layouts to make sure that if a user # uses sparse tensors to reduce the transmission size of prompt embeddings,