Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ license = {text = "Apache 2"}
dependencies = [
"fms-model-optimizer[fp8]>=0.8.0",
"ibm-fms>=1.7.0,<2.0",
"vllm>=0.15.1,<=0.15.1",
"vllm>=0.16.0,<=0.16.0",
]
requires-python = ">=3.11"
dynamic = ["version"]
Expand Down Expand Up @@ -76,7 +76,7 @@ environments = [
]

[tool.uv.sources]
vllm = { git = "https://github.com/vllm-project/vllm", rev = "v0.15.1" }
vllm = { git = "https://github.com/vllm-project/vllm", rev = "v0.16.0" }

[tool.ty.rules]
possibly-missing-attribute = "ignore"
Expand Down
10 changes: 5 additions & 5 deletions tests/utils/test_platform_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_strips_structured_outputs(self):

assert params.structured_outputs is not None

SpyrePlatform.validate_request("Test prompt", params)
SpyrePlatform.validate_request("Test prompt", params, processed_inputs=None)

assert params.structured_outputs is None

Expand All @@ -34,7 +34,7 @@ def test_logs_warning_when_stripping(self, caplog_vllm_spyre):
max_tokens=20, structured_outputs=StructuredOutputsParams(json_object=True)
)

SpyrePlatform.validate_request("Test prompt", params)
SpyrePlatform.validate_request("Test prompt", params, processed_inputs=None)

assert len(caplog_vllm_spyre.records) > 0
warning_record = caplog_vllm_spyre.records[0]
Expand All @@ -55,7 +55,7 @@ def test_strips_different_structured_output_types(self, structured_output):

assert params.structured_outputs is not None

SpyrePlatform.validate_request("Test prompt", params)
SpyrePlatform.validate_request("Test prompt", params, processed_inputs=None)

assert params.structured_outputs is None

Expand All @@ -77,7 +77,7 @@ def test_preserves_other_sampling_params(self):
"top_k": params.top_k,
}

SpyrePlatform.validate_request("Test prompt", params)
SpyrePlatform.validate_request("Test prompt", params, processed_inputs=None)

# Verify other params are unchanged
assert params.max_tokens == original_values["max_tokens"]
Expand All @@ -92,7 +92,7 @@ def test_does_not_affect_pooling_params(self):
pooling_params = PoolingParams()

# Should not raise any errors and should return early
SpyrePlatform.validate_request("Test prompt", pooling_params)
SpyrePlatform.validate_request("Test prompt", pooling_params, processed_inputs=None)

# PoolingParams don't have structured_outputs, so just verify no exception
assert True # If we got here, the early return worked
Expand Down
10 changes: 5 additions & 5 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions vllm_spyre/multimodal/mm_mappings/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ def _build_multimodal_spec(proc_res):
}
mm_fields = MultiModalKwargsItem(
{
mm_key: MultiModalFieldElem(
modality="image", key=mm_key, data=mm_data, field=MultiModalBatchedField()
)
mm_key: MultiModalFieldElem(data=mm_data, field=MultiModalBatchedField())
for mm_key, mm_data in mm_data.items()
}
)
Expand Down
4 changes: 1 addition & 3 deletions vllm_spyre/multimodal/mm_mappings/mistral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,7 @@ def get_warmup_inputs(self, req_count: int) -> MMWarmupInputs:
}
mm_fields = MultiModalKwargsItem(
{
mm_key: MultiModalFieldElem(
modality="image", key=mm_key, data=mm_data, field=MultiModalBatchedField()
)
mm_key: MultiModalFieldElem(data=mm_data, field=MultiModalBatchedField())
for mm_key, mm_data in mm_data.items()
}
)
Expand Down
16 changes: 11 additions & 5 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
import math
import operator
import os
from typing import TYPE_CHECKING, Union, cast
from typing import TYPE_CHECKING, cast

import torch
from vllm.inputs import ProcessorInputs, PromptType, TokenInputs
from vllm.logger import init_logger
from vllm.utils.argparse_utils import FlexibleArgumentParser

