Skip to content
49 changes: 49 additions & 0 deletions tests/entrypoints/openai/test_prompt_validation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import io

# imports for guided decoding tests
import openai
import pybase64
import pytest
import regex as re
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably causing the fork issue

Copy link
Contributor Author

@qthequartermasterman qthequartermasterman Aug 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think it wasn't the torch import, but investigating your comment did help me better isolate the issue! Instead hypothesis-torch had a side effect that was initializing cuda during test collection. Because it has a hypothesis plugin, this side effect of registering strategies for certain types occured any time pytest was invoked, regardless of whether the collected tests depended on hypothesis or hypothesis-torch.

I submitted a patch to hypothesis-torch, but it didn't seem to fully resolve the issue on the failing tests running locally. I realized though that even if that issue were resolved, simply trying to instantiate a CUDA tensor in this test would be enough to re-initialize cuda and cause the failures in any tests that ran alongside these. So I decided to just drop using hypothesis-torch altogether for generating arbitrary tensors, as well as testing against non-cpu tensors.

The test is weaker than I'd like, but side effects make life difficult. 😢

We'll see if CI is happy after this. Thanks for taking a look.


from vllm.entrypoints.openai.serving_engine import OpenAIServing

from ...utils import RemoteOpenAIServer

Expand Down Expand Up @@ -42,3 +48,46 @@ async def test_out_of_vocab_token_ids():
prompt=[999999],
max_tokens=5,
temperature=0.0)


@pytest.mark.parametrize("dtype",
[torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"layout",
[torch.strided, torch.sparse_coo, torch.sparse_csc, torch.sparse_csr])
@pytest.mark.parametrize("seq_len", [2, 10])
@pytest.mark.parametrize("hidden_size", [2, 10])
def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout,
seq_len: int, hidden_size: int):
# construct arbitrary tensors of various dtypes, layouts, and sizes.
# We need to check against different layouts to make sure that if a user
# uses sparse tensors to reduce the transmission size of prompt embeddings,
# we must cast them to dense/strided before passing them into the engine.
# We don't use non-CPU tensors in this test to avoid preemptively
# initializing cuda and break other tests in the suite that fork processes.
# We also need to make sure that we only use devices that are actually
# available in the environment the test is running on. For simplicity,
# we just test against CPU.
tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
if layout == torch.strided:
tensor = tensor.contiguous()
elif layout == torch.sparse_coo:
tensor = tensor.to_sparse_coo()
elif layout == torch.sparse_csc:
tensor = tensor.to_sparse_csc()
elif layout == torch.sparse_csr:
tensor = tensor.to_sparse_csr()

buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
encoded_tensor = pybase64.b64encode(buffer.getvalue())

loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor)
assert len(loaded_prompt_embeds) == 1
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
assert loaded_tensor.device.type == "cpu"
assert loaded_tensor.layout == torch.strided
torch.testing.assert_close(loaded_tensor,
tensor.to("cpu").to_dense(),
equal_nan=True)
6 changes: 4 additions & 2 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,21 +1006,23 @@ async def _generate_with_builtin_tools(
# OPTIMIZATION
priority = orig_priority - 1

@staticmethod
def _load_prompt_embeds(
self,
prompt_embeds: Optional[Union[bytes, list[bytes]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
) -> list[EmbedsPrompt]:

def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
tensor = torch.load(io.BytesIO(
pybase64.b64decode(embed, validate=True)),
weights_only=True)
weights_only=True,
map_location=torch.device("cpu"))
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
Expand Down