Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions vllm_spyre/v1/sample/spyre_logit_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import itertools
from typing import Optional, Sequence, Union

import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.sample.logits_processor import (BUILTIN_LOGITS_PROCESSORS,
STR_POOLING_REJECTS_LOGITSPROCS,
BatchUpdate, LogitsProcessor,
_load_custom_logitsprocs)
from vllm.v1.sample.logits_processor.state import LogitsProcessors

logger = init_logger(__name__)


def build_logitsprocs_for_cb(
vllm_config: "VllmConfig",
device: torch.device,
is_pin_memory: bool,
is_pooling_model: bool,
batch_size: int,
custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (),
) -> LogitsProcessors:
if is_pooling_model:
if custom_logitsprocs:
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
logger.debug("Skipping logits processor loading because pooling models"
" do not support logits processors.")
return LogitsProcessors()
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)

return LogitsProcessors(
LogitProcessorWrapper(logit_processor,
vllm_config,
device,
is_pin_memory,
batch_size) \
for logit_processor in itertools.chain(
BUILTIN_LOGITS_PROCESSORS,
custom_logitsprocs_classes
)
)


class LogitProcessorWrapper(LogitsProcessor):
"""Logit processor to inject expected token during generation for tests"""

def __init__(self, logit_processor: LogitsProcessor,
vllm_config: VllmConfig, device: torch.device,
is_pin_memory: bool, batch_size: int):
self.logitprocs: list[LogitsProcessor] = [
logit_processor(vllm_config, device, is_pin_memory) \
for _ in range(batch_size)
]

self._is_argmax_invariant : bool = \
self.logitprocs[0].is_argmax_invariant()

self._prefill_index: Optional[int] = None

def is_argmax_invariant(self) -> bool:
"""Never impacts greedy sampling"""
return self._is_argmax_invariant

def update_state(self, batch_update: Optional[BatchUpdate]):
# This method keeps the indices consistent of request while the
# persistent batch is changing.
if not batch_update:
return

# Process added requests.
for index, params, prompt_tok_ids, out_tok_ids in batch_update.added:
self.logitprocs[index].update_state(
BatchUpdate(
batch_size=1,
removed=[],
moved=[],
added=[(0, params, prompt_tok_ids, out_tok_ids)],
))

for index in batch_update.removed:
self.logitprocs[index].update_state(
BatchUpdate(batch_size=1, removed=[0], moved=[], added=[]))

for adx, bdx, _ in batch_update.moved:
self.logitprocs[adx], self.logitprocs[bdx] = \
self.logitprocs[bdx], self.logitprocs[adx]

def apply(self, logits: torch.Tensor) -> torch.Tensor:

if self._prefill_index is not None:
logits = self.logitprocs[self._prefill_index].apply(logits)
self._prefill_index = None
return logits

batch_size = logits.shape[0]
for i in range(batch_size):
logits[i] = self.logitprocs[i].apply(logits[i].unsqueeze(0))

return logits

def set_prefill_index(self, idx: int) -> None:
self._prefill_index = idx
31 changes: 16 additions & 15 deletions vllm_spyre/v1/worker/spyre_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
MoveDirectionality)
from vllm.v1.sample.metadata import SamplingMetadata

from vllm_spyre.v1.sample.spyre_logit_processor import LogitProcessorWrapper


@dataclass
class BaseRequestState:
Expand Down Expand Up @@ -223,16 +225,13 @@ class SamplingInputBatch(BaseInputBatch[SamplingRequestState]):
condense the sampling parameters.
'''

def __init__(
self,
max_num_reqs: int,
max_model_len: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
# Type here is any for compatibility reasons
logitsprocs: Optional[LogitsProcessors] = None,
):
def __init__(self,
max_num_reqs: int,
max_model_len: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
logitsprocs: Optional[LogitsProcessors] = None):

super().__init__(
max_num_reqs,
Expand Down Expand Up @@ -302,6 +301,9 @@ def __init__(
self.batch_update_builder = BatchUpdateBuilder()

self.logitsprocs = logitsprocs or LogitsProcessors()
self.logitsprocs_wrappers = [lp for lp \
in self.logitsprocs.all if isinstance(lp, LogitProcessorWrapper)
]

self.has_allowed_token_ids: set[str] = set()
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -517,13 +519,12 @@ def remove_request(self, req_id: str):
self.req_indices_mask[req_index] = False

# Remove and move up
tmp_dense = dense_index
self.batch_update_builder.removed_append(tmp_dense)
self.batch_update_builder.removed_append(dense_index)

while tmp_dense < self._num_requests + 1:
end_dense_idx = min(self._num_requests + 1, self.max_num_reqs - 1)
for tmp_dense in range(dense_index, end_dense_idx):
self.batch_update_builder.moved.append(
(tmp_dense, tmp_dense + 1, MoveDirectionality.SWAP))
tmp_dense = tmp_dense + 1
(tmp_dense, tmp_dense + 1, MoveDirectionality.UNIDIRECTIONAL))

# Remove the references
self.req_output_token_ids.pop(dense_index)
Expand Down
37 changes: 35 additions & 2 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from vllm_spyre.model_executor.model_loader.spyre import (
BACKEND_LIST, SpyreAttentionMetadata, SpyreCausalLM)
from vllm_spyre.platform import SpyrePlatform
from vllm_spyre.v1.sample.spyre_logit_processor import build_logitsprocs_for_cb
# yapf conflicts with ruff for this block
# yapf: disable
from vllm_spyre.v1.worker.spyre_input_batch import (BaseInputBatch,
Expand Down Expand Up @@ -979,6 +980,10 @@ def _prepare_prompt(
prefill_index = self.input_batch.add_request(req_state)
self.prefill_batch.add_request(req_state)

# set prefill index for logits processor
for logitsproc in self.input_batch.logitsprocs_wrappers:
logitsproc.set_prefill_index(prefill_index)

# Refresh sampling metadata after all request are added to the batch
self.input_batch.refresh_metadata()
self.prefill_batch.refresh_metadata()
Expand Down Expand Up @@ -1227,8 +1232,13 @@ def build_attn_metadata(
is_prefill=model_input.is_prompt)

def get_sampling_metadata(self, is_prefill: bool) -> SamplingMetadata:
return self.prefill_batch.sampling_metadata \
if is_prefill else self.input_batch.sampling_metadata

if is_prefill:
sampling_data = self.prefill_batch.sampling_metadata
sampling_data.logitsprocs = self.input_batch.logitsprocs
return sampling_data
else:
return self.input_batch.sampling_metadata

def get_req_id_to_index(self, is_prefill: bool) -> dict[str, int]:
req_id_to_index = self.prefill_batch.get_unpadded_output_indices() \
Expand Down Expand Up @@ -1310,6 +1320,29 @@ def _mark_input_tensors(self, model_input: SamplingForwardInputs) -> None:
torch._dynamo.mark_static(model_input.input_positions,
1) # always 1

def build_input_batch(self) -> SamplingInputBatch:
# Define logits processors.

custom_logitsprocs = self.vllm_config.model_config.logits_processors

batch_size = self.scheduler_config.max_num_seqs
logits_processors = \
build_logitsprocs_for_cb(vllm_config=self.vllm_config,
device=self.device,
is_pin_memory=self.pin_memory,
is_pooling_model=False,
custom_logitsprocs=custom_logitsprocs,
batch_size=batch_size)

return SamplingInputBatch(
max_num_reqs=batch_size,
max_model_len=self.model_config.max_model_len,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
logitsprocs=logits_processors,
)


class PoolerAdapter(torch.nn.Module):

Expand Down