Skip to content
2 changes: 2 additions & 0 deletions requirements/test.in
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ mteb[bm25s]>=1.38.11, <2 # required for mteb test
transformers==4.55.0
tokenizers==0.21.1
schemathesis>=3.39.15 # Required for openai schema test.
hypothesis>=6.131.0 # Required for prompt embeds tests and openai schema test.
hypothesis-torch>=1.1.0 # Required for prompt embeds tests
# quantization
bitsandbytes==0.46.1
buildkite-test-collector==0.1.9
Expand Down
5 changes: 5 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,17 @@ hydra-core==1.3.2
# lightning
hypothesis==6.131.0
# via
# -r requirements/test.in
# hypothesis-graphql
# hypothesis-jsonschema
# hypothesis-torch
# schemathesis
hypothesis-graphql==0.11.1
# via schemathesis
hypothesis-jsonschema==0.23.1
# via schemathesis
hypothesis-torch==1.1.0
# via -r requirements/test.in
idna==3.10
# via
# anyio
Expand Down Expand Up @@ -1074,6 +1078,7 @@ torch==2.7.1+cu128
# efficientnet-pytorch
# encodec
# fastsafetensors
# hypothesis-torch
# kornia
# lightly
# lightning
Expand Down
32 changes: 32 additions & 0 deletions tests/entrypoints/openai/test_prompt_validation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import io

import hypothesis
import hypothesis_torch
# imports for guided decoding tests
import openai
import pybase64
import pytest
import regex as re
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably causing the fork issue

Copy link
Contributor Author

@qthequartermasterman qthequartermasterman Aug 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think it wasn't the torch import, but investigating your comment did help me better isolate the issue! Instead hypothesis-torch had a side effect that was initializing cuda during test collection. Because it has a hypothesis plugin, this side effect of registering strategies for certain types occured any time pytest was invoked, regardless of whether the collected tests depended on hypothesis or hypothesis-torch.

I submitted a patch to hypothesis-torch, but it didn't seem to fully resolve the issue on the failing tests running locally. I realized though that even if that issue were resolved, simply trying to instantiate a CUDA tensor in this test would be enough to re-initialize cuda and cause the failures in any tests that ran alongside these. So I decided to just drop using hypothesis-torch altogether for generating arbitrary tensors, as well as testing against non-cpu tensors.

The test is weaker than I'd like, but side effects make life difficult. 😢

We'll see if CI is happy after this. Thanks for taking a look.

from hypothesis import strategies as st

from vllm.entrypoints.openai.serving_engine import OpenAIServing

from ...utils import RemoteOpenAIServer

Expand Down Expand Up @@ -42,3 +51,26 @@ async def test_out_of_vocab_token_ids():
prompt=[999999],
max_tokens=5,
temperature=0.0)


@hypothesis.given(tensor=hypothesis_torch.tensor_strategy(
dtype=hypothesis_torch.dtype_strategy(
[torch.float32, torch.bfloat16, torch.float16]),
shape=st.tuples(st.integers(min_value=2, max_value=10),
st.integers(min_value=2, max_value=10)),
device=hypothesis_torch.device_strategy(),
layout=hypothesis_torch.layout_strategy()))
def test_load_prompt_embeds(tensor: torch.Tensor):
buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
encoded_tensor = pybase64.b64encode(buffer.getvalue())

loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor)
assert len(loaded_prompt_embeds) == 1
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
assert loaded_tensor.device.type == "cpu"
assert loaded_tensor.layout == torch.strided
torch.testing.assert_close(loaded_tensor,
tensor.to("cpu").to_dense(),
equal_nan=True)
6 changes: 4 additions & 2 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,21 +1006,23 @@ async def _generate_with_builtin_tools(
# OPTIMIZATION
priority = orig_priority - 1

@staticmethod
def _load_prompt_embeds(
self,
prompt_embeds: Optional[Union[bytes, list[bytes]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
) -> list[EmbedsPrompt]:

def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
tensor = torch.load(io.BytesIO(
pybase64.b64decode(embed, validate=True)),
weights_only=True)
weights_only=True,
map_location=torch.device("cpu"))
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
Expand Down