From 5877a7a3a71fe25e4457a6477dceac3dd99ed4ee Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 7 Jul 2024 18:44:35 +0000 Subject: [PATCH 01/12] first draft of SamplingController --- vllm/engine/async_llm_engine.py | 3 ++- vllm/engine/llm_engine.py | 8 +++++--- vllm/sequence.py | 16 +++++++++++++++- vllm/worker/model_runner.py | 3 +++ vllm/worker/model_runner_base.py | 4 +++- vllm/worker/worker_base.py | 5 +++++ 6 files changed, 33 insertions(+), 6 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 33e40c7b3624..07f660f6dcd7 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -237,7 +237,8 @@ async def step_async( virtual_engine=virtual_engine, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=finished_requests_ids) + finished_requests_ids=finished_requests_ids, + sampling_controller=self.sampling_controller) output = await self.model_executor.execute_model_async( execute_model_req) else: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index de7604ece7c3..7dd7a05b13ea 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -29,8 +29,8 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - PoolerOutput, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupMetadata, + PoolerOutput, SamplerOutput, SamplingController, + Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) @@ -225,6 +225,7 @@ def __init__( self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats + self.sampling_controller: Optional[SamplingController] = None if not self.model_config.skip_tokenizer_init: self.tokenizer = self._init_tokenizer() @@ -857,7 +858,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: blocks_to_copy=scheduler_outputs.blocks_to_copy, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=finished_requests_ids) + finished_requests_ids=finished_requests_ids, + sampling_controller=self.sampling_controller) output = self.model_executor.execute_model( execute_model_req=execute_model_req) else: diff --git a/vllm/sequence.py b/vllm/sequence.py index d200115aa092..f9998f9b23b8 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -912,6 +912,17 @@ def prune(self, self.seq_ids = seq_ids +class SamplingController: + + def prepare(self, seq_group_metadata_list: List[SequenceGroupMetadata]): + """Prepare the sampling controller for the next step.""" + pass + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + """Apply the sampling controller to the logits.""" + return logits + + @dataclass class ExecuteModelRequest: """The model execution request, containing CPU metadata only. The LLM @@ -936,6 +947,8 @@ class ExecuteModelRequest: num_steps: int = 1 # Finished request ids since last step. finished_requests_ids: List[str] = field(default_factory=list) + # Sampling controller to use for this step. + sampling_controller: Optional[SamplingController] = None def clone( self, seq_group_metadata_list: List[SequenceGroupMetadata] @@ -951,4 +964,5 @@ def clone( running_queue_size=self.running_queue_size, previous_hidden_states=self.previous_hidden_states, num_steps=self.num_steps, - finished_requests_ids=self.finished_requests_ids) + finished_requests_ids=self.finished_requests_ids, + sampling_controller=self.sampling_controller) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d0c82d6bbedf..d95818d3a0ed 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1259,6 +1259,9 @@ def execute_model( if not self.is_driver_worker: return [] + if (ctrl := model_input.sampling_controller) is not None: + logits = ctrl.apply(logits) + # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits, diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index bc0960fa1622..65f7e85e67ba 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -6,7 +6,7 @@ import torch from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) + SamplingController, SequenceGroupMetadata) if TYPE_CHECKING: from vllm.attention import AttentionMetadata @@ -92,6 +92,8 @@ class ModelRunnerInputBase(ABC): serialize/deserialize a ModelInput for broadcast between workers. """ + sampling_controller: Optional[SamplingController] = None + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: """ Extract broadcastable fields. Override for fields that require some diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index b082f4534486..8fb7f8e6dea7 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -236,6 +236,11 @@ def execute_model( execute_model_req.seq_group_metadata_list, execute_model_req.virtual_engine, execute_model_req.finished_requests_ids)) + ctrl = execute_model_req.sampling_controller + if ctrl is not None: + ctrl.prepare(execute_model_req.seq_group_metadata_list) + model_input = dataclasses.replace(model_input, + sampling_controller=ctrl) num_steps = execute_model_req.num_steps if self.do_metadata_broadcast: From 0190aef925059fd1ae6bd470cb9bfb797a009b82 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 7 Jul 2024 19:10:17 +0000 Subject: [PATCH 02/12] make SamplingController work with SamplingMetadata --- vllm/model_executor/sampling_metadata.py | 2 ++ vllm/sequence.py | 3 ++- vllm/worker/model_runner.py | 7 ++++++- vllm/worker/worker_base.py | 12 +++++++----- 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index ad5fb13176ed..dd04ca049769 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -24,6 +24,7 @@ class SequenceGroupToSample: # |-- query_len ---| # Sequence ids for the sequence group in a previous step. + request_id: str seq_ids: List[int] sampling_params: SamplingParams # seq_id -> sequence data. @@ -273,6 +274,7 @@ def sample(logits): seq_groups.append( SequenceGroupToSample( + request_id=seq_group_metadata.request_id, seq_ids=seq_ids, sampling_params=sampling_params, seq_data=seq_group_metadata.seq_data, diff --git a/vllm/sequence.py b/vllm/sequence.py index f9998f9b23b8..718310349188 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -16,6 +16,7 @@ from vllm.inputs import LLMInputs from vllm.multimodal import MultiModalDataDict from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + from vllm.model_executor.sampling_metadata import SamplingMetadata @dataclass @@ -914,7 +915,7 @@ def prune(self, class SamplingController: - def prepare(self, seq_group_metadata_list: List[SequenceGroupMetadata]): + def prepare(self, sampling_metadata: "SamplingMetadata"): """Prepare the sampling controller for the next step.""" pass diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d95818d3a0ed..afed2e7628c9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1240,6 +1240,11 @@ def execute_model( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_seqlen_agnostic else {} + + if (ctrl := model_input.sampling_controller) is not None: + assert model_input.sampling_metadata is not None + ctrl.prepare(model_input.sampling_metadata) + hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -1259,7 +1264,7 @@ def execute_model( if not self.is_driver_worker: return [] - if (ctrl := model_input.sampling_controller) is not None: + if ctrl is not None: logits = ctrl.apply(logits) # Sample the next token. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 8fb7f8e6dea7..6cd9d39d37e5 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -236,11 +236,6 @@ def execute_model( execute_model_req.seq_group_metadata_list, execute_model_req.virtual_engine, execute_model_req.finished_requests_ids)) - ctrl = execute_model_req.sampling_controller - if ctrl is not None: - ctrl.prepare(execute_model_req.seq_group_metadata_list) - model_input = dataclasses.replace(model_input, - sampling_controller=ctrl) num_steps = execute_model_req.num_steps if self.do_metadata_broadcast: @@ -249,6 +244,13 @@ def execute_model( model_input.as_broadcastable_tensor_dict()) broadcast_data["num_steps"] = num_steps broadcast_tensor_dict(broadcast_data, src=0) + + # SamplingController is only used in the driver worker, so it + # doesn't need to be broadcasted. + ctrl = execute_model_req.sampling_controller + if ctrl is not None: + model_input = dataclasses.replace(model_input, + sampling_controller=ctrl) else: assert self.do_metadata_broadcast broadcast_data = broadcast_tensor_dict(src=0) From 3c6723c2f7869730fb2fef3c2287a50e480cfba1 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 7 Jul 2024 22:29:14 +0000 Subject: [PATCH 03/12] add fast_forward_tokens to SequenceOutput --- vllm/engine/output_processor/single_step.py | 6 ++--- vllm/sequence.py | 28 +++++++++++++++++++-- vllm/worker/model_runner.py | 5 +++- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index fa672e1feda9..c1bdc4d49841 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -102,15 +102,13 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, for child_sample in child_samples[:-1]: new_child_seq_id: int = next(self.seq_counter) child = parent.fork(new_child_seq_id) - child.append_token_id(child_sample.output_token, - child_sample.logprobs) + child_sample.append_to(child) child_seqs.append((child, parent)) # Continue the parent sequence for the last child sample. # We reuse the parent sequence here to reduce redundant memory # copies, especially when using non-beam search sampling methods. last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs) + last_child_sample.append_to(parent) child_seqs.append((parent, parent)) for seq, _ in child_seqs: diff --git a/vllm/sequence.py b/vllm/sequence.py index 718310349188..b10f09642272 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -14,9 +14,9 @@ if TYPE_CHECKING: from vllm.inputs import LLMInputs + from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MultiModalDataDict from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics - from vllm.model_executor.sampling_metadata import SamplingMetadata @dataclass @@ -702,6 +702,26 @@ def __init__( self.parent_seq_id = parent_seq_id self.output_token = output_token self.logprobs = logprobs + # If present, these tokens should appended to the output + # instead of output_token. + self.fast_forward_tokens: Optional[List[int]] = None + + def append_to(self, seq: Sequence) -> None: + if self.fast_forward_tokens is not None: + logprobs = self.logprobs + for token in self.fast_forward_tokens: + # On first iteration, use the existing self.logprobs, provided + # they contain the token. + # On subsequent iterations, logprobs is cleared, so always use + # artificially created logprobs. + if token not in logprobs: + logprobs = { + token: Logprob(logprob=0.0, rank=1, decoded_token=None) + } + seq.append_token_id(token, logprobs) + logprobs = {} + else: + seq.append_token_id(self.output_token, self.logprobs) def __repr__(self) -> str: return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " @@ -919,10 +939,14 @@ def prepare(self, sampling_metadata: "SamplingMetadata"): """Prepare the sampling controller for the next step.""" pass - def apply(self, logits: torch.Tensor) -> torch.Tensor: + def transform_logits(self, logits: torch.Tensor) -> torch.Tensor: """Apply the sampling controller to the logits.""" return logits + def transform_sampler_output(self, output: SamplerOutput) -> SamplerOutput: + """Apply the sampling controller to the sampler output.""" + return output + @dataclass class ExecuteModelRequest: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index afed2e7628c9..660fa538efa1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1265,7 +1265,7 @@ def execute_model( return [] if ctrl is not None: - logits = ctrl.apply(logits) + logits = ctrl.transform_logits(logits) # Sample the next token. output: SamplerOutput = self.model.sample( @@ -1273,6 +1273,9 @@ def execute_model( sampling_metadata=model_input.sampling_metadata, ) + if ctrl is not None: + ctrl.transform_sampler_output(output) + if self.return_hidden_states: # we only need to pass hidden states of most recent token assert model_input.sampling_metadata is not None From 80c5091842b968fefc7b9eb2d0ab40ed140bdd9c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 7 Jul 2024 22:42:01 +0000 Subject: [PATCH 04/12] reset to prefill on ff --- vllm/sequence.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index b10f09642272..75db78a6cb58 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -712,14 +712,17 @@ def append_to(self, seq: Sequence) -> None: for token in self.fast_forward_tokens: # On first iteration, use the existing self.logprobs, provided # they contain the token. - # On subsequent iterations, logprobs is cleared, so always use - # artificially created logprobs. if token not in logprobs: logprobs = { token: Logprob(logprob=0.0, rank=1, decoded_token=None) } seq.append_token_id(token, logprobs) + # On subsequent iterations always use artificially created + # logprobs. logprobs = {} + # If more than one token was appended, switch to prefill stage. + if seq.data.get_num_uncomputed_tokens() > 1: + seq.data._stage = SequenceStage.PREFILL else: seq.append_token_id(self.output_token, self.logprobs) From 1273203b4d32837aed502ed6441387842bc25210 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 7 Jul 2024 23:33:07 +0000 Subject: [PATCH 05/12] refactor to be easier to override things --- vllm/entrypoints/openai/api_server.py | 29 +++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 76879c96c31e..73f2e25dcc49 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -42,6 +42,7 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding +engine_args: AsyncEngineArgs logger = init_logger('vllm.entrypoints.openai.api_server') @@ -167,9 +168,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) -if __name__ == "__main__": - args = parse_args() - +def create_engine(args): app.add_middleware( CORSMiddleware, allow_origins=args.allowed_origins, @@ -206,16 +205,21 @@ async def authentication(request: Request, call_next): logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - if args.served_model_name is not None: - served_model_names = args.served_model_name - else: - served_model_names = [args.model] - + global engine_args engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + return engine + + +def start_engine(args, engine): + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + event_loop: Optional[asyncio.AbstractEventLoop] try: event_loop = asyncio.get_running_loop() @@ -230,6 +234,9 @@ async def authentication(request: Request, call_next): # When using single vLLM without engine_use_ray model_config = asyncio.run(engine.get_model_config()) + global openai_serving_chat, openai_serving_completion + global openai_serving_embedding + openai_serving_chat = OpenAIServingChat(engine, model_config, served_model_names, args.response_role, @@ -249,3 +256,9 @@ async def authentication(request: Request, call_next): ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs) + + +if __name__ == "__main__": + args = parse_args() + engine = create_engine(args) + start_engine(args, engine) From 374414314e0534a55c26571e891d8564bb7c39e0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 8 Jul 2024 21:06:53 +0000 Subject: [PATCH 06/12] more extensible --- vllm/entrypoints/openai/api_server.py | 40 ++++++++++++++++++--------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 73f2e25dcc49..4db0e48b53c8 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -42,7 +42,8 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding -engine_args: AsyncEngineArgs +_engine_args: AsyncEngineArgs +_engine: AsyncLLMEngine logger = init_logger('vllm.entrypoints.openai.api_server') @@ -55,9 +56,9 @@ async def lifespan(app: fastapi.FastAPI): async def _force_log(): while True: await asyncio.sleep(10) - await engine.do_log_stats() + await _engine.do_log_stats() - if not engine_args.disable_log_stats: + if not _engine_args.disable_log_stats: task = asyncio.create_task(_force_log()) _running_tasks.add(task) task.add_done_callback(_running_tasks.remove) @@ -205,21 +206,19 @@ async def authentication(request: Request, call_next): logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - global engine_args engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + + global _engine_args, _engine + _engine_args = engine_args + _engine = engine return engine -def start_engine(args, engine): - if args.served_model_name is not None: - served_model_names = args.served_model_name - else: - served_model_names = [args.model] - +def get_model_config(args, engine: AsyncLLMEngine): event_loop: Optional[asyncio.AbstractEventLoop] try: event_loop = asyncio.get_running_loop() @@ -234,6 +233,17 @@ def start_engine(args, engine): # When using single vLLM without engine_use_ray model_config = asyncio.run(engine.get_model_config()) + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + return model_config, served_model_names + + +def start_engine(args, engine: AsyncLLMEngine): + model_config, served_model_names = get_model_config(args, engine) + global openai_serving_chat, openai_serving_completion global openai_serving_embedding @@ -258,7 +268,11 @@ def start_engine(args, engine): ssl_cert_reqs=args.ssl_cert_reqs) +def _main(): + _args = parse_args() + _engine = create_engine(_args) + start_engine(_args, _engine) + + if __name__ == "__main__": - args = parse_args() - engine = create_engine(args) - start_engine(args, engine) + _main() From a70c68e1ec2211fa3ce3f40b9d3487148342af85 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 8 Jul 2024 23:46:27 +0000 Subject: [PATCH 07/12] fix assert --- vllm/sequence.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 75db78a6cb58..7076e49b8892 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -469,8 +469,8 @@ def lora_int_id(self) -> int: def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" - # If still in prefill phase, raise Error. - if self.is_prefill(): + # If still in initial prefill phase, raise Error. + if self.is_prefill() and self.get_seqs()[0].get_output_len() == 0: raise ValueError( "seq_group.get_last_latency() should not be called " "if the seq_group is in prefill phase.") From 039db20d7af2c087623ecd00472a5763aaf106d9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 9 Jul 2024 18:40:23 +0000 Subject: [PATCH 08/12] add SamplingController.empty_step() callback --- vllm/engine/async_llm_engine.py | 2 ++ vllm/engine/llm_engine.py | 4 ++++ vllm/sequence.py | 5 +++++ 3 files changed, 11 insertions(+) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 07f660f6dcd7..43698f78cd56 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -258,6 +258,8 @@ async def step_async( async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" + if ctrl := self.sampling_controller: + ctrl.empty_step() await self.model_executor.stop_remote_worker_execution_loop_async() async def process_model_inputs_async( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7dd7a05b13ea..65e2bb71bb36 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -529,6 +529,8 @@ def _add_processed_request( min_cost_scheduler.add_seq_group(seq_group) def stop_remote_worker_execution_loop(self) -> None: + if ctrl := self.sampling_controller: + ctrl.empty_step() self.model_executor.stop_remote_worker_execution_loop() def process_model_inputs( @@ -863,6 +865,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: output = self.model_executor.execute_model( execute_model_req=execute_model_req) else: + if ctrl := self.sampling_controller: + ctrl.empty_step() output = [] request_outputs = self._process_model_outputs( diff --git a/vllm/sequence.py b/vllm/sequence.py index 7076e49b8892..a6d82277b8e7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -942,6 +942,11 @@ def prepare(self, sampling_metadata: "SamplingMetadata"): """Prepare the sampling controller for the next step.""" pass + def empty_step(self): + """Called instead of prepare() when the scheduler found no sequences + to run.""" + pass + def transform_logits(self, logits: torch.Tensor) -> torch.Tensor: """Apply the sampling controller to the logits.""" return logits From a4db33353639340448232cd5ca6409cb1fd5ca96 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 9 Jul 2024 19:19:15 +0000 Subject: [PATCH 09/12] take result from transform_sampler_output --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 660fa538efa1..145820fe5da6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1274,7 +1274,7 @@ def execute_model( ) if ctrl is not None: - ctrl.transform_sampler_output(output) + output = ctrl.transform_sampler_output(output) if self.return_hidden_states: # we only need to pass hidden states of most recent token From e0eb2da57a84e3c704b6288f6f948e051431c1e0 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 9 Jul 2024 21:49:46 +0000 Subject: [PATCH 10/12] allow appending empty token seq --- vllm/core/block/block_table.py | 9 ++++++--- vllm/sequence.py | 9 +++++++-- vllm/transformers_utils/detokenizer.py | 2 +- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index 49e63c23155b..a2da29a8da86 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -147,10 +147,13 @@ def append_token_ids(self, # Update the blocks with the new tokens first_block_idx = self._num_full_slots // self._block_size - token_blocks = self._chunk_token_blocks_for_append(token_ids) - for i, token_block in enumerate(token_blocks): - self._blocks.append_token_ids(first_block_idx + i, token_block) + # don't bother appending anything, if no new token_ids were generated + if token_ids: + token_blocks = self._chunk_token_blocks_for_append(token_ids) + + for i, token_block in enumerate(token_blocks): + self._blocks.append_token_ids(first_block_idx + i, token_block) self._num_full_slots += len(token_ids) diff --git a/vllm/sequence.py b/vllm/sequence.py index a6d82277b8e7..cefc83a661a9 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -186,9 +186,14 @@ def get_num_computed_tokens(self) -> int: def update_num_computed_tokens(self, num_new_computed_tokens: int): """Update number of tokens computed so far.""" + seq_len = self.get_len() self._num_computed_tokens += num_new_computed_tokens - assert self._num_computed_tokens <= self.get_len(), ( - self._num_computed_tokens, self.get_len()) + # We can overflow by 1 if previous sampling was updated by + # SamplingController to generate an empty sequence of tokens. + if self._num_computed_tokens == seq_len + 1: + self._num_computed_tokens = seq_len + assert self._num_computed_tokens <= seq_len, ( + self._num_computed_tokens, seq_len) # If all tokens are computed, it means it is in decoding phase. if self.get_num_uncomputed_tokens() == 0: self._stage = SequenceStage.DECODE diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index e8e53f4946ef..e31332939453 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -124,7 +124,7 @@ def decode_sequence_inplace(self, seq: Sequence, ) # Decode logprobs - logprobs = seq.output_logprobs[-1] + logprobs = seq.output_logprobs[-1] if seq.output_logprobs else None if logprobs: previous_tokens = all_input_ids[:-1] for token_id, sample_logprob in logprobs.items(): From 6fec2b0c6ffa64d77c638aae31bc64dfae9aa63a Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 9 Jul 2024 16:17:44 -0700 Subject: [PATCH 11/12] revert api_server changes --- vllm/entrypoints/openai/api_server.py | 47 ++++++--------------------- 1 file changed, 10 insertions(+), 37 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 4db0e48b53c8..76879c96c31e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -42,8 +42,6 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding -_engine_args: AsyncEngineArgs -_engine: AsyncLLMEngine logger = init_logger('vllm.entrypoints.openai.api_server') @@ -56,9 +54,9 @@ async def lifespan(app: fastapi.FastAPI): async def _force_log(): while True: await asyncio.sleep(10) - await _engine.do_log_stats() + await engine.do_log_stats() - if not _engine_args.disable_log_stats: + if not engine_args.disable_log_stats: task = asyncio.create_task(_force_log()) _running_tasks.add(task) task.add_done_callback(_running_tasks.remove) @@ -169,7 +167,9 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) -def create_engine(args): +if __name__ == "__main__": + args = parse_args() + app.add_middleware( CORSMiddleware, allow_origins=args.allowed_origins, @@ -206,19 +206,16 @@ async def authentication(request: Request, call_next): logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - - global _engine_args, _engine - _engine_args = engine_args - _engine = engine - - return engine - -def get_model_config(args, engine: AsyncLLMEngine): event_loop: Optional[asyncio.AbstractEventLoop] try: event_loop = asyncio.get_running_loop() @@ -233,20 +230,6 @@ def get_model_config(args, engine: AsyncLLMEngine): # When using single vLLM without engine_use_ray model_config = asyncio.run(engine.get_model_config()) - if args.served_model_name is not None: - served_model_names = args.served_model_name - else: - served_model_names = [args.model] - - return model_config, served_model_names - - -def start_engine(args, engine: AsyncLLMEngine): - model_config, served_model_names = get_model_config(args, engine) - - global openai_serving_chat, openai_serving_completion - global openai_serving_embedding - openai_serving_chat = OpenAIServingChat(engine, model_config, served_model_names, args.response_role, @@ -266,13 +249,3 @@ def start_engine(args, engine: AsyncLLMEngine): ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs) - - -def _main(): - _args = parse_args() - _engine = create_engine(_args) - start_engine(_args, _engine) - - -if __name__ == "__main__": - _main() From 59f2e5ed3df093c605ffa1e433f20c3b061b2de7 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 19 Jul 2024 17:55:34 +0000 Subject: [PATCH 12/12] PR feedback - docstrings --- vllm/sequence.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/vllm/sequence.py b/vllm/sequence.py index cefc83a661a9..8afdc43d1f95 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -712,6 +712,13 @@ def __init__( self.fast_forward_tokens: Optional[List[int]] = None def append_to(self, seq: Sequence) -> None: + """ + Append the sampling output to the sequence. + + If fast forward tokens is set, this appends them, generating appropriate + Logprobs, and switching the sequence to PREFILL if needed. + Otherwise, just the output token is appended. + """ if self.fast_forward_tokens is not None: logprobs = self.logprobs for token in self.fast_forward_tokens: @@ -942,6 +949,33 @@ def prune(self, class SamplingController: + """ + This is used to modify sampling process for a given LLMEngine. + There is only one instance of this class per LLMEngine. + + In each generation step, one of the following things can happen: + + There are no sequences to run, and empty_step() is called; + this can be used to run actions that normally run in sync with step, + when there are no sequences to run + + Otherwise (normal case), the following methods are run in this exact order: + - prepare() causes the sampling controller to start logit bias prepreation + for the sequences that will be run; typically the logit indices from + sampling_metadata will have to be stored in the sampling controller + - forward pass is started + - transform_logits() is called after the forward pass has finished, to + modify the logits + - sampling happens on biased logits + - transform_sampler_output() is called to modify the sampler output + + This class does nothing for each of these steps. Subclasses can override + any and each of these methods to modify the sampling process; they will + be stateful. + + Currently, you just have to assign an instance of your subclass to + engine.sampling_controller to use it. + """ def prepare(self, sampling_metadata: "SamplingMetadata"): """Prepare the sampling controller for the next step."""