diff --git a/pyproject.toml b/pyproject.toml index 31258c81a..c39a0ea56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,7 +164,7 @@ dev = [ "pytest-timeout==2.3.1", "requests==2.32.3", "sentence-transformers==3.4.1", - "aiu-fms-testing-utils>=0.2.3", + "aiu-fms-testing-utils>=0.4.0", ] lint = [ "clang-format==18.1.5", diff --git a/tests/e2e/test_logits_processors.py b/tests/e2e/test_logits_processors.py index 21d0e13f8..0e2b37333 100644 --- a/tests/e2e/test_logits_processors.py +++ b/tests/e2e/test_logits_processors.py @@ -5,7 +5,8 @@ from spyre_util import ModelInfo from vllm import LLM, SamplingParams from vllm.config import VllmConfig -from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor +from vllm.v1.sample.logits_processor import (BatchUpdate, LogitsProcessor, + MoveDirectionality) def test_custom_logits_processor(model: ModelInfo, backend, monkeypatch, @@ -51,3 +52,101 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: spyre_model.generate(prompt, params) assert has_invoked_logits_processor + + +def test_cb_logits_processor(model: ModelInfo, backend, monkeypatch, + warmup_shapes, max_model_len, cb): + ''' + Test if the state of logits for CB are correct due to the switch of + prefill/decode in a step engine. The LLM is initialized with bs=2, + we send 3 requests, one of them should be waiting for the other 2 + to complete. The first request should finish and give its slot to + the last one. The logits processors will do a greedy sampling + decoding to emulate the 'state' of the logit processor. After + the generation we assert that the generated output is the same + for the spy and vllm. + ''' + + # Same process to ease things + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + # Hack to collect outputs from logits, the key + # is the max_tokens to ease identify the requests + spy_outputs: dict[int, list[int]] = {} + + class SpyLogitsProcessor(LogitsProcessor): + ''' + This logits processor collect the tokens + ''' + + def __init__(self, vllm_config: VllmConfig, device: torch.device, + is_pin_memory: bool): + self.req_info: dict[int, SamplingParams] = {} + + def is_argmax_invariant(self) -> bool: + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + if not batch_update: + return + + for index, params, _, _ in batch_update.added: + self.req_info[index] = params + nonlocal spy_outputs + spy_outputs[params.max_tokens] = [] + + if self.req_info: + # Process removed requests. + for index in batch_update.removed: + self.req_info.pop(index, None) + + # Process moved requests, unidirectional move (a->b) and swap + # (a<->b) + for adx, bdx, direct in batch_update.moved: + a_val = self.req_info.pop(adx, None) + b_val = self.req_info.pop(bdx, None) + if a_val is not None: + self.req_info[bdx] = a_val + if direct == MoveDirectionality.SWAP and b_val is not None: + self.req_info[adx] = b_val + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.req_info: + return + batch_size = logits.shape[0] + nonlocal spy_outputs + for i in range(batch_size): + params = self.req_info[i] + token_id = logits[i].argmax(-1).reshape(-1).item() + spy_outputs[params.max_tokens].append(token_id) + return logits + + patch_environment(True, None, backend, monkeypatch) + + spyre_model = LLM(model=model.name, + revision=model.revision, + max_model_len=max_model_len, + max_num_seqs=2, + logits_processors=[SpyLogitsProcessor]) + prompt = ["Hello Logits Processors"] * 3 + params0 = SamplingParams(max_tokens=5, + temperature=0, + logprobs=0, + ignore_eos=True) + params1 = SamplingParams(max_tokens=10, + temperature=0, + logprobs=0, + ignore_eos=True) + params2 = SamplingParams(max_tokens=7, + temperature=0, + logprobs=0, + ignore_eos=True) + + # clear from the warmup + spy_outputs = {} + params = [params0, params1, params2] + outputs = spyre_model.generate(prompt, params) + + assert spy_outputs[5] == outputs[0].outputs[0].token_ids + assert spy_outputs[10] == outputs[1].outputs[0].token_ids + assert spy_outputs[7] == outputs[2].outputs[0].token_ids diff --git a/uv.lock b/uv.lock index 8a4dabfe2..771195d0b 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.12'", @@ -3864,7 +3864,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ - { name = "aiu-fms-testing-utils", specifier = ">=0.2.3" }, + { name = "aiu-fms-testing-utils", specifier = ">=0.4.0" }, { name = "pytest", specifier = "==8.3.4" }, { name = "pytest-asyncio", specifier = ">=1.0.0" }, { name = "pytest-forked", specifier = ">=1.6.0" }, diff --git a/vllm_spyre/v1/sample/spyre_logits_processor.py b/vllm_spyre/v1/sample/spyre_logits_processor.py new file mode 100644 index 000000000..ec6621f7a --- /dev/null +++ b/vllm_spyre/v1/sample/spyre_logits_processor.py @@ -0,0 +1,103 @@ +import itertools +from typing import Optional, Sequence, Union + +import torch +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.sample.logits_processor import (BUILTIN_LOGITS_PROCESSORS, + STR_POOLING_REJECTS_LOGITSPROCS, + BatchUpdate, LogitsProcessor, + _load_custom_logitsprocs) +from vllm.v1.sample.logits_processor.state import LogitsProcessors + +logger = init_logger(__name__) + + +def build_logitsprocs_for_cb( + vllm_config: "VllmConfig", + device: torch.device, + is_pin_memory: bool, + is_pooling_model: bool, + batch_size: int, + custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (), +) -> LogitsProcessors: + if is_pooling_model: + if custom_logitsprocs: + raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) + logger.debug("Skipping logits processor loading because pooling models" + " do not support logits processors.") + return LogitsProcessors() + custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) + + return LogitsProcessors( + LogitProcessorWrapper(logit_processor, + vllm_config, + device, + is_pin_memory, + batch_size) \ + for logit_processor in itertools.chain( + BUILTIN_LOGITS_PROCESSORS, + custom_logitsprocs_classes + ) + ) + + +class LogitProcessorWrapper(LogitsProcessor): + """Logit processor to inject expected token during generation for tests""" + + def __init__(self, logit_processor: LogitsProcessor, + vllm_config: VllmConfig, device: torch.device, + is_pin_memory: bool, batch_size: int): + self.logitprocs: list[LogitsProcessor] = [ + logit_processor(vllm_config, device, is_pin_memory) \ + for _ in range(batch_size) + ] + + self._is_argmax_invariant : bool = \ + self.logitprocs[0].is_argmax_invariant() + + self._prefill_index: Optional[int] = None + + def is_argmax_invariant(self) -> bool: + """Never impacts greedy sampling""" + return self._is_argmax_invariant + + def update_state(self, batch_update: Optional[BatchUpdate]): + # This method keeps the indices consistent of request while the + # persistent batch is changing. + if not batch_update: + return + + # Process added requests. + for index, params, prompt_tok_ids, out_tok_ids in batch_update.added: + self.logitprocs[index].update_state( + BatchUpdate( + batch_size=1, + removed=[], + moved=[], + added=[(0, params, prompt_tok_ids, out_tok_ids)], + )) + + for index in batch_update.removed: + self.logitprocs[index].update_state( + BatchUpdate(batch_size=1, removed=[0], moved=[], added=[])) + + for adx, bdx, _ in batch_update.moved: + self.logitprocs[adx], self.logitprocs[bdx] = \ + self.logitprocs[bdx], self.logitprocs[adx] + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + + if self._prefill_index is not None: + logits = self.logitprocs[self._prefill_index].apply(logits) + self._prefill_index = None + return logits + + batch_size = logits.shape[0] + for i in range(batch_size): + logits[i] = self.logitprocs[i].apply(logits[i].unsqueeze(0)) + + return logits + + def set_prefill_index(self, idx: int) -> None: + self._prefill_index = idx diff --git a/vllm_spyre/v1/worker/spyre_input_batch.py b/vllm_spyre/v1/worker/spyre_input_batch.py index 9c019ca88..fbdd0d15d 100644 --- a/vllm_spyre/v1/worker/spyre_input_batch.py +++ b/vllm_spyre/v1/worker/spyre_input_batch.py @@ -18,6 +18,8 @@ MoveDirectionality) from vllm.v1.sample.metadata import SamplingMetadata +from vllm_spyre.v1.sample.spyre_logits_processor import LogitProcessorWrapper + @dataclass class BaseRequestState: @@ -223,16 +225,13 @@ class SamplingInputBatch(BaseInputBatch[SamplingRequestState]): condense the sampling parameters. ''' - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - device: torch.device, - pin_memory: bool, - vocab_size: int, - # Type here is any for compatibility reasons - logitsprocs: Optional[LogitsProcessors] = None, - ): + def __init__(self, + max_num_reqs: int, + max_model_len: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + logitsprocs: Optional[LogitsProcessors] = None): super().__init__( max_num_reqs, @@ -302,6 +301,9 @@ def __init__( self.batch_update_builder = BatchUpdateBuilder() self.logitsprocs = logitsprocs or LogitsProcessors() + self.logitsprocs_wrappers = [lp for lp \ + in self.logitsprocs.all if isinstance(lp, LogitProcessorWrapper) + ] self.has_allowed_token_ids: set[str] = set() self.allowed_token_ids_mask: Optional[torch.Tensor] = None @@ -517,13 +519,12 @@ def remove_request(self, req_id: str): self.req_indices_mask[req_index] = False # Remove and move up - tmp_dense = dense_index - self.batch_update_builder.removed_append(tmp_dense) + self.batch_update_builder.removed_append(dense_index) - while tmp_dense < self._num_requests + 1: + end_dense_idx = min(self._num_requests + 1, self.max_num_reqs - 1) + for tmp_dense in range(dense_index, end_dense_idx): self.batch_update_builder.moved.append( - (tmp_dense, tmp_dense + 1, MoveDirectionality.SWAP)) - tmp_dense = tmp_dense + 1 + (tmp_dense, tmp_dense + 1, MoveDirectionality.UNIDIRECTIONAL)) # Remove the references self.req_output_token_ids.pop(dense_index) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 0c66e8b95..ec04d5718 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -27,6 +27,8 @@ from vllm_spyre.model_executor.model_loader.spyre import ( BACKEND_LIST, SpyreAttentionMetadata, SpyreCausalLM) from vllm_spyre.platform import SpyrePlatform +from vllm_spyre.v1.sample.spyre_logits_processor import ( + build_logitsprocs_for_cb) # yapf conflicts with ruff for this block # yapf: disable from vllm_spyre.v1.worker.spyre_input_batch import (BaseInputBatch, @@ -979,6 +981,10 @@ def _prepare_prompt( prefill_index = self.input_batch.add_request(req_state) self.prefill_batch.add_request(req_state) + # set prefill index for logits processor + for logitsproc in self.input_batch.logitsprocs_wrappers: + logitsproc.set_prefill_index(prefill_index) + # Refresh sampling metadata after all request are added to the batch self.input_batch.refresh_metadata() self.prefill_batch.refresh_metadata() @@ -1227,8 +1233,13 @@ def build_attn_metadata( is_prefill=model_input.is_prompt) def get_sampling_metadata(self, is_prefill: bool) -> SamplingMetadata: - return self.prefill_batch.sampling_metadata \ - if is_prefill else self.input_batch.sampling_metadata + + if is_prefill: + sampling_data = self.prefill_batch.sampling_metadata + sampling_data.logitsprocs = self.input_batch.logitsprocs + return sampling_data + else: + return self.input_batch.sampling_metadata def get_req_id_to_index(self, is_prefill: bool) -> dict[str, int]: req_id_to_index = self.prefill_batch.get_unpadded_output_indices() \ @@ -1310,6 +1321,29 @@ def _mark_input_tensors(self, model_input: SamplingForwardInputs) -> None: torch._dynamo.mark_static(model_input.input_positions, 1) # always 1 + def build_input_batch(self) -> SamplingInputBatch: + # Define logits processors. + + custom_logitsprocs = self.vllm_config.model_config.logits_processors + + batch_size = self.scheduler_config.max_num_seqs + logits_processors = \ + build_logitsprocs_for_cb(vllm_config=self.vllm_config, + device=self.device, + is_pin_memory=self.pin_memory, + is_pooling_model=False, + custom_logitsprocs=custom_logitsprocs, + batch_size=batch_size) + + return SamplingInputBatch( + max_num_reqs=batch_size, + max_model_len=self.model_config.max_model_len, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + logitsprocs=logits_processors, + ) + class PoolerAdapter(torch.nn.Module):