Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
df7ae76
⬆️ make vllm >=0.10.1.1,<=0.10.2
prashantgupta24 Sep 15, 2025
1584ba4
⬆️ lockfile update to vllm >=0.10.1.1,<=0.10.2
prashantgupta24 Sep 15, 2025
d043a22
⬆️ bump aftu to 0.2.3
prashantgupta24 Sep 15, 2025
2665cda
🚧 changes needed for 0.10.2
prashantgupta24 Sep 16, 2025
1772188
🎨 change the error msg
prashantgupta24 Sep 16, 2025
b00b4e5
🚧 test 0.10.2 instead of main
prashantgupta24 Sep 16, 2025
db5aea7
🐛 fix pooler stuff
prashantgupta24 Sep 17, 2025
703d59d
⏪ revert change for main
prashantgupta24 Sep 17, 2025
0de9b1c
⬆️ bump lowest to 0.10.1.1
prashantgupta24 Sep 17, 2025
6c03694
⬆️ bump default to 0.10.2
prashantgupta24 Sep 17, 2025
5b2f32c
♻️ make platform.py check simple
prashantgupta24 Sep 17, 2025
1c22ec7
🐛 set vllm_config for ClassifierPooler
prashantgupta24 Sep 17, 2025
125a9b3
Merge remote-tracking branch 'upstream/main' into upstream-versions
prashantgupta24 Sep 22, 2025
d63fb01
🚧 add baclward compatibility code
prashantgupta24 Sep 22, 2025
d9965a9
✅ add upstream compat tests
prashantgupta24 Sep 22, 2025
e5ab49b
🐛 fix request params
prashantgupta24 Sep 22, 2025
7c2199e
🐛 revert pytest-mock import
prashantgupta24 Sep 22, 2025
480a177
🚧 not needed?
prashantgupta24 Sep 22, 2025
493d727
⏪ yep need those for pooler models
prashantgupta24 Sep 22, 2025
f93eac5
🎨 fix comment
prashantgupta24 Sep 22, 2025
25ba89a
🎨 remove extra assert
prashantgupta24 Sep 22, 2025
4f1a6a2
🎨 typo
prashantgupta24 Sep 23, 2025
1012c08
Merge branch 'main' into upstream-versions
maxdebayser Sep 24, 2025
67cc853
fix pooler adapter
maxdebayser Sep 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jobs:
include:
- vllm_version:
name: "vLLM:lowest"
repo: "git+https://github.com/vllm-project/vllm --tag v0.10.0"
repo: "git+https://github.com/vllm-project/vllm --tag v0.10.1.1"
test_suite:
name: "backward compat"
markers: "compat or (cpu and basic)"
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ license = {text = "Apache 2"}
dependencies = [
"fms-model-optimizer[fp8]>=0.6.0",
"ibm-fms>=1.2.1",
"vllm>=0.10.0,<=0.10.1.1",
"vllm>=0.10.1.1,<=0.10.2",
]
requires-python = ">=3.11"
dynamic = ["version"]
Expand Down Expand Up @@ -163,7 +163,7 @@ dev = [
"pytest-timeout==2.3.1",
"requests==2.32.3",
"sentence-transformers==3.4.1",
"aiu-fms-testing-utils>=0.2.1",
"aiu-fms-testing-utils>=0.2.3",
]
lint = [
"clang-format==18.1.5",
Expand Down
3 changes: 1 addition & 2 deletions tests/e2e/test_spyre_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def test_api_cb_rejects_oversized_request(
overflow_prompt = " ".join(["hi"] * max_model_len)
max_tokens = 10

with pytest.raises(BadRequestError,
match="This model's maximum context length is"):
with pytest.raises(BadRequestError, match="maximum context length is"):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug in vllm upstream - opened a PR vllm-project/vllm#24995

client.completions.create(
model=model.name,
prompt=overflow_prompt,
Expand Down
8 changes: 3 additions & 5 deletions tests/spyre_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,12 @@ def create_random_request(
cache_salt=None,
**kwargs,
)
kwargs = {
"multi_modal_kwargs" if inputs_renamed else "multi_modal_inputs": None
}
kwargs = {}
if inputs_renamed:
kwargs = {"multi_modal_kwargs"}
return Request(
request_id=str(request_id),
prompt_token_ids=prompt_token_ids,
multi_modal_hashes=None,
multi_modal_placeholders=None,
sampling_params=sampling_params,
eos_token_id=None,
arrival_time=0,
Expand Down
205 changes: 104 additions & 101 deletions uv.lock

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@
]


# Add comment
class _StreamPlaceholder:

def __init__(self):
self.synchronize = lambda: None


class classproperty:

def __init__(self, func):
Expand Down Expand Up @@ -80,6 +87,9 @@ class SpyrePlatform(Platform):
# See vllm batched_count_greater_than method
# simple_compile_backend: str = "eager"

# ADD COMMENT
current_stream = lambda _: _StreamPlaceholder()

@classproperty
def device_type(cls):
# TODO: temporary hack while BertModels
Expand All @@ -106,6 +116,12 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

# in case vllm passes a default vllm_config to us
# add some more comments as to why this needed
if vllm_config.model_config is None:
return

cls._config = vllm_config
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
Expand Down
31 changes: 18 additions & 13 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch import nn
from transformers import (AutoModel, AutoModelForSequenceClassification,
AutoTokenizer)
from vllm.config import DeviceConfig, VllmConfig
from vllm.config import DeviceConfig, VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
Expand Down Expand Up @@ -558,7 +558,6 @@ def execute_model(
req_ids=list(req_id_to_index.keys()),
req_id_to_index=req_id_to_index,
sampled_token_ids=output.sampled_token_ids.tolist(),
spec_token_ids=None,
logprobs=(output.logprobs_tensors.tolists()
if output.logprobs_tensors else None),
prompt_logprobs_dict=prompt_logprobs_dicts,
Expand Down Expand Up @@ -1448,15 +1447,17 @@ def load_model(self, prompt_lens: Iterable[int],
extra_args['default_pooling_type'] = PoolingType.CLS

if task == "embed":
self.pooler = Pooler.for_embed(pooler_config=pooler_config,
**extra_args)
with set_current_vllm_config(self.vllm_config):
Copy link
Collaborator Author

@prashantgupta24 prashantgupta24 Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to set_current_vllm_config because now Pooler class needs the vllm config for it to read vllm_config.model_config.head_dtype

self.pooler = Pooler.for_embed(pooler_config=pooler_config,
**extra_args)
elif task == "classify":
self.pooler = ClassifierPooler(
pooling=self._pooler,
classifier=ClassifierAdapter(self.classifier),
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
self.model_config),
)
with set_current_vllm_config(self.vllm_config):
self.pooler = ClassifierPooler(
pooling=self._pooler,
classifier=ClassifierAdapter(self.classifier),
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
self.model_config),
)

@property
def vocab_size(self) -> int:
Expand Down Expand Up @@ -1630,6 +1631,10 @@ def execute_model(
logger.debug("t_batch: %.2fms", (t1 * 1000))

pooling_metadata = self.input_batch.make_pooling_metadata()
## No partial prefill, hence
pooling_metadata.build_pooling_cursor(
num_scheduled_tokens=pooling_metadata.prompt_lens,
device=self.device)

# prepare unpadded output for the pooler
hidden_state_list: list[torch.Tensor] = []
Expand All @@ -1638,8 +1643,9 @@ def execute_model(
# we're left padding
hidden_state_list.append(hidden_state[-prompt_len:])

raw_pooler_output = self.pooler(hidden_states=hidden_state_list,
pooling_metadata=pooling_metadata)
raw_pooler_output = self.pooler(
hidden_states=torch.cat(hidden_state_list),
pooling_metadata=pooling_metadata)

pooler_output: list[Optional[torch.Tensor]] = []

Expand All @@ -1650,7 +1656,6 @@ def execute_model(
req_ids=self.input_batch.requests_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
Expand Down
10 changes: 5 additions & 5 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
Expand Down Expand Up @@ -508,7 +508,7 @@ def _cleanup_model_runner(self, request) -> None:
num_common_prefix_blocks=0,
# The requests to be removed
finished_req_ids=set([r.req_id for r in request]),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
Expand Down Expand Up @@ -590,7 +590,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
Expand Down Expand Up @@ -655,7 +655,7 @@ def _dynamic_warmup(
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
Expand Down Expand Up @@ -692,7 +692,7 @@ def _dynamic_warmup(
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
Expand Down
Loading