Expand All @@ -26,11 +25,18 @@
from vllm.config import ModelConfig, VllmConfig
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.inputs import ProcessorInputs, PromptType, TokenInputs
else:
ModelConfig = None
VllmConfig = None
SamplingParams = None
PoolingParams = None
DictPrompt = None
TokPrompt = None
ProcessorInputs = None
PromptType = None
TokenInputs = None
from vllm.platforms import Platform, PlatformEnum

import vllm_spyre.envs as envs_spyre
Expand Down Expand Up @@ -341,9 +347,9 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
@classmethod
def validate_request(
cls,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
processed_inputs: ProcessorInputs | None = None,
prompt: "PromptType | DictPrompt | TokPrompt",
params: "SamplingParams | PoolingParams",
processed_inputs: "ProcessorInputs",
) -> None:
"""Raises if this request is unsupported on this platform"""

Expand Down
60 changes: 30 additions & 30 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import vllm.envs as envs
from huggingface_hub import hf_hub_download
from vllm.config import VllmConfig
from vllm.profiler.wrapper import TorchProfilerWrapper
from vllm.distributed import ensure_model_parallel_initialized, init_distributed_environment
from vllm.logger import init_logger
from vllm.pooling_params import PoolingParams
Expand Down Expand Up @@ -261,14 +262,23 @@ def __init__(
)

self._env_initialized = False
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s", torch_profiler_trace_dir)
# Torch profiler. Enabled and configured through ProfilerConfig. Set via:
# --profiler-config.profiler=torch
# --profiler-config.torch_profiler_dir=/path/to/save/trace)
# OR
# --profiler-config '{"profiler": "torch", "torch_profiler_dir": "/path/to/save/trace"}'
profiler_config = vllm_config.profiler_config
if profiler_config.profiler == "torch":
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
self.profiler: TorchProfilerWrapper | None = TorchProfilerWrapper(
profiler_config,
worker_name=worker_name,
local_rank=self.local_rank,
activities=["CPU"],
)

if SpyrePlatform.is_backend_sendnn_enabled():
logger.info(
logger.info_once(
"Traces will contain AIU events if PyTorch with"
" AIU profiling support is installed."
)
Expand All @@ -279,32 +289,12 @@ def __init__(
options = dict(opt.split("=") for opt in dt_opt.split(",") if "=" in opt)
autopilot_opt = options.get("autopilot", "1") # autopilot defaults to 1 if not set
if autopilot_opt == "1":
logger.warning(
logger.warning_once(
"autopilot on detected with profiling enabled. Add "
"autpilot=0 to DT_OPT to see individual AIU-kernel "
"autopilot=0 to DT_OPT to see individual AIU-kernel "
"execution in the trace."
)

logger.debug(
"Profiler config: record_shapes=%s,profile_memory=%s,with_stack=%s,with_flops=%s",
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
envs.VLLM_TORCH_PROFILER_WITH_STACK,
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
)

# TODO: These flags should be set as bools, but are passed through as strings.
# This is probably a bug.
self.profiler = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU],
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, # ty: ignore
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, # ty: ignore
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, # ty: ignore
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, # ty: ignore
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True
),
)
else:
self.profiler = None

Expand Down Expand Up @@ -723,12 +713,20 @@ def _warmup_model_forward_pass(
}
self.execute_model(scheduler_output) # Prefill

def profile(self, is_start=True):
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
raise RuntimeError(
"Profiling is not enabled. Please set --profiler-config to enable "
"profiling. Example: "
"'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir"
"=YOUR_DIR_PATH_TO_DUMP_TRACE'"
)
if is_start:
self.profiler.start()
else:
if self.profiler is None:
logger.warning("Profiler was not started, nothing to stop.")
return
self.profiler.stop()

@property
Expand All @@ -752,6 +750,8 @@ def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput | None:
if self.profiler is not None:
self.profiler.step()
output = self.model_runner.execute_model(scheduler_output)
return output if self.is_driver_worker else None

Expand Down