From fa8705de390cc727acc5a094abbba2f070de27dd Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sat, 6 Apr 2024 22:29:36 -0700 Subject: [PATCH 001/120] wip --- vllm/executor/gpu_executor.py | 71 +++++++++++++++++++++++++- vllm/spec_decode/spec_decode_worker.py | 4 ++ 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 80ca5cb7367c..ac7e4c5dda74 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -33,14 +33,81 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.speculative_config = speculative_config - assert (not speculative_config - ), "Speculative decoding not yet supported for GPU backend" + #assert (not speculative_config + # ), "Speculative decoding not yet supported for GPU backend" # Instantiate the worker and load the model to GPU. self._init_worker() def _init_worker(self): + if self.speculative_config is None: + self._init_non_spec_worker() + else: + self._init_spec_worker() + + def _init_spec_worker(self): + from vllm.worker.worker import Worker + from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker + from vllm.spec_decode.multi_step_worker import MultiStepWorker + + #from vllm.worker.multi_step_worker import MultiStepWorker # pylint: disable=import-outside-toplevel + #from vllm.worker.single_tp_worker import SingleTpWorker # pylint: disable=import-outside-toplevel + #from vllm.worker.draft_target_worker import DraftTargetWorker # pylint: disable=import-outside-toplevel + + #scheduler_config: "SchedulerConfig" = worker_kwargs.pop( + # "scheduler_config") + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + target_worker = Worker( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + is_driver_worker=True, + ) + + from vllm.spec_decode.multi_step_worker import MultiStepWorker + draft_worker = MultiStepWorker( + model_config=self.speculative_config.draft_model_config, + parallel_config=self.speculative_config.draft_parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + is_driver_worker=True, + ) + + from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker + from vllm.model_executor.layers.rejection_sampler import RejectionSampler + spec_decode_worker = SpecDecodeWorker( + proposer_worker=draft_worker, + scorer_worker=target_worker, + rejection_sampler=RejectionSampler(), + ) + + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + self.driver_worker = spec_decode_worker + + self.driver_worker.init_device() + #self.driver_worker.load_model() + + def _init_non_spec_worker(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker from vllm.worker.worker import Worker diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 885bf537568e..d555f27650e1 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -87,6 +87,10 @@ def init_device(self) -> None: self.scorer_worker.init_device() self.proposer_worker.init_device() + # TODO separate from init_device? + self.scorer_worker.load_model() + self.proposer_worker.load_model() + self._metrics.init_gpu_tensors(self.rank) self.rejection_sampler.init_gpu_tensors(self.rank) self.scorer = BatchExpansionTop1Scorer( From 84953210e527c011704974435ae1b61ed7296a26 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sat, 6 Apr 2024 22:36:27 -0700 Subject: [PATCH 002/120] wip --- tests/spec_decode/e2e/test_correctness.py | 3 +++ vllm/engine/llm_engine.py | 10 ++++++---- vllm/executor/gpu_executor.py | 5 ++++- vllm/spec_decode/spec_decode_worker.py | 9 +++++---- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index b5a6fcb7900a..c427fbc7a05b 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -11,6 +11,9 @@ "speculative_model": "facebook/opt-125m", "num_speculative_tokens": 5, + # Skip cuda graph recording for fast test. + "enforce_eager": True, + # Required for spec decode. "use_v2_block_manager": True }]) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1c639af69654..9ca809f51d0f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -709,12 +709,14 @@ def step(self) -> List[RequestOutput]: if not scheduler_outputs.is_empty(): output = self.model_executor.execute_model( - seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, - scheduler_outputs.blocks_to_swap_out, - scheduler_outputs.blocks_to_copy) + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots) else: output = [] - + return self._process_model_outputs(output, scheduler_outputs) def do_log_stats(self) -> None: diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index ac7e4c5dda74..80ec79ba3c3c 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -154,12 +154,15 @@ def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int, + ) -> SamplerOutput: output = self.driver_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, + num_lookahead_slots=num_lookahead_slots, ) return output diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index d555f27650e1..a2c9a9944af5 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -135,7 +135,7 @@ def execute_model( blocks_to_swap_in: Optional[Dict[int, int]], blocks_to_swap_out: Optional[Dict[int, int]], blocks_to_copy: Optional[Dict[int, List[int]]], - num_spec_tokens: int, + num_lookahead_slots: int, ) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ @@ -146,7 +146,7 @@ def execute_model( # If no spec tokens, call the proposer and scorer workers normally. # Used for prefill. - if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0: + if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0: return self._run_no_spec( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, @@ -159,7 +159,7 @@ def execute_model( blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - k=num_spec_tokens, + k=num_lookahead_slots, ) @nvtx_range("spec_decode_worker._run_no_spec") @@ -180,7 +180,8 @@ def _run_no_spec( blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - return_python_output=False) + #return_python_output=False + ) sampler_output = self.scorer_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, From b63975bd45ea1a1770a8c742dc732b91e6f3cbf9 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 00:06:56 -0700 Subject: [PATCH 003/120] wip --- tests/spec_decode/e2e/test_correctness.py | 18 ++-- vllm/core/scheduler.py | 14 +-- vllm/engine/llm_engine.py | 121 +++++++++++++++++++++- vllm/model_executor/layers/sampler.py | 8 +- vllm/spec_decode/batch_expansion.py | 3 +- vllm/spec_decode/multi_step_worker.py | 2 +- vllm/spec_decode/spec_decode_worker.py | 4 +- 7 files changed, 145 insertions(+), 25 deletions(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index c427fbc7a05b..782bd9d0cecb 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -21,14 +21,14 @@ @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("seed", [1]) def test_spec_decode_config(test_llm_generator): - output_len = 1024 + output_len = 128 temperature = 0.0 prompts = [ "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", + #"The president of the United States is", + #"The capital of France is", + #"The future of AI is", ] sampling_params = SamplingParams( @@ -37,11 +37,11 @@ def test_spec_decode_config(test_llm_generator): temperature=temperature, ) - with pytest.raises( - AssertionError, - match="Speculative decoding not yet supported for GPU backend"): - get_token_ids_from_llm_generator(test_llm_generator, prompts, - sampling_params) + #with pytest.raises( + # AssertionError, + # match="Speculative decoding not yet supported for GPU backend"): + get_token_ids_from_llm_generator(test_llm_generator, prompts, + sampling_params) def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 0ae53f937496..e176848c0490 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -753,9 +753,10 @@ def _schedule_default(self) -> SchedulerOutputs: blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, swapped_in.blocks_to_copy), ignored_seq_groups=prefills.ignored_seq_groups, - num_lookahead_slots=(prefills.num_lookahead_slots + - running_scheduled.num_lookahead_slots + - swapped_in.num_lookahead_slots), + num_lookahead_slots=running_scheduled.num_lookahead_slots, + #num_lookahead_slots=(prefills.num_lookahead_slots + + # running_scheduled.num_lookahead_slots + + # swapped_in.num_lookahead_slots), ) def _schedule_chunked_prefill(self): @@ -842,9 +843,10 @@ def _schedule_chunked_prefill(self): blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, swapped_in.blocks_to_copy), ignored_seq_groups=prefills.ignored_seq_groups, - num_lookahead_slots=(prefills.num_lookahead_slots + - running_scheduled.num_lookahead_slots + - swapped_in.num_lookahead_slots), + num_lookahead_slots=running_scheduled.num_lookahead_slots, + #num_lookahead_slots=(prefills.num_lookahead_slots + + # running_scheduled.num_lookahead_slots + + # swapped_in.num_lookahead_slots), ) def _schedule(self) -> SchedulerOutputs: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9ca809f51d0f..1bd4129090c2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -626,14 +626,38 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, def _process_model_outputs( self, output: SamplerOutput, scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: + + + if not isinstance(output, list): + all_output = [output] + else: + all_output = output + + scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups + + # Organize list of sampler output by sequence group. + output_by_sequence_group: List[List[SequenceGroupOutputs]] = [ + [] for _ in scheduled_seq_groups + ] + for step in output: + for i, sequence_group_output in enumerate(step): + output_by_sequence_group[i].append(sequence_group_output) + now = time.time() + # Update the scheduled sequence groups with the model outputs. - scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups - for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output): + for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output_by_sequence_group): + seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - self._process_sequence_group_outputs(seq_group, outputs) + + assert len(outputs) > 0 + # TODO can spec decode go through second path? + if len(outputs) > 1: + self._process_sequence_group_outputs_multi_step(seq_group, outputs) + else: + self._process_sequence_group_outputs(seq_group, outputs[0]) # Free the finished sequence groups. self.scheduler.free_finished_seq_groups() @@ -654,6 +678,91 @@ def _process_model_outputs( self.stat_logger.log(self._get_stats(scheduler_outputs)) return request_outputs + def _process_sequence_group_outputs_multi_step(self, seq_group, outputs): + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + + assert seqs + #if not seqs: + # return [] + + assert len(seqs) == 1, ("Beam search not supported in speculative " + "decoding.") + seq = seqs[0] + + # Since there's only one sequence per sequence group, we can take the + # first sample. + samples = [outputs[step].samples[0] for step in range(len(outputs))] + + # -1 means the output token is not valid (eg. due to spec decode + # rejecting tokens). + valid_samples = [ + sample for sample in samples if sample.output_token != -1 + ] + + # Draft target worker pads all outputs with -1 to have same length. + output_token_ids = [sample.output_token for sample in valid_samples] + #successes = [sample.success for sample in samples] + + ## Truncate to max_tokens if necessary. + #remaining_tokens = seq_group.sampling_params.max_tokens - ( + # seq.get_output_len() + len(output_token_ids)) + #if remaining_tokens < 0: + # valid_samples = valid_samples[:remaining_tokens] + # output_token_ids = output_token_ids[:remaining_tokens] + + ## Truncate any tokens after EOS. This is required as spec decode + ## generates tokens in fixed blocks, which may go beyond the EOS token. + #if not seq_group.sampling_params.ignore_eos: + # eos_token_id = self.tokenizer.get_lora_tokenizer( + # seq.lora_request).eos_token_id + # # Avoiding .index calls as exception throwing in the happy path + # # is expensive. + # for i in range(len(output_token_ids)): + # if output_token_ids[i] == eos_token_id: + # output_token_ids = output_token_ids[:i + 1] + # valid_samples = valid_samples[:i + 1] + # break + + #output_logprobs = [sample.logprobs for sample in valid_samples] + + ## Use the last sample for the sequence as it will have + ## the speculation and num_unprocessed_tokens for all the + ## previous samples (they are cumulative when it comes + ## to those two attributes). + #speculation = valid_samples[-1].speculation + #num_unprocessed_tokens = valid_samples[-1].num_unprocessed_tokens + + for output_token_id in output_token_ids: + from vllm.sequence import Logprob + seq.append_token_id( + token_id=output_token_id, + logprobs={output_token_id: Logprob(0.0)}, + ) + print(f'Appended token id {output_token_id=}') + + #seq.append_token_ids(output_token_ids, + # output_logprobs, + # ) + # #num_unprocessed_tokens=num_unprocessed_tokens) + ##seq.set_last_speculation(speculation) + + #if not all(successes): + # seq.set_status_to_failed() + + #if decode: + # self._decode_sequence(seq, + # seq_group.sampling_params, + # token_ids=seq.get_token_ids(), + # unseen_token_ids=output_token_ids, + # prefix_offset=seq.prefix_offset, + # read_offset=seq.read_offset) + #self._check_stop(seq, seq_group.sampling_params, seq.lora_request, + # output_token_ids) + # TODO pass output token ids + self._check_stop(seq, seq_group.sampling_params) + if seq.is_finished(): + self.scheduler.free_seq(seq) + def step(self) -> List[RequestOutput]: """Performs one decoding iteration and returns newly generated results. @@ -804,9 +913,11 @@ def _check_stop(self, seq: Sequence, if seq.get_len() > self.scheduler_config.max_model_len: seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return - + + breakpoint() # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: + if seq.get_output_len() >= sampling_params.max_tokens: + # TODO should cap block seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index cb1480de03e3..4f0cc4405e81 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -684,4 +684,10 @@ def _build_sampler_output( SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) sampler_output.append( SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) - return SamplerOutput(outputs=sampler_output) + + return SamplerOutput( + outputs=sampler_output, + # TODO + sampled_token_probs=torch.empty((len(sampler_output), 50_272), device='cuda', dtype=torch.float32), + sampled_token_ids=torch.empty((len(sampler_output), 1), device='cuda', dtype=torch.long), + ) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index e0b75837e8a3..89be25252c2c 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -83,7 +83,8 @@ def score_proposals( blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - return_python_output=False) + #return_python_output=False + ) all_tokens, all_probs = self._contract_batch( original_bs=len(seq_group_metadata_list), diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 73b6e201c67a..c817f54d7fe3 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -340,7 +340,7 @@ def _merge_outputs( return proposal_tokens, proposal_probs, proposal_lens sampler_output = maybe_sampler_output - + proposal_tokens, proposal_probs = sampler_output_to_torch( sampler_output) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index a2c9a9944af5..85667a6c3dd4 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -5,7 +5,7 @@ from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, - SequenceGroupOutput, SequenceOutput) + SequenceGroupOutput, SequenceOutput, Logprob) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) @@ -316,7 +316,7 @@ def _create_output_sampler_list( parent_seq_id=seq_id, output_token=token_id, # TODO Add verifier logprobs. - logprobs={token_id: 0.0}, + logprobs={token_id: Logprob(0.0)}, ) ], prompt_logprobs=None, From cb23e8ca4e6ff3c667b44e9ce4f179f629740008 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 00:07:10 -0700 Subject: [PATCH 004/120] wip --- vllm/engine/llm_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1bd4129090c2..15ef7df26b0b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -914,7 +914,6 @@ def _check_stop(self, seq: Sequence, seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return - breakpoint() # Check if the sequence has reached max_tokens. if seq.get_output_len() >= sampling_params.max_tokens: # TODO should cap block From 143ca28e5de41f1d32e730bc3e9da2a954a2024e Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 00:14:02 -0700 Subject: [PATCH 005/120] wip --- vllm/executor/cpu_executor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 2bf97338da0e..835ba18ab756 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -80,7 +80,8 @@ def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int) -> SamplerOutput: output = self.driver_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, From d8d4725d3365e25c67cbb115e5a437fd7e574fd0 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 13:41:20 -0700 Subject: [PATCH 006/120] fix --- tests/spec_decode/e2e/test_correctness.py | 7 +++++-- vllm/model_executor/layers/sampler.py | 2 +- vllm/spec_decode/spec_decode_worker.py | 11 +++++++++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index 782bd9d0cecb..fc5640d23ab5 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -7,10 +7,13 @@ "common_llm_kwargs", [{ # Use a small model for a fast test. - "model": "facebook/opt-125m", - "speculative_model": "facebook/opt-125m", + "model": "JackFram/llama-68m", + "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + # Skip real loading for fast test. + "load_format": "dummy", + # Skip cuda graph recording for fast test. "enforce_eager": True, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 4f0cc4405e81..9540a3d89bd8 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -688,6 +688,6 @@ def _build_sampler_output( return SamplerOutput( outputs=sampler_output, # TODO - sampled_token_probs=torch.empty((len(sampler_output), 50_272), device='cuda', dtype=torch.float32), + sampled_token_probs=torch.empty((len(sampler_output), 32_000), device='cuda', dtype=torch.float32), sampled_token_ids=torch.empty((len(sampler_output), 1), device='cuda', dtype=torch.long), ) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 85667a6c3dd4..f665c3b72219 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -15,7 +15,9 @@ split_batch_by_proposal_len) from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.logger import init_logger +logger = init_logger(__name__) class SpecDecodeWorker(LoraNotSupportedWorkerBase): """Worker which implements speculative decoding. @@ -144,6 +146,8 @@ def execute_model( "speculative decoding " "requires non-None seq_group_metadata_list") + logger.info(f"spec_decode_worker.execute_model {num_lookahead_slots=}") + # If no spec tokens, call the proposer and scorer workers normally. # Used for prefill. if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0: @@ -174,6 +178,7 @@ def _run_no_spec( proposer and scorer model so that the KV cache is consistent between the two. """ + logger.info("run proposer worker no spec") self.proposer_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, @@ -183,6 +188,7 @@ def _run_no_spec( #return_python_output=False ) + logger.info("run target worker no spec") sampler_output = self.scorer_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, @@ -214,11 +220,14 @@ def _run_speculative_decoding_step( sequence. """ + logger.info("get spec proposals") # Generate proposals using draft worker. proposals = self.proposer_worker.get_spec_proposals( seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, k) + #logger.info(f"score proposals {proposals=}") + logger.info(f"score proposals") proposal_scores = self.scorer.score_proposals( seq_group_metadata_list, blocks_to_swap_in, @@ -228,9 +237,11 @@ def _run_speculative_decoding_step( proposals, ) + logger.info("verify proposals") accepted_token_ids = self._verify_tokens(seq_group_metadata_list, proposal_scores, proposals, k) + logger.info("create output list") return self._create_output_sampler_list(seq_group_metadata_list, accepted_token_ids, k) From b2728e03de0703d9e479bd9e0e4aa3f158f426f6 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 14:03:53 -0700 Subject: [PATCH 007/120] wip --- tests/spec_decode/e2e/test_correctness.py | 54 +++++++++++++++++++++- vllm/spec_decode/spec_decode_worker.py | 55 ++++++++++++++++++++++- vllm/worker/worker.py | 3 ++ 3 files changed, 109 insertions(+), 3 deletions(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index fc5640d23ab5..28a88a750edb 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -20,10 +20,14 @@ # Required for spec decode. "use_v2_block_manager": True }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "tensor_parallel_size": 1, + }, +]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("seed", [1]) -def test_spec_decode_config(test_llm_generator): +def test_spec_decode(test_llm_generator): output_len = 128 temperature = 0.0 @@ -46,6 +50,51 @@ def test_spec_decode_config(test_llm_generator): get_token_ids_from_llm_generator(test_llm_generator, prompts, sampling_params) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "JackFram/llama-68m", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Skip real loading for fast test. + "load_format": "dummy", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + # Expect failure as spec decode not supported by + # Ray backend. + "tensor_parallel_size": 2, + }, +]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail(test_llm_generator): + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises( + AssertionError, + match="Speculative decoding not yet supported for "): + get_token_ids_from_llm_generator(test_llm_generator, prompts, + sampling_params) def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): for llm in llm_generator: @@ -54,3 +103,4 @@ def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): del llm return token_ids + diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index f665c3b72219..3802ed42f786 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -19,6 +19,60 @@ logger = init_logger(__name__) +def create_spec_decode_worker(): + + from vllm.worker.worker import Worker + from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker + from vllm.spec_decode.multi_step_worker import MultiStepWorker + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + target_worker = Worker( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + is_driver_worker=True, + ) + + from vllm.spec_decode.multi_step_worker import MultiStepWorker + draft_worker = MultiStepWorker( + model_config=self.speculative_config.draft_model_config, + parallel_config=self.speculative_config.draft_parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + is_driver_worker=True, + ) + + from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker + from vllm.model_executor.layers.rejection_sampler import RejectionSampler + spec_decode_worker = SpecDecodeWorker( + proposer_worker=draft_worker, + scorer_worker=target_worker, + rejection_sampler=RejectionSampler(), + ) + + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + self.driver_worker = spec_decode_worker + + self.driver_worker.init_device() + #self.driver_worker.load_model() + class SpecDecodeWorker(LoraNotSupportedWorkerBase): """Worker which implements speculative decoding. @@ -226,7 +280,6 @@ def _run_speculative_decoding_step( seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, k) - #logger.info(f"score proposals {proposals=}") logger.info(f"score proposals") proposal_scores = self.scorer.score_proposals( seq_group_metadata_list, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b46229c5b694..5d9a9acd763e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -205,7 +205,10 @@ def execute_model( blocks_to_swap_in: Optional[Dict[int, int]] = None, blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, + num_lookahead_slots: int = 0, ) -> Optional[SamplerOutput]: + assert (num_lookahead_slots == 0), "worker does not support lookahead slots" + if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups = len(seq_group_metadata_list) From 6250f6cf32842de588edfe58f93e942a64cfd5b6 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 14:12:50 -0700 Subject: [PATCH 008/120] assertion --- tests/spec_decode/e2e/test_correctness.py | 26 ++++++++++++++++------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index 28a88a750edb..92076d88ea83 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -1,4 +1,5 @@ import pytest +from itertools import cycle from vllm import SamplingParams @@ -26,30 +27,39 @@ }, ]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("seed", [1]) -def test_spec_decode(test_llm_generator): +def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): + """Run generation with speculative decoding on a batch. Verify the number + of output tokens is equal to the expected number. + """ output_len = 128 temperature = 0.0 prompts = [ "Hello, my name is", - #"The president of the United States is", - #"The capital of France is", - #"The future of AI is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", ] + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + sampling_params = SamplingParams( max_tokens=output_len, ignore_eos=True, temperature=temperature, ) - #with pytest.raises( - # AssertionError, - # match="Speculative decoding not yet supported for GPU backend"): - get_token_ids_from_llm_generator(test_llm_generator, prompts, + batch_token_ids = get_token_ids_from_llm_generator(test_llm_generator, prompts, sampling_params) + # Expect a generation for each prompt in the batch. + assert len(batch_token_ids) == len(prompts) + + # TODO(cadedaniel) check for equality once block truncation is implemented. + assert all(len(token_ids) >= output_len for token_ids in batch_token_ids) + @pytest.mark.parametrize( "common_llm_kwargs", [{ From a930755de760545726cfcc9de5fc8d51a4b6fb71 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 14:18:19 -0700 Subject: [PATCH 009/120] fix --- vllm/model_executor/layers/sampler.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 9540a3d89bd8..7c7148b12229 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -78,8 +78,15 @@ def forward( # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) + + breakpoint() + + return _build_sampler_output(sample_results, sampling_metadata, - prompt_logprobs, sample_logprobs) + prompt_logprobs, sample_logprobs, + sampled_token_probs=probs, + sampled_token_ids=torch.empty((len(sampling_metadata.seq_groups), 1), device=probs.device, dtype=torch.long), + ) def _get_bin_counts_and_mask( @@ -668,6 +675,8 @@ def _build_sampler_output( sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], sample_logprobs: List[SampleLogprobs], + sampled_token_ids: Optional[torch.Tensor] = None, + sampled_token_probs: Optional[torch.Tensor] = None, ) -> SamplerOutput: sampler_output = [] for (seq_group, sample_result, group_prompt_logprobs, @@ -687,7 +696,6 @@ def _build_sampler_output( return SamplerOutput( outputs=sampler_output, - # TODO - sampled_token_probs=torch.empty((len(sampler_output), 32_000), device='cuda', dtype=torch.float32), - sampled_token_ids=torch.empty((len(sampler_output), 1), device='cuda', dtype=torch.long), + sampled_token_probs=sampled_token_probs, + sampled_token_ids=sampled_token_ids, ) From 5b896a3fe4e9614ee2557a9361cb381f88eeb15d Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 14:18:43 -0700 Subject: [PATCH 010/120] fix --- vllm/model_executor/layers/sampler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 7c7148b12229..71807b25834a 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -79,9 +79,7 @@ def forward( prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) - breakpoint() - - + # TODO gate by config return _build_sampler_output(sample_results, sampling_metadata, prompt_logprobs, sample_logprobs, sampled_token_probs=probs, From bb43b530ce2eeecaa29a8108dc17e0f24b80b099 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 14:19:23 -0700 Subject: [PATCH 011/120] lint --- tests/spec_decode/e2e/test_correctness.py | 29 +++++++++++++---------- vllm/engine/llm_engine.py | 11 +++++---- vllm/executor/gpu_executor.py | 19 ++++++++------- vllm/model_executor/layers/sampler.py | 14 +++++++---- vllm/spec_decode/batch_expansion.py | 2 +- vllm/spec_decode/multi_step_worker.py | 2 +- vllm/spec_decode/spec_decode_worker.py | 10 ++++---- vllm/worker/worker.py | 3 ++- 8 files changed, 52 insertions(+), 38 deletions(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index 92076d88ea83..36a66ea2ec38 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -51,8 +51,9 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): temperature=temperature, ) - batch_token_ids = get_token_ids_from_llm_generator(test_llm_generator, prompts, - sampling_params) + batch_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, + sampling_params) # Expect a generation for each prompt in the batch. assert len(batch_token_ids) == len(prompts) @@ -60,6 +61,7 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): # TODO(cadedaniel) check for equality once block truncation is implemented. assert all(len(token_ids) >= output_len for token_ids in batch_token_ids) + @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -77,13 +79,15 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): # Required for spec decode. "use_v2_block_manager": True }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - { - # Expect failure as spec decode not supported by - # Ray backend. - "tensor_parallel_size": 2, - }, -]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Expect failure as spec decode not supported by + # Ray backend. + "tensor_parallel_size": 2, + }, + ]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("seed", [1]) def test_spec_decode_xfail(test_llm_generator): @@ -100,12 +104,12 @@ def test_spec_decode_xfail(test_llm_generator): temperature=temperature, ) - with pytest.raises( - AssertionError, - match="Speculative decoding not yet supported for "): + with pytest.raises(AssertionError, + match="Speculative decoding not yet supported for "): get_token_ids_from_llm_generator(test_llm_generator, prompts, sampling_params) + def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): for llm in llm_generator: outputs = llm.generate(prompts, sampling_params, use_tqdm=True) @@ -113,4 +117,3 @@ def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): del llm return token_ids - diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 15ef7df26b0b..1ca447890d4c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -627,7 +627,6 @@ def _process_model_outputs( self, output: SamplerOutput, scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: - if not isinstance(output, list): all_output = [output] else: @@ -646,7 +645,8 @@ def _process_model_outputs( now = time.time() # Update the scheduled sequence groups with the model outputs. - for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output_by_sequence_group): + for scheduled_seq_group, outputs in zip(scheduled_seq_groups, + output_by_sequence_group): seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( @@ -655,7 +655,8 @@ def _process_model_outputs( assert len(outputs) > 0 # TODO can spec decode go through second path? if len(outputs) > 1: - self._process_sequence_group_outputs_multi_step(seq_group, outputs) + self._process_sequence_group_outputs_multi_step( + seq_group, outputs) else: self._process_sequence_group_outputs(seq_group, outputs[0]) @@ -825,7 +826,7 @@ def step(self) -> List[RequestOutput]: num_lookahead_slots=scheduler_outputs.num_lookahead_slots) else: output = [] - + return self._process_model_outputs(output, scheduler_outputs) def do_log_stats(self) -> None: @@ -913,7 +914,7 @@ def _check_stop(self, seq: Sequence, if seq.get_len() > self.scheduler_config.max_model_len: seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return - + # Check if the sequence has reached max_tokens. if seq.get_output_len() >= sampling_params.max_tokens: # TODO should cap block diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 80ec79ba3c3c..60c9a9ca3c78 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -75,7 +75,7 @@ def _init_spec_worker(self): vision_language_config=self.vision_language_config, is_driver_worker=True, ) - + from vllm.spec_decode.multi_step_worker import MultiStepWorker draft_worker = MultiStepWorker( model_config=self.speculative_config.draft_model_config, @@ -90,7 +90,7 @@ def _init_spec_worker(self): vision_language_config=self.vision_language_config, is_driver_worker=True, ) - + from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.model_executor.layers.rejection_sampler import RejectionSampler spec_decode_worker = SpecDecodeWorker( @@ -150,13 +150,14 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int, - ) -> SamplerOutput: + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int, + ) -> SamplerOutput: output = self.driver_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 71807b25834a..5c1017207878 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -80,11 +80,17 @@ def forward( logprobs, sampling_metadata, sample_results) # TODO gate by config - return _build_sampler_output(sample_results, sampling_metadata, - prompt_logprobs, sample_logprobs, + return _build_sampler_output( + sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, sampled_token_probs=probs, - sampled_token_ids=torch.empty((len(sampling_metadata.seq_groups), 1), device=probs.device, dtype=torch.long), - ) + sampled_token_ids=torch.empty( + (len(sampling_metadata.seq_groups), 1), + device=probs.device, + dtype=torch.long), + ) def _get_bin_counts_and_mask( diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 89be25252c2c..6be8c843cf7a 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -84,7 +84,7 @@ def score_proposals( blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, #return_python_output=False - ) + ) all_tokens, all_probs = self._contract_batch( original_bs=len(seq_group_metadata_list), diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index c817f54d7fe3..73b6e201c67a 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -340,7 +340,7 @@ def _merge_outputs( return proposal_tokens, proposal_probs, proposal_lens sampler_output = maybe_sampler_output - + proposal_tokens, proposal_probs = sampler_output_to_torch( sampler_output) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3802ed42f786..12a70d402e98 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -19,8 +19,9 @@ logger = init_logger(__name__) + def create_spec_decode_worker(): - + from vllm.worker.worker import Worker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker @@ -41,7 +42,7 @@ def create_spec_decode_worker(): vision_language_config=self.vision_language_config, is_driver_worker=True, ) - + from vllm.spec_decode.multi_step_worker import MultiStepWorker draft_worker = MultiStepWorker( model_config=self.speculative_config.draft_model_config, @@ -56,7 +57,7 @@ def create_spec_decode_worker(): vision_language_config=self.vision_language_config, is_driver_worker=True, ) - + from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.model_executor.layers.rejection_sampler import RejectionSampler spec_decode_worker = SpecDecodeWorker( @@ -73,6 +74,7 @@ def create_spec_decode_worker(): self.driver_worker.init_device() #self.driver_worker.load_model() + class SpecDecodeWorker(LoraNotSupportedWorkerBase): """Worker which implements speculative decoding. @@ -240,7 +242,7 @@ def _run_no_spec( blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, #return_python_output=False - ) + ) logger.info("run target worker no spec") sampler_output = self.scorer_worker.execute_model( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 5d9a9acd763e..941c06208129 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -207,7 +207,8 @@ def execute_model( blocks_to_copy: Optional[Dict[int, List[int]]] = None, num_lookahead_slots: int = 0, ) -> Optional[SamplerOutput]: - assert (num_lookahead_slots == 0), "worker does not support lookahead slots" + assert (num_lookahead_slots == 0 + ), "worker does not support lookahead slots" if self.is_driver_worker: assert seq_group_metadata_list is not None From cde3160fdd542b80abba0d9855c98d8a12d959ac Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 14:57:45 -0700 Subject: [PATCH 012/120] fix --- vllm/executor/gpu_executor.py | 2 +- vllm/model_executor/layers/sampler.py | 11 ++++++----- vllm/sequence.py | 10 ++++++++++ vllm/spec_decode/batch_expansion.py | 10 +++++++++- vllm/spec_decode/multi_step_worker.py | 10 +++++++++- vllm/spec_decode/util.py | 7 +++++++ 6 files changed, 42 insertions(+), 8 deletions(-) diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 60c9a9ca3c78..ac445cd51a7e 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -96,7 +96,7 @@ def _init_spec_worker(self): spec_decode_worker = SpecDecodeWorker( proposer_worker=draft_worker, scorer_worker=target_worker, - rejection_sampler=RejectionSampler(), + rejection_sampler=RejectionSampler(strict_mode=True), ) assert self.parallel_config.world_size == 1, ( diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 5c1017207878..135bc13e8d7c 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -85,11 +85,12 @@ def forward( sampling_metadata, prompt_logprobs, sample_logprobs, - sampled_token_probs=probs, - sampled_token_ids=torch.empty( - (len(sampling_metadata.seq_groups), 1), - device=probs.device, - dtype=torch.long), + #sampled_token_probs=probs, + ## TODO + #sampled_token_ids=torch.empty( + # (len(sampling_metadata.seq_groups), 1), + # device=probs.device, + # dtype=torch.long), ) diff --git a/vllm/sequence.py b/vllm/sequence.py index 576bbe8c4f6c..223a7cf80232 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -686,3 +686,13 @@ def __len__(self): def __eq__(self, other: object): return isinstance(other, self.__class__) and self.outputs == other.outputs + + def __repr__(self) -> str: + """Show the shape of a tensor instead of its values to reduce noise. + """ + sampled_token_probs_repr = ("None" if self.sampled_token_probs is None else self.sampled_token_probs.shape) + sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else self.sampled_token_ids.shape) + return (f"SamplerOutput(outputs={self.outputs}, " + f"sampled_token_probs={sampled_token_probs_repr}, " + f"sampled_token_ids={sampled_token_ids_repr}, " + f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 6be8c843cf7a..701324c16dfe 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -8,7 +8,7 @@ SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, sampler_output_to_torch, - split_batch_by_proposal_len) + split_batch_by_proposal_len, mock_device_tensors) from vllm.worker.worker import Worker SeqId = int @@ -143,6 +143,14 @@ def _contract_batch(self, original_bs: int, This maps the scores of speculative tokens back to their original sequences. """ + + mock_device_tensors( + sampler_output=target_sampler_output, + batch_size=len(non_spec_indices) + num_scoring_tokens, + vocab_size=self._vocab_size, + device=self._device, + ) + (target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 73b6e201c67a..262bab162649 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -6,7 +6,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) -from vllm.spec_decode.util import sampler_output_to_torch +from vllm.spec_decode.util import (sampler_output_to_torch, mock_device_tensors) from vllm.worker.worker import Worker @@ -341,6 +341,14 @@ def _merge_outputs( sampler_output = maybe_sampler_output + for step_output in sampler_output: + mock_device_tensors( + sampler_output=step_output, + batch_size=len(proposal_lens), + vocab_size=self._vocab_size, + device=self._device, + ) + proposal_tokens, proposal_probs = sampler_output_to_torch( sampler_output) diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 406568a4bc08..234ed9e44f4e 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -82,6 +82,13 @@ def sampler_output_to_torch( return sampled_token_ids, sampled_token_probs +def mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, vocab_size: int, device: str) -> None: + assert sampler_output.sampled_token_probs is None + assert sampler_output.sampled_token_ids is None + + sampler_output.sampled_token_probs = torch.nn.functional.softmax(torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device), dim=-1) + sampler_output.sampled_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size,), dtype=torch.long, device=device) + @contextmanager def nvtx_range(msg, *args, **kwargs): """ From dd8aeff307f7c035b7db4a5184d00172cad6c3e9 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 15:00:34 -0700 Subject: [PATCH 013/120] fix --- vllm/engine/llm_engine.py | 1 - vllm/sequence.py | 9 +++-- vllm/spec_decode/batch_expansion.py | 3 +- vllm/spec_decode/multi_step_worker.py | 5 ++- vllm/spec_decode/spec_decode_worker.py | 55 -------------------------- vllm/spec_decode/util.py | 14 +++++-- 6 files changed, 22 insertions(+), 65 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1ca447890d4c..9d65ec1a2faa 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -739,7 +739,6 @@ def _process_sequence_group_outputs_multi_step(self, seq_group, outputs): token_id=output_token_id, logprobs={output_token_id: Logprob(0.0)}, ) - print(f'Appended token id {output_token_id=}') #seq.append_token_ids(output_token_ids, # output_logprobs, diff --git a/vllm/sequence.py b/vllm/sequence.py index 223a7cf80232..fa51483301a3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -690,9 +690,12 @@ def __eq__(self, other: object): def __repr__(self) -> str: """Show the shape of a tensor instead of its values to reduce noise. """ - sampled_token_probs_repr = ("None" if self.sampled_token_probs is None else self.sampled_token_probs.shape) - sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else self.sampled_token_ids.shape) - return (f"SamplerOutput(outputs={self.outputs}, " + sampled_token_probs_repr = ("None" if self.sampled_token_probs is None + else self.sampled_token_probs.shape) + sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else + self.sampled_token_ids.shape) + return ( + f"SamplerOutput(outputs={self.outputs}, " f"sampled_token_probs={sampled_token_probs_repr}, " f"sampled_token_ids={sampled_token_ids_repr}, " f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 701324c16dfe..bba3c4733e4f 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -8,7 +8,8 @@ SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, sampler_output_to_torch, - split_batch_by_proposal_len, mock_device_tensors) + split_batch_by_proposal_len, + mock_device_tensors) from vllm.worker.worker import Worker SeqId = int diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 262bab162649..0ac189a7bacc 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -6,7 +6,8 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) -from vllm.spec_decode.util import (sampler_output_to_torch, mock_device_tensors) +from vllm.spec_decode.util import (sampler_output_to_torch, + mock_device_tensors) from vllm.worker.worker import Worker @@ -343,7 +344,7 @@ def _merge_outputs( for step_output in sampler_output: mock_device_tensors( - sampler_output=step_output, + sampler_output=step_output, batch_size=len(proposal_lens), vocab_size=self._vocab_size, device=self._device, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 12a70d402e98..3e33371edadf 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -20,61 +20,6 @@ logger = init_logger(__name__) -def create_spec_decode_worker(): - - from vllm.worker.worker import Worker - from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker - from vllm.spec_decode.multi_step_worker import MultiStepWorker - - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - - target_worker = Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=True, - ) - - from vllm.spec_decode.multi_step_worker import MultiStepWorker - draft_worker = MultiStepWorker( - model_config=self.speculative_config.draft_model_config, - parallel_config=self.speculative_config.draft_parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=True, - ) - - from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker - from vllm.model_executor.layers.rejection_sampler import RejectionSampler - spec_decode_worker = SpecDecodeWorker( - proposer_worker=draft_worker, - scorer_worker=target_worker, - rejection_sampler=RejectionSampler(), - ) - - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - self.driver_worker = spec_decode_worker - - self.driver_worker.init_device() - #self.driver_worker.load_model() - - class SpecDecodeWorker(LoraNotSupportedWorkerBase): """Worker which implements speculative decoding. diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 234ed9e44f4e..7129f47d65f6 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -82,12 +82,20 @@ def sampler_output_to_torch( return sampled_token_ids, sampled_token_probs -def mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, vocab_size: int, device: str) -> None: +def mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, + vocab_size: int, device: str) -> None: assert sampler_output.sampled_token_probs is None assert sampler_output.sampled_token_ids is None - sampler_output.sampled_token_probs = torch.nn.functional.softmax(torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device), dim=-1) - sampler_output.sampled_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size,), dtype=torch.long, device=device) + sampler_output.sampled_token_probs = torch.nn.functional.softmax( + torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device), + dim=-1) + sampler_output.sampled_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, ), + dtype=torch.long, + device=device) + @contextmanager def nvtx_range(msg, *args, **kwargs): From 46e48474ab355254f4d831b86f2b3303abde0d22 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 15:10:22 -0700 Subject: [PATCH 014/120] test --- tests/spec_decode/e2e/test_correctness.py | 8 +++++--- vllm/engine/llm_engine.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index 36a66ea2ec38..a1df4dccbe3b 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -9,8 +9,6 @@ [{ # Use a small model for a fast test. "model": "JackFram/llama-68m", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, # Skip real loading for fast test. "load_format": "dummy", @@ -23,7 +21,11 @@ }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { - "tensor_parallel_size": 1, + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, + { + # No spec decode. }, ]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9d65ec1a2faa..a08a883539a9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -627,7 +627,7 @@ def _process_model_outputs( self, output: SamplerOutput, scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: - if not isinstance(output, list): + if self.speculative_config is None: all_output = [output] else: all_output = output @@ -638,7 +638,7 @@ def _process_model_outputs( output_by_sequence_group: List[List[SequenceGroupOutputs]] = [ [] for _ in scheduled_seq_groups ] - for step in output: + for step in all_output: for i, sequence_group_output in enumerate(step): output_by_sequence_group[i].append(sequence_group_output) From 8454edc8bf13cb04936b7f552f7e6ec368a6693f Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 15:41:02 -0700 Subject: [PATCH 015/120] test fixes --- tests/spec_decode/test_spec_decode_worker.py | 14 +++++++------- vllm/engine/llm_engine.py | 2 +- vllm/executor/ray_gpu_executor.py | 3 ++- vllm/worker/worker.py | 2 -- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 47aff8f57541..bd06d5b17d07 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -37,7 +37,7 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): execute_model_data, _, _ = create_batch(batch_size, k) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + worker.execute_model(**execute_model_data.to_dict(), num_lookahead_slots=k) call_args_list = draft_worker.get_spec_proposals.call_args_list assert len(call_args_list) == 1 @@ -102,7 +102,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int): target_worker.execute_model.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + worker.execute_model(**execute_model_data.to_dict(), num_lookahead_slots=k) seen_contexts = [] @@ -195,7 +195,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): rejection_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + worker.execute_model(**execute_model_data.to_dict(), num_lookahead_slots=k) assert len(rejection_sampler.call_args_list) == 1 args, _ = rejection_sampler.call_args_list[0] @@ -283,7 +283,7 @@ def test_correctly_formats_output(k: int, batch_size: int): rejection_sampler.return_value = rejection_sampler_output output = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + num_lookahead_slots=k) expected_output = create_sampler_output_list( rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)]) @@ -400,7 +400,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): mock_rejsample_metrics) output = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + num_lookahead_slots=k) assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics call_args_list = ( @@ -435,7 +435,7 @@ def test_k_equals_zero(k: int, batch_size: int): batch_size, k, prev_output_token_len=0) out = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + num_lookahead_slots=k) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].probs is None, "expect gpu tensor references to be None" @@ -474,7 +474,7 @@ def test_empty_input_batch(k: int, batch_size: int): batch_size, k, prev_output_token_len=0) out = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + num_lookahead_slots=k) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].probs is None, "expect gpu tensor references to be None" diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a08a883539a9..e47af8dfcf9e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -915,7 +915,7 @@ def _check_stop(self, seq: Sequence, return # Check if the sequence has reached max_tokens. - if seq.get_output_len() >= sampling_params.max_tokens: + if seq.get_output_len() >= int(sampling_params.max_tokens): # TODO should cap block seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index a508d1e8fe60..226183855708 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -238,7 +238,8 @@ def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int = 0) -> SamplerOutput: all_outputs = self._run_workers( "execute_model", driver_kwargs={ diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 941c06208129..cb30f658482b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -207,8 +207,6 @@ def execute_model( blocks_to_copy: Optional[Dict[int, List[int]]] = None, num_lookahead_slots: int = 0, ) -> Optional[SamplerOutput]: - assert (num_lookahead_slots == 0 - ), "worker does not support lookahead slots" if self.is_driver_worker: assert seq_group_metadata_list is not None From 819e65695455e9d63e4ed306f313b1d96f6b2c9a Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 15:41:22 -0700 Subject: [PATCH 016/120] lint --- tests/spec_decode/e2e/test_correctness.py | 20 +++++++++++--------- tests/spec_decode/test_spec_decode_worker.py | 9 ++++++--- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index a1df4dccbe3b..d8b09ce5b77a 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -19,15 +19,17 @@ # Required for spec decode. "use_v2_block_manager": True }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - { - # No spec decode. - }, -]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, + { + # No spec decode. + }, + ]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("seed", [1]) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index bd06d5b17d07..3725924ea89c 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -37,7 +37,8 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): execute_model_data, _, _ = create_batch(batch_size, k) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_lookahead_slots=k) + worker.execute_model(**execute_model_data.to_dict(), + num_lookahead_slots=k) call_args_list = draft_worker.get_spec_proposals.call_args_list assert len(call_args_list) == 1 @@ -102,7 +103,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int): target_worker.execute_model.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_lookahead_slots=k) + worker.execute_model(**execute_model_data.to_dict(), + num_lookahead_slots=k) seen_contexts = [] @@ -195,7 +197,8 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): rejection_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_lookahead_slots=k) + worker.execute_model(**execute_model_data.to_dict(), + num_lookahead_slots=k) assert len(rejection_sampler.call_args_list) == 1 args, _ = rejection_sampler.call_args_list[0] From d0fbe47bdb778b9ba32bda2b0d9a621d9ecd1134 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 16:01:35 -0700 Subject: [PATCH 017/120] clean --- vllm/model_executor/layers/sampler.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 135bc13e8d7c..bed915faf3fb 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -79,19 +79,7 @@ def forward( prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) - # TODO gate by config - return _build_sampler_output( - sample_results, - sampling_metadata, - prompt_logprobs, - sample_logprobs, - #sampled_token_probs=probs, - ## TODO - #sampled_token_ids=torch.empty( - # (len(sampling_metadata.seq_groups), 1), - # device=probs.device, - # dtype=torch.long), - ) + return _build_sampler_output(sample_results, sampling_metadata, prompt_logprobs, sample_logprobs) def _get_bin_counts_and_mask( @@ -699,8 +687,4 @@ def _build_sampler_output( sampler_output.append( SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) - return SamplerOutput( - outputs=sampler_output, - sampled_token_probs=sampled_token_probs, - sampled_token_ids=sampled_token_ids, - ) + return SamplerOutput(outputs=sampler_output) From 5445af6ddf43cf9b1b82dc53260627e455d0ae81 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 16:45:19 -0700 Subject: [PATCH 018/120] refactor out beam search model processor --- vllm/engine/llm_engine.py | 537 ++++++++++--------- vllm/engine/output_processor/__init__.py | 0 vllm/engine/output_processor/beam_search.py | 321 +++++++++++ vllm/engine/output_processor/block_decode.py | 186 +++++++ vllm/engine/output_processor/interfaces.py | 36 ++ 5 files changed, 817 insertions(+), 263 deletions(-) create mode 100644 vllm/engine/output_processor/__init__.py create mode 100644 vllm/engine/output_processor/beam_search.py create mode 100644 vllm/engine/output_processor/block_decode.py create mode 100644 vllm/engine/output_processor/interfaces.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e47af8dfcf9e..1ac73bc874de 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -25,6 +25,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter +from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -180,6 +181,14 @@ def __init__( labels=dict(model_name=model_config.model)) self.stat_logger.info("cache_config", self.cache_config) + self.output_processor = SequenceGroupOutputProcessor.create_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + self.get_tokenizer_for_seq, + ) + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -449,179 +458,179 @@ def _check_beam_search_early_stopping( eos_token_id=best_running_seq.eos_token_id)) return current_worst_score >= highest_attainable_score - def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutput) -> None: - - # Process prompt logprobs - prompt_logprobs = outputs.prompt_logprobs - if prompt_logprobs is not None and seq_group.sampling_params.detokenize: - self.detokenizer.decode_prompt_logprobs_inplace( - seq_group, prompt_logprobs) - seq_group.prompt_logprobs = prompt_logprobs - - # Process samples - samples = outputs.samples - parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - existing_finished_seqs = seq_group.get_finished_seqs() - parent_child_dict = { - parent_seq.seq_id: [] - for parent_seq in parent_seqs - } - for sample in samples: - parent_child_dict[sample.parent_seq_id].append(sample) - # List of (child, parent) - child_seqs: List[Tuple[Sequence, Sequence]] = [] - - # Process the child samples for each parent sequence - for parent in parent_seqs: - child_samples: List[SequenceOutput] = parent_child_dict[ - parent.seq_id] - if len(child_samples) == 0: - # This parent sequence has no children samples. Remove - # the parent sequence from the sequence group since it will - # not be used in the future iterations. - parent.status = SequenceStatus.FINISHED_ABORTED - seq_group.remove(parent.seq_id) - self.scheduler.free_seq(parent) - continue - # Fork the parent sequence if there are multiple child samples. - for child_sample in child_samples[:-1]: - new_child_seq_id = next(self.seq_counter) - child = parent.fork(new_child_seq_id) - child.append_token_id(child_sample.output_token, - child_sample.logprobs) - child_seqs.append((child, parent)) - # Continue the parent sequence for the last child sample. - # We reuse the parent sequence here to reduce redundant memory - # copies, especially when using non-beam search sampling methods. - last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs) - child_seqs.append((parent, parent)) - - for seq, _ in child_seqs: - if seq_group.sampling_params.detokenize: - self.detokenizer.decode_sequence_inplace( - seq, seq_group.sampling_params) - self._check_stop(seq, seq_group.sampling_params) - - # Non-beam search case - if not seq_group.sampling_params.use_beam_search: - # For newly created child sequences, add them to the sequence group - # and fork them in block manager if they are not finished. - for seq, parent in child_seqs: - if seq is not parent: - seq_group.add(seq) - if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) - - # Free the finished and selected parent sequences' memory in block - # manager. Keep them in the sequence group as candidate output. - # NOTE: we need to fork the new sequences before freeing the - # old sequences. - for seq, parent in child_seqs: - if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) - return - - # Beam search case - # Select the child sequences to keep in the sequence group. - selected_child_seqs = [] - unselected_child_seqs = [] - beam_width = seq_group.sampling_params.best_of - length_penalty = seq_group.sampling_params.length_penalty - - # Select the newly finished sequences with the highest scores - # to replace existing finished sequences. - # Tuple of (seq, parent, is_new) - existing_finished_seqs = [(seq, None, False) - for seq in existing_finished_seqs] - new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs - if seq.is_finished()] - all_finished_seqs = existing_finished_seqs + new_finished_seqs - # Sort the finished sequences by their scores. - all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), - reverse=True) - for seq, parent, is_new in all_finished_seqs[:beam_width]: - if is_new: - # A newly generated child sequence finishes and has a high - # score, so we will add it into the sequence group. - selected_child_seqs.append((seq, parent)) - for seq, parent, is_new in all_finished_seqs[beam_width:]: - if is_new: - # A newly generated child sequence finishes but has a low - # score, so we will not add it into the sequence group. - # Additionally, if this sequence is a continuation of a - # parent sequence, we will need remove the parent sequence - # from the sequence group. - unselected_child_seqs.append((seq, parent)) - else: - # An existing finished sequence has a low score, so we will - # remove it from the sequence group. - seq_group.remove(seq.seq_id) - - # select the top beam_width sequences from the running - # sequences for the next iteration to continue the beam - # search. - running_child_seqs = [(seq, parent) for seq, parent in child_seqs - if not seq.is_finished()] - # Sort the running sequences by their scores. - running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), - reverse=True) - - # Check if we can stop the beam search. - if len(running_child_seqs) == 0: - # No running sequences, stop the beam search. - stop_beam_search = True - elif len(all_finished_seqs) < beam_width: - # Not enough finished sequences, continue the beam search. - stop_beam_search = False - else: - # Check the early stopping criteria - best_running_seq = running_child_seqs[0][0] - current_worst_seq = all_finished_seqs[beam_width - 1][0] - stop_beam_search = self._check_beam_search_early_stopping( - seq_group.sampling_params.early_stopping, - seq_group.sampling_params, best_running_seq, current_worst_seq) - - if stop_beam_search: - # Stop the beam search and remove all the running sequences from - # the sequence group. - unselected_child_seqs.extend(running_child_seqs) - else: - # Continue the beam search and select the top beam_width sequences - # to continue the beam search. - selected_child_seqs.extend(running_child_seqs[:beam_width]) - # The remaining running sequences will not be used in the next - # iteration. Again, if these sequences are continuations of - # parent sequences, we will need to remove the parent sequences - # from the sequence group. - unselected_child_seqs.extend(running_child_seqs[beam_width:]) - - # For newly created child sequences, add them to the sequence group - # and fork them in block manager if they are not finished. - for seq, parent in selected_child_seqs: - if seq is not parent: - seq_group.add(seq) - if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) - - # Free the finished and selected parent sequences' memory in block - # manager. Keep them in the sequence group as candidate output. - for seq, parent in selected_child_seqs: - if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) - - # Remove the unselected parent sequences from the sequence group and - # free their memory in block manager. - for seq, parent in unselected_child_seqs: - if seq is parent: - # Remove the parent sequence if it is not selected for next - # iteration - seq_group.remove(seq.seq_id) - self.scheduler.free_seq(seq) + #def _process_sequence_group_outputs(self, seq_group: SequenceGroup, + # outputs: SequenceGroupOutput) -> None: + + # # Process prompt logprobs + # prompt_logprobs = outputs.prompt_logprobs + # if prompt_logprobs is not None and seq_group.sampling_params.detokenize: + # self.detokenizer.decode_prompt_logprobs_inplace( + # seq_group, prompt_logprobs) + # seq_group.prompt_logprobs = prompt_logprobs + + # # Process samples + # samples = outputs.samples + # parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + # existing_finished_seqs = seq_group.get_finished_seqs() + # parent_child_dict = { + # parent_seq.seq_id: [] + # for parent_seq in parent_seqs + # } + # for sample in samples: + # parent_child_dict[sample.parent_seq_id].append(sample) + # # List of (child, parent) + # child_seqs: List[Tuple[Sequence, Sequence]] = [] + + # # Process the child samples for each parent sequence + # for parent in parent_seqs: + # child_samples: List[SequenceOutput] = parent_child_dict[ + # parent.seq_id] + # if len(child_samples) == 0: + # # This parent sequence has no children samples. Remove + # # the parent sequence from the sequence group since it will + # # not be used in the future iterations. + # parent.status = SequenceStatus.FINISHED_ABORTED + # seq_group.remove(parent.seq_id) + # self.scheduler.free_seq(parent) + # continue + # # Fork the parent sequence if there are multiple child samples. + # for child_sample in child_samples[:-1]: + # new_child_seq_id = next(self.seq_counter) + # child = parent.fork(new_child_seq_id) + # child.append_token_id(child_sample.output_token, + # child_sample.logprobs) + # child_seqs.append((child, parent)) + # # Continue the parent sequence for the last child sample. + # # We reuse the parent sequence here to reduce redundant memory + # # copies, especially when using non-beam search sampling methods. + # last_child_sample = child_samples[-1] + # parent.append_token_id(last_child_sample.output_token, + # last_child_sample.logprobs) + # child_seqs.append((parent, parent)) + + # for seq, _ in child_seqs: + # if seq_group.sampling_params.detokenize: + # self.detokenizer.decode_sequence_inplace( + # seq, seq_group.sampling_params) + # self._check_stop(seq, seq_group.sampling_params) + + # # Non-beam search case + # if not seq_group.sampling_params.use_beam_search: + # # For newly created child sequences, add them to the sequence group + # # and fork them in block manager if they are not finished. + # for seq, parent in child_seqs: + # if seq is not parent: + # seq_group.add(seq) + # if not seq.is_finished(): + # self.scheduler.fork_seq(parent, seq) + + # # Free the finished and selected parent sequences' memory in block + # # manager. Keep them in the sequence group as candidate output. + # # NOTE: we need to fork the new sequences before freeing the + # # old sequences. + # for seq, parent in child_seqs: + # if seq is parent and seq.is_finished(): + # self.scheduler.free_seq(seq) + # return + + # # Beam search case + # # Select the child sequences to keep in the sequence group. + # selected_child_seqs = [] + # unselected_child_seqs = [] + # beam_width = seq_group.sampling_params.best_of + # length_penalty = seq_group.sampling_params.length_penalty + + # # Select the newly finished sequences with the highest scores + # # to replace existing finished sequences. + # # Tuple of (seq, parent, is_new) + # existing_finished_seqs = [(seq, None, False) + # for seq in existing_finished_seqs] + # new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs + # if seq.is_finished()] + # all_finished_seqs = existing_finished_seqs + new_finished_seqs + # # Sort the finished sequences by their scores. + # all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( + # length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), + # reverse=True) + # for seq, parent, is_new in all_finished_seqs[:beam_width]: + # if is_new: + # # A newly generated child sequence finishes and has a high + # # score, so we will add it into the sequence group. + # selected_child_seqs.append((seq, parent)) + # for seq, parent, is_new in all_finished_seqs[beam_width:]: + # if is_new: + # # A newly generated child sequence finishes but has a low + # # score, so we will not add it into the sequence group. + # # Additionally, if this sequence is a continuation of a + # # parent sequence, we will need remove the parent sequence + # # from the sequence group. + # unselected_child_seqs.append((seq, parent)) + # else: + # # An existing finished sequence has a low score, so we will + # # remove it from the sequence group. + # seq_group.remove(seq.seq_id) + + # # select the top beam_width sequences from the running + # # sequences for the next iteration to continue the beam + # # search. + # running_child_seqs = [(seq, parent) for seq, parent in child_seqs + # if not seq.is_finished()] + # # Sort the running sequences by their scores. + # running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( + # length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), + # reverse=True) + + # # Check if we can stop the beam search. + # if len(running_child_seqs) == 0: + # # No running sequences, stop the beam search. + # stop_beam_search = True + # elif len(all_finished_seqs) < beam_width: + # # Not enough finished sequences, continue the beam search. + # stop_beam_search = False + # else: + # # Check the early stopping criteria + # best_running_seq = running_child_seqs[0][0] + # current_worst_seq = all_finished_seqs[beam_width - 1][0] + # stop_beam_search = self._check_beam_search_early_stopping( + # seq_group.sampling_params.early_stopping, + # seq_group.sampling_params, best_running_seq, current_worst_seq) + + # if stop_beam_search: + # # Stop the beam search and remove all the running sequences from + # # the sequence group. + # unselected_child_seqs.extend(running_child_seqs) + # else: + # # Continue the beam search and select the top beam_width sequences + # # to continue the beam search. + # selected_child_seqs.extend(running_child_seqs[:beam_width]) + # # The remaining running sequences will not be used in the next + # # iteration. Again, if these sequences are continuations of + # # parent sequences, we will need to remove the parent sequences + # # from the sequence group. + # unselected_child_seqs.extend(running_child_seqs[beam_width:]) + + # # For newly created child sequences, add them to the sequence group + # # and fork them in block manager if they are not finished. + # for seq, parent in selected_child_seqs: + # if seq is not parent: + # seq_group.add(seq) + # if not seq.is_finished(): + # self.scheduler.fork_seq(parent, seq) + + # # Free the finished and selected parent sequences' memory in block + # # manager. Keep them in the sequence group as candidate output. + # for seq, parent in selected_child_seqs: + # if seq is parent and seq.is_finished(): + # self.scheduler.free_seq(seq) + + # # Remove the unselected parent sequences from the sequence group and + # # free their memory in block manager. + # for seq, parent in unselected_child_seqs: + # if seq is parent: + # # Remove the parent sequence if it is not selected for next + # # iteration + # seq_group.remove(seq.seq_id) + # self.scheduler.free_seq(seq) def _process_model_outputs( self, output: SamplerOutput, @@ -651,14 +660,16 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) + + self.output_processor.process_outputs(seq_group, outputs) - assert len(outputs) > 0 - # TODO can spec decode go through second path? - if len(outputs) > 1: - self._process_sequence_group_outputs_multi_step( - seq_group, outputs) - else: - self._process_sequence_group_outputs(seq_group, outputs[0]) + #assert len(outputs) > 0 + ## TODO can spec decode go through second path? + #if len(outputs) > 1: + # self._process_sequence_group_outputs_multi_step( + # seq_group, outputs) + #else: + # self._process_sequence_group_outputs(seq_group, outputs[0]) # Free the finished sequence groups. self.scheduler.free_finished_seq_groups() @@ -679,89 +690,89 @@ def _process_model_outputs( self.stat_logger.log(self._get_stats(scheduler_outputs)) return request_outputs - def _process_sequence_group_outputs_multi_step(self, seq_group, outputs): - seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - - assert seqs - #if not seqs: - # return [] - - assert len(seqs) == 1, ("Beam search not supported in speculative " - "decoding.") - seq = seqs[0] - - # Since there's only one sequence per sequence group, we can take the - # first sample. - samples = [outputs[step].samples[0] for step in range(len(outputs))] - - # -1 means the output token is not valid (eg. due to spec decode - # rejecting tokens). - valid_samples = [ - sample for sample in samples if sample.output_token != -1 - ] - - # Draft target worker pads all outputs with -1 to have same length. - output_token_ids = [sample.output_token for sample in valid_samples] - #successes = [sample.success for sample in samples] - - ## Truncate to max_tokens if necessary. - #remaining_tokens = seq_group.sampling_params.max_tokens - ( - # seq.get_output_len() + len(output_token_ids)) - #if remaining_tokens < 0: - # valid_samples = valid_samples[:remaining_tokens] - # output_token_ids = output_token_ids[:remaining_tokens] - - ## Truncate any tokens after EOS. This is required as spec decode - ## generates tokens in fixed blocks, which may go beyond the EOS token. - #if not seq_group.sampling_params.ignore_eos: - # eos_token_id = self.tokenizer.get_lora_tokenizer( - # seq.lora_request).eos_token_id - # # Avoiding .index calls as exception throwing in the happy path - # # is expensive. - # for i in range(len(output_token_ids)): - # if output_token_ids[i] == eos_token_id: - # output_token_ids = output_token_ids[:i + 1] - # valid_samples = valid_samples[:i + 1] - # break - - #output_logprobs = [sample.logprobs for sample in valid_samples] - - ## Use the last sample for the sequence as it will have - ## the speculation and num_unprocessed_tokens for all the - ## previous samples (they are cumulative when it comes - ## to those two attributes). - #speculation = valid_samples[-1].speculation - #num_unprocessed_tokens = valid_samples[-1].num_unprocessed_tokens - - for output_token_id in output_token_ids: - from vllm.sequence import Logprob - seq.append_token_id( - token_id=output_token_id, - logprobs={output_token_id: Logprob(0.0)}, - ) - - #seq.append_token_ids(output_token_ids, - # output_logprobs, - # ) - # #num_unprocessed_tokens=num_unprocessed_tokens) - ##seq.set_last_speculation(speculation) - - #if not all(successes): - # seq.set_status_to_failed() - - #if decode: - # self._decode_sequence(seq, - # seq_group.sampling_params, - # token_ids=seq.get_token_ids(), - # unseen_token_ids=output_token_ids, - # prefix_offset=seq.prefix_offset, - # read_offset=seq.read_offset) - #self._check_stop(seq, seq_group.sampling_params, seq.lora_request, - # output_token_ids) - # TODO pass output token ids - self._check_stop(seq, seq_group.sampling_params) - if seq.is_finished(): - self.scheduler.free_seq(seq) + #def _process_sequence_group_outputs_multi_step(self, seq_group, outputs): + # seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + + # assert seqs + # #if not seqs: + # # return [] + + # assert len(seqs) == 1, ("Beam search not supported in speculative " + # "decoding.") + # seq = seqs[0] + + # # Since there's only one sequence per sequence group, we can take the + # # first sample. + # samples = [outputs[step].samples[0] for step in range(len(outputs))] + + # # -1 means the output token is not valid (eg. due to spec decode + # # rejecting tokens). + # valid_samples = [ + # sample for sample in samples if sample.output_token != -1 + # ] + + # # Draft target worker pads all outputs with -1 to have same length. + # output_token_ids = [sample.output_token for sample in valid_samples] + # #successes = [sample.success for sample in samples] + + # ## Truncate to max_tokens if necessary. + # #remaining_tokens = seq_group.sampling_params.max_tokens - ( + # # seq.get_output_len() + len(output_token_ids)) + # #if remaining_tokens < 0: + # # valid_samples = valid_samples[:remaining_tokens] + # # output_token_ids = output_token_ids[:remaining_tokens] + + # ## Truncate any tokens after EOS. This is required as spec decode + # ## generates tokens in fixed blocks, which may go beyond the EOS token. + # #if not seq_group.sampling_params.ignore_eos: + # # eos_token_id = self.tokenizer.get_lora_tokenizer( + # # seq.lora_request).eos_token_id + # # # Avoiding .index calls as exception throwing in the happy path + # # # is expensive. + # # for i in range(len(output_token_ids)): + # # if output_token_ids[i] == eos_token_id: + # # output_token_ids = output_token_ids[:i + 1] + # # valid_samples = valid_samples[:i + 1] + # # break + + # #output_logprobs = [sample.logprobs for sample in valid_samples] + + # ## Use the last sample for the sequence as it will have + # ## the speculation and num_unprocessed_tokens for all the + # ## previous samples (they are cumulative when it comes + # ## to those two attributes). + # #speculation = valid_samples[-1].speculation + # #num_unprocessed_tokens = valid_samples[-1].num_unprocessed_tokens + + # for output_token_id in output_token_ids: + # from vllm.sequence import Logprob + # seq.append_token_id( + # token_id=output_token_id, + # logprobs={output_token_id: Logprob(0.0)}, + # ) + + # #seq.append_token_ids(output_token_ids, + # # output_logprobs, + # # ) + # # #num_unprocessed_tokens=num_unprocessed_tokens) + # ##seq.set_last_speculation(speculation) + + # #if not all(successes): + # # seq.set_status_to_failed() + + # #if decode: + # # self._decode_sequence(seq, + # # seq_group.sampling_params, + # # token_ids=seq.get_token_ids(), + # # unseen_token_ids=output_token_ids, + # # prefix_offset=seq.prefix_offset, + # # read_offset=seq.read_offset) + # #self._check_stop(seq, seq_group.sampling_params, seq.lora_request, + # # output_token_ids) + # # TODO pass output token ids + # self._check_stop(seq, seq_group.sampling_params) + # if seq.is_finished(): + # self.scheduler.free_seq(seq) def step(self) -> List[RequestOutput]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/engine/output_processor/__init__.py b/vllm/engine/output_processor/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/engine/output_processor/beam_search.py b/vllm/engine/output_processor/beam_search.py new file mode 100644 index 000000000000..5f823b5c5c72 --- /dev/null +++ b/vllm/engine/output_processor/beam_search.py @@ -0,0 +1,321 @@ +import time +from typing import Iterable, List, Optional, Tuple, Type, Union + +from transformers import PreTrainedTokenizer + +import vllm +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) +from vllm.core.scheduler import Scheduler, SchedulerOutputs +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.metrics import StatLogger, Stats +from vllm.engine.ray_utils import initialize_ray_cluster +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, + SequenceGroup, SequenceGroupOutput, SequenceOutput, + SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, + get_tokenizer_group) +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) +from vllm.utils import Counter +from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor + +logger = init_logger(__name__) + + +class BeamSearchOutputProcessor(SequenceGroupOutputProcessor): + + def __init__( + self, + scheduler_config: SchedulerConfig, + detokenizer, + scheduler, + seq_counter, + get_tokenizer_for_seq, + ): + self.scheduler_config = scheduler_config + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.get_tokenizer_for_seq = get_tokenizer_for_seq + + def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: + assert (len(outputs) == 1), f"{type(self)} does not support multiple outputs per step" + return self._process_sequence_group_outputs(sequence_group, outputs[0]) + + def _check_beam_search_early_stopping( + self, + early_stopping: Union[bool, str], + sampling_params: SamplingParams, + best_running_seq: Sequence, + current_worst_seq: Sequence, + ) -> bool: + assert sampling_params.use_beam_search + length_penalty = sampling_params.length_penalty + if early_stopping is True: + return True + + current_worst_score = current_worst_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=current_worst_seq.eos_token_id) + if early_stopping is False: + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id) + else: + assert early_stopping == "never" + if length_penalty > 0.0: + # If length_penalty > 0.0, beam search will prefer longer + # sequences. The highest attainable score calculation is + # based on the longest possible sequence length in this case. + max_possible_length = max( + best_running_seq.get_prompt_len() + + sampling_params.max_tokens, + self.scheduler_config.max_model_len) + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id, + seq_len=max_possible_length)) + else: + # Otherwise, beam search will prefer shorter sequences. The + # highest attainable score calculation is based on the current + # sequence length. + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id)) + return current_worst_score >= highest_attainable_score + + def _process_sequence_group_outputs(self, seq_group: SequenceGroup, + outputs: SequenceGroupOutput) -> None: + + # Process prompt logprobs + prompt_logprobs = outputs.prompt_logprobs + if prompt_logprobs is not None and seq_group.sampling_params.detokenize: + self.detokenizer.decode_prompt_logprobs_inplace( + seq_group, prompt_logprobs) + seq_group.prompt_logprobs = prompt_logprobs + + # Process samples + samples = outputs.samples + parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + existing_finished_seqs = seq_group.get_finished_seqs() + parent_child_dict = { + parent_seq.seq_id: [] + for parent_seq in parent_seqs + } + for sample in samples: + parent_child_dict[sample.parent_seq_id].append(sample) + # List of (child, parent) + child_seqs: List[Tuple[Sequence, Sequence]] = [] + + # Process the child samples for each parent sequence + for parent in parent_seqs: + child_samples: List[SequenceOutput] = parent_child_dict[ + parent.seq_id] + if len(child_samples) == 0: + # This parent sequence has no children samples. Remove + # the parent sequence from the sequence group since it will + # not be used in the future iterations. + parent.status = SequenceStatus.FINISHED_ABORTED + seq_group.remove(parent.seq_id) + self.scheduler.free_seq(parent) + continue + # Fork the parent sequence if there are multiple child samples. + for child_sample in child_samples[:-1]: + new_child_seq_id = next(self.seq_counter) + child = parent.fork(new_child_seq_id) + child.append_token_id(child_sample.output_token, + child_sample.logprobs) + child_seqs.append((child, parent)) + # Continue the parent sequence for the last child sample. + # We reuse the parent sequence here to reduce redundant memory + # copies, especially when using non-beam search sampling methods. + last_child_sample = child_samples[-1] + parent.append_token_id(last_child_sample.output_token, + last_child_sample.logprobs) + child_seqs.append((parent, parent)) + + for seq, _ in child_seqs: + if seq_group.sampling_params.detokenize: + self.detokenizer.decode_sequence_inplace( + seq, seq_group.sampling_params) + self._check_stop(seq, seq_group.sampling_params) + + # Non-beam search case + if not seq_group.sampling_params.use_beam_search: + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + # NOTE: we need to fork the new sequences before freeing the + # old sequences. + for seq, parent in child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + return + + # Beam search case + # Select the child sequences to keep in the sequence group. + selected_child_seqs = [] + unselected_child_seqs = [] + beam_width = seq_group.sampling_params.best_of + length_penalty = seq_group.sampling_params.length_penalty + + # Select the newly finished sequences with the highest scores + # to replace existing finished sequences. + # Tuple of (seq, parent, is_new) + existing_finished_seqs = [(seq, None, False) + for seq in existing_finished_seqs] + new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs + if seq.is_finished()] + all_finished_seqs = existing_finished_seqs + new_finished_seqs + # Sort the finished sequences by their scores. + all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), + reverse=True) + for seq, parent, is_new in all_finished_seqs[:beam_width]: + if is_new: + # A newly generated child sequence finishes and has a high + # score, so we will add it into the sequence group. + selected_child_seqs.append((seq, parent)) + for seq, parent, is_new in all_finished_seqs[beam_width:]: + if is_new: + # A newly generated child sequence finishes but has a low + # score, so we will not add it into the sequence group. + # Additionally, if this sequence is a continuation of a + # parent sequence, we will need remove the parent sequence + # from the sequence group. + unselected_child_seqs.append((seq, parent)) + else: + # An existing finished sequence has a low score, so we will + # remove it from the sequence group. + seq_group.remove(seq.seq_id) + + # select the top beam_width sequences from the running + # sequences for the next iteration to continue the beam + # search. + running_child_seqs = [(seq, parent) for seq, parent in child_seqs + if not seq.is_finished()] + # Sort the running sequences by their scores. + running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), + reverse=True) + + # Check if we can stop the beam search. + if len(running_child_seqs) == 0: + # No running sequences, stop the beam search. + stop_beam_search = True + elif len(all_finished_seqs) < beam_width: + # Not enough finished sequences, continue the beam search. + stop_beam_search = False + else: + # Check the early stopping criteria + best_running_seq = running_child_seqs[0][0] + current_worst_seq = all_finished_seqs[beam_width - 1][0] + stop_beam_search = self._check_beam_search_early_stopping( + seq_group.sampling_params.early_stopping, + seq_group.sampling_params, best_running_seq, current_worst_seq) + + if stop_beam_search: + # Stop the beam search and remove all the running sequences from + # the sequence group. + unselected_child_seqs.extend(running_child_seqs) + else: + # Continue the beam search and select the top beam_width sequences + # to continue the beam search. + selected_child_seqs.extend(running_child_seqs[:beam_width]) + # The remaining running sequences will not be used in the next + # iteration. Again, if these sequences are continuations of + # parent sequences, we will need to remove the parent sequences + # from the sequence group. + unselected_child_seqs.extend(running_child_seqs[beam_width:]) + + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in selected_child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + for seq, parent in selected_child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + + # Remove the unselected parent sequences from the sequence group and + # free their memory in block manager. + for seq, parent in unselected_child_seqs: + if seq is parent: + # Remove the parent sequence if it is not selected for next + # iteration + seq_group.remove(seq.seq_id) + self.scheduler.free_seq(seq) + + def _check_stop(self, seq: Sequence, + sampling_params: SamplingParams) -> None: + """Stop the finished sequences.""" + # Check if the sequence has reached max_model_len. + if seq.get_len() > self.scheduler_config.max_model_len: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() >= int(sampling_params.max_tokens): + # TODO should cap block + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the minimum number of tokens has been generated yet; + # skip the stop string/token checks if not + if seq.get_output_len() < sampling_params.min_tokens: + return + + if sampling_params.detokenize: + for stop_str in sampling_params.stop: + if seq.output_text.endswith(stop_str): + self._finalize_sequence(seq, sampling_params, stop_str) + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str + return + last_token_id = seq.get_last_token_id() + if last_token_id in sampling_params.stop_token_ids: + stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( + last_token_id) + self._finalize_sequence(seq, sampling_params, stop_str) + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = last_token_id + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): + seq.status = SequenceStatus.FINISHED_STOPPED + return + + def _finalize_sequence(self, seq: Sequence, + sampling_params: SamplingParams, + stop_string: str) -> None: + if sampling_params.include_stop_str_in_output: + return + + if stop_string and seq.output_text.endswith(stop_string): + # Truncate the output text so that the stop string is + # not included in the output. + seq.output_text = seq.output_text[:-len(stop_string)] diff --git a/vllm/engine/output_processor/block_decode.py b/vllm/engine/output_processor/block_decode.py new file mode 100644 index 000000000000..f11520d3a7e9 --- /dev/null +++ b/vllm/engine/output_processor/block_decode.py @@ -0,0 +1,186 @@ +import time +from typing import Iterable, List, Optional, Tuple, Type, Union + +from transformers import PreTrainedTokenizer + +import vllm +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) +from vllm.core.scheduler import Scheduler, SchedulerOutputs +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.metrics import StatLogger, Stats +from vllm.engine.ray_utils import initialize_ray_cluster +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, + SequenceGroup, SequenceGroupOutput, SequenceOutput, + SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, + get_tokenizer_group) +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) +from vllm.utils import Counter +from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor + +logger = init_logger(__name__) + + +class BlockDecodeOutputProcessor(SequenceGroupOutputProcessor): + + def __init__( + self, + scheduler_config: SchedulerConfig, + detokenizer, + scheduler, + seq_counter, + get_tokenizer_for_seq, + ): + self.scheduler_config = scheduler_config + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.get_tokenizer_for_seq = get_tokenizer_for_seq + + def process_outputs(self, sequence_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: + return self._process_sequence_group_outputs_multi_step(sequence_group, outputs) + + def _process_sequence_group_outputs_multi_step(self, seq_group, outputs): + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + + assert seqs + #if not seqs: + # return [] + + assert len(seqs) == 1, ("Beam search not supported in speculative " + "decoding.") + seq = seqs[0] + + # Since there's only one sequence per sequence group, we can take the + # first sample. + samples = [outputs[step].samples[0] for step in range(len(outputs))] + + # -1 means the output token is not valid (eg. due to spec decode + # rejecting tokens). + valid_samples = [ + sample for sample in samples if sample.output_token != -1 + ] + + # Draft target worker pads all outputs with -1 to have same length. + output_token_ids = [sample.output_token for sample in valid_samples] + #successes = [sample.success for sample in samples] + + ## Truncate to max_tokens if necessary. + #remaining_tokens = seq_group.sampling_params.max_tokens - ( + # seq.get_output_len() + len(output_token_ids)) + #if remaining_tokens < 0: + # valid_samples = valid_samples[:remaining_tokens] + # output_token_ids = output_token_ids[:remaining_tokens] + + ## Truncate any tokens after EOS. This is required as spec decode + ## generates tokens in fixed blocks, which may go beyond the EOS token. + #if not seq_group.sampling_params.ignore_eos: + # eos_token_id = self.tokenizer.get_lora_tokenizer( + # seq.lora_request).eos_token_id + # # Avoiding .index calls as exception throwing in the happy path + # # is expensive. + # for i in range(len(output_token_ids)): + # if output_token_ids[i] == eos_token_id: + # output_token_ids = output_token_ids[:i + 1] + # valid_samples = valid_samples[:i + 1] + # break + + #output_logprobs = [sample.logprobs for sample in valid_samples] + + ## Use the last sample for the sequence as it will have + ## the speculation and num_unprocessed_tokens for all the + ## previous samples (they are cumulative when it comes + ## to those two attributes). + #speculation = valid_samples[-1].speculation + #num_unprocessed_tokens = valid_samples[-1].num_unprocessed_tokens + + for output_token_id in output_token_ids: + from vllm.sequence import Logprob + seq.append_token_id( + token_id=output_token_id, + logprobs={output_token_id: Logprob(0.0)}, + ) + + #seq.append_token_ids(output_token_ids, + # output_logprobs, + # ) + # #num_unprocessed_tokens=num_unprocessed_tokens) + ##seq.set_last_speculation(speculation) + + #if not all(successes): + # seq.set_status_to_failed() + + #if decode: + # self._decode_sequence(seq, + # seq_group.sampling_params, + # token_ids=seq.get_token_ids(), + # unseen_token_ids=output_token_ids, + # prefix_offset=seq.prefix_offset, + # read_offset=seq.read_offset) + #self._check_stop(seq, seq_group.sampling_params, seq.lora_request, + # output_token_ids) + # TODO pass output token ids + self._check_stop(seq, seq_group.sampling_params) + if seq.is_finished(): + self.scheduler.free_seq(seq) + + def _check_stop(self, seq: Sequence, + sampling_params: SamplingParams) -> None: + """Stop the finished sequences.""" + # Check if the sequence has reached max_model_len. + if seq.get_len() > self.scheduler_config.max_model_len: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() >= int(sampling_params.max_tokens): + # TODO should cap block + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the minimum number of tokens has been generated yet; + # skip the stop string/token checks if not + if seq.get_output_len() < sampling_params.min_tokens: + return + + if sampling_params.detokenize: + for stop_str in sampling_params.stop: + if seq.output_text.endswith(stop_str): + self._finalize_sequence(seq, sampling_params, stop_str) + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str + return + last_token_id = seq.get_last_token_id() + if last_token_id in sampling_params.stop_token_ids: + stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( + last_token_id) + self._finalize_sequence(seq, sampling_params, stop_str) + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = last_token_id + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): + seq.status = SequenceStatus.FINISHED_STOPPED + return + + def _finalize_sequence(self, seq: Sequence, + sampling_params: SamplingParams, + stop_string: str) -> None: + if sampling_params.include_stop_str_in_output: + return + + if stop_string and seq.output_text.endswith(stop_string): + # Truncate the output text so that the stop string is + # not included in the output. + seq.output_text = seq.output_text[:-len(stop_string)] diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py new file mode 100644 index 000000000000..4d1da960dc41 --- /dev/null +++ b/vllm/engine/output_processor/interfaces.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod +from vllm.config import SchedulerConfig +from vllm.sequence import SequenceGroup, SequenceGroupOutput + +class SequenceGroupOutputProcessor(ABC): + + @staticmethod + def create_output_processor( + scheduler_config: SchedulerConfig, + detokenizer, + scheduler, + seq_counter, + get_tokenizer_for_seq, + ): + if scheduler_config.num_lookahead_slots == 0: + from vllm.engine.output_processor.beam_search import BeamSearchOutputProcessor + return BeamSearchOutputProcessor( + scheduler_config, + detokenizer, + scheduler, + seq_counter, + get_tokenizer_for_seq, + ) + else: + from vllm.engine.output_processor.block_decode import BlockDecodeOutputProcessor + return BlockDecodeOutputProcessor( + scheduler_config, + detokenizer, + scheduler, + seq_counter, + get_tokenizer_for_seq, + ) + + @abstractmethod + def process_outputs(self, sequence_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: + pass From 632b439541021309fbc0f83b78210532e1a94606 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 16:46:14 -0700 Subject: [PATCH 019/120] fix --- vllm/engine/llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1ac73bc874de..60b0f46b2318 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -926,7 +926,7 @@ def _check_stop(self, seq: Sequence, return # Check if the sequence has reached max_tokens. - if seq.get_output_len() >= int(sampling_params.max_tokens): + if (sampling_params.max_tokens is not None) and (seq.get_output_len() >= sampling_params.max_tokens): # TODO should cap block seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return From 26e7368e95f824fdce6cac30f476529d270ac6ed Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 16:53:01 -0700 Subject: [PATCH 020/120] dedup stop check --- vllm/engine/llm_engine.py | 104 +++++++------- vllm/engine/output_processor/beam_search.py | 140 ++++++------------- vllm/engine/output_processor/block_decode.py | 57 +------- vllm/engine/output_processor/interfaces.py | 3 + vllm/engine/output_processor/stop_checker.py | 89 ++++++++++++ 5 files changed, 194 insertions(+), 199 deletions(-) create mode 100644 vllm/engine/output_processor/stop_checker.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 60b0f46b2318..570b5eff581d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -26,6 +26,7 @@ usage_message) from vllm.utils import Counter from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor +from vllm.engine.output_processor.stop_checker import StopChecker logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -187,6 +188,7 @@ def __init__( self.scheduler, self.seq_counter, self.get_tokenizer_for_seq, + stop_checker=StopChecker(scheduler, self.get_tokenizer_for_seq), ) def _initialize_kv_caches(self) -> None: @@ -917,57 +919,57 @@ def _get_stats(self, time_e2e_requests=time_e2e_requests, ) - def _check_stop(self, seq: Sequence, - sampling_params: SamplingParams) -> None: - """Stop the finished sequences.""" - # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the sequence has reached max_tokens. - if (sampling_params.max_tokens is not None) and (seq.get_output_len() >= sampling_params.max_tokens): - # TODO should cap block - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the minimum number of tokens has been generated yet; - # skip the stop string/token checks if not - if seq.get_output_len() < sampling_params.min_tokens: - return - - if sampling_params.detokenize: - for stop_str in sampling_params.stop: - if seq.output_text.endswith(stop_str): - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return - last_token_id = seq.get_last_token_id() - if last_token_id in sampling_params.stop_token_ids: - stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( - last_token_id) - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = last_token_id - return - - # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): - seq.status = SequenceStatus.FINISHED_STOPPED - return - - def _finalize_sequence(self, seq: Sequence, - sampling_params: SamplingParams, - stop_string: str) -> None: - if sampling_params.include_stop_str_in_output: - return - - if stop_string and seq.output_text.endswith(stop_string): - # Truncate the output text so that the stop string is - # not included in the output. - seq.output_text = seq.output_text[:-len(stop_string)] + #def _check_stop(self, seq: Sequence, + # sampling_params: SamplingParams) -> None: + # """Stop the finished sequences.""" + # # Check if the sequence has reached max_model_len. + # if seq.get_len() > self.scheduler_config.max_model_len: + # seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + # return + + # # Check if the sequence has reached max_tokens. + # if (sampling_params.max_tokens is not None) and (seq.get_output_len() >= sampling_params.max_tokens): + # # TODO should cap block + # seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + # return + + # # Check if the minimum number of tokens has been generated yet; + # # skip the stop string/token checks if not + # if seq.get_output_len() < sampling_params.min_tokens: + # return + + # if sampling_params.detokenize: + # for stop_str in sampling_params.stop: + # if seq.output_text.endswith(stop_str): + # self._finalize_sequence(seq, sampling_params, stop_str) + # seq.status = SequenceStatus.FINISHED_STOPPED + # seq.stop_reason = stop_str + # return + # last_token_id = seq.get_last_token_id() + # if last_token_id in sampling_params.stop_token_ids: + # stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( + # last_token_id) + # self._finalize_sequence(seq, sampling_params, stop_str) + # seq.status = SequenceStatus.FINISHED_STOPPED + # seq.stop_reason = last_token_id + # return + + # # Check if the sequence has generated the EOS token. + # if ((not sampling_params.ignore_eos) + # and seq.get_last_token_id() == seq.eos_token_id): + # seq.status = SequenceStatus.FINISHED_STOPPED + # return + + #def _finalize_sequence(self, seq: Sequence, + # sampling_params: SamplingParams, + # stop_string: str) -> None: + # if sampling_params.include_stop_str_in_output: + # return + + # if stop_string and seq.output_text.endswith(stop_string): + # # Truncate the output text so that the stop string is + # # not included in the output. + # seq.output_text = seq.output_text[:-len(stop_string)] def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) diff --git a/vllm/engine/output_processor/beam_search.py b/vllm/engine/output_processor/beam_search.py index 5f823b5c5c72..c9ded1171151 100644 --- a/vllm/engine/output_processor/beam_search.py +++ b/vllm/engine/output_processor/beam_search.py @@ -39,61 +39,19 @@ def __init__( scheduler, seq_counter, get_tokenizer_for_seq, + stop_checker, ): self.scheduler_config = scheduler_config self.detokenizer = detokenizer self.scheduler = scheduler self.seq_counter = seq_counter self.get_tokenizer_for_seq = get_tokenizer_for_seq + self.stop_checker = stop_checker def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: assert (len(outputs) == 1), f"{type(self)} does not support multiple outputs per step" return self._process_sequence_group_outputs(sequence_group, outputs[0]) - def _check_beam_search_early_stopping( - self, - early_stopping: Union[bool, str], - sampling_params: SamplingParams, - best_running_seq: Sequence, - current_worst_seq: Sequence, - ) -> bool: - assert sampling_params.use_beam_search - length_penalty = sampling_params.length_penalty - if early_stopping is True: - return True - - current_worst_score = current_worst_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=current_worst_seq.eos_token_id) - if early_stopping is False: - highest_attainable_score = best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id) - else: - assert early_stopping == "never" - if length_penalty > 0.0: - # If length_penalty > 0.0, beam search will prefer longer - # sequences. The highest attainable score calculation is - # based on the longest possible sequence length in this case. - max_possible_length = max( - best_running_seq.get_prompt_len() + - sampling_params.max_tokens, - self.scheduler_config.max_model_len) - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id, - seq_len=max_possible_length)) - else: - # Otherwise, beam search will prefer shorter sequences. The - # highest attainable score calculation is based on the current - # sequence length. - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id)) - return current_worst_score >= highest_attainable_score - def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: @@ -148,7 +106,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if seq_group.sampling_params.detokenize: self.detokenizer.decode_sequence_inplace( seq, seq_group.sampling_params) - self._check_stop(seq, seq_group.sampling_params) + self.stop_checker.check_stop(seq, seq_group.sampling_params) # Non-beam search case if not seq_group.sampling_params.use_beam_search: @@ -268,54 +226,46 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, seq_group.remove(seq.seq_id) self.scheduler.free_seq(seq) - def _check_stop(self, seq: Sequence, - sampling_params: SamplingParams) -> None: - """Stop the finished sequences.""" - # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the sequence has reached max_tokens. - if seq.get_output_len() >= int(sampling_params.max_tokens): - # TODO should cap block - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the minimum number of tokens has been generated yet; - # skip the stop string/token checks if not - if seq.get_output_len() < sampling_params.min_tokens: - return - - if sampling_params.detokenize: - for stop_str in sampling_params.stop: - if seq.output_text.endswith(stop_str): - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return - last_token_id = seq.get_last_token_id() - if last_token_id in sampling_params.stop_token_ids: - stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( - last_token_id) - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = last_token_id - return - - # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): - seq.status = SequenceStatus.FINISHED_STOPPED - return - - def _finalize_sequence(self, seq: Sequence, - sampling_params: SamplingParams, - stop_string: str) -> None: - if sampling_params.include_stop_str_in_output: - return + def _check_beam_search_early_stopping( + self, + early_stopping: Union[bool, str], + sampling_params: SamplingParams, + best_running_seq: Sequence, + current_worst_seq: Sequence, + ) -> bool: + assert sampling_params.use_beam_search + length_penalty = sampling_params.length_penalty + if early_stopping is True: + return True - if stop_string and seq.output_text.endswith(stop_string): - # Truncate the output text so that the stop string is - # not included in the output. - seq.output_text = seq.output_text[:-len(stop_string)] + current_worst_score = current_worst_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=current_worst_seq.eos_token_id) + if early_stopping is False: + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id) + else: + assert early_stopping == "never" + if length_penalty > 0.0: + # If length_penalty > 0.0, beam search will prefer longer + # sequences. The highest attainable score calculation is + # based on the longest possible sequence length in this case. + max_possible_length = max( + best_running_seq.get_prompt_len() + + sampling_params.max_tokens, + self.scheduler_config.max_model_len) + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id, + seq_len=max_possible_length)) + else: + # Otherwise, beam search will prefer shorter sequences. The + # highest attainable score calculation is based on the current + # sequence length. + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id)) + return current_worst_score >= highest_attainable_score diff --git a/vllm/engine/output_processor/block_decode.py b/vllm/engine/output_processor/block_decode.py index f11520d3a7e9..90ad03df32dd 100644 --- a/vllm/engine/output_processor/block_decode.py +++ b/vllm/engine/output_processor/block_decode.py @@ -39,12 +39,14 @@ def __init__( scheduler, seq_counter, get_tokenizer_for_seq, + stop_checker, ): self.scheduler_config = scheduler_config self.detokenizer = detokenizer self.scheduler = scheduler self.seq_counter = seq_counter self.get_tokenizer_for_seq = get_tokenizer_for_seq + self.stop_checker = stop_checker def process_outputs(self, sequence_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: return self._process_sequence_group_outputs_multi_step(sequence_group, outputs) @@ -129,58 +131,7 @@ def _process_sequence_group_outputs_multi_step(self, seq_group, outputs): #self._check_stop(seq, seq_group.sampling_params, seq.lora_request, # output_token_ids) # TODO pass output token ids - self._check_stop(seq, seq_group.sampling_params) + self.stop_checker.check_stop(seq, seq_group.sampling_params) + if seq.is_finished(): self.scheduler.free_seq(seq) - - def _check_stop(self, seq: Sequence, - sampling_params: SamplingParams) -> None: - """Stop the finished sequences.""" - # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the sequence has reached max_tokens. - if seq.get_output_len() >= int(sampling_params.max_tokens): - # TODO should cap block - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the minimum number of tokens has been generated yet; - # skip the stop string/token checks if not - if seq.get_output_len() < sampling_params.min_tokens: - return - - if sampling_params.detokenize: - for stop_str in sampling_params.stop: - if seq.output_text.endswith(stop_str): - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return - last_token_id = seq.get_last_token_id() - if last_token_id in sampling_params.stop_token_ids: - stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( - last_token_id) - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = last_token_id - return - - # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): - seq.status = SequenceStatus.FINISHED_STOPPED - return - - def _finalize_sequence(self, seq: Sequence, - sampling_params: SamplingParams, - stop_string: str) -> None: - if sampling_params.include_stop_str_in_output: - return - - if stop_string and seq.output_text.endswith(stop_string): - # Truncate the output text so that the stop string is - # not included in the output. - seq.output_text = seq.output_text[:-len(stop_string)] diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 4d1da960dc41..d2368fc811a0 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -11,6 +11,7 @@ def create_output_processor( scheduler, seq_counter, get_tokenizer_for_seq, + stop_checker, ): if scheduler_config.num_lookahead_slots == 0: from vllm.engine.output_processor.beam_search import BeamSearchOutputProcessor @@ -20,6 +21,7 @@ def create_output_processor( scheduler, seq_counter, get_tokenizer_for_seq, + stop_checker, ) else: from vllm.engine.output_processor.block_decode import BlockDecodeOutputProcessor @@ -29,6 +31,7 @@ def create_output_processor( scheduler, seq_counter, get_tokenizer_for_seq, + stop_checker, ) @abstractmethod diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py new file mode 100644 index 000000000000..feeef1c0f24a --- /dev/null +++ b/vllm/engine/output_processor/stop_checker.py @@ -0,0 +1,89 @@ +import time +from typing import Iterable, List, Optional, Tuple, Type, Union + +from transformers import PreTrainedTokenizer + +import vllm +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) +from vllm.core.scheduler import Scheduler, SchedulerOutputs +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.metrics import StatLogger, Stats +from vllm.engine.ray_utils import initialize_ray_cluster +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, + SequenceGroup, SequenceGroupOutput, SequenceOutput, + SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, + get_tokenizer_group) +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) +from vllm.utils import Counter +from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor + +logger = init_logger(__name__) +_LOCAL_LOGGING_INTERVAL_SEC = 5 + +class StopChecker: + + def __init__(self, scheduler, get_tokenizer_for_seq): + self.scheduler = scheduler + self.get_tokenizer_for_seq = get_tokenizer_for_seq + + def check_stop(self, seq: Sequence, + sampling_params: SamplingParams) -> None: + """Stop the finished sequences.""" + # Check if the sequence has reached max_model_len. + if seq.get_len() > self.scheduler_config.max_model_len: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if (sampling_params.max_tokens is not None) and (seq.get_output_len() >= sampling_params.max_tokens): + # TODO should cap block + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the minimum number of tokens has been generated yet; + # skip the stop string/token checks if not + if seq.get_output_len() < sampling_params.min_tokens: + return + + if sampling_params.detokenize: + for stop_str in sampling_params.stop: + if seq.output_text.endswith(stop_str): + self._finalize_sequence(seq, sampling_params, stop_str) + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str + return + last_token_id = seq.get_last_token_id() + if last_token_id in sampling_params.stop_token_ids: + stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( + last_token_id) + self._finalize_sequence(seq, sampling_params, stop_str) + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = last_token_id + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): + seq.status = SequenceStatus.FINISHED_STOPPED + return + + def _finalize_sequence(self, seq: Sequence, + sampling_params: SamplingParams, + stop_string: str) -> None: + if sampling_params.include_stop_str_in_output: + return + + if stop_string and seq.output_text.endswith(stop_string): + # Truncate the output text so that the stop string is + # not included in the output. + seq.output_text = seq.output_text[:-len(stop_string)] From 06e7c01d3867439289e8f5958cf1bc00be0c305a Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 16:55:20 -0700 Subject: [PATCH 021/120] wip --- vllm/engine/llm_engine.py | 6 +++++- vllm/engine/output_processor/stop_checker.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 570b5eff581d..036709a414c2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -188,7 +188,11 @@ def __init__( self.scheduler, self.seq_counter, self.get_tokenizer_for_seq, - stop_checker=StopChecker(scheduler, self.get_tokenizer_for_seq), + stop_checker=StopChecker( + self.scheduler, + self.scheduler_config, + self.get_tokenizer_for_seq, + ), ) def _initialize_kv_caches(self) -> None: diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index feeef1c0f24a..cc6655b7aaa7 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -32,8 +32,9 @@ class StopChecker: - def __init__(self, scheduler, get_tokenizer_for_seq): + def __init__(self, scheduler, scheduler_config, get_tokenizer_for_seq): self.scheduler = scheduler + self.scheduler_config = scheduler_config self.get_tokenizer_for_seq = get_tokenizer_for_seq def check_stop(self, seq: Sequence, From 184a52c166ec6eeb75dfedbb544c65188322ece7 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 16:56:18 -0700 Subject: [PATCH 022/120] del --- vllm/engine/llm_engine.py | 352 -------------------------------------- 1 file changed, 352 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 036709a414c2..2be4a260f164 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -420,224 +420,6 @@ def has_unfinished_requests(self) -> bool: """Returns True if there are unfinished requests.""" return self.scheduler.has_unfinished_seqs() - def _check_beam_search_early_stopping( - self, - early_stopping: Union[bool, str], - sampling_params: SamplingParams, - best_running_seq: Sequence, - current_worst_seq: Sequence, - ) -> bool: - assert sampling_params.use_beam_search - length_penalty = sampling_params.length_penalty - if early_stopping is True: - return True - - current_worst_score = current_worst_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=current_worst_seq.eos_token_id) - if early_stopping is False: - highest_attainable_score = best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id) - else: - assert early_stopping == "never" - if length_penalty > 0.0: - # If length_penalty > 0.0, beam search will prefer longer - # sequences. The highest attainable score calculation is - # based on the longest possible sequence length in this case. - max_possible_length = max( - best_running_seq.get_prompt_len() + - sampling_params.max_tokens, - self.scheduler_config.max_model_len) - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id, - seq_len=max_possible_length)) - else: - # Otherwise, beam search will prefer shorter sequences. The - # highest attainable score calculation is based on the current - # sequence length. - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id)) - return current_worst_score >= highest_attainable_score - - #def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - # outputs: SequenceGroupOutput) -> None: - - # # Process prompt logprobs - # prompt_logprobs = outputs.prompt_logprobs - # if prompt_logprobs is not None and seq_group.sampling_params.detokenize: - # self.detokenizer.decode_prompt_logprobs_inplace( - # seq_group, prompt_logprobs) - # seq_group.prompt_logprobs = prompt_logprobs - - # # Process samples - # samples = outputs.samples - # parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - # existing_finished_seqs = seq_group.get_finished_seqs() - # parent_child_dict = { - # parent_seq.seq_id: [] - # for parent_seq in parent_seqs - # } - # for sample in samples: - # parent_child_dict[sample.parent_seq_id].append(sample) - # # List of (child, parent) - # child_seqs: List[Tuple[Sequence, Sequence]] = [] - - # # Process the child samples for each parent sequence - # for parent in parent_seqs: - # child_samples: List[SequenceOutput] = parent_child_dict[ - # parent.seq_id] - # if len(child_samples) == 0: - # # This parent sequence has no children samples. Remove - # # the parent sequence from the sequence group since it will - # # not be used in the future iterations. - # parent.status = SequenceStatus.FINISHED_ABORTED - # seq_group.remove(parent.seq_id) - # self.scheduler.free_seq(parent) - # continue - # # Fork the parent sequence if there are multiple child samples. - # for child_sample in child_samples[:-1]: - # new_child_seq_id = next(self.seq_counter) - # child = parent.fork(new_child_seq_id) - # child.append_token_id(child_sample.output_token, - # child_sample.logprobs) - # child_seqs.append((child, parent)) - # # Continue the parent sequence for the last child sample. - # # We reuse the parent sequence here to reduce redundant memory - # # copies, especially when using non-beam search sampling methods. - # last_child_sample = child_samples[-1] - # parent.append_token_id(last_child_sample.output_token, - # last_child_sample.logprobs) - # child_seqs.append((parent, parent)) - - # for seq, _ in child_seqs: - # if seq_group.sampling_params.detokenize: - # self.detokenizer.decode_sequence_inplace( - # seq, seq_group.sampling_params) - # self._check_stop(seq, seq_group.sampling_params) - - # # Non-beam search case - # if not seq_group.sampling_params.use_beam_search: - # # For newly created child sequences, add them to the sequence group - # # and fork them in block manager if they are not finished. - # for seq, parent in child_seqs: - # if seq is not parent: - # seq_group.add(seq) - # if not seq.is_finished(): - # self.scheduler.fork_seq(parent, seq) - - # # Free the finished and selected parent sequences' memory in block - # # manager. Keep them in the sequence group as candidate output. - # # NOTE: we need to fork the new sequences before freeing the - # # old sequences. - # for seq, parent in child_seqs: - # if seq is parent and seq.is_finished(): - # self.scheduler.free_seq(seq) - # return - - # # Beam search case - # # Select the child sequences to keep in the sequence group. - # selected_child_seqs = [] - # unselected_child_seqs = [] - # beam_width = seq_group.sampling_params.best_of - # length_penalty = seq_group.sampling_params.length_penalty - - # # Select the newly finished sequences with the highest scores - # # to replace existing finished sequences. - # # Tuple of (seq, parent, is_new) - # existing_finished_seqs = [(seq, None, False) - # for seq in existing_finished_seqs] - # new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs - # if seq.is_finished()] - # all_finished_seqs = existing_finished_seqs + new_finished_seqs - # # Sort the finished sequences by their scores. - # all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( - # length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), - # reverse=True) - # for seq, parent, is_new in all_finished_seqs[:beam_width]: - # if is_new: - # # A newly generated child sequence finishes and has a high - # # score, so we will add it into the sequence group. - # selected_child_seqs.append((seq, parent)) - # for seq, parent, is_new in all_finished_seqs[beam_width:]: - # if is_new: - # # A newly generated child sequence finishes but has a low - # # score, so we will not add it into the sequence group. - # # Additionally, if this sequence is a continuation of a - # # parent sequence, we will need remove the parent sequence - # # from the sequence group. - # unselected_child_seqs.append((seq, parent)) - # else: - # # An existing finished sequence has a low score, so we will - # # remove it from the sequence group. - # seq_group.remove(seq.seq_id) - - # # select the top beam_width sequences from the running - # # sequences for the next iteration to continue the beam - # # search. - # running_child_seqs = [(seq, parent) for seq, parent in child_seqs - # if not seq.is_finished()] - # # Sort the running sequences by their scores. - # running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( - # length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), - # reverse=True) - - # # Check if we can stop the beam search. - # if len(running_child_seqs) == 0: - # # No running sequences, stop the beam search. - # stop_beam_search = True - # elif len(all_finished_seqs) < beam_width: - # # Not enough finished sequences, continue the beam search. - # stop_beam_search = False - # else: - # # Check the early stopping criteria - # best_running_seq = running_child_seqs[0][0] - # current_worst_seq = all_finished_seqs[beam_width - 1][0] - # stop_beam_search = self._check_beam_search_early_stopping( - # seq_group.sampling_params.early_stopping, - # seq_group.sampling_params, best_running_seq, current_worst_seq) - - # if stop_beam_search: - # # Stop the beam search and remove all the running sequences from - # # the sequence group. - # unselected_child_seqs.extend(running_child_seqs) - # else: - # # Continue the beam search and select the top beam_width sequences - # # to continue the beam search. - # selected_child_seqs.extend(running_child_seqs[:beam_width]) - # # The remaining running sequences will not be used in the next - # # iteration. Again, if these sequences are continuations of - # # parent sequences, we will need to remove the parent sequences - # # from the sequence group. - # unselected_child_seqs.extend(running_child_seqs[beam_width:]) - - # # For newly created child sequences, add them to the sequence group - # # and fork them in block manager if they are not finished. - # for seq, parent in selected_child_seqs: - # if seq is not parent: - # seq_group.add(seq) - # if not seq.is_finished(): - # self.scheduler.fork_seq(parent, seq) - - # # Free the finished and selected parent sequences' memory in block - # # manager. Keep them in the sequence group as candidate output. - # for seq, parent in selected_child_seqs: - # if seq is parent and seq.is_finished(): - # self.scheduler.free_seq(seq) - - # # Remove the unselected parent sequences from the sequence group and - # # free their memory in block manager. - # for seq, parent in unselected_child_seqs: - # if seq is parent: - # # Remove the parent sequence if it is not selected for next - # # iteration - # seq_group.remove(seq.seq_id) - # self.scheduler.free_seq(seq) - def _process_model_outputs( self, output: SamplerOutput, scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: @@ -696,89 +478,6 @@ def _process_model_outputs( self.stat_logger.log(self._get_stats(scheduler_outputs)) return request_outputs - #def _process_sequence_group_outputs_multi_step(self, seq_group, outputs): - # seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - - # assert seqs - # #if not seqs: - # # return [] - - # assert len(seqs) == 1, ("Beam search not supported in speculative " - # "decoding.") - # seq = seqs[0] - - # # Since there's only one sequence per sequence group, we can take the - # # first sample. - # samples = [outputs[step].samples[0] for step in range(len(outputs))] - - # # -1 means the output token is not valid (eg. due to spec decode - # # rejecting tokens). - # valid_samples = [ - # sample for sample in samples if sample.output_token != -1 - # ] - - # # Draft target worker pads all outputs with -1 to have same length. - # output_token_ids = [sample.output_token for sample in valid_samples] - # #successes = [sample.success for sample in samples] - - # ## Truncate to max_tokens if necessary. - # #remaining_tokens = seq_group.sampling_params.max_tokens - ( - # # seq.get_output_len() + len(output_token_ids)) - # #if remaining_tokens < 0: - # # valid_samples = valid_samples[:remaining_tokens] - # # output_token_ids = output_token_ids[:remaining_tokens] - - # ## Truncate any tokens after EOS. This is required as spec decode - # ## generates tokens in fixed blocks, which may go beyond the EOS token. - # #if not seq_group.sampling_params.ignore_eos: - # # eos_token_id = self.tokenizer.get_lora_tokenizer( - # # seq.lora_request).eos_token_id - # # # Avoiding .index calls as exception throwing in the happy path - # # # is expensive. - # # for i in range(len(output_token_ids)): - # # if output_token_ids[i] == eos_token_id: - # # output_token_ids = output_token_ids[:i + 1] - # # valid_samples = valid_samples[:i + 1] - # # break - - # #output_logprobs = [sample.logprobs for sample in valid_samples] - - # ## Use the last sample for the sequence as it will have - # ## the speculation and num_unprocessed_tokens for all the - # ## previous samples (they are cumulative when it comes - # ## to those two attributes). - # #speculation = valid_samples[-1].speculation - # #num_unprocessed_tokens = valid_samples[-1].num_unprocessed_tokens - - # for output_token_id in output_token_ids: - # from vllm.sequence import Logprob - # seq.append_token_id( - # token_id=output_token_id, - # logprobs={output_token_id: Logprob(0.0)}, - # ) - - # #seq.append_token_ids(output_token_ids, - # # output_logprobs, - # # ) - # # #num_unprocessed_tokens=num_unprocessed_tokens) - # ##seq.set_last_speculation(speculation) - - # #if not all(successes): - # # seq.set_status_to_failed() - - # #if decode: - # # self._decode_sequence(seq, - # # seq_group.sampling_params, - # # token_ids=seq.get_token_ids(), - # # unseen_token_ids=output_token_ids, - # # prefix_offset=seq.prefix_offset, - # # read_offset=seq.read_offset) - # #self._check_stop(seq, seq_group.sampling_params, seq.lora_request, - # # output_token_ids) - # # TODO pass output token ids - # self._check_stop(seq, seq_group.sampling_params) - # if seq.is_finished(): - # self.scheduler.free_seq(seq) def step(self) -> List[RequestOutput]: """Performs one decoding iteration and returns newly generated results. @@ -923,57 +622,6 @@ def _get_stats(self, time_e2e_requests=time_e2e_requests, ) - #def _check_stop(self, seq: Sequence, - # sampling_params: SamplingParams) -> None: - # """Stop the finished sequences.""" - # # Check if the sequence has reached max_model_len. - # if seq.get_len() > self.scheduler_config.max_model_len: - # seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - # return - - # # Check if the sequence has reached max_tokens. - # if (sampling_params.max_tokens is not None) and (seq.get_output_len() >= sampling_params.max_tokens): - # # TODO should cap block - # seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - # return - - # # Check if the minimum number of tokens has been generated yet; - # # skip the stop string/token checks if not - # if seq.get_output_len() < sampling_params.min_tokens: - # return - - # if sampling_params.detokenize: - # for stop_str in sampling_params.stop: - # if seq.output_text.endswith(stop_str): - # self._finalize_sequence(seq, sampling_params, stop_str) - # seq.status = SequenceStatus.FINISHED_STOPPED - # seq.stop_reason = stop_str - # return - # last_token_id = seq.get_last_token_id() - # if last_token_id in sampling_params.stop_token_ids: - # stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( - # last_token_id) - # self._finalize_sequence(seq, sampling_params, stop_str) - # seq.status = SequenceStatus.FINISHED_STOPPED - # seq.stop_reason = last_token_id - # return - - # # Check if the sequence has generated the EOS token. - # if ((not sampling_params.ignore_eos) - # and seq.get_last_token_id() == seq.eos_token_id): - # seq.status = SequenceStatus.FINISHED_STOPPED - # return - - #def _finalize_sequence(self, seq: Sequence, - # sampling_params: SamplingParams, - # stop_string: str) -> None: - # if sampling_params.include_stop_str_in_output: - # return - - # if stop_string and seq.output_text.endswith(stop_string): - # # Truncate the output text so that the stop string is - # # not included in the output. - # seq.output_text = seq.output_text[:-len(stop_string)] def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) From 34468fe8af84d0a2bd313e9b4dc06582e17c1458 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 16:57:51 -0700 Subject: [PATCH 023/120] rename --- vllm/engine/output_processor/beam_search.py | 2 +- vllm/engine/output_processor/block_decode.py | 2 +- vllm/engine/output_processor/stop_checker.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/engine/output_processor/beam_search.py b/vllm/engine/output_processor/beam_search.py index c9ded1171151..829c5ecd7839 100644 --- a/vllm/engine/output_processor/beam_search.py +++ b/vllm/engine/output_processor/beam_search.py @@ -106,7 +106,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if seq_group.sampling_params.detokenize: self.detokenizer.decode_sequence_inplace( seq, seq_group.sampling_params) - self.stop_checker.check_stop(seq, seq_group.sampling_params) + self.stop_checker.maybe_stop_sequence(seq, seq_group.sampling_params) # Non-beam search case if not seq_group.sampling_params.use_beam_search: diff --git a/vllm/engine/output_processor/block_decode.py b/vllm/engine/output_processor/block_decode.py index 90ad03df32dd..44b4efba6372 100644 --- a/vllm/engine/output_processor/block_decode.py +++ b/vllm/engine/output_processor/block_decode.py @@ -131,7 +131,7 @@ def _process_sequence_group_outputs_multi_step(self, seq_group, outputs): #self._check_stop(seq, seq_group.sampling_params, seq.lora_request, # output_token_ids) # TODO pass output token ids - self.stop_checker.check_stop(seq, seq_group.sampling_params) + self.stop_checker.maybe_stop_sequence(seq, seq_group.sampling_params) if seq.is_finished(): self.scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index cc6655b7aaa7..82973e304202 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -37,7 +37,7 @@ def __init__(self, scheduler, scheduler_config, get_tokenizer_for_seq): self.scheduler_config = scheduler_config self.get_tokenizer_for_seq = get_tokenizer_for_seq - def check_stop(self, seq: Sequence, + def maybe_stop_sequence(self, seq: Sequence, sampling_params: SamplingParams) -> None: """Stop the finished sequences.""" # Check if the sequence has reached max_model_len. From 208c4671593534e9a2f9ed7f64da80c5a74a4fb4 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 17:10:05 -0700 Subject: [PATCH 024/120] wip --- vllm/engine/llm_engine.py | 23 +++++------------------ vllm/engine/output_processor/util.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 18 deletions(-) create mode 100644 vllm/engine/output_processor/util.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2be4a260f164..86ba02023627 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -27,6 +27,7 @@ from vllm.utils import Counter from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.engine.output_processor.util import create_output_by_sequence_group logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -424,6 +425,9 @@ def _process_model_outputs( self, output: SamplerOutput, scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: + now = time.time() + + # TODO if self.speculative_config is None: all_output = [output] else: @@ -431,34 +435,17 @@ def _process_model_outputs( scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups - # Organize list of sampler output by sequence group. - output_by_sequence_group: List[List[SequenceGroupOutputs]] = [ - [] for _ in scheduled_seq_groups - ] - for step in all_output: - for i, sequence_group_output in enumerate(step): - output_by_sequence_group[i].append(sequence_group_output) - - now = time.time() + output_by_sequence_group = create_output_by_sequence_group(sampler_outputs=all_output, num_seq_groups=len(scheduled_seq_groups)) # Update the scheduled sequence groups with the model outputs. for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output_by_sequence_group): - seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) self.output_processor.process_outputs(seq_group, outputs) - #assert len(outputs) > 0 - ## TODO can spec decode go through second path? - #if len(outputs) > 1: - # self._process_sequence_group_outputs_multi_step( - # seq_group, outputs) - #else: - # self._process_sequence_group_outputs(seq_group, outputs[0]) - # Free the finished sequence groups. self.scheduler.free_finished_seq_groups() diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py new file mode 100644 index 000000000000..1fcd651deef1 --- /dev/null +++ b/vllm/engine/output_processor/util.py @@ -0,0 +1,12 @@ +from vllm.sequence import SequenceGroupOutput, SamplerOutput +from typing import List + +def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput], num_seq_groups: int): + output_by_sequence_group = [ + [] for _ in range(num_seq_groups) + ] + for step in sampler_outputs: + for i, sequence_group_output in enumerate(step): + output_by_sequence_group[i].append(sequence_group_output) + + return output_by_sequence_group From 3c6abcc564bafc242316797ccbed1e10db54dff7 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 17:14:22 -0700 Subject: [PATCH 025/120] wip --- vllm/engine/llm_engine.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 86ba02023627..72af9c3da9f7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -423,7 +423,8 @@ def has_unfinished_requests(self) -> bool: def _process_model_outputs( self, output: SamplerOutput, - scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: + scheduled_seq_groups: List[SequenceGroup], + ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]: now = time.time() @@ -433,8 +434,6 @@ def _process_model_outputs( else: all_output = output - scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups - output_by_sequence_group = create_output_by_sequence_group(sampler_outputs=all_output, num_seq_groups=len(scheduled_seq_groups)) # Update the scheduled sequence groups with the model outputs. @@ -456,13 +455,9 @@ def _process_model_outputs( seq_group.maybe_set_first_token_time(now) request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) - for seq_group in scheduler_outputs.ignored_seq_groups: + for seq_group in ignored_seq_groups: request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) - - # Log stats. - if self.log_stats: - self.stat_logger.log(self._get_stats(scheduler_outputs)) return request_outputs @@ -529,7 +524,13 @@ def step(self) -> List[RequestOutput]: else: output = [] - return self._process_model_outputs(output, scheduler_outputs) + request_outputs = self._process_model_outputs(output, scheduler_outputs.scheduled_seq_groups, scheduler_outputs.ignored_seq_groups) + + # Log stats. + if self.log_stats: + self.stat_logger.log(self._get_stats(scheduler_outputs)) + + return request_outputs def do_log_stats(self) -> None: """Forced log when no requests active.""" From bbbcef70d603ab791ecc62336a56ef25b1566d33 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 17:27:24 -0700 Subject: [PATCH 026/120] wip --- tests/spec_decode/e2e/test_correctness.py | 2 +- vllm/engine/llm_engine.py | 11 +++-------- vllm/executor/cpu_executor.py | 2 +- vllm/executor/executor_base.py | 5 +++-- vllm/executor/gpu_executor.py | 2 +- vllm/spec_decode/multi_step_worker.py | 2 ++ vllm/spec_decode/spec_decode_worker.py | 2 ++ vllm/worker/cpu_worker.py | 8 +++++--- vllm/worker/neuron_worker.py | 9 ++++++--- vllm/worker/worker.py | 9 ++++++--- vllm/worker/worker_base.py | 5 +++-- 11 files changed, 33 insertions(+), 24 deletions(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index d8b09ce5b77a..eb6d1e1c5ddd 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -89,7 +89,7 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): { # Expect failure as spec decode not supported by # Ray backend. - "tensor_parallel_size": 2, + "worker_use_ray": True, }, ]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 72af9c3da9f7..bce36ddccc81 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -422,19 +422,14 @@ def has_unfinished_requests(self) -> bool: return self.scheduler.has_unfinished_seqs() def _process_model_outputs( - self, output: SamplerOutput, + self, + output: List[SamplerOutput], scheduled_seq_groups: List[SequenceGroup], ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]: now = time.time() - # TODO - if self.speculative_config is None: - all_output = [output] - else: - all_output = output - - output_by_sequence_group = create_output_by_sequence_group(sampler_outputs=all_output, num_seq_groups=len(scheduled_seq_groups)) + output_by_sequence_group = create_output_by_sequence_group(sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) # Update the scheduled sequence groups with the model outputs. for scheduled_seq_group, outputs in zip(scheduled_seq_groups, diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 835ba18ab756..f308f9149475 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -81,7 +81,7 @@ def execute_model(self, blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int) -> SamplerOutput: + num_lookahead_slots: int) -> List[SamplerOutput]: output = self.driver_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index c18edd75d7a4..23927c113744 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -58,8 +58,9 @@ def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - """Executes one model step on the given sequences.""" + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int) -> List[SamplerOutput]: + """Executes at least one model step on the given sequences.""" raise NotImplementedError @abstractmethod diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index ac445cd51a7e..90a534dc1271 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -157,7 +157,7 @@ def execute_model( blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], num_lookahead_slots: int, - ) -> SamplerOutput: + ) -> List[SamplerOutput]: output = self.driver_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 0ac189a7bacc..4cdbe0923455 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -70,6 +70,8 @@ def execute_model_multi_step( blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ) + assert (len(model_output) == 1), "composing multistep workers not supported" + model_output = model_output[0] self._append_new_tokens(model_output, copied_seq_group_metadata_list) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3e33371edadf..894377c9421e 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -196,6 +196,8 @@ def _run_no_spec( blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ) + assert len(sampler_output) == 1, "expected single output from scorer worker" + sampler_output = sampler_output[0] # Clear device tensors from sampler output. This reduces communication # overhead when the engine runs in a different process than the workers. diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index bd67f9f8850a..09a37c25783a 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -257,7 +257,7 @@ def execute_model( blocks_to_swap_in: Optional[Dict[int, int]] = None, blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, - ) -> Optional[SamplerOutput]: + ) -> List[SamplerOutput]: if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups = len(seq_group_metadata_list) @@ -280,11 +280,13 @@ def execute_model( # If there is no input, we don't need to execute the model. if num_seq_groups == 0: - return {} + return [] output = self.model_runner.execute_model(seq_group_metadata_list, self.cpu_cache) - return output + + # CPU worker only supports single-step execution. + return [output] def init_distributed_environment(self) -> None: """Initialize the distributed environment.""" diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 6136d50d0c06..d0f01b893bc6 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -73,15 +73,18 @@ def initialize_cache(self, num_gpu_blocks: int, def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Optional[SamplerOutput]: + ) -> List[SamplerOutput]: num_seq_groups = len(seq_group_metadata_list) # If there is no input, we don't need to execute the model. if num_seq_groups == 0: - return {} + return [] output = self.model_runner.execute_model(seq_group_metadata_list) - return output + + # Neuron worker only supports single-step output. Wrap the output in a + # list to conform to interface. + return [output] def get_cache_block_size_bytes(self) -> int: """Determine the size in bytes of a cache block. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index cb30f658482b..95e62b9e6a75 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -206,7 +206,7 @@ def execute_model( blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, num_lookahead_slots: int = 0, - ) -> Optional[SamplerOutput]: + ) -> List[SamplerOutput]: if self.is_driver_worker: assert seq_group_metadata_list is not None @@ -232,11 +232,14 @@ def execute_model( # If there is no input, we don't need to execute the model. if num_seq_groups == 0: - return {} + return [] output = self.model_runner.execute_model(seq_group_metadata_list, self.gpu_cache) - return output + + # Worker only supports single-step execution. Wrap the output in a list + # to conform to interface. + return [output] def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e3027c406ffe..1481a4c2eef4 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -44,8 +44,9 @@ def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - """Executes one model step on the given sequences.""" + blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]: + """Executes at least one model step on the given sequences, unless no + sequences are provided.""" raise NotImplementedError @abstractmethod From b58762d4fa0f64eb29af5a649650d6293c5d988f Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 17:29:07 -0700 Subject: [PATCH 027/120] fix --- vllm/spec_decode/batch_expansion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index bba3c4733e4f..f7bac45861a7 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -86,6 +86,8 @@ def score_proposals( blocks_to_copy=blocks_to_copy, #return_python_output=False ) + assert len(target_sampler_output) == 1, "expected single-step output" + target_sampler_output = target_sampler_output[0] all_tokens, all_probs = self._contract_batch( original_bs=len(seq_group_metadata_list), From 8b500d404b81b10857f75503e312ecf44ee9dd9f Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 17:43:04 -0700 Subject: [PATCH 028/120] wip --- vllm/engine/output_processor/block_decode.py | 67 ++++++-------------- 1 file changed, 18 insertions(+), 49 deletions(-) diff --git a/vllm/engine/output_processor/block_decode.py b/vllm/engine/output_processor/block_decode.py index 44b4efba6372..3fb2b7ee3235 100644 --- a/vllm/engine/output_processor/block_decode.py +++ b/vllm/engine/output_processor/block_decode.py @@ -18,7 +18,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, - SequenceStatus) + SequenceStatus, Logprob) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -49,17 +49,10 @@ def __init__( self.stop_checker = stop_checker def process_outputs(self, sequence_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: - return self._process_sequence_group_outputs_multi_step(sequence_group, outputs) + seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) - def _process_sequence_group_outputs_multi_step(self, seq_group, outputs): - seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - - assert seqs - #if not seqs: - # return [] - - assert len(seqs) == 1, ("Beam search not supported in speculative " - "decoding.") + assert seqs, "expected running sequences" + assert len(seqs) == 1, ("Beam search not supported in block decoding.") seq = seqs[0] # Since there's only one sequence per sequence group, we can take the @@ -71,21 +64,23 @@ def _process_sequence_group_outputs_multi_step(self, seq_group, outputs): valid_samples = [ sample for sample in samples if sample.output_token != -1 ] + assert valid_samples + + self._process_seq_outputs(seq, valid_samples, sequence_group.sampling_params) - # Draft target worker pads all outputs with -1 to have same length. + def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput], sampling_params: SamplingParams) -> None: output_token_ids = [sample.output_token for sample in valid_samples] - #successes = [sample.success for sample in samples] - ## Truncate to max_tokens if necessary. - #remaining_tokens = seq_group.sampling_params.max_tokens - ( - # seq.get_output_len() + len(output_token_ids)) - #if remaining_tokens < 0: - # valid_samples = valid_samples[:remaining_tokens] - # output_token_ids = output_token_ids[:remaining_tokens] + # Truncate to max_tokens if necessary. + remaining_tokens = sampling_params.max_tokens - ( + seq.get_output_len() + len(output_token_ids)) + if remaining_tokens < 0: + valid_samples = valid_samples[:remaining_tokens] + output_token_ids = output_token_ids[:remaining_tokens] ## Truncate any tokens after EOS. This is required as spec decode ## generates tokens in fixed blocks, which may go beyond the EOS token. - #if not seq_group.sampling_params.ignore_eos: + #if not sampling_params.ignore_eos: # eos_token_id = self.tokenizer.get_lora_tokenizer( # seq.lora_request).eos_token_id # # Avoiding .index calls as exception throwing in the happy path @@ -96,42 +91,16 @@ def _process_sequence_group_outputs_multi_step(self, seq_group, outputs): # valid_samples = valid_samples[:i + 1] # break - #output_logprobs = [sample.logprobs for sample in valid_samples] - - ## Use the last sample for the sequence as it will have - ## the speculation and num_unprocessed_tokens for all the - ## previous samples (they are cumulative when it comes - ## to those two attributes). - #speculation = valid_samples[-1].speculation - #num_unprocessed_tokens = valid_samples[-1].num_unprocessed_tokens - for output_token_id in output_token_ids: - from vllm.sequence import Logprob seq.append_token_id( token_id=output_token_id, + # TODO emit logprobs in block decoding. logprobs={output_token_id: Logprob(0.0)}, ) - #seq.append_token_ids(output_token_ids, - # output_logprobs, - # ) - # #num_unprocessed_tokens=num_unprocessed_tokens) - ##seq.set_last_speculation(speculation) - - #if not all(successes): - # seq.set_status_to_failed() - - #if decode: - # self._decode_sequence(seq, - # seq_group.sampling_params, - # token_ids=seq.get_token_ids(), - # unseen_token_ids=output_token_ids, - # prefix_offset=seq.prefix_offset, - # read_offset=seq.read_offset) - #self._check_stop(seq, seq_group.sampling_params, seq.lora_request, - # output_token_ids) + # TODO detokenize # TODO pass output token ids - self.stop_checker.maybe_stop_sequence(seq, seq_group.sampling_params) + self.stop_checker.maybe_stop_sequence(seq, sampling_params) if seq.is_finished(): self.scheduler.free_seq(seq) From 782ce22d604291a64ac6dce3efbb9b4c662c0557 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 18:26:30 -0700 Subject: [PATCH 029/120] unit tests for block decode --- tests/core/utils.py | 16 +- .../output_processor/test_block_decode.py | 238 ++++++++++++++++++ vllm/engine/output_processor/beam_search.py | 2 - vllm/engine/output_processor/block_decode.py | 27 +- vllm/engine/output_processor/interfaces.py | 5 +- 5 files changed, 262 insertions(+), 26 deletions(-) create mode 100644 tests/engine/output_processor/test_block_decode.py diff --git a/tests/core/utils.py b/tests/core/utils.py index fbbdb07cb8e6..d9d2eeaee1b9 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -1,5 +1,5 @@ import time -from typing import Optional, Tuple +from typing import Optional, Tuple, Iterable from vllm import SamplingParams from vllm.lora.request import LoRARequest @@ -31,14 +31,18 @@ def create_dummy_prompt( def create_seq_group( - seq_prompt_len=1024, - seq_output_lens=(128, ), - request_id='0', - seq_id_start=0, + seq_prompt_len: int=1024, + seq_output_lens: Iterable[int]=(128, ), + request_id: str='0', + seq_id_start: int=0, + sampling_params: Optional[SamplingParams] = None ) -> SequenceGroup: assert len(seq_output_lens) > 0 + if sampling_params is None: + sampling_params = SamplingParams() + prompt_token_ids = [0] * seq_prompt_len seqs = [] @@ -60,7 +64,7 @@ def create_seq_group( seq_group = SequenceGroup( request_id=request_id, seqs=seqs, - sampling_params=SamplingParams(), + sampling_params=sampling_params, arrival_time=time.time(), ) diff --git a/tests/engine/output_processor/test_block_decode.py b/tests/engine/output_processor/test_block_decode.py new file mode 100644 index 000000000000..aae184c16447 --- /dev/null +++ b/tests/engine/output_processor/test_block_decode.py @@ -0,0 +1,238 @@ +import pytest +from unittest.mock import MagicMock +import random + +from transformers import PreTrainedTokenizer + +from vllm.engine.output_processor.block_decode import BlockDecodeOutputProcessor +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.core.scheduler import Scheduler +from vllm.utils import Counter +from vllm.sequence import SequenceStatus, SequenceGroupOutput, SequenceOutput, Logprob +from vllm.sampling_params import SamplingParams +from tests.core.utils import create_seq_group + +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [1, 12]) +@pytest.mark.skip_global_cleanup +def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + output_processor = BlockDecodeOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=1024, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams( + max_tokens=seq_output_len + num_new_tokens, + ), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + + outputs = [SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids] + + assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids + output_processor.process_outputs(seq_group, outputs) + assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids + +@pytest.mark.parametrize("seq_prompt_len", [1024]) +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8]) +@pytest.mark.parametrize("max_tokens", [128 + 3]) +@pytest.mark.skip_global_cleanup +def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, seq_output_len: int, max_tokens: int): + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + output_processor = BlockDecodeOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=seq_prompt_len, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams( + max_tokens=max_tokens, + ), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + + outputs = [SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids] + + assert seq.get_len() == seq_prompt_len + seq_output_len + output_processor.process_outputs(seq_group, outputs) + + # Expect the processed sequence to not go over max tokens in len. + assert seq.get_len() == seq_prompt_len + max_tokens + + # Expect the correct tokens were appended. + expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len] + assert seq.get_token_ids()[-len(expected_appended_tokens):] == expected_appended_tokens + +@pytest.mark.parametrize("seq_prompt_len", [1024]) +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [12]) +@pytest.mark.parametrize("seed", list(range(6))) +@pytest.mark.skip_global_cleanup +def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq_output_len: int, seed: int): + random.seed(seed) + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + eos_token_id = 100 + + output_processor = BlockDecodeOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=seq_prompt_len, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams( + # Ensure enough space. + max_tokens=seq_output_len + num_new_tokens, + ), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + assert eos_token_id not in new_token_ids + eos_index = random.randint(0, len(new_token_ids) - 1) + new_token_ids[eos_index] = eos_token_id + + outputs = [SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids] + + assert seq.get_len() == seq_prompt_len + seq_output_len + output_processor.process_outputs(seq_group, outputs) + + # Expect the processed sequence to not go beyond provided eos. + assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1) + + # Expect the correct tokens were appended. + expected_appended_tokens = new_token_ids[:eos_index+1] + assert seq.get_token_ids()[-len(expected_appended_tokens):] == expected_appended_tokens + +@pytest.mark.parametrize("seq_prompt_len", [1024]) +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [12]) +@pytest.mark.parametrize("seed", list(range(6))) +@pytest.mark.skip_global_cleanup +def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq_output_len: int, seed: int): + random.seed(seed) + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + eos_token_id = 100 + + output_processor = BlockDecodeOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=seq_prompt_len, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams( + # Ensure enough space. + max_tokens=seq_output_len + num_new_tokens, + ignore_eos=True, + ), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + assert eos_token_id not in new_token_ids + eos_index = random.randint(0, len(new_token_ids) - 1) + new_token_ids[eos_index] = eos_token_id + + outputs = [SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids] + + assert seq.get_len() == seq_prompt_len + seq_output_len + output_processor.process_outputs(seq_group, outputs) + + # Expect the processed sequence to go beyond eos. + assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens + + # Expect the correct tokens were appended. + expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens - seq_output_len] + assert seq.get_token_ids()[-len(expected_appended_tokens):] == expected_appended_tokens + +def mock_tokenizer(eos_token_id=1000): + tokenizer = MagicMock(spec=PreTrainedTokenizer) + tokenizer.eos_token_id = eos_token_id + return tokenizer diff --git a/vllm/engine/output_processor/beam_search.py b/vllm/engine/output_processor/beam_search.py index 829c5ecd7839..827142bd4bf5 100644 --- a/vllm/engine/output_processor/beam_search.py +++ b/vllm/engine/output_processor/beam_search.py @@ -38,14 +38,12 @@ def __init__( detokenizer, scheduler, seq_counter, - get_tokenizer_for_seq, stop_checker, ): self.scheduler_config = scheduler_config self.detokenizer = detokenizer self.scheduler = scheduler self.seq_counter = seq_counter - self.get_tokenizer_for_seq = get_tokenizer_for_seq self.stop_checker = stop_checker def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: diff --git a/vllm/engine/output_processor/block_decode.py b/vllm/engine/output_processor/block_decode.py index 3fb2b7ee3235..06d3ee9306ef 100644 --- a/vllm/engine/output_processor/block_decode.py +++ b/vllm/engine/output_processor/block_decode.py @@ -34,21 +34,19 @@ class BlockDecodeOutputProcessor(SequenceGroupOutputProcessor): def __init__( self, - scheduler_config: SchedulerConfig, detokenizer, scheduler, seq_counter, get_tokenizer_for_seq, stop_checker, ): - self.scheduler_config = scheduler_config self.detokenizer = detokenizer self.scheduler = scheduler self.seq_counter = seq_counter self.get_tokenizer_for_seq = get_tokenizer_for_seq self.stop_checker = stop_checker - def process_outputs(self, sequence_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: + def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) assert seqs, "expected running sequences" @@ -78,18 +76,17 @@ def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput valid_samples = valid_samples[:remaining_tokens] output_token_ids = output_token_ids[:remaining_tokens] - ## Truncate any tokens after EOS. This is required as spec decode - ## generates tokens in fixed blocks, which may go beyond the EOS token. - #if not sampling_params.ignore_eos: - # eos_token_id = self.tokenizer.get_lora_tokenizer( - # seq.lora_request).eos_token_id - # # Avoiding .index calls as exception throwing in the happy path - # # is expensive. - # for i in range(len(output_token_ids)): - # if output_token_ids[i] == eos_token_id: - # output_token_ids = output_token_ids[:i + 1] - # valid_samples = valid_samples[:i + 1] - # break + # Truncate any tokens after EOS. This is required as spec decode + # generates tokens in fixed blocks, which may go beyond the EOS token. + if not sampling_params.ignore_eos: + eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id + # Avoiding .index calls as exception throwing in the happy path + # is expensive. + for i in range(len(output_token_ids)): + if output_token_ids[i] == eos_token_id: + output_token_ids = output_token_ids[:i + 1] + valid_samples = valid_samples[:i + 1] + break for output_token_id in output_token_ids: seq.append_token_id( diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index d2368fc811a0..8a7e27645b4d 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from vllm.config import SchedulerConfig from vllm.sequence import SequenceGroup, SequenceGroupOutput +from typing import List class SequenceGroupOutputProcessor(ABC): @@ -20,13 +21,11 @@ def create_output_processor( detokenizer, scheduler, seq_counter, - get_tokenizer_for_seq, stop_checker, ) else: from vllm.engine.output_processor.block_decode import BlockDecodeOutputProcessor return BlockDecodeOutputProcessor( - scheduler_config, detokenizer, scheduler, seq_counter, @@ -35,5 +34,5 @@ def create_output_processor( ) @abstractmethod - def process_outputs(self, sequence_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: + def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: pass From 3062e1cbeb11d66a8904d05c6ef935784caf44ef Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 18:34:53 -0700 Subject: [PATCH 030/120] stop token ids --- vllm/engine/output_processor/beam_search.py | 2 +- vllm/engine/output_processor/block_decode.py | 3 +-- vllm/engine/output_processor/stop_checker.py | 20 ++++++++++++-------- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/vllm/engine/output_processor/beam_search.py b/vllm/engine/output_processor/beam_search.py index 827142bd4bf5..2b5657d37ccd 100644 --- a/vllm/engine/output_processor/beam_search.py +++ b/vllm/engine/output_processor/beam_search.py @@ -104,7 +104,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if seq_group.sampling_params.detokenize: self.detokenizer.decode_sequence_inplace( seq, seq_group.sampling_params) - self.stop_checker.maybe_stop_sequence(seq, seq_group.sampling_params) + self.stop_checker.maybe_stop_sequence(seq, seq_group.sampling_params, [seq.get_last_token_id()]) # Non-beam search case if not seq_group.sampling_params.use_beam_search: diff --git a/vllm/engine/output_processor/block_decode.py b/vllm/engine/output_processor/block_decode.py index 06d3ee9306ef..e218fa99b0e6 100644 --- a/vllm/engine/output_processor/block_decode.py +++ b/vllm/engine/output_processor/block_decode.py @@ -96,8 +96,7 @@ def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput ) # TODO detokenize - # TODO pass output token ids - self.stop_checker.maybe_stop_sequence(seq, sampling_params) + self.stop_checker.maybe_stop_sequence(seq, sampling_params, new_token_ids=output_token_ids) if seq.is_finished(): self.scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 82973e304202..4d8f3730e9f6 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -38,7 +38,7 @@ def __init__(self, scheduler, scheduler_config, get_tokenizer_for_seq): self.get_tokenizer_for_seq = get_tokenizer_for_seq def maybe_stop_sequence(self, seq: Sequence, - sampling_params: SamplingParams) -> None: + sampling_params: SamplingParams, new_token_ids: List[int]) -> None: """Stop the finished sequences.""" # Check if the sequence has reached max_model_len. if seq.get_len() > self.scheduler_config.max_model_len: @@ -46,8 +46,7 @@ def maybe_stop_sequence(self, seq: Sequence, return # Check if the sequence has reached max_tokens. - if (sampling_params.max_tokens is not None) and (seq.get_output_len() >= sampling_params.max_tokens): - # TODO should cap block + if seq.get_output_len() == sampling_params.max_tokens: seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return @@ -63,18 +62,23 @@ def maybe_stop_sequence(self, seq: Sequence, seq.status = SequenceStatus.FINISHED_STOPPED seq.stop_reason = stop_str return - last_token_id = seq.get_last_token_id() - if last_token_id in sampling_params.stop_token_ids: + + # Determine if any stop_token_ids are in new_token_ids. + intersection = set(new_token_ids).intersection(sampling_params.stop_token_ids) + if intersection: + # Get arbitrary token id that caused the stop. + stop_token_id = next(iter(intersection)) + stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( - last_token_id) + stop_token_id) self._finalize_sequence(seq, sampling_params, stop_str) seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = last_token_id + seq.stop_reason = stop_token_id return # Check if the sequence has generated the EOS token. if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): + and seq.eos_token_id in new_token_ids): seq.status = SequenceStatus.FINISHED_STOPPED return From fba3b300f66e047750eb3a392e0b2f3aee0e0cd8 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 18:35:16 -0700 Subject: [PATCH 031/120] format --- tests/core/utils.py | 11 +- .../output_processor/test_block_decode.py | 136 ++++++++++-------- vllm/engine/llm_engine.py | 14 +- vllm/engine/output_processor/beam_search.py | 12 +- vllm/engine/output_processor/block_decode.py | 20 ++- vllm/engine/output_processor/interfaces.py | 6 +- vllm/engine/output_processor/stop_checker.py | 7 +- vllm/engine/output_processor/util.py | 10 +- vllm/model_executor/layers/sampler.py | 3 +- vllm/spec_decode/multi_step_worker.py | 3 +- vllm/spec_decode/spec_decode_worker.py | 3 +- vllm/worker/worker_base.py | 10 +- 12 files changed, 134 insertions(+), 101 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index d9d2eeaee1b9..39f8e507d0f1 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -31,12 +31,11 @@ def create_dummy_prompt( def create_seq_group( - seq_prompt_len: int=1024, - seq_output_lens: Iterable[int]=(128, ), - request_id: str='0', - seq_id_start: int=0, - sampling_params: Optional[SamplingParams] = None -) -> SequenceGroup: + seq_prompt_len: int = 1024, + seq_output_lens: Iterable[int] = (128, ), + request_id: str = '0', + seq_id_start: int = 0, + sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: assert len(seq_output_lens) > 0 diff --git a/tests/engine/output_processor/test_block_decode.py b/tests/engine/output_processor/test_block_decode.py index aae184c16447..f426f1d32d7a 100644 --- a/tests/engine/output_processor/test_block_decode.py +++ b/tests/engine/output_processor/test_block_decode.py @@ -13,6 +13,7 @@ from vllm.sampling_params import SamplingParams from tests.core.utils import create_seq_group + @pytest.mark.parametrize("seq_output_len", [128]) @pytest.mark.parametrize("num_new_tokens", [1, 12]) @pytest.mark.skip_global_cleanup @@ -33,37 +34,40 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): seq_group = create_seq_group( seq_prompt_len=1024, seq_output_lens=[seq_output_len], - sampling_params=SamplingParams( - max_tokens=seq_output_len + num_new_tokens, - ), + sampling_params=SamplingParams(max_tokens=seq_output_len + + num_new_tokens, ), ) - + seq = seq_group.get_seqs()[0] seq.status = SequenceStatus.RUNNING new_token_ids = list(range(num_new_tokens)) - outputs = [SequenceGroupOutput( - samples=[ - SequenceOutput( - parent_seq_id=seq.seq_id, - output_token=output_token, - logprobs={output_token: Logprob(0.0)}, - ) - ], - prompt_logprobs=None, - ) for output_token in new_token_ids] + outputs = [ + SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids output_processor.process_outputs(seq_group, outputs) assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids + @pytest.mark.parametrize("seq_prompt_len", [1024]) @pytest.mark.parametrize("seq_output_len", [128]) @pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8]) @pytest.mark.parametrize("max_tokens", [128 + 3]) @pytest.mark.skip_global_cleanup -def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, seq_output_len: int, max_tokens: int): +def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, + seq_output_len: int, max_tokens: int): detokenizer = MagicMock(spec=Detokenizer) scheduler = MagicMock(spec=Scheduler) stop_checker = MagicMock(spec=StopChecker) @@ -80,26 +84,26 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, seq_outpu seq_group = create_seq_group( seq_prompt_len=seq_prompt_len, seq_output_lens=[seq_output_len], - sampling_params=SamplingParams( - max_tokens=max_tokens, - ), + sampling_params=SamplingParams(max_tokens=max_tokens, ), ) - + seq = seq_group.get_seqs()[0] seq.status = SequenceStatus.RUNNING new_token_ids = list(range(num_new_tokens)) - outputs = [SequenceGroupOutput( - samples=[ - SequenceOutput( - parent_seq_id=seq.seq_id, - output_token=output_token, - logprobs={output_token: Logprob(0.0)}, - ) - ], - prompt_logprobs=None, - ) for output_token in new_token_ids] + outputs = [ + SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] assert seq.get_len() == seq_prompt_len + seq_output_len output_processor.process_outputs(seq_group, outputs) @@ -109,14 +113,17 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, seq_outpu # Expect the correct tokens were appended. expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len] - assert seq.get_token_ids()[-len(expected_appended_tokens):] == expected_appended_tokens + assert seq.get_token_ids( + )[-len(expected_appended_tokens):] == expected_appended_tokens + @pytest.mark.parametrize("seq_prompt_len", [1024]) @pytest.mark.parametrize("seq_output_len", [128]) @pytest.mark.parametrize("num_new_tokens", [12]) @pytest.mark.parametrize("seed", list(range(6))) @pytest.mark.skip_global_cleanup -def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq_output_len: int, seed: int): +def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, + seq_output_len: int, seed: int): random.seed(seed) detokenizer = MagicMock(spec=Detokenizer) scheduler = MagicMock(spec=Scheduler) @@ -138,10 +145,9 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq_out seq_output_lens=[seq_output_len], sampling_params=SamplingParams( # Ensure enough space. - max_tokens=seq_output_len + num_new_tokens, - ), + max_tokens=seq_output_len + num_new_tokens, ), ) - + seq = seq_group.get_seqs()[0] seq.status = SequenceStatus.RUNNING @@ -150,16 +156,18 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq_out eos_index = random.randint(0, len(new_token_ids) - 1) new_token_ids[eos_index] = eos_token_id - outputs = [SequenceGroupOutput( - samples=[ - SequenceOutput( - parent_seq_id=seq.seq_id, - output_token=output_token, - logprobs={output_token: Logprob(0.0)}, - ) - ], - prompt_logprobs=None, - ) for output_token in new_token_ids] + outputs = [ + SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] assert seq.get_len() == seq_prompt_len + seq_output_len output_processor.process_outputs(seq_group, outputs) @@ -168,15 +176,18 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq_out assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1) # Expect the correct tokens were appended. - expected_appended_tokens = new_token_ids[:eos_index+1] - assert seq.get_token_ids()[-len(expected_appended_tokens):] == expected_appended_tokens + expected_appended_tokens = new_token_ids[:eos_index + 1] + assert seq.get_token_ids( + )[-len(expected_appended_tokens):] == expected_appended_tokens + @pytest.mark.parametrize("seq_prompt_len", [1024]) @pytest.mark.parametrize("seq_output_len", [128]) @pytest.mark.parametrize("num_new_tokens", [12]) @pytest.mark.parametrize("seed", list(range(6))) @pytest.mark.skip_global_cleanup -def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq_output_len: int, seed: int): +def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, + seq_output_len: int, seed: int): random.seed(seed) detokenizer = MagicMock(spec=Detokenizer) scheduler = MagicMock(spec=Scheduler) @@ -202,7 +213,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq_outp ignore_eos=True, ), ) - + seq = seq_group.get_seqs()[0] seq.status = SequenceStatus.RUNNING @@ -211,16 +222,18 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq_outp eos_index = random.randint(0, len(new_token_ids) - 1) new_token_ids[eos_index] = eos_token_id - outputs = [SequenceGroupOutput( - samples=[ - SequenceOutput( - parent_seq_id=seq.seq_id, - output_token=output_token, - logprobs={output_token: Logprob(0.0)}, - ) - ], - prompt_logprobs=None, - ) for output_token in new_token_ids] + outputs = [ + SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] assert seq.get_len() == seq_prompt_len + seq_output_len output_processor.process_outputs(seq_group, outputs) @@ -229,8 +242,11 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq_outp assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens # Expect the correct tokens were appended. - expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens - seq_output_len] - assert seq.get_token_ids()[-len(expected_appended_tokens):] == expected_appended_tokens + expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens - + seq_output_len] + assert seq.get_token_ids( + )[-len(expected_appended_tokens):] == expected_appended_tokens + def mock_tokenizer(eos_token_id=1000): tokenizer = MagicMock(spec=PreTrainedTokenizer) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bce36ddccc81..9936eb18c032 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -422,14 +422,14 @@ def has_unfinished_requests(self) -> bool: return self.scheduler.has_unfinished_seqs() def _process_model_outputs( - self, - output: List[SamplerOutput], + self, output: List[SamplerOutput], scheduled_seq_groups: List[SequenceGroup], ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]: now = time.time() - output_by_sequence_group = create_output_by_sequence_group(sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) + output_by_sequence_group = create_output_by_sequence_group( + sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) # Update the scheduled sequence groups with the model outputs. for scheduled_seq_group, outputs in zip(scheduled_seq_groups, @@ -437,7 +437,7 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - + self.output_processor.process_outputs(seq_group, outputs) # Free the finished sequence groups. @@ -455,7 +455,6 @@ def _process_model_outputs( request_outputs.append(request_output) return request_outputs - def step(self) -> List[RequestOutput]: """Performs one decoding iteration and returns newly generated results. @@ -519,7 +518,9 @@ def step(self) -> List[RequestOutput]: else: output = [] - request_outputs = self._process_model_outputs(output, scheduler_outputs.scheduled_seq_groups, scheduler_outputs.ignored_seq_groups) + request_outputs = self._process_model_outputs( + output, scheduler_outputs.scheduled_seq_groups, + scheduler_outputs.ignored_seq_groups) # Log stats. if self.log_stats: @@ -605,7 +606,6 @@ def _get_stats(self, time_e2e_requests=time_e2e_requests, ) - def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) diff --git a/vllm/engine/output_processor/beam_search.py b/vllm/engine/output_processor/beam_search.py index 2b5657d37ccd..94af809e2673 100644 --- a/vllm/engine/output_processor/beam_search.py +++ b/vllm/engine/output_processor/beam_search.py @@ -31,7 +31,7 @@ class BeamSearchOutputProcessor(SequenceGroupOutputProcessor): - + def __init__( self, scheduler_config: SchedulerConfig, @@ -46,8 +46,10 @@ def __init__( self.seq_counter = seq_counter self.stop_checker = stop_checker - def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: - assert (len(outputs) == 1), f"{type(self)} does not support multiple outputs per step" + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + assert (len(outputs) == 1 + ), f"{type(self)} does not support multiple outputs per step" return self._process_sequence_group_outputs(sequence_group, outputs[0]) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, @@ -104,7 +106,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if seq_group.sampling_params.detokenize: self.detokenizer.decode_sequence_inplace( seq, seq_group.sampling_params) - self.stop_checker.maybe_stop_sequence(seq, seq_group.sampling_params, [seq.get_last_token_id()]) + self.stop_checker.maybe_stop_sequence(seq, + seq_group.sampling_params, + [seq.get_last_token_id()]) # Non-beam search case if not seq_group.sampling_params.use_beam_search: diff --git a/vllm/engine/output_processor/block_decode.py b/vllm/engine/output_processor/block_decode.py index e218fa99b0e6..3b6a60e857fa 100644 --- a/vllm/engine/output_processor/block_decode.py +++ b/vllm/engine/output_processor/block_decode.py @@ -31,7 +31,7 @@ class BlockDecodeOutputProcessor(SequenceGroupOutputProcessor): - + def __init__( self, detokenizer, @@ -46,7 +46,8 @@ def __init__( self.get_tokenizer_for_seq = get_tokenizer_for_seq self.stop_checker = stop_checker - def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) assert seqs, "expected running sequences" @@ -64,14 +65,17 @@ def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceG ] assert valid_samples - self._process_seq_outputs(seq, valid_samples, sequence_group.sampling_params) + self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) - def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput], sampling_params: SamplingParams) -> None: + def _process_seq_outputs(self, seq: Sequence, + valid_samples: List[SequenceOutput], + sampling_params: SamplingParams) -> None: output_token_ids = [sample.output_token for sample in valid_samples] # Truncate to max_tokens if necessary. - remaining_tokens = sampling_params.max_tokens - ( - seq.get_output_len() + len(output_token_ids)) + remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + + len(output_token_ids)) if remaining_tokens < 0: valid_samples = valid_samples[:remaining_tokens] output_token_ids = output_token_ids[:remaining_tokens] @@ -96,7 +100,9 @@ def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput ) # TODO detokenize - self.stop_checker.maybe_stop_sequence(seq, sampling_params, new_token_ids=output_token_ids) + self.stop_checker.maybe_stop_sequence(seq, + sampling_params, + new_token_ids=output_token_ids) if seq.is_finished(): self.scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 8a7e27645b4d..2b931a0b2f41 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -3,8 +3,9 @@ from vllm.sequence import SequenceGroup, SequenceGroupOutput from typing import List + class SequenceGroupOutputProcessor(ABC): - + @staticmethod def create_output_processor( scheduler_config: SchedulerConfig, @@ -34,5 +35,6 @@ def create_output_processor( ) @abstractmethod - def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: pass diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 4d8f3730e9f6..3f03373f2698 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -30,6 +30,7 @@ logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 + class StopChecker: def __init__(self, scheduler, scheduler_config, get_tokenizer_for_seq): @@ -38,7 +39,8 @@ def __init__(self, scheduler, scheduler_config, get_tokenizer_for_seq): self.get_tokenizer_for_seq = get_tokenizer_for_seq def maybe_stop_sequence(self, seq: Sequence, - sampling_params: SamplingParams, new_token_ids: List[int]) -> None: + sampling_params: SamplingParams, + new_token_ids: List[int]) -> None: """Stop the finished sequences.""" # Check if the sequence has reached max_model_len. if seq.get_len() > self.scheduler_config.max_model_len: @@ -64,7 +66,8 @@ def maybe_stop_sequence(self, seq: Sequence, return # Determine if any stop_token_ids are in new_token_ids. - intersection = set(new_token_ids).intersection(sampling_params.stop_token_ids) + intersection = set(new_token_ids).intersection( + sampling_params.stop_token_ids) if intersection: # Get arbitrary token id that caused the stop. stop_token_id = next(iter(intersection)) diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index 1fcd651deef1..b49bbb2fab32 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -1,12 +1,12 @@ from vllm.sequence import SequenceGroupOutput, SamplerOutput from typing import List -def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput], num_seq_groups: int): - output_by_sequence_group = [ - [] for _ in range(num_seq_groups) - ] + +def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput], + num_seq_groups: int): + output_by_sequence_group = [[] for _ in range(num_seq_groups)] for step in sampler_outputs: for i, sequence_group_output in enumerate(step): - output_by_sequence_group[i].append(sequence_group_output) + output_by_sequence_group[i].append(sequence_group_output) return output_by_sequence_group diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index bed915faf3fb..be970e56b611 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -79,7 +79,8 @@ def forward( prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, sampling_metadata, prompt_logprobs, sample_logprobs) + return _build_sampler_output(sample_results, sampling_metadata, + prompt_logprobs, sample_logprobs) def _get_bin_counts_and_mask( diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 4cdbe0923455..85060ccf2b15 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -70,7 +70,8 @@ def execute_model_multi_step( blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ) - assert (len(model_output) == 1), "composing multistep workers not supported" + assert (len(model_output) == 1 + ), "composing multistep workers not supported" model_output = model_output[0] self._append_new_tokens(model_output, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 894377c9421e..b9824937a944 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -196,7 +196,8 @@ def _run_no_spec( blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ) - assert len(sampler_output) == 1, "expected single output from scorer worker" + assert len( + sampler_output) == 1, "expected single output from scorer worker" sampler_output = sampler_output[0] # Clear device tensors from sampler output. This reduces communication diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 1481a4c2eef4..d5d3ffda1f43 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -40,11 +40,11 @@ def initialize_cache(self, num_gpu_blocks: int, raise NotImplementedError @abstractmethod - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]: + def execute_model( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, + int], + blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" raise NotImplementedError From bda141fe4dca51b53edf0bafb97882155b2b6839 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 18:56:05 -0700 Subject: [PATCH 032/120] fixing spec tests --- tests/spec_decode/test_multi_step_worker.py | 5 +++-- tests/spec_decode/test_spec_decode_worker.py | 16 +++++++++++----- tests/spec_decode/utils.py | 4 ++-- vllm/engine/async_llm_engine.py | 2 +- vllm/spec_decode/batch_expansion.py | 4 ++-- vllm/spec_decode/multi_step_worker.py | 5 +++-- vllm/spec_decode/spec_decode_worker.py | 3 +-- vllm/spec_decode/util.py | 17 ++++++++++------- 8 files changed, 33 insertions(+), 23 deletions(-) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index f4d44108b47c..f9840d6157c3 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -125,7 +125,7 @@ def test_same_output_for_single_step(): zero_kv_cache(worker.cache_engine) set_random_seed(seed) expected_output = worker.execute_model( - **single_step_execute_model_data.to_dict(), ) + **single_step_execute_model_data.to_dict(), )[0] actual_token_ids = [ output.samples[0].output_token for output in actual_output @@ -219,7 +219,7 @@ def test_same_output_for_multi_step(): continuations=continuations, final_seq_lens=final_seq_lens)) - single_step_output.append( + single_step_output.extend( worker.execute_model(**execute_model_data.to_dict(), )) # Append output tokens to new sequence data. @@ -352,6 +352,7 @@ def test_draft_proposals_no_speculations(): @torch.inference_mode() +#@pytest.skip("Broken because output is padded.") def test_draft_proposals_mixed_k(): """Verify DraftModelTop1Proposer correctly handles case some sequences can speculate and some can't. diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 3725924ea89c..889712fb9360 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -12,6 +12,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, split_num_cache_blocks_evenly) +from vllm.sequence import SamplerOutput from .utils import (ExecuteModelData, create_batch, create_sampler_output_list, mock_worker) @@ -191,7 +192,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): target_output = create_sampler_output_list(target_token_ids, target_token_probs) - target_worker.execute_model.return_value = target_output[0] + target_worker.execute_model.return_value = [target_output[0]] exception_secret = 'artifical stop' rejection_sampler.side_effect = ValueError(exception_secret) @@ -271,7 +272,7 @@ def test_correctly_formats_output(k: int, batch_size: int): target_output = create_sampler_output_list(target_token_ids, target_token_probs) - target_worker.execute_model.return_value = target_output[0] + target_worker.execute_model.return_value = [target_output[0]] rejection_sampler_output = torch.randint(low=0, high=vocab_size, @@ -340,6 +341,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) + draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -383,7 +385,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): target_output = create_sampler_output_list(target_token_ids, target_token_probs) - target_worker.execute_model.return_value = target_output[0] + target_worker.execute_model.return_value = [target_output[0]] rejection_sampler_output = torch.randint(low=0, high=vocab_size, @@ -426,6 +428,8 @@ def test_k_equals_zero(k: int, batch_size: int): rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) + target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] + draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -446,7 +450,7 @@ def test_k_equals_zero(k: int, batch_size: int): 0].sampled_tokens is None, "expect gpu tensor references to be None" draft_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict(), return_python_output=False) + **execute_model_data.to_dict()) target_worker.execute_model.assert_called_once_with( **execute_model_data.to_dict()) @@ -465,6 +469,8 @@ def test_empty_input_batch(k: int, batch_size: int): rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) + target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] + draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -485,7 +491,7 @@ def test_empty_input_batch(k: int, batch_size: int): 0].sampled_tokens is None, "expect gpu tensor references to be None" draft_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict(), return_python_output=False) + **execute_model_data.to_dict()) target_worker.execute_model.assert_called_once_with( **execute_model_data.to_dict()) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 4637826f254d..3914af945eff 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -10,7 +10,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (Logprob, SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceGroupOutput, - SequenceOutput) + SequenceOutput, Logprob) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine from vllm.worker.worker import Worker @@ -211,7 +211,7 @@ def create_sampler_output_list( SequenceOutput( output_token=token_id, parent_seq_id=seq_ids[seq_index], - logprobs={token_id: 0}, + logprobs={token_id: Logprob(0)}, ) ], prompt_logprobs=None, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f61049513512..378484510247 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -217,7 +217,7 @@ async def step_async(self) -> List[RequestOutput]: else: output = [] - return self._process_model_outputs(output, scheduler_outputs) + return self._process_model_outputs(output, scheduler_outputs.scheduled_seq_groups, scheduler_outputs.ignored_seq_groups) async def encode_request_async( self, diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index f7bac45861a7..1011dd970ebc 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -9,7 +9,7 @@ from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, sampler_output_to_torch, split_batch_by_proposal_len, - mock_device_tensors) + maybe_mock_device_tensors) from vllm.worker.worker import Worker SeqId = int @@ -147,7 +147,7 @@ def _contract_batch(self, original_bs: int, sequences. """ - mock_device_tensors( + maybe_mock_device_tensors( sampler_output=target_sampler_output, batch_size=len(non_spec_indices) + num_scoring_tokens, vocab_size=self._vocab_size, diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 85060ccf2b15..4182b8758465 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -7,7 +7,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.util import (sampler_output_to_torch, - mock_device_tensors) + maybe_mock_device_tensors) from vllm.worker.worker import Worker @@ -346,7 +346,7 @@ def _merge_outputs( sampler_output = maybe_sampler_output for step_output in sampler_output: - mock_device_tensors( + maybe_mock_device_tensors( sampler_output=step_output, batch_size=len(proposal_lens), vocab_size=self._vocab_size, @@ -364,6 +364,7 @@ def _merge_outputs( fill_value=-1, dtype=torch.long, device=self._device) + entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens entire_proposal_probs = torch.zeros(batch_size, *proposal_probs.shape[1:], diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index b9824937a944..c221f0421f53 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -196,8 +196,7 @@ def _run_no_spec( blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ) - assert len( - sampler_output) == 1, "expected single output from scorer worker" + assert len(sampler_output) == 1 sampler_output = sampler_output[0] # Clear device tensors from sampler output. This reduces communication diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 7129f47d65f6..c47d5b878153 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -82,19 +82,22 @@ def sampler_output_to_torch( return sampled_token_ids, sampled_token_probs -def mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, +def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, vocab_size: int, device: str) -> None: - assert sampler_output.sampled_token_probs is None - assert sampler_output.sampled_token_ids is None + values = [sampler_output.sampled_token_probs, sampler_output.sampled_token_ids] + assert all(v is None for v in values) or not any(v is None for v in values) + if not any(v is None for v in values): + return sampler_output.sampled_token_probs = torch.nn.functional.softmax( torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device), dim=-1) + sampler_output.sampled_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, ), - dtype=torch.long, - device=device) + high=vocab_size, + size=(batch_size, ), + dtype=torch.long, + device=device) @contextmanager From 49865fba9be8aeb19735b3b08ec9a830bf9caee7 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 19:05:55 -0700 Subject: [PATCH 033/120] lint --- vllm/engine/async_llm_engine.py | 4 +++- vllm/spec_decode/multi_step_worker.py | 2 +- vllm/spec_decode/util.py | 16 +++++++++------- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 378484510247..4bab116dcb14 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -217,7 +217,9 @@ async def step_async(self) -> List[RequestOutput]: else: output = [] - return self._process_model_outputs(output, scheduler_outputs.scheduled_seq_groups, scheduler_outputs.ignored_seq_groups) + return self._process_model_outputs( + output, scheduler_outputs.scheduled_seq_groups, + scheduler_outputs.ignored_seq_groups) async def encode_request_async( self, diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 4182b8758465..c79d79930a18 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -364,7 +364,7 @@ def _merge_outputs( fill_value=-1, dtype=torch.long, device=self._device) - + entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens entire_proposal_probs = torch.zeros(batch_size, *proposal_probs.shape[1:], diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index c47d5b878153..efc54c4de4cf 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -83,8 +83,10 @@ def sampler_output_to_torch( def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, - vocab_size: int, device: str) -> None: - values = [sampler_output.sampled_token_probs, sampler_output.sampled_token_ids] + vocab_size: int, device: str) -> None: + values = [ + sampler_output.sampled_token_probs, sampler_output.sampled_token_ids + ] assert all(v is None for v in values) or not any(v is None for v in values) if not any(v is None for v in values): return @@ -92,12 +94,12 @@ def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, sampler_output.sampled_token_probs = torch.nn.functional.softmax( torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device), dim=-1) - + sampler_output.sampled_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, ), - dtype=torch.long, - device=device) + high=vocab_size, + size=(batch_size, ), + dtype=torch.long, + device=device) @contextmanager From 1a17ed14a57c13def30b6d7e99236ffa92cdfb61 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 19:15:03 -0700 Subject: [PATCH 034/120] clean up gpu executor --- vllm/executor/gpu_executor.py | 70 +++++++++++--------------- vllm/spec_decode/spec_decode_worker.py | 9 ++++ 2 files changed, 37 insertions(+), 42 deletions(-) diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 90a534dc1271..18be6da10ce9 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -47,18 +47,37 @@ def _init_worker(self): else: self._init_spec_worker() + def _init_non_spec_worker(self): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from vllm.worker.worker import Worker + + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + self.driver_worker = Worker( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + is_driver_worker=True, + ) + self.driver_worker.init_device() + self.driver_worker.load_model() + def _init_spec_worker(self): from vllm.worker.worker import Worker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker - #from vllm.worker.multi_step_worker import MultiStepWorker # pylint: disable=import-outside-toplevel - #from vllm.worker.single_tp_worker import SingleTpWorker # pylint: disable=import-outside-toplevel - #from vllm.worker.draft_target_worker import DraftTargetWorker # pylint: disable=import-outside-toplevel - - #scheduler_config: "SchedulerConfig" = worker_kwargs.pop( - # "scheduler_config") - distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) @@ -76,7 +95,6 @@ def _init_spec_worker(self): is_driver_worker=True, ) - from vllm.spec_decode.multi_step_worker import MultiStepWorker draft_worker = MultiStepWorker( model_config=self.speculative_config.draft_model_config, parallel_config=self.speculative_config.draft_parallel_config, @@ -91,47 +109,15 @@ def _init_spec_worker(self): is_driver_worker=True, ) - from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker - from vllm.model_executor.layers.rejection_sampler import RejectionSampler - spec_decode_worker = SpecDecodeWorker( - proposer_worker=draft_worker, - scorer_worker=target_worker, - rejection_sampler=RejectionSampler(strict_mode=True), - ) + spec_decode_worker = SpecDecodeWorker.from_workers(proposer_worker=draft_worker, scorer_worker=target_worker) assert self.parallel_config.world_size == 1, ( "GPUExecutor only supports single GPU.") self.driver_worker = spec_decode_worker + # Load model handled in spec decode worker. self.driver_worker.init_device() - #self.driver_worker.load_model() - - def _init_non_spec_worker(self): - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker - - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - self.driver_worker = Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=True, - ) - self.driver_worker.init_device() - self.driver_worker.load_model() def determine_num_available_blocks(self) -> tuple[int, int]: """Determine the number of available KV blocks by invoking the diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index c221f0421f53..91bc530084e7 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -48,6 +48,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit. """ + @classmethod + def from_workers(cls, proposer_worker: MultiStepWorker, scorer_worker: WorkerBase) -> "SpecDecodeWorker": + return SpecDecodeWorker( + proposer_worker, + scorer_worker, + # TODO(cade) disable strict mode for speedup. + rejection_sampler=RejectionSampler(strict_mode=True), + ) + def __init__( self, proposer_worker: MultiStepWorker, From dea67bbd6fb1f0278ee4c605d8be77991c8657ae Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 19:16:16 -0700 Subject: [PATCH 035/120] wip --- vllm/spec_decode/batch_expansion.py | 4 ++-- vllm/spec_decode/spec_decode_worker.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 1011dd970ebc..4dc34f1ab7c7 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -10,7 +10,7 @@ sampler_output_to_torch, split_batch_by_proposal_len, maybe_mock_device_tensors) -from vllm.worker.worker import Worker +from vllm.worker.worker_base import WorkerBase SeqId = int TargetSeqId = int @@ -32,7 +32,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): of topk/tree. """ - def __init__(self, scorer_worker: Worker, device: str, vocab_size: int): + def __init__(self, scorer_worker: WorkerBase, device: str, vocab_size: int): self._scorer_worker = scorer_worker self._device = device self._vocab_size = vocab_size diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 91bc530084e7..e5b493c46c6c 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -14,7 +14,7 @@ from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, split_batch_by_proposal_len) from vllm.worker.worker import Worker -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase from vllm.logger import init_logger logger = init_logger(__name__) @@ -60,7 +60,7 @@ def from_workers(cls, proposer_worker: MultiStepWorker, scorer_worker: WorkerBas def __init__( self, proposer_worker: MultiStepWorker, - scorer_worker: Worker, + scorer_worker: WorkerBase, rejection_sampler: RejectionSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, ): From 189d7ebab4a783cb651fb339b2fba88fd8b1f019 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 19:17:59 -0700 Subject: [PATCH 036/120] fix --- tests/spec_decode/e2e/test_correctness.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index eb6d1e1c5ddd..6b01936e8178 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -62,8 +62,7 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): # Expect a generation for each prompt in the batch. assert len(batch_token_ids) == len(prompts) - # TODO(cadedaniel) check for equality once block truncation is implemented. - assert all(len(token_ids) >= output_len for token_ids in batch_token_ids) + assert all(len(token_ids) == output_len for token_ids in batch_token_ids) @pytest.mark.parametrize( From a70a0408b12631ca00a78e7cbbcf1db7ef211f33 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 19:18:47 -0700 Subject: [PATCH 037/120] wip --- vllm/executor/gpu_executor.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 18be6da10ce9..22cd2797282e 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -35,9 +35,6 @@ def __init__( self.vision_language_config = vision_language_config self.speculative_config = speculative_config - #assert (not speculative_config - # ), "Speculative decoding not yet supported for GPU backend" - # Instantiate the worker and load the model to GPU. self._init_worker() From 3e1b8f5c17e8ac0a96a1ddc05300b4eeb1996e66 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 19:42:50 -0700 Subject: [PATCH 038/120] detokenization --- tests/spec_decode/e2e/test_correctness.py | 20 ++++++++++++++++---- vllm/engine/output_processor/block_decode.py | 2 +- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index 6b01936e8178..d2f07f729f5a 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -1,13 +1,16 @@ import pytest from itertools import cycle +from typing import Tuple, List from vllm import SamplingParams +from transformers import AutoTokenizer @pytest.mark.parametrize( "common_llm_kwargs", [{ # Use a small model for a fast test. + # Note this is repeated in the test body; to initialize a tokenizer. "model": "JackFram/llama-68m", # Skip real loading for fast test. @@ -55,15 +58,23 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): temperature=temperature, ) - batch_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + batch_tokens, batch_token_ids = get_output_from_llm_generator(test_llm_generator, prompts, sampling_params) # Expect a generation for each prompt in the batch. assert len(batch_token_ids) == len(prompts) + # Expect each generation to have expected number of tokens (note + # ignore_eos=True). assert all(len(token_ids) == output_len for token_ids in batch_token_ids) + # Expect detokenized string to match. + tok = AutoTokenizer.from_pretrained("JackFram/llama-68m") + for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids): + expected_tokens = tok.decode(actual_token_ids) + assert actual_tokens == expected_tokens + @pytest.mark.parametrize( "common_llm_kwargs", @@ -109,14 +120,15 @@ def test_spec_decode_xfail(test_llm_generator): with pytest.raises(AssertionError, match="Speculative decoding not yet supported for "): - get_token_ids_from_llm_generator(test_llm_generator, prompts, + get_output_from_llm_generator(test_llm_generator, prompts, sampling_params) -def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): +def get_output_from_llm_generator(llm_generator, prompts, sampling_params) -> Tuple[List[str], List[List[int]]]: for llm in llm_generator: outputs = llm.generate(prompts, sampling_params, use_tqdm=True) token_ids = [output.outputs[0].token_ids for output in outputs] + tokens = [output.outputs[0].text for output in outputs] del llm - return token_ids + return tokens, token_ids diff --git a/vllm/engine/output_processor/block_decode.py b/vllm/engine/output_processor/block_decode.py index 3b6a60e857fa..99963111e219 100644 --- a/vllm/engine/output_processor/block_decode.py +++ b/vllm/engine/output_processor/block_decode.py @@ -98,8 +98,8 @@ def _process_seq_outputs(self, seq: Sequence, # TODO emit logprobs in block decoding. logprobs={output_token_id: Logprob(0.0)}, ) + self.detokenizer.decode_sequence_inplace(seq, sampling_params) - # TODO detokenize self.stop_checker.maybe_stop_sequence(seq, sampling_params, new_token_ids=output_token_ids) From b9777a6ea80e4d0340e406dfe0748a32d5d34138 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 19:48:20 -0700 Subject: [PATCH 039/120] lint --- tests/core/utils.py | 2 +- .../output_processor/test_block_decode.py | 16 +++++---- tests/spec_decode/e2e/test_correctness.py | 18 +++++----- tests/spec_decode/test_spec_decode_worker.py | 2 +- tests/spec_decode/utils.py | 2 +- vllm/engine/llm_engine.py | 33 ++++++++++--------- vllm/engine/output_processor/beam_search.py | 31 ++++------------- vllm/engine/output_processor/block_decode.py | 32 ++++-------------- vllm/engine/output_processor/interfaces.py | 11 +++++-- vllm/engine/output_processor/stop_checker.py | 27 ++------------- vllm/engine/output_processor/util.py | 3 +- vllm/executor/gpu_executor.py | 7 ++-- vllm/spec_decode/batch_expansion.py | 10 +++--- vllm/spec_decode/multi_step_worker.py | 4 +-- vllm/spec_decode/spec_decode_worker.py | 12 +++---- vllm/worker/neuron_worker.py | 2 +- 16 files changed, 81 insertions(+), 131 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index 39f8e507d0f1..22c1d3826dff 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -1,5 +1,5 @@ import time -from typing import Optional, Tuple, Iterable +from typing import Iterable, Optional, Tuple from vllm import SamplingParams from vllm.lora.request import LoRARequest diff --git a/tests/engine/output_processor/test_block_decode.py b/tests/engine/output_processor/test_block_decode.py index f426f1d32d7a..87f451da7c29 100644 --- a/tests/engine/output_processor/test_block_decode.py +++ b/tests/engine/output_processor/test_block_decode.py @@ -1,17 +1,19 @@ -import pytest -from unittest.mock import MagicMock import random +from unittest.mock import MagicMock +import pytest from transformers import PreTrainedTokenizer -from vllm.engine.output_processor.block_decode import BlockDecodeOutputProcessor +from tests.core.utils import create_seq_group +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.block_decode import ( + BlockDecodeOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sampling_params import SamplingParams +from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput, + SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.core.scheduler import Scheduler from vllm.utils import Counter -from vllm.sequence import SequenceStatus, SequenceGroupOutput, SequenceOutput, Logprob -from vllm.sampling_params import SamplingParams -from tests.core.utils import create_seq_group @pytest.mark.parametrize("seq_output_len", [128]) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index d2f07f729f5a..fe543dfda552 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -1,10 +1,11 @@ -import pytest from itertools import cycle -from typing import Tuple, List +from typing import List, Tuple -from vllm import SamplingParams +import pytest from transformers import AutoTokenizer +from vllm import SamplingParams + @pytest.mark.parametrize( "common_llm_kwargs", @@ -58,9 +59,8 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): temperature=temperature, ) - batch_tokens, batch_token_ids = get_output_from_llm_generator(test_llm_generator, - prompts, - sampling_params) + batch_tokens, batch_token_ids = get_output_from_llm_generator( + test_llm_generator, prompts, sampling_params) # Expect a generation for each prompt in the batch. assert len(batch_token_ids) == len(prompts) @@ -121,10 +121,12 @@ def test_spec_decode_xfail(test_llm_generator): with pytest.raises(AssertionError, match="Speculative decoding not yet supported for "): get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) + sampling_params) -def get_output_from_llm_generator(llm_generator, prompts, sampling_params) -> Tuple[List[str], List[List[int]]]: +def get_output_from_llm_generator( + llm_generator, prompts, + sampling_params) -> Tuple[List[str], List[List[int]]]: for llm in llm_generator: outputs = llm.generate(prompts, sampling_params, use_tqdm=True) token_ids = [output.outputs[0].token_ids for output in outputs] diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 889712fb9360..4470cee78eed 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -6,13 +6,13 @@ from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.utils import set_random_seed +from vllm.sequence import SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.metrics import (AsyncMetricsCollector, SpecDecodeWorkerMetrics) from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, split_num_cache_blocks_evenly) -from vllm.sequence import SamplerOutput from .utils import (ExecuteModelData, create_batch, create_sampler_output_list, mock_worker) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 3914af945eff..c428c4258c14 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -10,7 +10,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (Logprob, SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceGroupOutput, - SequenceOutput, Logprob) + SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine from vllm.worker.worker import Worker diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9936eb18c032..8c3786354f40 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,5 +1,5 @@ import time -from typing import Iterable, List, Optional, Tuple, Type, Union +from typing import Iterable, List, Optional, Type, Union from transformers import PreTrainedTokenizer @@ -10,6 +10,10 @@ from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.ray_utils import initialize_ray_cluster from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger @@ -17,17 +21,13 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupOutput, SequenceOutput, - SequenceStatus) + SequenceGroup) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter -from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor -from vllm.engine.output_processor.stop_checker import StopChecker -from vllm.engine.output_processor.util import create_output_by_sequence_group logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -183,18 +183,19 @@ def __init__( labels=dict(model_name=model_config.model)) self.stat_logger.info("cache_config", self.cache_config) - self.output_processor = SequenceGroupOutputProcessor.create_output_processor( - self.scheduler_config, - self.detokenizer, - self.scheduler, - self.seq_counter, - self.get_tokenizer_for_seq, - stop_checker=StopChecker( - self.scheduler, + self.output_processor = ( + SequenceGroupOutputProcessor.create_output_processor( self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, self.get_tokenizer_for_seq, - ), - ) + stop_checker=StopChecker( + self.scheduler, + self.scheduler_config, + self.get_tokenizer_for_seq, + ), + )) def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). diff --git a/vllm/engine/output_processor/beam_search.py b/vllm/engine/output_processor/beam_search.py index 94af809e2673..885a241f7b2d 100644 --- a/vllm/engine/output_processor/beam_search.py +++ b/vllm/engine/output_processor/beam_search.py @@ -1,31 +1,12 @@ -import time -from typing import Iterable, List, Optional, Tuple, Type, Union +from typing import List, Tuple, Union -from transformers import PreTrainedTokenizer - -import vllm -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) -from vllm.core.scheduler import Scheduler, SchedulerOutputs -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics import StatLogger, Stats -from vllm.engine.ray_utils import initialize_ray_cluster -from vllm.executor.executor_base import ExecutorBase +from vllm.config import SchedulerConfig +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupOutput, SequenceOutput, - SequenceStatus) -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, - get_tokenizer_group) -from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, - usage_message) -from vllm.utils import Counter -from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor +from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, + SequenceOutput, SequenceStatus) logger = init_logger(__name__) diff --git a/vllm/engine/output_processor/block_decode.py b/vllm/engine/output_processor/block_decode.py index 99963111e219..f63ce7d0ef41 100644 --- a/vllm/engine/output_processor/block_decode.py +++ b/vllm/engine/output_processor/block_decode.py @@ -1,31 +1,11 @@ -import time -from typing import Iterable, List, Optional, Tuple, Type, Union - -from transformers import PreTrainedTokenizer - -import vllm -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) -from vllm.core.scheduler import Scheduler, SchedulerOutputs -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics import StatLogger, Stats -from vllm.engine.ray_utils import initialize_ray_cluster -from vllm.executor.executor_base import ExecutorBase +from typing import List + +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupOutput, SequenceOutput, - SequenceStatus, Logprob) -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, - get_tokenizer_group) -from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, - usage_message) -from vllm.utils import Counter -from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor +from vllm.sequence import (Logprob, Sequence, SequenceGroup, + SequenceGroupOutput, SequenceOutput, SequenceStatus) logger = init_logger(__name__) diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 2b931a0b2f41..5596bc3f3d67 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod +from typing import List + from vllm.config import SchedulerConfig from vllm.sequence import SequenceGroup, SequenceGroupOutput -from typing import List class SequenceGroupOutputProcessor(ABC): @@ -16,7 +17,9 @@ def create_output_processor( stop_checker, ): if scheduler_config.num_lookahead_slots == 0: - from vllm.engine.output_processor.beam_search import BeamSearchOutputProcessor + # Importing here to avoid cycle. + from vllm.engine.output_processor.beam_search import ( + BeamSearchOutputProcessor) return BeamSearchOutputProcessor( scheduler_config, detokenizer, @@ -25,7 +28,9 @@ def create_output_processor( stop_checker, ) else: - from vllm.engine.output_processor.block_decode import BlockDecodeOutputProcessor + # Importing here to avoid cycle. + from vllm.engine.output_processor.block_decode import ( + BlockDecodeOutputProcessor) return BlockDecodeOutputProcessor( detokenizer, scheduler, diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 3f03373f2698..b55e47ab3c12 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -1,31 +1,8 @@ -import time -from typing import Iterable, List, Optional, Tuple, Type, Union +from typing import List -from transformers import PreTrainedTokenizer - -import vllm -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) -from vllm.core.scheduler import Scheduler, SchedulerOutputs -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics import StatLogger, Stats -from vllm.engine.ray_utils import initialize_ray_cluster -from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupOutput, SequenceOutput, - SequenceStatus) -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, - get_tokenizer_group) -from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, - usage_message) -from vllm.utils import Counter -from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor +from vllm.sequence import Sequence, SequenceStatus logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index b49bbb2fab32..e4939b9be445 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -1,6 +1,7 @@ -from vllm.sequence import SequenceGroupOutput, SamplerOutput from typing import List +from vllm.sequence import SamplerOutput + def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput], num_seq_groups: int): diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 22cd2797282e..b5e64843213a 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -71,9 +71,9 @@ def _init_non_spec_worker(self): self.driver_worker.load_model() def _init_spec_worker(self): - from vllm.worker.worker import Worker - from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker + from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker + from vllm.worker.worker import Worker distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) @@ -106,7 +106,8 @@ def _init_spec_worker(self): is_driver_worker=True, ) - spec_decode_worker = SpecDecodeWorker.from_workers(proposer_worker=draft_worker, scorer_worker=target_worker) + spec_decode_worker = SpecDecodeWorker.from_workers( + proposer_worker=draft_worker, scorer_worker=target_worker) assert self.parallel_config.world_size == 1, ( "GPUExecutor only supports single GPU.") diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 4dc34f1ab7c7..6945877fbf34 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -6,10 +6,9 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) -from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, - sampler_output_to_torch, - split_batch_by_proposal_len, - maybe_mock_device_tensors) +from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors, + nvtx_range, sampler_output_to_torch, + split_batch_by_proposal_len) from vllm.worker.worker_base import WorkerBase SeqId = int @@ -32,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): of topk/tree. """ - def __init__(self, scorer_worker: WorkerBase, device: str, vocab_size: int): + def __init__(self, scorer_worker: WorkerBase, device: str, + vocab_size: int): self._scorer_worker = scorer_worker self._device = device self._vocab_size = vocab_size diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index c79d79930a18..6fdc3b294295 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -6,8 +6,8 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) -from vllm.spec_decode.util import (sampler_output_to_torch, - maybe_mock_device_tensors) +from vllm.spec_decode.util import (maybe_mock_device_tensors, + sampler_output_to_torch) from vllm.worker.worker import Worker diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index e5b493c46c6c..84aa562eba50 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -3,9 +3,10 @@ import torch +from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, - SequenceGroupOutput, SequenceOutput, Logprob) +from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata, + SequenceGroupOutput, SequenceOutput) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) @@ -13,9 +14,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, split_batch_by_proposal_len) -from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase -from vllm.logger import init_logger logger = init_logger(__name__) @@ -49,7 +48,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): """ @classmethod - def from_workers(cls, proposer_worker: MultiStepWorker, scorer_worker: WorkerBase) -> "SpecDecodeWorker": + def from_workers(cls, proposer_worker: MultiStepWorker, + scorer_worker: WorkerBase) -> "SpecDecodeWorker": return SpecDecodeWorker( proposer_worker, scorer_worker, @@ -238,7 +238,7 @@ def _run_speculative_decoding_step( seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, k) - logger.info(f"score proposals") + logger.info("score proposals") proposal_scores = self.scorer.score_proposals( seq_group_metadata_list, blocks_to_swap_in, diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index d0f01b893bc6..7472a795fb51 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -1,5 +1,5 @@ """A Neuron worker class.""" -from typing import List, Optional +from typing import List import torch import torch.distributed From 29b4f12dc07a1c4d5238d9e5cc6fe9211d57b4d9 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 20:21:51 -0700 Subject: [PATCH 040/120] docstrings --- .../output_processor/test_block_decode.py | 17 ++++++++- tests/spec_decode/e2e/test_correctness.py | 7 +++- tests/spec_decode/test_multi_step_worker.py | 1 - tests/spec_decode/test_spec_decode_worker.py | 1 - vllm/core/scheduler.py | 6 --- vllm/engine/llm_engine.py | 9 ++++- vllm/engine/output_processor/beam_search.py | 28 +++++++++++--- vllm/engine/output_processor/block_decode.py | 36 +++++++++++++++--- vllm/engine/output_processor/interfaces.py | 37 +++++++++++++++---- vllm/engine/output_processor/stop_checker.py | 14 ++++--- vllm/engine/output_processor/util.py | 3 ++ vllm/executor/gpu_executor.py | 2 + vllm/model_executor/layers/sampler.py | 4 -- vllm/spec_decode/batch_expansion.py | 3 +- vllm/spec_decode/multi_step_worker.py | 3 +- vllm/spec_decode/spec_decode_worker.py | 3 +- vllm/spec_decode/util.py | 6 +++ 17 files changed, 137 insertions(+), 43 deletions(-) diff --git a/tests/engine/output_processor/test_block_decode.py b/tests/engine/output_processor/test_block_decode.py index 87f451da7c29..c4a88d67cabc 100644 --- a/tests/engine/output_processor/test_block_decode.py +++ b/tests/engine/output_processor/test_block_decode.py @@ -20,6 +20,11 @@ @pytest.mark.parametrize("num_new_tokens", [1, 12]) @pytest.mark.skip_global_cleanup def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): + """Verify block decoding appends token ids correctly. + + We append token ids and verify all the token ids were appended correctly. + Note that ignore_eos=True. + """ detokenizer = MagicMock(spec=Detokenizer) scheduler = MagicMock(spec=Scheduler) stop_checker = MagicMock(spec=StopChecker) @@ -37,7 +42,8 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): seq_prompt_len=1024, seq_output_lens=[seq_output_len], sampling_params=SamplingParams(max_tokens=seq_output_len + - num_new_tokens, ), + num_new_tokens, + ignore_eos=True), ) seq = seq_group.get_seqs()[0] @@ -70,6 +76,9 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): @pytest.mark.skip_global_cleanup def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, seq_output_len: int, max_tokens: int): + """Verify tokens after max_tokens are dropped and not appended to the + sequence. + """ detokenizer = MagicMock(spec=Detokenizer) scheduler = MagicMock(spec=Scheduler) stop_checker = MagicMock(spec=StopChecker) @@ -126,6 +135,9 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, @pytest.mark.skip_global_cleanup def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq_output_len: int, seed: int): + """Verify the eos token id is included in the sequence, but subsequent + tokens are dropped (not appended to sequence). + """ random.seed(seed) detokenizer = MagicMock(spec=Detokenizer) scheduler = MagicMock(spec=Scheduler) @@ -190,6 +202,9 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, @pytest.mark.skip_global_cleanup def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, seq_output_len: int, seed: int): + """When sampling parameters dictate that we should ignore the eos token id, + ensure all token ids are appended even if the eos token id is emitted. + """ random.seed(seed) detokenizer = MagicMock(spec=Detokenizer) scheduler = MagicMock(spec=Scheduler) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index fe543dfda552..160510e6c0c0 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -38,8 +38,9 @@ @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("seed", [1]) def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): - """Run generation with speculative decoding on a batch. Verify the number - of output tokens is equal to the expected number. + """Run generation with speculative decoding on a batch. Verify the engine + generates the correct number of tokens (via ignore_eos=True), and that the + detokenization matches HF transformers. """ output_len = 128 temperature = 0.0 @@ -105,6 +106,8 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("seed", [1]) def test_spec_decode_xfail(test_llm_generator): + """Verify that speculative decoding with Ray fails. + """ output_len = 128 temperature = 0.0 diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index f9840d6157c3..d6edbab579af 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -352,7 +352,6 @@ def test_draft_proposals_no_speculations(): @torch.inference_mode() -#@pytest.skip("Broken because output is padded.") def test_draft_proposals_mixed_k(): """Verify DraftModelTop1Proposer correctly handles case some sequences can speculate and some can't. diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 4470cee78eed..0a3110775e2d 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -341,7 +341,6 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) - draft_worker.device = 'cuda' target_worker.device = 'cuda' diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index e176848c0490..db48a1f7f0d2 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -754,9 +754,6 @@ def _schedule_default(self) -> SchedulerOutputs: swapped_in.blocks_to_copy), ignored_seq_groups=prefills.ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, - #num_lookahead_slots=(prefills.num_lookahead_slots + - # running_scheduled.num_lookahead_slots + - # swapped_in.num_lookahead_slots), ) def _schedule_chunked_prefill(self): @@ -844,9 +841,6 @@ def _schedule_chunked_prefill(self): swapped_in.blocks_to_copy), ignored_seq_groups=prefills.ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, - #num_lookahead_slots=(prefills.num_lookahead_slots + - # running_scheduled.num_lookahead_slots + - # swapped_in.num_lookahead_slots), ) def _schedule(self) -> SchedulerOutputs: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8c3786354f40..e6e75ee59c76 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -183,6 +183,8 @@ def __init__( labels=dict(model_name=model_config.model)) self.stat_logger.info("cache_config", self.cache_config) + # Create sequence output processor, e.g. for beam search or + # speculative decoding. self.output_processor = ( SequenceGroupOutputProcessor.create_output_processor( self.scheduler_config, @@ -426,9 +428,15 @@ def _process_model_outputs( self, output: List[SamplerOutput], scheduled_seq_groups: List[SequenceGroup], ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]: + """Apply the model output to the sequences in the scheduled seq groups. + + Returns RequestOutputs that can be returned to the client. + """ now = time.time() + # Organize outputs by [sequence group][step] instead of + # [step][sequence group]. output_by_sequence_group = create_output_by_sequence_group( sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) @@ -438,7 +446,6 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - self.output_processor.process_outputs(seq_group, outputs) # Free the finished sequence groups. diff --git a/vllm/engine/output_processor/beam_search.py b/vllm/engine/output_processor/beam_search.py index 885a241f7b2d..330eeced21cf 100644 --- a/vllm/engine/output_processor/beam_search.py +++ b/vllm/engine/output_processor/beam_search.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Iterable from vllm.config import SchedulerConfig from vllm.engine.output_processor.interfaces import ( @@ -7,19 +7,31 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.stop_checker import StopChecker logger = init_logger(__name__) class BeamSearchOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles logic related to beam search + sequence management and coupled logic like detokenization and stop logic. + + This class is in charge of sorting out which sequences survive after beam + sampling. It manages forking and freeing of sequences. + + It does not support lookahead decoding, e.g. where the model generates >1 + token per scheduling invocation. + """ def __init__( self, scheduler_config: SchedulerConfig, - detokenizer, - scheduler, - seq_counter, - stop_checker, + detokenizer: Detokenizer, + scheduler: Scheduler, + seq_counter: Iterable[int], + stop_checker: StopChecker, ): self.scheduler_config = scheduler_config self.detokenizer = detokenizer @@ -29,6 +41,12 @@ def __init__( def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: + """Append all new tokens to sequences in the sequence group. Fork any + surviving beam candidates; free any unsurviving ones. + + Invokes detokenizer to detokenize new tokens, and also marks sequences + as finished if they meet stop conditions. + """ assert (len(outputs) == 1 ), f"{type(self)} does not support multiple outputs per step" return self._process_sequence_group_outputs(sequence_group, outputs[0]) diff --git a/vllm/engine/output_processor/block_decode.py b/vllm/engine/output_processor/block_decode.py index f63ce7d0ef41..8c9b3e25598f 100644 --- a/vllm/engine/output_processor/block_decode.py +++ b/vllm/engine/output_processor/block_decode.py @@ -1,24 +1,39 @@ -from typing import List +from typing import List, Iterable, Callable from vllm.engine.output_processor.interfaces import ( SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.sequence import (Logprob, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.core.scheduler import Scheduler +from vllm.transformers_utils.detokenizer import Detokenizer +from transformers import PreTrainedTokenizer logger = init_logger(__name__) class BlockDecodeOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles logic related to + detokenization and stopping conditions. Besides not supporting beam search, + this differs from BeamSearchOutputProcessor in that it supports lookahead + scheduling (where the model may generate >1 token per scheduler invocation). + + This allows it to support speculative decoding and cases where the model + runs more than once. We generalize these cases as "block decoding", where + the model emits a block of tokens at the same time. In this case, this class + is responsible for correctly appending all token ids to sequences and + detokenizing new token ids. + """ def __init__( self, - detokenizer, - scheduler, - seq_counter, - get_tokenizer_for_seq, - stop_checker, + detokenizer: Detokenizer, + scheduler: Scheduler, + seq_counter: Iterable[int], + get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], + stop_checker: StopChecker, ): self.detokenizer = detokenizer self.scheduler = scheduler @@ -28,6 +43,15 @@ def __init__( def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: + """Append new tokens in the outputs to sequences in the sequence group. + + This only supports sequence groups of size 1. It supports greater than + one new token per sequence. + + This applies logic like stop condition checking and detokenization, + including freeing finished sequences. It also handles cases where there + are tokens emitted after the EOS token. + """ seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) assert seqs, "expected running sequences" diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 5596bc3f3d67..1f940f292406 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -1,21 +1,40 @@ from abc import ABC, abstractmethod -from typing import List +from typing import List, Callable, Iterable from vllm.config import SchedulerConfig -from vllm.sequence import SequenceGroup, SequenceGroupOutput +from vllm.sequence import SequenceGroup, SequenceGroupOutput, Sequence +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.stop_checker import StopChecker class SequenceGroupOutputProcessor(ABC): + """Interface for logic that processes new token ids in sequence groups, + managing detokenization, stop checking, and freeing/forking sequences with + the scheduler. + + This is highly coupled with the LLMEngine and should be seen as an extension + of it. The logic is separated out to simplify the LLMEngine class and to + allow a beam search implementation (which handles forking, etc) and a block + decode implementation (which handles decoding >1 token per step). + """ @staticmethod def create_output_processor( scheduler_config: SchedulerConfig, - detokenizer, - scheduler, - seq_counter, - get_tokenizer_for_seq, - stop_checker, + detokenizer: Detokenizer, + scheduler: Scheduler, + seq_counter: Iterable[int], + get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], + stop_checker: "StopChecker", ): + """Create an output processor. + + This returns an output processor compatible with beam search if the + scheduler is not configured to scheduler lookahead slots. Otherwise, it + returns an output processor that is incompatible with beam search but + which supports decoding more than one token per scheduling invocation. + """ if scheduler_config.num_lookahead_slots == 0: # Importing here to avoid cycle. from vllm.engine.output_processor.beam_search import ( @@ -42,4 +61,8 @@ def create_output_processor( @abstractmethod def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: + """Process new token ids for the sequence group. Handles logic such as + detokenization, stop checking, and freeing/forking sequences in the + scheduler. + """ pass diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index b55e47ab3c12..2a6c79d2dc02 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -1,14 +1,15 @@ from typing import List -from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.sequence import Sequence, SequenceStatus -logger = init_logger(__name__) -_LOCAL_LOGGING_INTERVAL_SEC = 5 - class StopChecker: + """LLMEngine helper class which separates out the logic involving stop + checking. This checks things such as: whether the eos token was emitted, + whether the max_tokens has been consumed, whether a stop string has been + emitted, or if we have exceeded the max model len. + """ def __init__(self, scheduler, scheduler_config, get_tokenizer_for_seq): self.scheduler = scheduler @@ -18,7 +19,9 @@ def __init__(self, scheduler, scheduler_config, get_tokenizer_for_seq): def maybe_stop_sequence(self, seq: Sequence, sampling_params: SamplingParams, new_token_ids: List[int]) -> None: - """Stop the finished sequences.""" + """Check if the sequences should be stopped. If so, mark it as finished. + """ + # Check if the sequence has reached max_model_len. if seq.get_len() > self.scheduler_config.max_model_len: seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED @@ -36,6 +39,7 @@ def maybe_stop_sequence(self, seq: Sequence, if sampling_params.detokenize: for stop_str in sampling_params.stop: + # TODO(cade) Fix this for speculative decoding. if seq.output_text.endswith(stop_str): self._finalize_sequence(seq, sampling_params, stop_str) seq.status = SequenceStatus.FINISHED_STOPPED diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index e4939b9be445..5fbb09a857a4 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -5,6 +5,9 @@ def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput], num_seq_groups: int): + """Helper method which transforms a 2d list organized by + [step][sequence group] into [sequence group][step]. + """ output_by_sequence_group = [[] for _ in range(num_seq_groups)] for step in sampler_outputs: for i, sequence_group_output in enumerate(step): diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index b5e64843213a..9330d754d5d7 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -71,6 +71,8 @@ def _init_non_spec_worker(self): self.driver_worker.load_model() def _init_spec_worker(self): + """Initialize a SpecDecodeWorker, using a draft model for proposals. + """ from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.worker.worker import Worker diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index be970e56b611..cb1480de03e3 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -78,7 +78,6 @@ def forward( # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, sampling_metadata, prompt_logprobs, sample_logprobs) @@ -669,8 +668,6 @@ def _build_sampler_output( sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], sample_logprobs: List[SampleLogprobs], - sampled_token_ids: Optional[torch.Tensor] = None, - sampled_token_probs: Optional[torch.Tensor] = None, ) -> SamplerOutput: sampler_output = [] for (seq_group, sample_result, group_prompt_logprobs, @@ -687,5 +684,4 @@ def _build_sampler_output( SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) sampler_output.append( SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) - return SamplerOutput(outputs=sampler_output) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 6945877fbf34..88af1dd36015 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -84,7 +84,6 @@ def score_proposals( blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - #return_python_output=False ) assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] @@ -147,6 +146,8 @@ def _contract_batch(self, original_bs: int, sequences. """ + # We mock the device tensors until PR 7/9 is merged (e2e correctness). + # https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer maybe_mock_device_tensors( sampler_output=target_sampler_output, batch_size=len(non_spec_indices) + num_scoring_tokens, diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 6fdc3b294295..ce63c329a40a 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -345,6 +345,8 @@ def _merge_outputs( sampler_output = maybe_sampler_output + # We mock the device tensors until PR 7/9 is merged (e2e correctness). + # https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer for step_output in sampler_output: maybe_mock_device_tensors( sampler_output=step_output, @@ -364,7 +366,6 @@ def _merge_outputs( fill_value=-1, dtype=torch.long, device=self._device) - entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens entire_proposal_probs = torch.zeros(batch_size, *proposal_probs.shape[1:], diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 84aa562eba50..be3af7be9386 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -99,7 +99,7 @@ def init_device(self) -> None: self.scorer_worker.init_device() self.proposer_worker.init_device() - # TODO separate from init_device? + # NOTE(cade): load_model is not part of the WorkerBase interface. self.scorer_worker.load_model() self.proposer_worker.load_model() @@ -195,7 +195,6 @@ def _run_no_spec( blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - #return_python_output=False ) logger.info("run target worker no spec") diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index efc54c4de4cf..85aee137dcbc 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -84,13 +84,19 @@ def sampler_output_to_torch( def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, vocab_size: int, device: str) -> None: + """Helper method which mocks out the GPU tensors in SamplerOutput with dummy + values. This will be removed in PR 7/9. + https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer + """ values = [ sampler_output.sampled_token_probs, sampler_output.sampled_token_ids ] assert all(v is None for v in values) or not any(v is None for v in values) if not any(v is None for v in values): + # Do nothing if the tensors are already created (usually in unit tests). return + # Softmax to ensure valid probs. sampler_output.sampled_token_probs = torch.nn.functional.softmax( torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device), dim=-1) From 42aa0bc45900b49ca5ae7878f90e371a123e0e66 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 20:30:23 -0700 Subject: [PATCH 041/120] fix --- vllm/engine/output_processor/beam_search.py | 6 +++--- vllm/engine/output_processor/block_decode.py | 7 ++++--- vllm/engine/output_processor/interfaces.py | 8 +++++--- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/vllm/engine/output_processor/beam_search.py b/vllm/engine/output_processor/beam_search.py index 330eeced21cf..b0c0246b9935 100644 --- a/vllm/engine/output_processor/beam_search.py +++ b/vllm/engine/output_processor/beam_search.py @@ -1,15 +1,15 @@ -from typing import List, Tuple, Union, Iterable +from typing import Iterable, List, Tuple, Union from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.interfaces import ( SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.core.scheduler import Scheduler -from vllm.engine.output_processor.stop_checker import StopChecker logger = init_logger(__name__) diff --git a/vllm/engine/output_processor/block_decode.py b/vllm/engine/output_processor/block_decode.py index 8c9b3e25598f..e309b57af6de 100644 --- a/vllm/engine/output_processor/block_decode.py +++ b/vllm/engine/output_processor/block_decode.py @@ -1,5 +1,8 @@ -from typing import List, Iterable, Callable +from typing import Callable, Iterable, List +from transformers import PreTrainedTokenizer + +from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.interfaces import ( SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker @@ -7,9 +10,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (Logprob, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) -from vllm.core.scheduler import Scheduler from vllm.transformers_utils.detokenizer import Detokenizer -from transformers import PreTrainedTokenizer logger = init_logger(__name__) diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 1f940f292406..26ec982cc13f 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -1,11 +1,13 @@ from abc import ABC, abstractmethod -from typing import List, Callable, Iterable +from typing import Callable, Iterable, List + +from transformers import PreTrainedTokenizer from vllm.config import SchedulerConfig -from vllm.sequence import SequenceGroup, SequenceGroupOutput, Sequence -from vllm.transformers_utils.detokenizer import Detokenizer from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput +from vllm.transformers_utils.detokenizer import Detokenizer class SequenceGroupOutputProcessor(ABC): From 0ebd93b98f1c334aca3f4f4f6b651a7301a4f427 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 20:31:51 -0700 Subject: [PATCH 042/120] more spec test --- tests/spec_decode/e2e/test_correctness.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index 160510e6c0c0..c9665ee5bbc2 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -26,6 +26,10 @@ @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 1, + }, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, From 33a3d7230b1e6f6a699b3863046494ecf5aca365 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 20:37:05 -0700 Subject: [PATCH 043/120] remove --- tests/spec_decode/e2e/test_correctness.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index c9665ee5bbc2..160510e6c0c0 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -26,10 +26,6 @@ @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ - { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 1, - }, { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, From 15c942dfc8a49e294d803a1088bd8776bfd69aa2 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 20:37:29 -0700 Subject: [PATCH 044/120] wip --- tests/spec_decode/e2e/test_correctness.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index 160510e6c0c0..ac79f977ce39 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -26,6 +26,7 @@ @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ + # TODO(cade) handle output { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, From 063e34b370e0dcd8080faa3e397f303f0e4d3795 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 21:24:25 -0700 Subject: [PATCH 045/120] strip --- tests/spec_decode/e2e/test_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index ac79f977ce39..173f96c4de60 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -75,7 +75,7 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): tok = AutoTokenizer.from_pretrained("JackFram/llama-68m") for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids): expected_tokens = tok.decode(actual_token_ids) - assert actual_tokens == expected_tokens + assert actual_tokens.strip() == expected_tokens.strip() @pytest.mark.parametrize( From 672a855bb1ca4a074a9158d79eb99253fe3b2540 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Sun, 7 Apr 2024 22:57:11 -0700 Subject: [PATCH 046/120] print --- tests/spec_decode/e2e/test_correctness.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index 173f96c4de60..d76dbc50c872 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -75,6 +75,7 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): tok = AutoTokenizer.from_pretrained("JackFram/llama-68m") for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids): expected_tokens = tok.decode(actual_token_ids) + print(f"{actual_token_ids=}") assert actual_tokens.strip() == expected_tokens.strip() From b4543c8f6bf67a7f1a0d6d0fd6cf5697c7eeaabb Mon Sep 17 00:00:00 2001 From: ywfang <47963924+SUDA-HLT-ywfang@users.noreply.github.com> Date: Mon, 8 Apr 2024 18:28:36 +0800 Subject: [PATCH 047/120] [Model] add minicpm (#3893) --- README.md | 1 + docs/source/models/supported_models.rst | 4 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/minicpm.py | 537 ++++++++++++++++++++++++ 4 files changed, 543 insertions(+) create mode 100644 vllm/model_executor/models/minicpm.py diff --git a/README.md b/README.md index 2a070b9e2064..d53227b82d87 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.) - Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.) - LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) +- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) - Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 9c2f5ba458eb..e7bfdcb65316 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -83,6 +83,10 @@ Alongside each architecture, we include some popular models that use it. - LLaMA, LLaMA-2, Vicuna, Alpaca, Yi - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. - ✅︎ + * - :code:`MiniCPMForCausalLM` + - MiniCPM + - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. + - * - :code:`MistralForCausalLM` - Mistral, Mistral-Instruct - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 4647947f695a..17fc97056804 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -41,6 +41,7 @@ # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), + "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py new file mode 100644 index 000000000000..99d1b4eb97bb --- /dev/null +++ b/vllm/model_executor/models/minicpm.py @@ -0,0 +1,537 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM model compatible with HuggingFace weights.""" +import math +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import LoRAConfig +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + + +class MiniCPMMoE(nn.Module): + """A tensor-parallel MoE implementation that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + ): + super().__init__() + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // self.tp_size + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.gate = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + linear_method=None) + + self.ws = nn.Parameter( + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.ws, + self.w2s, + router_logits, + self.top_k, + renormalize=True, + inplace=True) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_size) + + +class MiniCPMMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MiniCPMAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + # set rope as fp32 instead of bf16 + self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache( + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + orig_dtype = q.dtype + q, k = q.float(), k.float() + q, k = self.rotary_emb(positions, q, k) + q, k = q.to(orig_dtype), k.to(orig_dtype) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class MiniCPMDecoderLayer(nn.Module): + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = MiniCPMAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + linear_method=linear_method, + ) + self.num_experts = getattr(self.config, "num_experts", 0) + if self.num_experts == 0: + self.mlp = MiniCPMMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + else: + self.mlp = MiniCPMMoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states * \ + (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * \ + (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) + + return hidden_states, None + + +class MiniCPMModel(nn.Module): + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.layers = nn.ModuleList([ + MiniCPMDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + embedding = self.embed_tokens(input_ids) + return embedding * self.config.scale_emb + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + attn_metadata, + residual, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MiniCPMForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.num_experts = getattr(self.config, "num_experts", 0) + self.linear_method = linear_method + self.model = MiniCPMModel(config, + linear_method, + lora_config=lora_config) + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + if not self.config.tie_word_embeddings: + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.scale_width = self.config.hidden_size / self.config.dim_model_base + + self.logits_processor = LogitsProcessor(unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + hidden_states = hidden_states / self.scale_width + if self.config.tie_word_embeddings: + lm_head_weight = self.model.embed_tokens.weight + else: + lm_head_weight = self.lm_head.weight + logits = self.logits_processor(lm_head_weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ("ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.num_experts) + for weight_name in ["w1", "w2", "w3"] + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From f46864d68dfb46ff88f574e6844f10fdb14cd3b5 Mon Sep 17 00:00:00 2001 From: egortolmachev <150433814+egortolmachev@users.noreply.github.com> Date: Mon, 8 Apr 2024 17:59:38 +0300 Subject: [PATCH 048/120] [Bugfix] Added Command-R GPTQ support (#3849) Co-authored-by: Egor Tolmachev --- vllm/model_executor/models/commandr.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 620d63135190..4674dcbc14da 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -349,11 +349,21 @@ def load_weights( if shard_name not in name: continue name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From bc0c0192d13ca6ea4aeea4725f752a89483895bc Mon Sep 17 00:00:00 2001 From: Kiran R Date: Tue, 9 Apr 2024 01:12:35 +0530 Subject: [PATCH 049/120] [Bugfix] Enable Proper `attention_bias` Usage in Llama Model Configuration (#3767) Co-authored-by: roy --- vllm/model_executor/models/llama.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index ef19c41e67ae..72fe21df67d8 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -184,6 +184,10 @@ def __init__( max_position_embeddings = getattr(config, "max_position_embeddings", 8192) sliding_window = getattr(config, "sliding_window", None) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) self.self_attn = LlamaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -193,7 +197,7 @@ def __init__( rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, linear_method=linear_method, - bias=getattr(config, "bias", False), + bias=attention_bias, sliding_window=sliding_window, ) self.mlp = LlamaMLP( From 59a6abf3c99ee4fed5312d357f6ecbf857f24433 Mon Sep 17 00:00:00 2001 From: Matt Wong <156021403+mawong-amd@users.noreply.github.com> Date: Mon, 8 Apr 2024 14:31:02 -0700 Subject: [PATCH 050/120] [Hotfix][CI/Build][Kernel] CUDA 11.8 does not support layernorm optimizations (#3782) --- cmake/utils.cmake | 2 ++ csrc/layernorm_kernels.cu | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 4cb8a69f93de..7c71673e36f2 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -100,6 +100,8 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") + endif() + if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) list(REMOVE_ITEM GPU_FLAGS "-D__CUDA_NO_HALF_OPERATORS__" "-D__CUDA_NO_HALF_CONVERSIONS__" diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index ea30fa274783..e56b4d220400 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -59,6 +59,8 @@ __global__ void rms_norm_kernel( template struct _typeConvert { static constexpr bool exists = false; }; +#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) +// CUDA < 12.0 runs into issues with packed type conversion template<> struct _typeConvert { static constexpr bool exists = true; @@ -85,8 +87,8 @@ struct _typeConvert { __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } }; -#endif - +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) /* Vector POD struct to generate vectorized and packed FP16/BF16 ops for appropriate specializations of fused_add_rms_norm_kernel. From d036198e23345f3c25438f082396f7487028e8b6 Mon Sep 17 00:00:00 2001 From: Roy Date: Tue, 9 Apr 2024 06:17:21 +0800 Subject: [PATCH 051/120] [BugFix][Model] Fix commandr RoPE max_position_embeddings (#3919) --- vllm/model_executor/models/commandr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 4674dcbc14da..29ba3844eb11 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -140,7 +140,9 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.max_position_embeddings = config.max_position_embeddings + self.max_position_embeddings = getattr( + config, "model_max_length", None) or getattr( + config, "max_position_embeddings", 8192) self.rope_theta = config.rope_theta self.rope_scaling = getattr(config, "rope_scaling", None) self.use_qk_norm = getattr(config, "use_qk_norm", False) From 8021b38ab38f85e187c6462fa804f8e55a18f8c2 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Mon, 8 Apr 2024 15:25:49 -0700 Subject: [PATCH 052/120] fix flaky test --- tests/spec_decode/e2e/test_correctness.py | 16 +++++++++++++--- vllm/spec_decode/util.py | 4 ++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index d76dbc50c872..1041a5ddac12 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -26,17 +26,25 @@ @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ - # TODO(cade) handle output { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 1, + }, { # No spec decode. }, ]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("batch_size", [1, 10]) +@pytest.mark.parametrize("batch_size", [1]) +# NOTE: We should run more permutations of this test (more BS, more seeds). But +# because our spec decode generates gibberish token ids, the likelihood of +# emitting an invalid token combination is nontrivial. This causes divergence in +# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf- +# start" bytes are emitted. @pytest.mark.parametrize("seed", [1]) def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): """Run generation with speculative decoding on a batch. Verify the engine @@ -59,6 +67,8 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): max_tokens=output_len, ignore_eos=True, temperature=temperature, + skip_special_tokens=True, + spaces_between_special_tokens=False, ) batch_tokens, batch_token_ids = get_output_from_llm_generator( @@ -76,7 +86,7 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids): expected_tokens = tok.decode(actual_token_ids) print(f"{actual_token_ids=}") - assert actual_tokens.strip() == expected_tokens.strip() + assert actual_tokens == expected_tokens @pytest.mark.parametrize( diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 85aee137dcbc..eb6d4ca1da8e 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -101,8 +101,8 @@ def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device), dim=-1) - sampler_output.sampled_token_ids = torch.randint(low=0, - high=vocab_size, + sampler_output.sampled_token_ids = torch.randint(low=10, + high=100, size=(batch_size, ), dtype=torch.long, device=device) From 8e93fff38628411da884e35290f547f42c6f3d27 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Mon, 8 Apr 2024 15:55:21 -0700 Subject: [PATCH 053/120] reduce output len --- tests/spec_decode/e2e/test_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index 1041a5ddac12..c8b6cf0d7df7 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -51,7 +51,7 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): generates the correct number of tokens (via ignore_eos=True), and that the detokenization matches HF transformers. """ - output_len = 128 + output_len = 32 temperature = 0.0 prompts = [ From d06e9a482125150d7d94ac8095203e86481c4c55 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Mon, 8 Apr 2024 16:44:05 -0700 Subject: [PATCH 054/120] strip --- tests/spec_decode/e2e/test_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index c8b6cf0d7df7..a8ebd66841eb 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -86,7 +86,7 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids): expected_tokens = tok.decode(actual_token_ids) print(f"{actual_token_ids=}") - assert actual_tokens == expected_tokens + assert actual_tokens.strip() == expected_tokens.strip() @pytest.mark.parametrize( From 6d592eb430a37a7f8f5f9beb2dbc014bf3aa76bc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 9 Apr 2024 01:49:02 -0700 Subject: [PATCH 055/120] [Core] separate distributed_init from worker (#3904) --- .../parallel_utils/parallel_state.py | 63 ++++++++++++++++++- vllm/test_utils.py | 13 ++-- vllm/worker/cpu_worker.py | 28 +++------ vllm/worker/worker.py | 39 ++++-------- 4 files changed, 85 insertions(+), 58 deletions(-) diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index bcda5ebf8548..3bbfa1bd5443 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -4,6 +4,7 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" import contextlib +from typing import Optional import torch @@ -14,14 +15,59 @@ # Pipeline model parallel group that the current rank belongs to. _PIPELINE_MODEL_PARALLEL_GROUP = None +# when people blindly call `torch.distributed.all_reduce` etc, +# it will use this group. It is initialized with the `backend` +# parameter of `init_distributed_environment` below. +# Essentially, this is `torch.distributed.group.WORLD`. +# We leave a line here to note that this is device-specific. +# Note that this variable is not safe to use, because when users +# call `init_distributed_environment` first, and then destroy +# the process group themselves, this variable will keep a reference to the +# destroyed process group, which is not useful. +_DEVICE_WORLD_GROUP = None + +# duing `init_distributed_environment`, we will also initialize a +# group with `gloo` backend, to allow direct coordination between +# processes through the CPU. +_CPU_WORLD_GROUP = None + +# In summary, after calling `init_distributed_environment`, we will +# always have two groups: one for device-specific (and is the default) +# and one for CPU. All processes will be part of both groups. + # A list of global ranks for each pipeline group to ease calculation of the # source rank when broadcasting from the first or last pipeline stage. _PIPELINE_GLOBAL_RANKS = None +def init_distributed_environment( + world_size: int, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, + backend: str = "nccl", +): + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment") + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank) + global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP + _DEVICE_WORLD_GROUP = torch.distributed.group.WORLD + ranks = list(range(torch.distributed.get_world_size())) + _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks, + backend="gloo") + + def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, ) -> None: """ Initialize model parallel groups. @@ -48,6 +94,8 @@ def initialize_model_parallel( # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() + # get the backend of _DEVICE_WORLD_GROUP + backend = backend or torch.distributed.get_backend() if (world_size != tensor_model_parallel_size * pipeline_model_parallel_size): @@ -69,7 +117,7 @@ def initialize_model_parallel( for i in range(num_tensor_model_parallel_groups): ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) - group = torch.distributed.new_group(ranks) + group = torch.distributed.new_group(ranks, backend=backend) if rank in ranks: _TENSOR_MODEL_PARALLEL_GROUP = group @@ -80,7 +128,7 @@ def initialize_model_parallel( "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): ranks = range(i, world_size, num_pipeline_model_parallel_groups) - group = torch.distributed.new_group(ranks) + group = torch.distributed.new_group(ranks, backend=backend) if rank in ranks: _PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_GLOBAL_RANKS = ranks @@ -89,14 +137,17 @@ def initialize_model_parallel( def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, + backend: Optional[str] = None, ) -> None: """Helper to initialize model parallel groups if they are not initialized, or ensure tensor-parallel and pipeline-parallel sizes are equal to expected values if the model parallel groups are initialized. """ + # get the backend of _DEVICE_WORLD_GROUP + backend = backend or torch.distributed.get_backend() if not model_parallel_is_initialized(): initialize_model_parallel(tensor_model_parallel_size, - pipeline_model_parallel_size) + pipeline_model_parallel_size, backend) return assert ( @@ -117,6 +168,12 @@ def model_parallel_is_initialized(): and _PIPELINE_MODEL_PARALLEL_GROUP is not None) +def get_cpu_world_group(): + """Get the CPU world group.""" + assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized") + return _CPU_WORLD_GROUP + + def get_tensor_model_parallel_group(): """Get the tensor model parallel group the caller rank belongs to.""" assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ( diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 94e962e12e87..bc220d3b8a43 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -1,8 +1,8 @@ import ray -from vllm.config import ParallelConfig +from vllm.model_executor.parallel_utils.parallel_state import ( + ensure_model_parallel_initialized, init_distributed_environment) from vllm.utils import get_open_port -from vllm.worker.worker import init_distributed_environment def init_test_distributed_environment( @@ -12,15 +12,14 @@ def init_test_distributed_environment( distributed_init_port: str, local_rank: int = -1, ) -> None: - parallel_config = ParallelConfig(pipeline_parallel_size, - tensor_parallel_size, - worker_use_ray=True) distributed_init_method = f"tcp://localhost:{distributed_init_port}" init_distributed_environment( - parallel_config, - rank, + world_size=pipeline_parallel_size * tensor_parallel_size, + rank=rank, distributed_init_method=distributed_init_method, local_rank=local_rank) + ensure_model_parallel_initialized(tensor_parallel_size, + pipeline_parallel_size) def multi_process_tensor_parallel( diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 262ed9abd36b..e1daa64346a9 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized) + ensure_model_parallel_initialized, init_distributed_environment) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.model_runner import ModelRunner @@ -251,26 +251,12 @@ def init_distributed_environment(self) -> None: parallel_config = self.parallel_config rank = self.rank distributed_init_method = self.distributed_init_method - - if torch.distributed.is_initialized(): - torch_world_size = torch.distributed.get_world_size() - if torch_world_size != parallel_config.world_size: - raise RuntimeError( - "torch.distributed is already initialized but the torch " - "world size does not match parallel_config.world_size " - f"({torch_world_size} vs. {parallel_config.world_size}).") - elif not distributed_init_method: - raise ValueError( - "distributed_init_method must be set if torch.distributed " - "is not already initialized") - else: - backend = "gloo" - torch.distributed.init_process_group( - backend=backend, - world_size=parallel_config.world_size, - rank=rank, - init_method=distributed_init_method, - ) + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cpu()) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 48facb57de19..bf0c6073ea9a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -15,7 +15,7 @@ broadcast_tensor_dict) from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized) + ensure_model_parallel_initialized, init_distributed_environment) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner @@ -97,9 +97,9 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method, - self.local_rank) + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) # Set random seed. set_random_seed(self.model_config.seed) @@ -248,31 +248,15 @@ def get_cache_block_size_bytes(self, block_size: int, self.parallel_config) -def init_distributed_environment( +def init_worker_distributed_environment( parallel_config: ParallelConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" - if torch.distributed.is_initialized(): - torch_world_size = torch.distributed.get_world_size() - if torch_world_size != parallel_config.world_size: - raise RuntimeError( - "torch.distributed is already initialized but the torch world " - "size does not match parallel_config.world_size " - f"({torch_world_size} vs. {parallel_config.world_size}).") - elif not distributed_init_method: - raise ValueError( - "distributed_init_method must be set if torch.distributed " - "is not already initialized") - else: - torch.distributed.init_process_group( - backend="nccl", - world_size=parallel_config.world_size, - rank=rank, - init_method=distributed_init_method, - ) + init_distributed_environment(parallel_config.world_size, rank, + distributed_init_method, local_rank) if pynccl_utils.is_initialized(): pynccl_world_size = pynccl_utils.get_world_size() @@ -291,10 +275,6 @@ def init_distributed_environment( init_method=distributed_init_method, ) - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) - if pynccl_utils.is_initialized(): - pynccl_utils.all_reduce(torch.zeros(1).cuda()) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) @@ -302,6 +282,11 @@ def init_distributed_environment( if not parallel_config.disable_custom_all_reduce: init_custom_ar() + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + if pynccl_utils.is_initialized(): + pynccl_utils.all_reduce(torch.zeros(1).cuda()) + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. From e7c7067b45c9f604e0c68015ee6e0fe345288111 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 9 Apr 2024 11:44:15 -0700 Subject: [PATCH 056/120] [Misc] [Core] Implement RFC "Augment BaseExecutor interfaces to enable hardware-agnostic speculative decoding" (#3837) --- tests/core/block/e2e/test_correctness.py | 6 +- tests/lora/test_worker.py | 8 +- tests/spec_decode/test_spec_decode_worker.py | 35 +++---- tests/spec_decode/utils.py | 6 +- tests/worker/test_swap.py | 10 +- vllm/config.py | 6 +- vllm/engine/arg_utils.py | 6 +- vllm/engine/llm_engine.py | 22 +++++ vllm/executor/cpu_executor.py | 58 +++++------- vllm/executor/executor_base.py | 23 +++++ vllm/executor/gpu_executor.py | 60 ++++-------- vllm/executor/neuron_executor.py | 28 +++--- vllm/executor/ray_gpu_executor.py | 96 ++++++++------------ vllm/executor/utils.py | 13 --- vllm/spec_decode/spec_decode_worker.py | 38 ++++---- vllm/worker/cache_engine.py | 9 +- vllm/worker/cpu_worker.py | 89 ++++++++++++++---- vllm/worker/neuron_worker.py | 45 ++++++++- vllm/worker/worker.py | 89 ++++++++++++------ vllm/worker/worker_base.py | 83 +++++++++++++++++ 20 files changed, 453 insertions(+), 277 deletions(-) delete mode 100644 vllm/executor/utils.py create mode 100644 vllm/worker/worker_base.py diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 5a7f828456e2..94b65401e1dd 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -16,7 +16,7 @@ # Allow only 5 sequences of ~1024 tokens in worst case. "block_size": 16, - "forced_num_gpu_blocks": 5 * (64 + 1), + "num_gpu_blocks_override": 5 * (64 + 1), }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{ @@ -162,14 +162,14 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator, # Allow only 2 sequences of ~128 tokens in worst case. # Note 8 = 128/block_size - "forced_num_gpu_blocks": 2 * (8 + 1), + "num_gpu_blocks_override": 2 * (8 + 1), }, { "block_size": 8, # Allow only 2 sequences of ~128 tokens in worst case. # Note 16 = 128/block_size - "forced_num_gpu_blocks": 2 * (16 + 1), + "num_gpu_blocks_override": 2 * (16 + 1), } ]) @pytest.mark.parametrize("baseline_llm_kwargs", [{ diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 60aa90fe4ee8..54594690f792 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -3,8 +3,8 @@ import tempfile from unittest.mock import patch -from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.worker.worker import Worker @@ -27,6 +27,10 @@ def test_worker_apply_lora(sql_lora_files): parallel_config=ParallelConfig(1, 1, False), scheduler_config=SchedulerConfig(32, 32, 32), device_config=DeviceConfig("cuda"), + cache_config=CacheConfig(block_size=16, + gpu_memory_utilization=1., + swap_space=0, + cache_dtype="auto"), local_rank=0, rank=0, lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 825d36067196..47aff8f57541 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -512,8 +512,8 @@ def test_init_device(): @torch.inference_mode() -def test_init_cache_engine(): - """Verify SpecDecodeWorker invokes init_cache_engine on proposer/scorer +def test_initialize_cache(): + """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer workers. """ draft_worker = mock_worker(cls=MultiStepWorker) @@ -525,12 +525,11 @@ def test_init_cache_engine(): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - cache_config = MagicMock() + kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} + worker.initialize_cache(**kwargs) - worker.init_cache_engine(cache_config) - - draft_worker.init_cache_engine.assert_called_once_with(cache_config) - target_worker.init_cache_engine.assert_called_once_with(cache_config) + draft_worker.initialize_cache.assert_called_once_with(**kwargs) + target_worker.initialize_cache.assert_called_once_with(**kwargs) @pytest.mark.parametrize('available_gpu_blocks', [1, 1024]) @@ -538,10 +537,10 @@ def test_init_cache_engine(): @pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) @pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) @pytest.mark.skip_global_cleanup -def test_profile_num_available_blocks(available_gpu_blocks: int, - available_cpu_blocks: int, - target_cache_block_size_bytes: int, - draft_kv_size_bytes: int): +def test_determine_num_available_blocks(available_gpu_blocks: int, + available_cpu_blocks: int, + target_cache_block_size_bytes: int, + draft_kv_size_bytes: int): """Verify SpecDecodeWorker correctly profiles num available GPU blocks. Specifically, it should run profiling in the scorer worker, and then evenly split the blocks between proposer and scorer worker. @@ -552,7 +551,7 @@ def test_profile_num_available_blocks(available_gpu_blocks: int, rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) - target_worker.profile_num_available_blocks.return_value = ( + target_worker.determine_num_available_blocks.return_value = ( available_gpu_blocks, available_cpu_blocks) target_worker.get_cache_block_size_bytes.return_value = ( target_cache_block_size_bytes) @@ -561,17 +560,9 @@ def test_profile_num_available_blocks(available_gpu_blocks: int, worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - # These values do not directly impact the adjusted block size calculation, - # so they can be fixed. - gpu_memory_utilization = 0.9 - cpu_swap_space = 100 - block_size = 16 - - num_gpu_blocks, num_cpu_blocks = worker.profile_num_available_blocks( - block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype="auto") + num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks() - target_worker.profile_num_available_blocks.assert_called_once_with( - block_size, gpu_memory_utilization, cpu_swap_space, "auto") + target_worker.determine_num_available_blocks.assert_called_once() assert num_cpu_blocks == available_cpu_blocks assert num_gpu_blocks == split_num_cache_blocks_evenly( diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 5ef1cc28253e..4637826f254d 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -117,6 +117,7 @@ def create_worker(cls: type, parallel_config=engine_config.parallel_config, scheduler_config=engine_config.scheduler_config, device_config=engine_config.device_config, + cache_config=engine_config.cache_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, @@ -128,8 +129,9 @@ def create_worker(cls: type, engine_config.cache_config.num_gpu_blocks = num_gpu_blocks engine_config.cache_config.num_cpu_blocks = 0 - worker.init_cache_engine(engine_config.cache_config) - worker.warm_up_model() + worker.initialize_cache( + num_gpu_blocks=engine_config.cache_config.num_gpu_blocks, + num_cpu_blocks=engine_config.cache_config.num_cpu_blocks) return worker diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 5d6ba51ea0f0..8edb1cf05c08 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -11,8 +11,8 @@ def test_swap() -> None: dtype="half", load_format="dummy") engine_config = engine_args.create_engine_config() - engine_config.cache_config.num_gpu_blocks = 100 - engine_config.cache_config.num_cpu_blocks = 100 + engine_config.cache_config.num_gpu_blocks = 1000 + engine_config.cache_config.num_cpu_blocks = 1000 # Create the worker. distributed_init_method = get_distributed_init_method( @@ -22,6 +22,7 @@ def test_swap() -> None: parallel_config=engine_config.parallel_config, scheduler_config=engine_config.scheduler_config, device_config=engine_config.device_config, + cache_config=engine_config.cache_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, @@ -31,8 +32,9 @@ def test_swap() -> None: # Initialize the worker. worker.init_device() worker.load_model() - worker.init_cache_engine(engine_config.cache_config) - worker.warm_up_model() + worker.initialize_cache( + num_gpu_blocks=engine_config.cache_config.num_gpu_blocks, + num_cpu_blocks=engine_config.cache_config.num_cpu_blocks) # Randomly initialize the cache. gpu_cache = worker.cache_engine.gpu_cache diff --git a/vllm/config.py b/vllm/config.py index 6762a75f25f2..753fc33e9b71 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -334,7 +334,7 @@ class CacheConfig: vLLM execution. swap_space: Size of the CPU swap space per GPU (in GiB). cache_dtype: Data type for kv cache storage. - forced_num_gpu_blocks: Number of GPU blocks to use. This overrides the + num_gpu_blocks_override: Number of GPU blocks to use. This overrides the profiled num_gpu_blocks if specified. Does nothing if None. """ @@ -344,14 +344,14 @@ def __init__( gpu_memory_utilization: float, swap_space: int, cache_dtype: str, - forced_num_gpu_blocks: Optional[int] = None, + num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB - self.forced_num_gpu_blocks = forced_num_gpu_blocks + self.num_gpu_blocks_override = num_gpu_blocks_override self.cache_dtype = cache_dtype self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a6197942645e..d4b573992c06 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -55,7 +55,7 @@ class EngineArgs: max_cpu_loras: Optional[int] = None device: str = 'auto' ray_workers_use_nsight: bool = False - forced_num_gpu_blocks: Optional[int] = None + num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 # Related to Vision-language models such as llava @@ -246,7 +246,7 @@ def add_cli_args( 'the model executor, which can range from 0 to 1.' 'If unspecified, will use the default value of 0.9.') parser.add_argument( - '--forced-num-gpu-blocks', + '--num-gpu-blocks-override', type=int, default=None, help='If specified, ignore GPU profiling result and use this number' @@ -426,7 +426,7 @@ def create_engine_config(self, ) -> EngineConfig: cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, - self.forced_num_gpu_blocks, + self.num_gpu_blocks_override, model_config.get_sliding_window(), self.enable_prefix_caching) parallel_config = ParallelConfig( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a9a4a7b83d93..1c639af69654 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -127,6 +127,8 @@ def __init__( speculative_config=speculative_config, ) + self._initialize_kv_caches() + # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): from vllm.model_executor.model_loader import ( @@ -178,6 +180,26 @@ def __init__( labels=dict(model_name=model_config.model)) self.stat_logger.info("cache_config", self.cache_config) + def _initialize_kv_caches(self) -> None: + """Initialize the KV cache in the worker(s). + + The workers will determine the number of blocks in both the GPU cache + and the swap CPU cache. + """ + num_gpu_blocks, num_cpu_blocks = ( + self.model_executor.determine_num_available_blocks()) + + if self.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override + logger.info(f"Overriding {num_gpu_blocks=} with " + f"{num_gpu_blocks_override=}") + num_gpu_blocks = num_gpu_blocks_override + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + @classmethod def from_engine_args( cls, diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 7b3cc784c98e..2bf97338da0e 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -35,7 +35,6 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, # Instantiate the worker and load the model to CPU. self._init_worker() - self._init_cache() def _init_worker(self): from vllm.worker.cpu_worker import CPUWorker @@ -46,10 +45,11 @@ def _init_worker(self): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) self.driver_worker = CPUWorker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, @@ -60,35 +60,21 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() - def _init_cache(self) -> None: - num_cpu_blocks = self.driver_worker.get_cpu_cache_block_num( - block_size=self.cache_config.block_size, - cache_space=self.cache_config.cpu_kvcache_space_bytes, - cache_dtype=self.cache_config.cache_dtype, - ) - + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. logger.info(f"# CPU blocks: {num_cpu_blocks}") - if num_cpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `VLLM_CPU_KVCACHE_SPACE` when " - "initializing the engine.") - - max_seq_len = self.cache_config.block_size * num_cpu_blocks - if self.model_config.max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({self.model_config.max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when " - "initializing the engine.") - - # Note: To reuse the cache management procedure, - # use cpu cache as 'gpu cache'. - self.cache_config.num_gpu_blocks = num_cpu_blocks # type: ignore - self.cache_config.num_cpu_blocks = 0 # type: ignore - - # Initialize the cache. - self.driver_worker.init_cache_engine(cache_config=self.cache_config) + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -104,13 +90,13 @@ def execute_model(self, return output def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError("LoRA is not implemented for cpu backend.") + return self.driver_worker.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError("LoRA is not implemented for cpu backend.") + return self.driver_worker.remove_lora(lora_id) def list_loras(self) -> List[int]: - raise NotImplementedError("LoRA is not implemented for cpu backend.") + return self.driver_worker.list_loras() def check_health(self) -> None: # CPUExecutor will always be healthy as long as diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 8ec5dfe1e00e..c18edd75d7a4 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -30,6 +30,29 @@ def __init__( ) -> None: raise NotImplementedError + @abstractmethod + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. + + Normally, this should simply delegate to the underlying Worker. Some + ExecutorBase may require modification of the result, e.g. to ensure the + selected cache sizes are compatible with all workers. + + Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ + raise NotImplementedError + + @abstractmethod + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache with the given size in blocks. + """ + raise NotImplementedError + @abstractmethod def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 7b683107d30e..80ca5cb7367c 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -4,7 +4,6 @@ ParallelConfig, SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.executor.utils import check_block_size_valid from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -41,9 +40,6 @@ def __init__( # Instantiate the worker and load the model to GPU. self._init_worker() - # Profile the memory usage and initialize the cache. - self._init_cache() - def _init_worker(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker @@ -55,61 +51,37 @@ def _init_worker(self): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) self.driver_worker = Worker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, - kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) self.driver_worker.init_device() self.driver_worker.load_model() - def _init_cache(self) -> None: - """Profiles the memory usage and initializes the KV cache. - - The engine first profiles the existing memory usage. - Then, it allocates the remaining memory for KV blocks. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_gpu_blocks, num_cpu_blocks = ( - self.driver_worker.profile_num_available_blocks( - block_size=self.cache_config.block_size, - gpu_memory_utilization=self.cache_config. - gpu_memory_utilization, - cpu_swap_space=self.cache_config.swap_space_bytes, - cache_dtype=self.cache_config.cache_dtype, - )) - - if self.cache_config.forced_num_gpu_blocks is not None: - forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks - logger.info(f"Replacing profiled {num_gpu_blocks=} with " - f"{forced_num_gpu_blocks=}") - num_gpu_blocks = forced_num_gpu_blocks + return self.driver_worker.determine_num_available_blocks() + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. logger.info(f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}") - check_block_size_valid(num_gpu_blocks, self.cache_config.block_size, - self.model_config.max_model_len) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - # Initialize the cache. - self.driver_worker.init_cache_engine(cache_config=self.cache_config) - # Warm up the model. This includes capturing the model into CUDA graph - # if enforce_eager is False. - self.driver_worker.warm_up_model() + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index c0af058cb90b..57436a85cfa2 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -25,7 +25,6 @@ def __init__( speculative_config: Optional[SpeculativeConfig], ) -> None: self.model_config = model_config - self.cache_config = cache_config assert lora_config is None, "LoRA is not supported for Neuron backend." self.parallel_config = parallel_config self.scheduler_config = scheduler_config @@ -33,12 +32,6 @@ def __init__( assert (not speculative_config ), "Speculative decoding not yet supported for Neuron backend." - # Set the number of GPU blocks to be the same as the maximum number of - # sequences that can be processed in a single batch. This is equivalent - # to schedule without PagedAttention. - self.cache_config.num_gpu_blocks = self.scheduler_config.max_num_seqs - self.cache_config.num_cpu_blocks = 0 - # Instantiate the worker and load the model to the device. self._init_worker() @@ -54,6 +47,18 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], @@ -68,16 +73,13 @@ def execute_model(self, return output def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError( - "LoRA is not implemented for neuron backend.") + return self.driver_worker.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError( - "LoRA is not implemented for neuron backend.") + return self.driver_worker.remove_lora(lora_id) def list_loras(self) -> List[int]: - raise NotImplementedError( - "LoRA is not implemented for neuron backend.") + return self.driver_worker.list_loras() def check_health(self) -> None: # NeuronExecutor will always be healthy as long as diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 43cb37cfb5e0..6c0ccd7e64c9 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -10,7 +10,6 @@ VisionLanguageConfig) from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.executor.utils import check_block_size_valid from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -65,9 +64,6 @@ def __init__( # Create the parallel GPU workers. self._init_workers_ray(placement_group) - # Profile the memory usage and initialize the cache. - self._init_cache() - self.forward_dag = None if USE_RAY_COMPILED_DAG: self.forward_dag = self._compiled_ray_dag() @@ -154,8 +150,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", scheduler_config = copy.deepcopy(self.scheduler_config) device_config = copy.deepcopy(self.device_config) lora_config = copy.deepcopy(self.lora_config) + cache_config = copy.deepcopy(self.cache_config) vision_language_config = copy.deepcopy(self.vision_language_config) - kv_cache_dtype = self.cache_config.cache_dtype # Initialize the actual workers with the Worker class. for rank, (worker, (node_id, _)) in enumerate( @@ -165,32 +161,32 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", local_rank = node_workers[node_id].index(rank) worker.init_worker.remote( lambda rank=rank, local_rank=local_rank: Worker( - model_config, - parallel_config, - scheduler_config, - device_config, - local_rank, - rank, - distributed_init_method, + model_config=model_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + cache_config=cache_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, lora_config=lora_config, vision_language_config=vision_language_config, - kv_cache_dtype=kv_cache_dtype, )) # Initialize the driver worker with the Worker class. driver_rank = 0 driver_local_rank = node_workers[driver_node_id].index(driver_rank) self.driver_worker = Worker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, - driver_local_rank, - driver_rank, - distributed_init_method, + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=driver_local_rank, + rank=driver_rank, + distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, - kv_cache_dtype=kv_cache_dtype, is_driver_worker=True, ) @@ -201,35 +197,18 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_parallel_loading_workers, ) - def _init_cache(self) -> None: - """Profiles the memory usage and initializes the KV cache. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - More details can be found in the - :meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method - from class :class:`~vllm.worker.Worker`. + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks. - Afterwards, as there may be multiple workers, - we take the minimum number of blocks across all workers - to ensure this can be applied to all of them. + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. - Finally, the engine will initialize the KV cache - with the calculated number of blocks. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] """ # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers( - "profile_num_available_blocks", - block_size=self.cache_config.block_size, - gpu_memory_utilization=self.cache_config.gpu_memory_utilization, - cpu_swap_space=self.cache_config.swap_space_bytes, - cache_dtype=self.cache_config.cache_dtype, - ) + num_blocks = self._run_workers("determine_num_available_blocks", ) # Since we use a shared centralized controller, we take the minimum # number of blocks across all workers to make sure all the memory @@ -237,26 +216,25 @@ def _init_cache(self) -> None: num_gpu_blocks = min(b[0] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks) - if self.cache_config.forced_num_gpu_blocks is not None: - forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks - logger.info(f"Replacing profiled {num_gpu_blocks=} with " - f"{forced_num_gpu_blocks=}") - num_gpu_blocks = forced_num_gpu_blocks + return num_gpu_blocks, num_cpu_blocks + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. logger.info(f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}") - check_block_size_valid(num_gpu_blocks, self.cache_config.block_size, - self.model_config.max_model_len) - self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - # Initialize the cache. - self._run_workers("init_cache_engine", cache_config=self.cache_config) - # Warm up the model. This includes capturing the model into CUDA graph - # if enforce_eager is False. - self._run_workers("warm_up_model") + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], diff --git a/vllm/executor/utils.py b/vllm/executor/utils.py deleted file mode 100644 index 44976696a77c..000000000000 --- a/vllm/executor/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -def check_block_size_valid(num_gpu_blocks, block_size, max_model_len) -> None: - if num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - max_seq_len = block_size * num_gpu_blocks - if max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 59f9d5b5107f..885bf537568e 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -3,7 +3,6 @@ import torch -from vllm.config import CacheConfig from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, SequenceGroupOutput, SequenceOutput) @@ -15,9 +14,10 @@ from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, split_batch_by_proposal_len) from vllm.worker.worker import Worker +from vllm.worker.worker_base import LoraNotSupportedWorkerBase -class SpecDecodeWorker: +class SpecDecodeWorker(LoraNotSupportedWorkerBase): """Worker which implements speculative decoding. Speculative decoding reduces decoding per-token latency by using a proposal @@ -94,10 +94,7 @@ def init_device(self) -> None: device=self.device, vocab_size=self._vocab_size) - def profile_num_available_blocks(self, block_size: int, - gpu_memory_utilization: float, - cpu_swap_space: int, - cache_dtype: str) -> Tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of cache blocks to use. This is done by profiling the scorer model (which is typically the @@ -106,27 +103,26 @@ def profile_num_available_blocks(self, block_size: int, such that the number of blocks is equal in both KV caches. """ num_gpu_blocks, num_cpu_blocks = ( - self.scorer_worker.profile_num_available_blocks( - block_size, gpu_memory_utilization, cpu_swap_space, - cache_dtype)) + self.scorer_worker.determine_num_available_blocks()) scorer_cache_block_size_bytes = ( - self.scorer_worker.get_cache_block_size_bytes( - block_size, cache_dtype)) + self.scorer_worker.get_cache_block_size_bytes()) proposer_cache_block_size_bytes = ( - self.proposer_worker.get_cache_block_size_bytes( - block_size, cache_dtype)) + self.proposer_worker.get_cache_block_size_bytes()) new_num_gpu_blocks = split_num_cache_blocks_evenly( scorer_cache_block_size_bytes, proposer_cache_block_size_bytes, num_gpu_blocks) return new_num_gpu_blocks, num_cpu_blocks - def init_cache_engine(self, cache_config: CacheConfig): + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: """Initialize the cache engine of the scorer and proposer workers. """ - self.scorer_worker.init_cache_engine(cache_config) - self.proposer_worker.init_cache_engine(cache_config) + self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) @torch.inference_mode() def execute_model( @@ -351,6 +347,16 @@ def rank(self): def device(self): return self.scorer_worker.device + def get_cache_block_size_bytes(self): + """Return the size of a cache block in bytes. + + This function is only used to compose workers within a SpecDecodeWorker. + We leave composing a SpecDecodeWorker within a SpecDecodeWorker + undefined for now, although it could be implemented in the future. + See https://arxiv.org/abs/2308.04623. + """ + raise NotImplementedError + def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, proposer_cache_block_size_bytes: int, diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 27d1727cd16a..c34ee0648626 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -82,8 +82,7 @@ def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: @staticmethod def get_cache_block_size( - block_size: int, - cache_dtype: str, + cache_config: CacheConfig, model_config: ModelConfig, parallel_config: ParallelConfig, ) -> int: @@ -91,13 +90,13 @@ def get_cache_block_size( num_heads = model_config.get_num_kv_heads(parallel_config) num_layers = model_config.get_num_layers(parallel_config) - key_cache_block = block_size * num_heads * head_size + key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block total = num_layers * (key_cache_block + value_cache_block) - if cache_dtype == "auto": + if cache_config.cache_dtype == "auto": dtype = model_config.dtype else: - dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] dtype_size = _get_dtype_size(dtype) return dtype_size * total diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index e1daa64346a9..42f0828b826e 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -17,6 +17,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.model_runner import ModelRunner +from vllm.worker.worker_base import LoraNotSupportedWorkerBase logger = init_logger(__name__) @@ -112,7 +113,7 @@ def get_cache_block_size( return dtype_size * total -class CPUWorker: +class CPUWorker(LoraNotSupportedWorkerBase): """A worker class that executes (a partition of) the model on a CPU socket. Each worker is associated with a single CPU socket. The worker is @@ -127,6 +128,7 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, local_rank: int, rank: int, distributed_init_method: str, @@ -138,6 +140,7 @@ def __init__( self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config + self.cache_config = cache_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -154,8 +157,7 @@ def __init__( kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by - # self.init_cache_engine(). - self.cache_config = None + # initialize_cache. self.cache_engine = None self.cpu_cache = None @@ -167,28 +169,70 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() - def get_cpu_cache_block_num( - self, - block_size: int, - cache_space: int, - cache_dtype: str, - ) -> int: - """ - Args: - block_size: The size of the cache block. - cache_space: The size of the CPU KV cache space in bytes. + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of blocks available for the KV cache. + + This determines how many KV blocks can fit into the configured CPU + KV cache space. + + Note that since vLLM assumes a block resides on GPU if it can be + modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0. + This allows us to reuse the scheduler of vLLM without generalizing it + to different devices. """ # For CPU device, the block number will be calculated based on the # cpu_kvcache_space. - cache_block_size = CPUCacheEngine.get_cache_block_size( - block_size, cache_dtype, self.model_config, self.parallel_config) - num_cpu_blocks = int(cache_space // cache_block_size) + cache_block_size = self.get_cache_block_size_bytes() + num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes // + cache_block_size) num_cpu_blocks = max(num_cpu_blocks, 0) - return num_cpu_blocks + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_gpu_blocks = num_cpu_blocks + num_cpu_blocks = 0 + return num_gpu_blocks, num_cpu_blocks - def init_cache_engine(self, cache_config: CacheConfig) -> None: - self.cache_config = cache_config + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. Currently, swappable CPU memory is not + supported. + + Since this worker does not support GPUs, we use the num_gpu_blocks to + determine how many non-swappable CPU blocks to allocate. + """ + assert (num_cpu_blocks == 0 + ), f"{type(self)} does not support swappable cache" + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_cpu_blocks = num_gpu_blocks + + self._validate_num_cpu_blocks(num_cpu_blocks) + self.cache_config.num_gpu_blocks = num_cpu_blocks + self.cache_config.num_cpu_blocks = 0 + + # Initialize the cache. + self._init_cache_engine() + + def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: + """Raise errors if the num_cpu_blocks is invalid. + """ + if num_cpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `VLLM_CPU_KVCACHE_SPACE` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_cpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when " + "initializing the engine.") + + def _init_cache_engine(self) -> None: self.cache_engine = CPUCacheEngine(self.cache_config, self.model_config, self.parallel_config, @@ -264,3 +308,10 @@ def init_distributed_environment(self) -> None: ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + + def get_cache_block_size_bytes(self) -> int: + """Return the size in bytes of a single KV cache block. + """ + return CPUCacheEngine.get_cache_block_size( + self.cache_config.block_size, self.cache_config.cache_dtype, + self.model_config, self.parallel_config) diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 0ae067aafb29..6136d50d0c06 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -4,14 +4,15 @@ import torch import torch.distributed -from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.model_executor import set_random_seed from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.neuron_model_runner import NeuronModelRunner +from vllm.worker.worker_base import LoraNotSupportedWorkerBase -class NeuronWorker: +class NeuronWorker(LoraNotSupportedWorkerBase): """A worker class that executes the model on a group of neuron cores. """ @@ -21,11 +22,13 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, ) -> None: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config + self.cache_config = cache_config self.model_runner = NeuronModelRunner(model_config, parallel_config, scheduler_config, device_config) @@ -37,6 +40,35 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available KV blocks. + + Swapping is not yet supported, so always return num_cpu_blocks=0. + + We configure num_gpu_blocks to be equal to max_num_seqs. + """ + # Set the number of GPU blocks to be the same as the maximum number of + # sequences that can be processed in a single batch. This is equivalent + # to schedule without PagedAttention. + num_gpu_blocks = self.scheduler_config.max_num_seqs + + # Swap not yet supported with Neuron backend. + num_cpu_blocks = 0 + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. + """ + + # Different values are not tested. + assert num_cpu_blocks == 0 + assert num_gpu_blocks == self.scheduler_config.max_num_seqs + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + @torch.inference_mode() def execute_model( self, @@ -50,3 +82,10 @@ def execute_model( output = self.model_runner.execute_model(seq_group_metadata_list) return output + + def get_cache_block_size_bytes(self) -> int: + """Determine the size in bytes of a cache block. + + This is required for speculative decoding; it is not yet implemented. + """ + raise NotImplementedError diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index bf0c6073ea9a..19de33089b2d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -19,9 +19,10 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner +from vllm.worker.worker_base import WorkerBase -class Worker: +class Worker(WorkerBase): """A worker class that executes (a partition of) the model on a GPU. Each worker is associated with a single GPU. The worker is responsible for @@ -35,18 +36,19 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, local_rank: int, rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, - kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ) -> None: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config + self.cache_config = cache_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -66,12 +68,11 @@ def __init__( scheduler_config, device_config, lora_config=self.lora_config, - kv_cache_dtype=kv_cache_dtype, + kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config) # Uninitialized cache engine. Will be initialized by - # self.init_cache_engine(). - self.cache_config = None + # initialize_cache. self.cache_engine = None self.gpu_cache = None @@ -107,20 +108,17 @@ def load_model(self): self.model_runner.load_model() @torch.inference_mode() - def profile_num_available_blocks( - self, - block_size: int, - gpu_memory_utilization: float, - cpu_swap_space: int, - cache_dtype: str, - ) -> Tuple[int, int]: - """Profiles the peak memory usage of the model and returns the maximum - number of GPU and CPU cache blocks that can be allocated. - - Args: - block_size: The size of the cache block. - gpu_memory_utilization: The fraction of the total GPU memory to use. - cpu_swap_space: The size of the CPU swap space in bytes. + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. @@ -141,12 +139,12 @@ def profile_num_available_blocks( "Error in memory profiling. This happens when the GPU memory was " "not properly cleaned up before initializing the vLLM instance.") - cache_block_size = self.get_cache_block_size_bytes( - block_size, cache_dtype) + cache_block_size = self.get_cache_block_size_bytes() num_gpu_blocks = int( - (total_gpu_memory * gpu_memory_utilization - peak_memory) // - cache_block_size) - num_cpu_blocks = int(cpu_swap_space // cache_block_size) + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) if self.model_runner.lora_manager: @@ -155,14 +153,30 @@ def profile_num_available_blocks( torch.cuda.empty_cache() return num_gpu_blocks, num_cpu_blocks - def init_cache_engine(self, cache_config: CacheConfig) -> None: - self.cache_config = cache_config + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Allocate GPU and CPU KV cache with the specified number of blocks. + + This also warms up the model, which may record CUDA graphs. + """ + raise_if_cache_size_invalid(num_gpu_blocks, + self.cache_config.block_size, + self.model_config.max_model_len) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._init_cache_engine() + self._warm_up_model() + + def _init_cache_engine(self): + assert self.cache_config.num_gpu_blocks is not None self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.gpu_cache = self.cache_engine.gpu_cache self.model_runner.set_block_size(self.cache_engine.block_size) - def warm_up_model(self) -> None: + def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: self.model_runner.capture_model(self.gpu_cache) # Reset the seed to ensure that the random state is not affected by @@ -239,11 +253,10 @@ def max_model_len(self) -> int: def vocab_size(self) -> int: return self.model_runner.vocab_size - def get_cache_block_size_bytes(self, block_size: int, - cache_dtype: str) -> int: + def get_cache_block_size_bytes(self) -> int: """Get the size of the KV cache block size in bytes. """ - return CacheEngine.get_cache_block_size(block_size, cache_dtype, + return CacheEngine.get_cache_block_size(self.cache_config, self.model_config, self.parallel_config) @@ -300,3 +313,19 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): f"{compute_capability[0]}.{compute_capability[1]}. " "You can use float16 instead by explicitly setting the" "`dtype` flag in CLI, for example: --dtype=half.") + + +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, + max_model_len) -> None: + if num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + max_seq_len = block_size * num_gpu_blocks + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py new file mode 100644 index 000000000000..e3027c406ffe --- /dev/null +++ b/vllm/worker/worker_base.py @@ -0,0 +1,83 @@ +from abc import ABC, abstractmethod +from typing import Dict, List + +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput, SequenceGroupMetadata + + +class WorkerBase(ABC): + """Worker interface that allows vLLM to cleanly separate implementations for + different hardware. + """ + + @abstractmethod + def init_device(self) -> None: + """Initialize device state, such as loading the model or other on-device + memory allocations. + """ + raise NotImplementedError + + @abstractmethod + def determine_num_available_blocks(self) -> tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. + + The implementation may run profiling or other heuristics to determine + the size of caches. + + Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ + raise NotImplementedError + + @abstractmethod + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache with the given size in blocks. + """ + raise NotImplementedError + + @abstractmethod + def execute_model(self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + """Executes one model step on the given sequences.""" + raise NotImplementedError + + @abstractmethod + def get_cache_block_size_bytes() -> int: + """Return the size of a single cache block, in bytes. Used in + speculative decoding. + """ + raise NotImplementedError + + @abstractmethod + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def list_loras(self) -> List[int]: + raise NotImplementedError + + +class LoraNotSupportedWorkerBase(WorkerBase): + """Partial implementation of WorkerBase that raises exceptions when LoRA + methods are invoked. + """ + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise ValueError(f"{type(self)} does not support LoRA") + + def remove_lora(self, lora_id: int) -> bool: + raise ValueError(f"{type(self)} does not support LoRA") + + def list_loras(self) -> List[int]: + raise ValueError(f"{type(self)} does not support LoRA") From f6c7b2ecded9a7b7e9575aec2ca405d7ae3dd9a7 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 9 Apr 2024 11:59:09 -0700 Subject: [PATCH 057/120] Zhuohan offline pr feedback --- ...est_block_decode.py => test_multi_step.py} | 13 ++++---- vllm/engine/output_processor/interfaces.py | 25 ++++++++-------- .../{block_decode.py => multi_step.py} | 30 +++++++++++-------- .../{beam_search.py => single_step.py} | 19 +++++++----- 4 files changed, 46 insertions(+), 41 deletions(-) rename tests/engine/output_processor/{test_block_decode.py => test_multi_step.py} (96%) rename vllm/engine/output_processor/{block_decode.py => multi_step.py} (79%) rename vllm/engine/output_processor/{beam_search.py => single_step.py} (94%) diff --git a/tests/engine/output_processor/test_block_decode.py b/tests/engine/output_processor/test_multi_step.py similarity index 96% rename from tests/engine/output_processor/test_block_decode.py rename to tests/engine/output_processor/test_multi_step.py index c4a88d67cabc..6da3da091db7 100644 --- a/tests/engine/output_processor/test_block_decode.py +++ b/tests/engine/output_processor/test_multi_step.py @@ -6,8 +6,7 @@ from tests.core.utils import create_seq_group from vllm.core.scheduler import Scheduler -from vllm.engine.output_processor.block_decode import ( - BlockDecodeOutputProcessor) +from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker from vllm.sampling_params import SamplingParams from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput, @@ -20,7 +19,7 @@ @pytest.mark.parametrize("num_new_tokens", [1, 12]) @pytest.mark.skip_global_cleanup def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): - """Verify block decoding appends token ids correctly. + """Verify multi-step decoding appends token ids correctly. We append token ids and verify all the token ids were appended correctly. Note that ignore_eos=True. @@ -30,7 +29,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): stop_checker = MagicMock(spec=StopChecker) seq_counter = Counter() - output_processor = BlockDecodeOutputProcessor( + output_processor = MultiStepOutputProcessor( detokenizer=detokenizer, scheduler=scheduler, seq_counter=seq_counter, @@ -84,7 +83,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, stop_checker = MagicMock(spec=StopChecker) seq_counter = Counter() - output_processor = BlockDecodeOutputProcessor( + output_processor = MultiStepOutputProcessor( detokenizer=detokenizer, scheduler=scheduler, seq_counter=seq_counter, @@ -146,7 +145,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, eos_token_id = 100 - output_processor = BlockDecodeOutputProcessor( + output_processor = MultiStepOutputProcessor( detokenizer=detokenizer, scheduler=scheduler, seq_counter=seq_counter, @@ -213,7 +212,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, eos_token_id = 100 - output_processor = BlockDecodeOutputProcessor( + output_processor = MultiStepOutputProcessor( detokenizer=detokenizer, scheduler=scheduler, seq_counter=seq_counter, diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 26ec982cc13f..9ddac7a04cb3 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -16,9 +16,10 @@ class SequenceGroupOutputProcessor(ABC): the scheduler. This is highly coupled with the LLMEngine and should be seen as an extension - of it. The logic is separated out to simplify the LLMEngine class and to - allow a beam search implementation (which handles forking, etc) and a block - decode implementation (which handles decoding >1 token per step). + of it. The logic is separated to simplify the LLMEngine class and allow + separate implementations for single-step decoding (which supports beam + search sequence forking) and multi-step decoding (which does not support + beam search, but does support speculative decoding). """ @staticmethod @@ -32,16 +33,14 @@ def create_output_processor( ): """Create an output processor. - This returns an output processor compatible with beam search if the - scheduler is not configured to scheduler lookahead slots. Otherwise, it - returns an output processor that is incompatible with beam search but - which supports decoding more than one token per scheduling invocation. + This returns a single-step output processor if num_lookahead_slots is + zero, else returns a multi-step output processor. """ if scheduler_config.num_lookahead_slots == 0: # Importing here to avoid cycle. - from vllm.engine.output_processor.beam_search import ( - BeamSearchOutputProcessor) - return BeamSearchOutputProcessor( + from vllm.engine.output_processor.single_step import ( + SingleStepOutputProcessor) + return SingleStepOutputProcessor( scheduler_config, detokenizer, scheduler, @@ -50,9 +49,9 @@ def create_output_processor( ) else: # Importing here to avoid cycle. - from vllm.engine.output_processor.block_decode import ( - BlockDecodeOutputProcessor) - return BlockDecodeOutputProcessor( + from vllm.engine.output_processor.multi_step import ( + MultiStepOutputProcessor) + return MultiStepOutputProcessor( detokenizer, scheduler, seq_counter, diff --git a/vllm/engine/output_processor/block_decode.py b/vllm/engine/output_processor/multi_step.py similarity index 79% rename from vllm/engine/output_processor/block_decode.py rename to vllm/engine/output_processor/multi_step.py index e309b57af6de..6b01a94f59e4 100644 --- a/vllm/engine/output_processor/block_decode.py +++ b/vllm/engine/output_processor/multi_step.py @@ -15,17 +15,18 @@ logger = init_logger(__name__) -class BlockDecodeOutputProcessor(SequenceGroupOutputProcessor): +class MultiStepOutputProcessor(SequenceGroupOutputProcessor): """SequenceGroupOutputProcessor which handles logic related to - detokenization and stopping conditions. Besides not supporting beam search, - this differs from BeamSearchOutputProcessor in that it supports lookahead - scheduling (where the model may generate >1 token per scheduler invocation). - - This allows it to support speculative decoding and cases where the model - runs more than once. We generalize these cases as "block decoding", where - the model emits a block of tokens at the same time. In this case, this class - is responsible for correctly appending all token ids to sequences and - detokenizing new token ids. + detokenization and stopping conditions. It specializes to "multi-step + decoding", where vLLM's worker may generate multiple tokens per invocation. + This is currently mutually exclusive with advanced sampling techniques like + beam search, which motivates the separation of this logic from the single + step output processor. + + This class is responsible for things such as correctly appending all new + token ids to their sequence, detokenizing new token ids, truncating new + output tokens after an eos token, and correctly handling the case where the + number of new output tokens per sequence differs in a single batch. """ def __init__( @@ -56,7 +57,8 @@ def process_outputs(self, sequence_group: SequenceGroup, seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) assert seqs, "expected running sequences" - assert len(seqs) == 1, ("Beam search not supported in block decoding.") + assert len(seqs) == 1, ( + "Beam search not supported in multi-step decoding.") seq = seqs[0] # Since there's only one sequence per sequence group, we can take the @@ -86,7 +88,9 @@ def _process_seq_outputs(self, seq: Sequence, output_token_ids = output_token_ids[:remaining_tokens] # Truncate any tokens after EOS. This is required as spec decode - # generates tokens in fixed blocks, which may go beyond the EOS token. + # generates a fixed number of tokens without evaluating stopping + # conditions within the block. This can cause an eos token to be + # unintentionally ignored. if not sampling_params.ignore_eos: eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id # Avoiding .index calls as exception throwing in the happy path @@ -100,7 +104,7 @@ def _process_seq_outputs(self, seq: Sequence, for output_token_id in output_token_ids: seq.append_token_id( token_id=output_token_id, - # TODO emit logprobs in block decoding. + # TODO emit logprobs in multi-step decoding. logprobs={output_token_id: Logprob(0.0)}, ) self.detokenizer.decode_sequence_inplace(seq, sampling_params) diff --git a/vllm/engine/output_processor/beam_search.py b/vllm/engine/output_processor/single_step.py similarity index 94% rename from vllm/engine/output_processor/beam_search.py rename to vllm/engine/output_processor/single_step.py index b0c0246b9935..a642070dce60 100644 --- a/vllm/engine/output_processor/beam_search.py +++ b/vllm/engine/output_processor/single_step.py @@ -14,15 +14,18 @@ logger = init_logger(__name__) -class BeamSearchOutputProcessor(SequenceGroupOutputProcessor): - """SequenceGroupOutputProcessor which handles logic related to beam search - sequence management and coupled logic like detokenization and stop logic. +class SingleStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles "output processing" logic, + which happens after the model returns generated token ids and before + scheduling of the next batch. Output processing logic includes + detokenization, and determining if a sequence is finished (e.g. via max len + or eos token). - This class is in charge of sorting out which sequences survive after beam - sampling. It manages forking and freeing of sequences. - - It does not support lookahead decoding, e.g. where the model generates >1 - token per scheduling invocation. + The SingleStepOutputProcessor is specialized to the case where the model + emits at most a single token per invocation, which precludes configurations + such as speculative decoding or multi-step decoding. This enables beam + search sampling, which requires forking/finishing/freeing sequences in a way + that is currently difficult to schedule multiple steps ahead of time. """ def __init__( From e23a43aef8bc51c5201658775445f529324ed728 Mon Sep 17 00:00:00 2001 From: Junichi Sato Date: Wed, 10 Apr 2024 04:11:31 +0900 Subject: [PATCH 058/120] [Bugfix] Fix KeyError on loading GPT-NeoX (#3925) --- vllm/model_executor/models/gpt_neox.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 673900487cc9..a5b5d717d984 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -274,6 +274,11 @@ def load_weights(self, if ("attention.bias" in name or "attention.masked_bias" in name or "rotary_emb.inv_freq" in name): continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using OpenRLHF may include + # these tensors in the checkpoint. Skip them. + continue param = params_dict[name] if "query_key_value" in name: From 96f81c4abdb4157b68bd33db3ff07a7825e6695e Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 9 Apr 2024 12:18:07 -0700 Subject: [PATCH 059/120] lint --- vllm/spec_decode/spec_decode_worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 29144f70ff6f..be3af7be9386 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -18,6 +18,7 @@ logger = init_logger(__name__) + class SpecDecodeWorker(LoraNotSupportedWorkerBase): """Worker which implements speculative decoding. From 6c0b04515fee7b402a6febde1467581825bb2164 Mon Sep 17 00:00:00 2001 From: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Date: Tue, 9 Apr 2024 17:10:47 -0500 Subject: [PATCH 060/120] [ROCm][Hardware][AMD] Use Triton Kernel for default FA on ROCm (#3643) Co-authored-by: jpvillam Co-authored-by: Gregory Shtrasberg Co-authored-by: Woosuk Kwon --- Dockerfile.rocm | 14 + vllm/attention/backends/rocm_flash_attn.py | 348 ++++++++ vllm/attention/backends/xformers.py | 78 +- vllm/attention/ops/triton_flash_attention.py | 809 +++++++++++++++++++ vllm/attention/selector.py | 57 +- 5 files changed, 1213 insertions(+), 93 deletions(-) create mode 100644 vllm/attention/backends/rocm_flash_attn.py create mode 100644 vllm/attention/ops/triton_flash_attention.py diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 65a367994f96..10b8bf1e7fab 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -23,6 +23,9 @@ RUN echo "FA_BRANCH is $FA_BRANCH" # In that case, we need to use the python reference attention implementation in vllm ARG BUILD_FA="1" +# whether to build triton on rocm +ARG BUILD_TRITON="1" + # Install some basic utilities RUN apt-get update && apt-get install python3 python3-pip -y @@ -75,6 +78,17 @@ RUN if [ "$BUILD_FA" = "1" ]; then \ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi +# build triton +RUN if [ "$BUILD_TRITON" = "1" ]; then \ + mkdir -p libs \ + && cd libs \ + && pip uninstall -y triton \ + && git clone https://github.com/ROCm/triton.git \ + && cd triton/python \ + && pip3 install . \ + && cd ../..; \ + fi + COPY ./ /app/vllm RUN python3 -m pip install --upgrade pip diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py new file mode 100644 index 000000000000..6019d917b449 --- /dev/null +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -0,0 +1,348 @@ +"""Attention layer ROCm GPUs.""" +import os +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class ROCmFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: + return ROCmFlashAttentionImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata": + return ROCmFlashAttentionMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + # (batch_size,). The prompt length per sequence. None if it is a decoding. + prompt_lens: Optional[List[int]] + # prompt_lens stored as a tensor. + prompt_lens_tensor: Optional[torch.Tensor] + # The number of prompt tokens. Doesn't include padding. + num_prompt_tokens: int + # The number of generation tokens. Doesn't include padding. + num_generation_tokens: int + + # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seqlen ----------------------| + # |- subquery_len -| + + # WARNING(sang): context_len has different definition depending on if it is + # prefill vs decoding. When it is prefill, it doesn't include new tokens. + # When it is for decoding, it includes a new token. + + # Maximum subquery length in the batch. + max_subquery_len: Optional[int] + # Maximum prompt length in the batch. + max_prompt_len: Optional[int] + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + subquery_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + +class ROCmFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9 + # NOTE: Allow for switching between Triton and CK. Defaulting to triton. + self.use_triton_flash_attn = (os.environ.get( + "VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")) + if self.use_naive_attn: + # AMD Radeon 7900 series (gfx1100) currently does not support + # xFormers nor FlashAttention. As a temporary workaround, we use + # naive PyTorch implementation of attention. + self.attn_fuc = _naive_attention() + logger.debug("Using naive attention in ROCmBackend") + elif self.use_triton_flash_attn: + from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 + triton_attention) + self.attn_func = triton_attention + logger.debug("Using Triton FA in ROCmBackend") + else: + from flash_attn import flash_attn_varlen_func # noqa: F401 + self.attn_func = flash_attn_varlen_func + logger.debug("Using CK FA in ROCmBackend") + + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + tokens, n_kv_heads, head_dim = x.shape + return (x[:, :, + None, :].expand(tokens, n_kv_heads, n_rep, + head_dim).reshape(tokens, n_kv_heads * n_rep, + head_dim)) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: ROCmFlashAttentionMetadata, + kv_scale: float = 1.0, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + attn_metadata.kv_cache_dtype, + kv_scale, + ) + + if attn_metadata.is_prompt: + # Prompt run. + if kv_cache is None or attn_metadata.block_tables.numel() == 0: + # triton attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + if self.use_naive_attn or self.use_triton_flash_attn: + if self.num_kv_heads != self.num_heads: + # Interleave for MQA workaround. + key = self.repeat_kv(key, self.num_queries_per_kv) + value = self.repeat_kv(value, self.num_queries_per_kv) + if self.use_naive_attn: + output = self.attn_fuc( + query, + key, + value, + attn_metadata.prompt_lens, + self.scale, + ) + else: + output, _ = self.attn_func( + query, + key, + value, + None, + attn_metadata.seq_start_loc, + attn_metadata.seq_start_loc, + attn_metadata.max_prompt_len, + attn_metadata.max_prompt_len, + True, + self.scale, + ) + else: + output = self.attn_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_prompt_len, + max_seqlen_k=attn_metadata.max_prompt_len, + softmax_scale=self.scale, + causal=True, + ) + + else: + # prefix-enabled attention + output = PagedAttention.forward_prefix( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.subquery_start_loc, + attn_metadata.prompt_lens_tensor, + attn_metadata.context_lens, + attn_metadata.max_subquery_len, + self.alibi_slopes, + ) + else: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, + attn_metadata.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) + + +def _naive_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + prompt_lens: List[int], + scale: float, +) -> torch.Tensor: + num_tokens = query.shape[0] + output = torch.empty_like(query) + start = 0 + for _, prompt_len in enumerate(prompt_lens): + end = start + prompt_len + out = _naive_masked_attention( + query[None, start:end], + key[None, start:end], + value[None, start:end], + scale, + ) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out) + start += prompt_len + + # Using view got RuntimeError: view size is not compatible + # with input tensor's size and stride (at least one + # dimension spans across two contiguous subspaces). + # Use reshape instead. + return output.reshape(num_tokens, -1) + + +def _naive_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, +) -> torch.Tensor: + seq_len, _, _ = query.shape + attn_mask = torch.triu(torch.ones(seq_len, + seq_len, + dtype=query.dtype, + device=query.device), + diagonal=1) + attn_mask = attn_mask * torch.finfo(query.dtype).min + + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d349c3ef19ea..05b68bba5e6e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -1,5 +1,4 @@ """Attention layer with xFormers and PagedAttention.""" -import importlib from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type @@ -14,7 +13,6 @@ from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger -from vllm.utils import is_hip logger = init_logger(__name__) @@ -166,11 +164,6 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") - # AMD Radeon 7900 series (gfx1100) currently does not support xFormers - # nor FlashAttention. As a temporary workaround, we use naive PyTorch - # implementation of attention. - self.use_naive_attention = _check_use_naive_attention() - def forward( self, query: torch.Tensor, @@ -233,30 +226,6 @@ def forward( self.num_queries_per_kv, value.shape[-1]) - if self.use_naive_attention: - output = torch.empty_like(query) - start = 0 - for _, prompt_len in enumerate(attn_metadata.prompt_lens): - end = start + prompt_len - out = _naive_masked_attention( - query[None, start:end], - key[None, start:end], - value[None, start:end], - self.num_heads, - self.num_kv_heads, - self.head_size, - self.scale, - ) - # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out) - start += prompt_len - - # Using view got RuntimeError: view size is not compatible - # with input tensor's size and stride (at least one - # dimension spans across two contiguous subspaces). - # Use reshape instead. - return output.reshape(num_tokens, hidden_size) - output = self._run_memory_efficient_xformers_forward( query, key, value, attn_metadata) else: @@ -329,8 +298,6 @@ def _run_memory_efficient_xformers_forward( self.alibi_slopes, self.num_kv_heads, query.dtype, attn_metadata.prompt_lens) - op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if ( - is_hip()) else None # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. @@ -344,8 +311,7 @@ def _run_memory_efficient_xformers_forward( value, attn_bias=attn_metadata.attn_bias[0], p=0.0, - scale=self.scale, - op=op) + scale=self.scale) return out.view_as(query) @@ -363,8 +329,7 @@ def _run_memory_efficient_xformers_forward( value[None, start:end], attn_bias=attn_metadata.attn_bias[i], p=0.0, - scale=self.scale, - op=op) + scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out.squeeze(0)) start += prompt_len @@ -405,42 +370,3 @@ def _make_alibi_bias( attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) return attn_biases - - -def _check_use_naive_attention() -> bool: - if not is_hip(): - return False - # For ROCm, check whether flash attention is installed or not. - use_naive_attention = importlib.util.find_spec("flash_attn") is None - if use_naive_attention: - logger.warning("flash_attn is not installed. Using naive attention. " - "This will take significantly more GPU memory.") - return True - return False - - -def _naive_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - num_kv_heads: int, - head_size: int, - scale: float, -) -> torch.Tensor: - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - seq_len, _, _ = query.shape - attn_mask = torch.triu(torch.ones(seq_len, - seq_len, - dtype=query.dtype, - device=query.device), - diagonal=1) - attn_mask = attn_mask * torch.finfo(query.dtype).min - - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py new file mode 100644 index 000000000000..b86e845020b0 --- /dev/null +++ b/vllm/attention/ops/triton_flash_attention.py @@ -0,0 +1,809 @@ +#!/usr/bin/env python +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao +(https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import torch +import triton +import triton.language as tl + +torch_dtype: tl.constexpr = torch.float16 + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, + stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, + stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + actual_seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, +): + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn( + K_block_ptr, + PADDED_HEAD, + MASK_STEPS and (n_extra_tokens != 0), + "zero", + ) + if PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: # noqa: SIM102 + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps + # if not is_modulo_mn. last step might get wasted but that is okay. + # check if this masking works for that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], + actual_seqlen_k, + dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS + and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, our + # optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += bias * 1.44269504089 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = (batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + start_n - + BLOCK_N) + keep = dropout_mask( + philox_seed, + philox_offset, + dropout_p, + BLOCK_M, + BLOCK_N, + actual_seqlen_k, + ) + if RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + tl.where(keep, p, + -p).to(encoded_softmax_block_ptr.type.element_ty), + ) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + ) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, BLOCK_N)) + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": True, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 64, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + # TODO: This config fails with head_size not pow2 with data mismatches. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, + # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config( + { + "BLOCK_M": 16, + "BLOCK_N": 16, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + ], + key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], +) +@triton.jit +def attn_fwd( + Q, + K, + V, + bias, + sm_scale, + L, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + hq, + hk, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if IS_CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn + # matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + # tl.store(O_block_ptr, + # acc.to(Out.type.element_ty), boundary_check=(0,1)) + # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # + offs_m + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 + # for these masked blocks. + # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + # tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here? + return + + is_mqa = hq != hk + off_h_k = off_h_q % hk if is_mqa else off_h_q + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + + # Compute pointers for all the tensors used in this kernel. + q_offset = (off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_offset = (off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + v_offset = (off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base \ + + (off_z * hq + off_h_q) \ + * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. + # In this case, we return an invalid pointer so indicate the mask is not i + # valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, padded_head, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + bias_ptr, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + ) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = (mask_m_offsets[:, None] >= + out_mask_boundary[None, :]) + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last + # few rows. This is only true for the last M block. For others, + # overflow_size will be -ve + # overflow_size = end_m_idx - seqlen_q + # if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + # else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + + +def check_args( + q, + k, + v, + o, + varlen=True, + max_seqlens=None, + cu_seqlens_q=None, + cu_seqlens_k=None, +): + assert q.dim() == k.dim() and q.dim() == v.dim() + if varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert cu_seqlens_q is not None + assert cu_seqlens_k is not None + assert len(cu_seqlens_q) == len(cu_seqlens_k) + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + # TODO: Fix assert to check head size <=256 once supported + assert head_size <= 128 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + causal=False, + sm_scale=1.0, + bias=None, + ): + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + + check_args( + q, + k, + v, + o, + varlen=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if True: # varlen + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + unpadded_head_dims = {32, 64, 128} + if head_size not in unpadded_head_dims: + padded_d_model = None + for i in unpadded_head_dims: + if i > head_size: + padded_d_model = i + break + assert padded_d_model is not None + else: + padded_d_model = head_size + + grid = lambda META: ( + triton.cdiv(max_seqlens_q, META["BLOCK_M"]), + nheads_q, + batch, + ) + + encoded_softmax = None + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if bias is not None: + bias_strides = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) + else: + bias_strides = (0, 0, 0, 0) + + attn_fwd[grid]( + q, + k, + v, + bias, + sm_scale, + None, + o, + *q_strides, + *k_strides, + *v_strides, + *o_strides, + *bias_strides, + cu_seqlens_q, + cu_seqlens_k, + dropout_p=0.0, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + hq=nheads_q, + hk=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=max_seqlens_q, + MAX_SEQLENS_K=max_seqlens_k, + IS_CAUSAL=causal, + VARLEN=True, + BLOCK_DMODEL=padded_d_model, + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False, + ) + + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = causal + ctx.dropout_p = 0.0 + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = False + return o, encoded_softmax + + +triton_attention = _attention.apply diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index b5cd39bbe625..4c699aed48d4 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,3 +1,4 @@ +import enum from functools import lru_cache from typing import Type @@ -10,46 +11,68 @@ logger = init_logger(__name__) +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + TORCH_SDPA = enum.auto() + + @lru_cache(maxsize=None) def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: - if _can_use_flash_attn(dtype): + backend = _which_attn_to_use(dtype) + if backend == _Backend.FLASH_ATTN: logger.info("Using FlashAttention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend - elif is_cpu(): - logger.info("Using Torch SDPA backend.") - from vllm.attention.backends.torch_sdpa import TorchSDPABackend - return TorchSDPABackend - else: + elif backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 XFormersBackend) return XFormersBackend + elif backend == _Backend.ROCM_FLASH: + logger.info("Using ROCmFlashAttention backend.") + from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401 + ROCmFlashAttentionBackend) + return ROCmFlashAttentionBackend + elif backend == _Backend.TORCH_SDPA: + logger.info("Using Torch SDPA backend.") + from vllm.attention.backends.torch_sdpa import TorchSDPABackend + return TorchSDPABackend + else: + raise ValueError("Invalid attention backend.") -def _can_use_flash_attn(dtype: torch.dtype) -> bool: +def _which_attn_to_use(dtype: torch.dtype) -> _Backend: + """Returns which flash attention backend to use.""" + if is_cpu(): + return _Backend.TORCH_SDPA + if is_hip(): # AMD GPUs. - logger.info("Cannot use FlashAttention backend for AMD GPUs.") - return False - if is_cpu(): - return False + if torch.cuda.get_device_capability()[0] != 9: + # not Instinct series GPUs. + logger.info("flash_atten is not supported on NAVI GPUs.") + return _Backend.ROCM_FLASH + + # NVIDIA GPUs. if torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info("Cannot use FlashAttention backend for Volta and Turing " "GPUs.") - return False + return _Backend.XFORMERS + if dtype not in (torch.float16, torch.bfloat16): logger.info("Cannot use FlashAttention backend for dtype other than " "torch.float16 or torch.bfloat16.") - return False + return _Backend.XFORMERS try: import flash_attn # noqa: F401 except ImportError: logger.info( - "Cannot use FlashAttention because the package is not found. " - "Please install it for better performance.") - return False - return True + "Cannot use FlashAttention backend because the flash_attn package " + "is not found. Please install it for better performance.") + return _Backend.XFORMERS + return _Backend.FLASH_ATTN From 11dd6ebb8950a66c371f6fa5d2489eb01fc4f6d5 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Wed, 10 Apr 2024 10:47:15 +0800 Subject: [PATCH 061/120] [Misc] Avoid loading incorrect LoRA config (#3777) --- tests/lora/test_lora_checkpoints.py | 40 +++++++++++++++++++++++++++++ vllm/lora/models.py | 17 ++++++++++-- vllm/lora/worker_manager.py | 11 ++++++++ 3 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 tests/lora/test_lora_checkpoints.py diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py new file mode 100644 index 000000000000..35ad7342944c --- /dev/null +++ b/tests/lora/test_lora_checkpoints.py @@ -0,0 +1,40 @@ +import pytest + +from vllm.lora.models import LoRAModel +from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM + + +@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"]) +def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): + supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules + packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping + embedding_modules = BaiChuanBaseForCausalLM.embedding_modules + embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules + expected_lora_modules = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend(packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + if lora_name == "baichuan7B": + # For the baichuan7B model, load it's LoRA, + # and the test should pass. + LoRAModel.from_local_checkpoint( + baichuan_lora_files, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embed_padding_modules) + else: + # For the baichuan7B model, load chatglm3-6b's LoRA, + # and the test should raise the following error. + expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501 + with pytest.raises(ValueError, match=expected_error): + LoRAModel.from_local_checkpoint( + chatglm3_lora_files, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embed_padding_modules) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 945917a5aa86..62f150245800 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -191,6 +191,7 @@ def from_lora_tensors( def from_local_checkpoint( cls, lora_dir: str, + expected_lora_modules: List[str], lora_model_id: Optional[int] = None, device: str = "cuda", dtype: Optional[torch.dtype] = None, @@ -206,6 +207,20 @@ def from_local_checkpoint( lora_dir, "new_embeddings.safetensors") new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") + with open(lora_config_path) as f: + config = json.load(f) + target_modules = config["target_modules"] + unexpected_modules = [] + for module in target_modules: + if module not in expected_lora_modules: + unexpected_modules.append(module) + # loaded lora's target modules must be a subset of expected_lora_modules + if unexpected_modules: + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct") if os.path.isfile(lora_tensor_path): tensors = safetensors.torch.load_file(lora_tensor_path) elif os.path.isfile(lora_bin_file_path): @@ -220,8 +235,6 @@ def from_local_checkpoint( elif os.path.isfile(new_embeddings_bin_file_path): embeddings = torch.load(new_embeddings_bin_file_path) - with open(lora_config_path) as f: - config = json.load(f) rank = config["r"] lora_alpha = config["lora_alpha"] return cls.from_lora_tensors( diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 3224b3a9e3eb..a0868defbd3c 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -136,8 +136,19 @@ def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: try: + model = self._lora_manager.model + supported_lora_modules = model.supported_lora_modules + packed_modules_mapping = model.packed_modules_mapping + expected_lora_modules = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend( + packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) lora = self._lora_model_cls.from_local_checkpoint( lora_request.lora_local_path, + expected_lora_modules, lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype, From c013d32c758699fbe5804af1b9d9408acd6cb8b7 Mon Sep 17 00:00:00 2001 From: Zedong Peng Date: Wed, 10 Apr 2024 12:30:03 +0800 Subject: [PATCH 062/120] [Benchmark] Add cpu options to bench scripts (#3915) --- benchmarks/benchmark_latency.py | 4 ++-- benchmarks/benchmark_throughput.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e2d358ea6631..91510dafc57a 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -169,8 +169,8 @@ def run_to_completion(profile_dir: Optional[str] = None): "--device", type=str, default="cuda", - choices=["cuda"], - help='device type for vLLM execution, supporting CUDA only currently.') + choices=["cuda", "cpu"], + help='device type for vLLM execution, supporting CUDA and CPU.') parser.add_argument('--block-size', type=int, default=16, diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index d6bf18c82e46..e71338273d1e 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -329,8 +329,8 @@ def main(args: argparse.Namespace): "--device", type=str, default="cuda", - choices=["cuda"], - help='device type for vLLM execution, supporting CUDA only currently.') + choices=["cuda", "cpu"], + help='device type for vLLM execution, supporting CUDA and CPU.') parser.add_argument( "--enable-prefix-caching", action='store_true', From c2e00af523b0638dcca68c9a42a9187449841ced Mon Sep 17 00:00:00 2001 From: zhaotyer <89376832+zhaotyer@users.noreply.github.com> Date: Wed, 10 Apr 2024 12:49:11 +0800 Subject: [PATCH 063/120] [Bugfix] fix utils.py/merge_dict func TypeError: 'type' object is not subscriptable (#3955) Co-authored-by: tianyi_zhao --- vllm/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 380ffe76fea7..8ba03333d3b6 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict from functools import lru_cache, partial from platform import uname -from typing import (Any, Awaitable, Callable, Generic, Hashable, List, +from typing import (Any, Awaitable, Callable, Dict, Generic, Hashable, List, Optional, Tuple, TypeVar, Union) import psutil @@ -452,8 +452,8 @@ def maybe_expand_dim(tensor: torch.Tensor, return tensor -def merge_dicts(dict1: dict[Any, list[Any]], - dict2: dict[Any, list[Any]]) -> dict[Any, list[Any]]: +def merge_dicts(dict1: Dict[Any, List[Any]], + dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]: """Merge 2 dicts that have key -> List of items. When a key conflicts, the values in dict1 is prioritized. From b3104b2a10ab7cb7532442177ae0d0c40acf9d03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E8=AF=91=E6=96=87?= <1020030101@qq.com> Date: Wed, 10 Apr 2024 15:09:36 +0800 Subject: [PATCH 064/120] [Bugfix] Fix logits processor when prompt_logprobs is not None (#3899) --- tests/samplers/test_logits_processor.py | 62 +++++++++++++++++++ .../model_executor/layers/logits_processor.py | 11 +++- 2 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 tests/samplers/test_logits_processor.py diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py new file mode 100644 index 000000000000..3788e9e9752f --- /dev/null +++ b/tests/samplers/test_logits_processor.py @@ -0,0 +1,62 @@ +import pytest +import torch + +from vllm import SamplingParams + +MODELS = ["facebook/opt-125m"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_logits_processor_force_generate( + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + vllm_model = vllm_runner(model, dtype=dtype) + tokenizer = vllm_model.model.get_tokenizer() + repeat_times = 2 + enforced_answers = " vLLM" + vllm_token_ids = tokenizer.encode(enforced_answers, + add_special_tokens=False) + max_tokens = len(vllm_token_ids) * repeat_times + + def pick_vllm(token_ids, logits): + token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)] + logits[token_id] = torch.finfo(logits.dtype).max + return logits + + params_with_logprobs = SamplingParams( + logits_processors=[pick_vllm], + prompt_logprobs=3, + max_tokens=max_tokens, + ) + + # test logits_processors when prompt_logprobs is not None + vllm_model.model._add_request( + prompt=example_prompts[0], + sampling_params=params_with_logprobs, + prompt_token_ids=None, + ) + + # test prompt_logprobs is not None + vllm_model.model._add_request( + prompt=example_prompts[1], + sampling_params=SamplingParams( + prompt_logprobs=3, + max_tokens=max_tokens, + ), + prompt_token_ids=None, + ) + + # test grouped requests + vllm_model.model._add_request( + prompt=example_prompts[2], + sampling_params=SamplingParams(max_tokens=max_tokens), + prompt_token_ids=None, + ) + + outputs = vllm_model.model._run_engine(False) + + assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 28e8f6bb7e63..ec531f79ced5 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -86,8 +86,16 @@ def _apply_logits_processors( ) -> torch.Tensor: logits_row_idx = 0 found_logits_processors = False - for seq_ids, sampling_params in sampling_metadata.seq_groups: + for i, seq_group in enumerate(sampling_metadata.seq_groups): + seq_ids, sampling_params = seq_group logits_processors = sampling_params.logits_processors + # handle prompt_logprobs by skipping rows in logits added for + # the prompt tokens (prompt logprobs are not processed) + if (i < sampling_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + assert len(seq_ids) == 1 + logits_row_idx += sampling_metadata.prompt_lens[i] - 1 + if logits_processors: found_logits_processors = True for seq_id in seq_ids: @@ -100,5 +108,6 @@ def _apply_logits_processors( else: logits_row_idx += len(seq_ids) if found_logits_processors: + # verifies that no rows in logits were missed unexpectedly assert logits_row_idx == logits.shape[0] return logits From 0258b7a94b08321ca01cf170f867b67c1920af87 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 10 Apr 2024 02:39:56 -0600 Subject: [PATCH 065/120] [Bugfix] handle prompt_logprobs in _apply_min_tokens_penalty (#3876) Signed-off-by: Travis Johnson --- tests/samplers/test_sampler.py | 116 +++++++++++++++++++++----- vllm/model_executor/layers/sampler.py | 19 ++++- 2 files changed, 112 insertions(+), 23 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 1626b7228207..26e2d29ffd04 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,3 +1,4 @@ +import itertools import random from typing import List, Optional, Tuple from unittest.mock import patch @@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): def create_sampling_params(min_tokens, eos_token_id=0, - stop_token_ids=None): + *, + stop_token_ids: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None): sampling_params = SamplingParams( min_tokens=min_tokens, max_tokens=9999, # keep higher than max of min_tokens stop_token_ids=stop_token_ids, + # requesting prompt_logprobs changes the structure of `logits` + prompt_logprobs=prompt_logprobs, ) sampling_params.eos_token_id = eos_token_id return sampling_params @@ -217,9 +222,9 @@ def generate_test_case(): expected_penalization = [] sequence_metadata_list = [] + # 20% chance to generate seq group metadata list with all prompts + is_prompt = random.random() < 0.2 while batch_size > 0: - # 20% chance to generate prompt seq group with single sequence - is_prompt = random.random() < 0.2 num_seqs = 1 if is_prompt else random.randint(1, batch_size) eos_token_id = random.randint(0, VOCAB_SIZE - 1) @@ -240,7 +245,7 @@ def generate_test_case(): seq_group_penalization = [] for _ in range(num_seqs): num_input = random.randint(1, 100) - num_generated = random.randint(1, 100) if not is_prompt else 0 + num_generated = 0 if is_prompt else random.randint(1, 100) seq_data[next(seq_id_counter)] = create_sequence_data( num_input=num_input, num_generated=num_generated) seq_group_penalization.append(num_generated < min_tokens) @@ -292,6 +297,21 @@ def generate_test_case(): ] } + prompt_with_penalization_and_prompt_logprobs = { + "expected_penalization": [False, False, True], + "seq_group_metadata_list": [ + SequenceGroupMetadata( + request_id="test_1", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(num_input=3), + }, + sampling_params=create_sampling_params(1, prompt_logprobs=3), + block_tables={}, + ), + ] + } + stop_penalizing_after_min_tokens = { "expected_penalization": [False], "seq_group_metadata_list": [ @@ -309,8 +329,34 @@ def generate_test_case(): } stop_token_ids = [42, 99, 42, 0] # intentional duplication - simple_combination = { - "expected_penalization": [True, False, False], + prompt_combination = { + "expected_penalization": [False, True, False], + "seq_group_metadata_list": [ + SequenceGroupMetadata( + request_id="test_2", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(num_input=2), + }, + sampling_params=create_sampling_params(1, prompt_logprobs=3), + block_tables={}, + ), + SequenceGroupMetadata( + request_id="test_3", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(), + }, + sampling_params=create_sampling_params( + 0, stop_token_ids=stop_token_ids), + block_tables={}, + ) + ] + } + + stop_token_ids = [1, 999, 37, 37] # intentional duplication + decode_combination = { + "expected_penalization": [True, False, False, True, False], "seq_group_metadata_list": [ SequenceGroupMetadata( request_id="test_1", @@ -327,14 +373,19 @@ def generate_test_case(): ), SequenceGroupMetadata( request_id="test_2", - is_prompt=True, + is_prompt=False, seq_data={ - next(seq_id_counter): create_sequence_data(), + next(seq_id_counter): + create_sequence_data(num_generated=20), + next(seq_id_counter): + create_sequence_data(num_generated=1), + next(seq_id_counter): + create_sequence_data(num_generated=10), }, sampling_params=create_sampling_params( - 0, stop_token_ids=stop_token_ids), + 10, prompt_logprobs=5, stop_token_ids=stop_token_ids), block_tables={}, - ) + ), ] } @@ -342,8 +393,10 @@ def generate_test_case(): test_cases = [ prompt_without_penalization, prompt_with_penalization, + prompt_with_penalization_and_prompt_logprobs, stop_penalizing_after_min_tokens, - simple_combination, + prompt_combination, + decode_combination, ] else: test_cases = [generate_test_case()] @@ -351,30 +404,49 @@ def generate_test_case(): def run_test_case(*, expected_penalization=None, seq_group_metadata_list=None): - assert expected_penalization, "Invalid test case" - assert seq_group_metadata_list, "Invalid test case" + assert expected_penalization, \ + "Invalid test case, need expected_penalization" + assert seq_group_metadata_list, \ + "Invalid test case, need seq_group_metadata_list" batch_size = 0 prompt_lens = [] - sampling_params_per_seq = [] + sampling_params_per_row = [] for sgm in seq_group_metadata_list: - num_seqs = len(sgm.seq_data) - batch_size += num_seqs sampling_params = sgm.sampling_params - for seq_id in sgm.seq_data: - prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len()) - sampling_params_per_seq.append(sampling_params) + + num_rows = len(sgm.seq_data) + if sgm.is_prompt: + # a prompt seq_group has only one sequence + seq_data = next(iter(sgm.seq_data.values())) + prompt_len = seq_data.get_prompt_len() + prompt_lens.append(prompt_len) + + if sgm.sampling_params.prompt_logprobs: + # with prompt_logprobs each token in the prompt has a row in + # logits + num_rows = prompt_len + + batch_size += num_rows + sampling_params_per_row.extend( + itertools.repeat(sampling_params, num_rows)) + + assert len( + expected_penalization + ) == batch_size, \ + ("Invalid test case, expected_penalization does not match computed" + "batch size") _, fake_logits, sampler, model_runner = _prepare_test(batch_size) sampling_metadata = model_runner._prepare_sample( seq_group_metadata_list, - prompt_lens=prompt_lens, - subquery_lens=prompt_lens) + prompt_lens=prompt_lens if prompt_lens else None, + subquery_lens=prompt_lens if prompt_lens else None) # the logits tensor is modified in-place by the sampler _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) for logits_idx, (should_penalize, sampling_params) in enumerate( - zip(expected_penalization, sampling_params_per_seq)): + zip(expected_penalization, sampling_params_per_row)): tokens_to_check = [sampling_params.eos_token_id] if sampling_params.stop_token_ids: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index cb1480de03e3..03bf38caebe0 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -27,6 +27,12 @@ class Sampler(nn.Module): 6. Sample the next tokens. Here, each sequence group within the batch can have different sampling parameters (e.g., sampling method, temperature, top-p, top-k, etc.). + + The structure of the logits tensor is coupled with the seq_groups in + sampling_metadata. Typically, each sequence in each seq_group has one row in + logits for the next token to be sampled; however, for a seq_group with a + prompt request with the prompt_logprobs sampling parameter, there are rows + in logits for each token in the input prompt. """ def forward( @@ -106,7 +112,16 @@ def _apply_min_tokens_penalty( # list of indices in logits that will be set to -inf logits_to_penalize = [] start_idx = 0 - for seq_ids, sampling_params in sampling_metadata.seq_groups: + for i, seq_group in enumerate(sampling_metadata.seq_groups): + seq_ids, sampling_params = seq_group + + # handle prompt_logprobs by skipping rows in logits added for the prompt + # tokens (prompt logprobs are not penalized) + if (i < sampling_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + assert len(seq_ids) == 1 + start_idx += sampling_metadata.prompt_lens[i] - 1 + min_tokens = sampling_params.min_tokens if min_tokens > 0: seqs_to_penalize = [] @@ -132,6 +147,8 @@ def _apply_min_tokens_penalty( # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) logits[tuple(zip(*logits_to_penalize))] = -float("inf") + # verifies that no rows in logits were missed unexpectedly + assert start_idx == logits.shape[0] return logits From bd3c144e0b8e82c9b3c5c40c6d557fe8665de5a3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 10 Apr 2024 07:37:17 -0700 Subject: [PATCH 066/120] [Bugfix][ROCm] Add numba to Dockerfile.rocm (#3962) --- Dockerfile.rocm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 10b8bf1e7fab..b1c5fac9d78e 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -91,7 +91,7 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \ COPY ./ /app/vllm -RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --upgrade pip numba RUN python3 -m pip install xformers==0.0.23 --no-deps RUN cd /app \ From 8b317c6dd09ce566f4b4abeb446585ac75262cce Mon Sep 17 00:00:00 2001 From: James Whedbee Date: Wed, 10 Apr 2024 10:12:00 -0500 Subject: [PATCH 067/120] [Model][AMD] ROCm support for 256 head dims for Gemma (#3972) --- vllm/attention/ops/triton_flash_attention.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index b86e845020b0..87cf30cbef79 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -677,8 +677,7 @@ def check_args( assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] # TODO: Change assert if we support qkl f8 and v f16 assert q.dtype == k.dtype and q.dtype == v.dtype - # TODO: Fix assert to check head size <=256 once supported - assert head_size <= 128 + assert head_size <= 256 assert o.shape == q.shape assert (nheads_q % nheads_k) == 0 @@ -729,7 +728,7 @@ def forward( o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) # Get closest power of 2 over or equal to 32. - unpadded_head_dims = {32, 64, 128} + unpadded_head_dims = {32, 64, 128, 256} if head_size not in unpadded_head_dims: padded_d_model = None for i in unpadded_head_dims: From e35397468f36a857b8d2b7d92a472265e1c500cc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 10 Apr 2024 10:03:02 -0700 Subject: [PATCH 068/120] [Doc] Add doc to state our model support policy (#3948) Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- docs/source/models/supported_models.rst | 26 +++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index e7bfdcb65316..c09b0ff25043 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -168,3 +168,29 @@ Alternatively, you can raise an issue on our `GitHub `_ and `test_big_models.py `_ for the models that have passed this test. +2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. +3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to `functionality tests `_ and `examples `_ for the models that have passed this test. +4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. From e4c4072c94b346053768691451566c56664e26a7 Mon Sep 17 00:00:00 2001 From: Daniel E Marasco Date: Wed, 10 Apr 2024 13:15:51 -0400 Subject: [PATCH 069/120] [Bugfix] Remove key sorting for `guided_json` parameter in OpenAi compatible Server (#3945) --- vllm/model_executor/guided_decoding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index e56f74c7794f..8e710f1ac2b5 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -91,7 +91,7 @@ def _get_guide_and_mode( json = request.guided_json if isinstance(json, dict): # turn dict into hashable string - json = json_dumps(json, sort_keys=True) + json = json_dumps(json) elif isinstance(json, BaseModel): # use pydantic signature so that different model classes # with the same fields will get hashed the same From 92cd2e2f21e8ec65b2cb635a9f15de38157a1359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=CE=B1n=C3=A7ois?= Date: Wed, 10 Apr 2024 20:05:52 +0200 Subject: [PATCH 070/120] [Doc] Fix getting stared to use publicly available model (#3963) --- docs/source/serving/openai_compatible_server.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 032fe5d03bd5..388b5daa79a9 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -4,7 +4,7 @@ vLLM provides an HTTP server that implements OpenAI's [Completions](https://plat You can start the server using Python, or using [Docker](deploying_with_docker.rst): ```bash -python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-hf --dtype float32 --api-key token-abc123 +python -m vllm.entrypoints.openai.api_server --model mistralai/Mistral-7B-Instruct-v0.2 --dtype auto --api-key token-abc123 ``` To call the server, you can use the official OpenAI Python client library, or any other HTTP client. @@ -16,9 +16,8 @@ client = OpenAI( ) completion = client.chat.completions.create( - model="meta-llama/Llama-2-7b-hf", + model="mistralai/Mistral-7B-Instruct-v0.2", messages=[ - {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"} ] ) @@ -38,9 +37,8 @@ Or directly merge them into the JSON payload if you are using HTTP call directly ```python completion = client.chat.completions.create( - model="meta-llama/Llama-2-7b-hf", + model="mistralai/Mistral-7B-Instruct-v0.2", messages=[ - {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} ], extra_body={ @@ -89,7 +87,7 @@ In order for the language model to support chat protocol, vLLM requires the mode a chat template in its tokenizer configuration. The chat template is a Jinja2 template that specifies how are roles, messages, and other chat-specific tokens are encoded in the input. -An example chat template for `meta-llama/Llama-2-7b-chat-hf` can be found [here](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/09bd0f49e16738cdfaa6e615203e126038736eb0/tokenizer_config.json#L12) +An example chat template for `mistralai/Mistral-7B-Instruct-v0.2` can be found [here](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format) Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model, you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat From de1691929e58af704c72b329c9e608d06f2d8320 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Wed, 10 Apr 2024 14:13:32 -0700 Subject: [PATCH 071/120] pr feedback --- vllm/engine/llm_engine.py | 3 +-- vllm/engine/output_processor/stop_checker.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e6e75ee59c76..59add1faba44 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -193,8 +193,7 @@ def __init__( self.seq_counter, self.get_tokenizer_for_seq, stop_checker=StopChecker( - self.scheduler, - self.scheduler_config, + self.scheduler_config.max_model_len, self.get_tokenizer_for_seq, ), )) diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 2a6c79d2dc02..37d53fa3c7fa 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -1,4 +1,6 @@ -from typing import List +from typing import Callable, List + +from transformers import PreTrainedTokenizer from vllm.sampling_params import SamplingParams from vllm.sequence import Sequence, SequenceStatus @@ -11,9 +13,10 @@ class StopChecker: emitted, or if we have exceeded the max model len. """ - def __init__(self, scheduler, scheduler_config, get_tokenizer_for_seq): - self.scheduler = scheduler - self.scheduler_config = scheduler_config + def __init__(self, max_model_len: int, + get_tokenizer_for_seq: Callable[[Sequence], + PreTrainedTokenizer]): + self.max_model_len = max_model_len self.get_tokenizer_for_seq = get_tokenizer_for_seq def maybe_stop_sequence(self, seq: Sequence, @@ -23,7 +26,7 @@ def maybe_stop_sequence(self, seq: Sequence, """ # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: + if seq.get_len() > self.max_model_len: seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return From 934d3662f716d60abfb04cf9fdd6d20f6e75f140 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 10 Apr 2024 16:28:25 -0600 Subject: [PATCH 072/120] [Bugfix] handle hf_config with architectures == None (#3982) Signed-off-by: Travis Johnson Co-authored-by: Simon Mo --- vllm/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 753fc33e9b71..bca250e92228 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -158,7 +158,9 @@ def _verify_load_format(self) -> None: # TODO: Remove this check once HF updates the pt weights of Mixtral. architectures = getattr(self.hf_config, "architectures", []) - if "MixtralForCausalLM" in architectures and load_format == "pt": + # architectures can be None instead of [] + if architectures and "MixtralForCausalLM" in architectures \ + and load_format == "pt": raise ValueError( "Currently, the 'pt' format is not supported for Mixtral. " "Please use the 'safetensors' format instead. ") From 63e7176f265be43dcc425f5ab4ab45c90234f5c3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 10 Apr 2024 15:33:30 -0700 Subject: [PATCH 073/120] [Core][Refactor] move parallel_utils into vllm/distributed (#3950) [WIP][Core][Refactor] move vllm/model_executor/parallel_utils into vllm/distributed and vllm/device_communicators (#3950) --- tests/conftest.py | 3 +-- tests/distributed/test_comm_ops.py | 6 +++--- tests/distributed/test_custom_all_reduce.py | 13 ++++++------- tests/distributed/test_pynccl.py | 4 ++-- tests/lora/conftest.py | 3 +-- vllm/distributed/__init__.py | 3 +++ .../communication_op.py | 14 ++++++++------ .../device_communicators}/__init__.py | 0 .../device_communicators}/custom_all_reduce.py | 5 +++-- .../device_communicators}/pynccl.py | 0 .../device_communicators}/pynccl_utils.py | 4 ++-- .../parallel_state.py | 4 ++-- .../parallel_utils => distributed}/utils.py | 0 vllm/lora/layers.py | 13 ++++++------- vllm/model_executor/layers/activation.py | 5 ++--- vllm/model_executor/layers/linear.py | 11 +++++------ vllm/model_executor/layers/logits_processor.py | 3 +-- .../layers/vocab_parallel_embedding.py | 8 +++----- vllm/model_executor/models/baichuan.py | 4 ++-- vllm/model_executor/models/bloom.py | 4 ++-- vllm/model_executor/models/chatglm.py | 3 +-- vllm/model_executor/models/commandr.py | 4 ++-- vllm/model_executor/models/dbrx.py | 7 +++---- vllm/model_executor/models/deepseek.py | 7 +++---- vllm/model_executor/models/falcon.py | 7 +++---- vllm/model_executor/models/gemma.py | 3 +-- vllm/model_executor/models/gpt2.py | 3 +-- vllm/model_executor/models/gpt_bigcode.py | 3 +-- vllm/model_executor/models/gpt_j.py | 3 +-- vllm/model_executor/models/gpt_neox.py | 3 +-- vllm/model_executor/models/internlm2.py | 3 +-- vllm/model_executor/models/jais.py | 4 ++-- vllm/model_executor/models/llama.py | 4 ++-- vllm/model_executor/models/minicpm.py | 7 +++---- vllm/model_executor/models/mixtral.py | 7 +++---- vllm/model_executor/models/mixtral_quant.py | 7 +++---- vllm/model_executor/models/mpt.py | 4 ++-- vllm/model_executor/models/olmo.py | 3 +-- vllm/model_executor/models/opt.py | 3 +-- vllm/model_executor/models/orion.py | 3 +-- vllm/model_executor/models/phi.py | 3 +-- vllm/model_executor/models/qwen.py | 3 +-- vllm/model_executor/models/qwen2.py | 3 +-- vllm/model_executor/models/qwen2_moe.py | 7 +++---- vllm/model_executor/models/stablelm.py | 3 +-- vllm/model_executor/models/starcoder2.py | 3 +-- vllm/model_executor/models/xverse.py | 3 +-- vllm/model_executor/parallel_utils/README.md | 1 - vllm/test_utils.py | 4 ++-- vllm/worker/cpu_worker.py | 7 +++---- vllm/worker/model_runner.py | 8 +++----- vllm/worker/worker.py | 12 ++++++------ 52 files changed, 111 insertions(+), 141 deletions(-) create mode 100644 vllm/distributed/__init__.py rename vllm/{model_executor/parallel_utils => distributed}/communication_op.py (94%) rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/__init__.py (100%) rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/custom_all_reduce.py (98%) rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/pynccl.py (100%) rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/pynccl_utils.py (91%) rename vllm/{model_executor/parallel_utils => distributed}/parallel_state.py (98%) rename vllm/{model_executor/parallel_utils => distributed}/utils.py (100%) delete mode 100644 vllm/model_executor/parallel_utils/README.md diff --git a/tests/conftest.py b/tests/conftest.py index e00f3eb871e3..a7e8963af0ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,8 +11,7 @@ from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig -from vllm.model_executor.parallel_utils.parallel_state import ( - destroy_model_parallel) +from vllm.distributed import destroy_model_parallel from vllm.sequence import MultiModalData from vllm.transformers_utils.tokenizer import get_tokenizer diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index d1811cb694db..aa9e0537c691 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -8,9 +8,9 @@ import ray import torch -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) +from vllm.distributed import (broadcast_tensor_dict, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) from vllm.test_utils import (init_test_distributed_environment, multi_process_tensor_parallel) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 1e6e7f89a528..3b1cd1773af1 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -6,9 +6,8 @@ import torch import torch.distributed as dist -from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed.device_communicators import custom_all_reduce from vllm.test_utils import (init_test_distributed_environment, multi_process_tensor_parallel) @@ -26,10 +25,10 @@ def graph_allreduce(world_size, rank, distributed_init_port): init_test_distributed_environment(1, world_size, rank, distributed_init_port) - custom_ar.init_custom_ar() + custom_all_reduce.init_custom_all_reduce() for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - with custom_ar.capture(): + with custom_all_reduce.capture(): # use integers so result matches NCCL exactly inp1 = torch.randint(1, 16, (sz, ), @@ -62,8 +61,8 @@ def eager_allreduce(world_size, rank, distributed_init_port): distributed_init_port) sz = 1024 - custom_ar.init_custom_ar() - fa = custom_ar.get_handle() + custom_all_reduce.init_custom_all_reduce() + fa = custom_all_reduce.get_handle() inp = torch.ones(sz, dtype=torch.float32, device=device) out = fa.all_reduce_unreg(inp) assert torch.allclose(out, inp * world_size) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 29782045130a..b50eed1c8c72 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -4,8 +4,8 @@ import pytest import torch -from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, - ncclGetUniqueId) +from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, + ncclGetUniqueId) def distributed_run(fn, world_size): diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index acb5fa91e201..207c635e2dc8 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -12,6 +12,7 @@ import vllm from vllm.config import LoRAConfig +from vllm.distributed import destroy_model_parallel, initialize_model_parallel from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) @@ -19,8 +20,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils.parallel_state import ( - destroy_model_parallel, initialize_model_parallel) def cleanup(): diff --git a/vllm/distributed/__init__.py b/vllm/distributed/__init__.py new file mode 100644 index 000000000000..db325cfabf55 --- /dev/null +++ b/vllm/distributed/__init__.py @@ -0,0 +1,3 @@ +from .communication_op import * +from .parallel_state import * +from .utils import * diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/distributed/communication_op.py similarity index 94% rename from vllm/model_executor/parallel_utils/communication_op.py rename to vllm/distributed/communication_op.py index 9cbb40708dd5..cf15db099b30 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -4,12 +4,10 @@ import torch from torch.distributed import ProcessGroup -from vllm.model_executor.parallel_utils import pynccl_utils -from vllm.model_executor.parallel_utils.custom_all_reduce import ( - custom_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce) +from .parallel_state import (get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + is_pynccl_enabled_for_all_reduce) def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: @@ -24,6 +22,10 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: TLDR: always assume this function modifies its input, but use the return value as the output. """ + from vllm.distributed.device_communicators import pynccl_utils + from vllm.distributed.device_communicators.custom_all_reduce import ( + custom_all_reduce) + # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: return input_ diff --git a/vllm/model_executor/parallel_utils/__init__.py b/vllm/distributed/device_communicators/__init__.py similarity index 100% rename from vllm/model_executor/parallel_utils/__init__.py rename to vllm/distributed/device_communicators/__init__.py diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py similarity index 98% rename from vllm/model_executor/parallel_utils/custom_all_reduce.py rename to vllm/distributed/device_communicators/custom_all_reduce.py index bf8ee07070c8..84238d2e4607 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -5,8 +5,6 @@ import torch.distributed as dist from vllm.logger import init_logger -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) try: import pynvml @@ -25,6 +23,9 @@ def init_custom_ar() -> None: + from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + global _CA_HANDLE if _CA_HANDLE is not None: return diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/distributed/device_communicators/pynccl.py similarity index 100% rename from vllm/model_executor/parallel_utils/pynccl.py rename to vllm/distributed/device_communicators/pynccl.py diff --git a/vllm/model_executor/parallel_utils/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py similarity index 91% rename from vllm/model_executor/parallel_utils/pynccl_utils.py rename to vllm/distributed/device_communicators/pynccl_utils.py index a099777aa000..aeb73015733d 100644 --- a/vllm/model_executor/parallel_utils/pynccl_utils.py +++ b/vllm/distributed/device_communicators/pynccl_utils.py @@ -9,8 +9,8 @@ logger = init_logger(__name__) try: - from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, - ncclGetVersion) + from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, + ncclGetVersion) except Exception as e: # in non-NVIDIA environments, we can't import the nccl module # e.g. when running on machines with AMD GPUs diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/distributed/parallel_state.py similarity index 98% rename from vllm/model_executor/parallel_utils/parallel_state.py rename to vllm/distributed/parallel_state.py index 3bbfa1bd5443..4bb77146295a 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -8,8 +8,6 @@ import torch -from vllm.model_executor.parallel_utils import pynccl_utils - # Tensor model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None # Pipeline model parallel group that the current rank belongs to. @@ -266,6 +264,7 @@ def destroy_model_parallel(): _PIPELINE_MODEL_PARALLEL_GROUP = None global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None + from vllm.distributed.device_communicators import pynccl_utils # Destroy the pynccl states if any. pynccl_utils.destroy_process_group() @@ -279,6 +278,7 @@ def destroy_model_parallel(): @contextlib.contextmanager def with_pynccl_for_all_reduce(): + from vllm.distributed.device_communicators import pynccl_utils """use pynccl instead of torch.distributed for all reduce""" tp_size = get_tensor_model_parallel_world_size() if tp_size == 1: diff --git a/vllm/model_executor/parallel_utils/utils.py b/vllm/distributed/utils.py similarity index 100% rename from vllm/model_executor/parallel_utils/utils.py rename to vllm/distributed/utils.py diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 050501475395..dd33868f7630 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -10,6 +10,12 @@ from transformers import PretrainedConfig from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_gather) from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -18,13 +24,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, - tensor_model_parallel_gather) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import ( - split_tensor_along_last_dim) if TYPE_CHECKING: pass diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index f569a5a49cbd..6786c48e0cab 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -7,10 +7,9 @@ import torch.nn.functional as F from vllm._C import ops +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import divide from vllm.model_executor.utils import set_weight_attrs diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f3d4d1789db2..8f42b3e8a4ab 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -5,13 +5,12 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) from vllm.logger import init_logger -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import ( - divide, split_tensor_along_last_dim) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index ec531f79ced5..e556e31f9937 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -4,8 +4,7 @@ import torch import torch.nn as nn -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_gather) +from vllm.distributed import tensor_model_parallel_gather from vllm.model_executor.sampling_metadata import SamplingMetadata diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 73bbfac33ed1..088c0849243c 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -4,11 +4,9 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import divide +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.utils import set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index fa5a27b5a697..30588aecdebe 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -27,6 +27,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -38,8 +40,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index a9ff90909058..40966ab33631 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -24,6 +24,8 @@ from transformers import BloomConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -33,8 +35,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 4008896e48dd..7b46ba306619 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -10,6 +10,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -21,8 +22,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 29ba3844eb11..aa27f0a96c74 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -29,6 +29,8 @@ from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -39,8 +41,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 14c0fece6921..49eb7f1b2c18 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -5,6 +5,9 @@ import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.linear import (LinearMethodBase, QKVParallelLinear, @@ -15,10 +18,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 2a2182ff4eba..c7dd11d07e6d 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -28,6 +28,9 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm @@ -41,10 +44,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 77c19b227d21..4f1ebcd5fb43 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -27,6 +27,9 @@ from transformers import FalconConfig as HF_FalconConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -37,10 +40,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 08609532b8b3..fc1fc3557036 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -23,6 +23,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -35,8 +36,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 3f816a9996be..43f0d47fcb12 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -24,6 +24,7 @@ from transformers import GPT2Config from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -33,8 +34,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 07c647c2e1c4..cec2d771adfa 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -25,6 +25,7 @@ from transformers import GPTBigCodeConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -34,8 +35,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 94048efe4842..566009765274 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -23,6 +23,7 @@ from transformers import GPTJConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -33,8 +34,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index a5b5d717d984..2f9e2171cf11 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -23,6 +23,7 @@ from transformers import GPTNeoXConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -33,8 +34,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index bdb48bf21042..6e9cbd3f9f43 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -6,6 +6,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -17,8 +18,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 12fc9dbd5073..a041b0c9a045 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -26,6 +26,8 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -34,8 +36,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 72fe21df67d8..c86e292e7df1 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,6 +29,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -40,8 +42,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 99d1b4eb97bb..49eda9c9a811 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -29,6 +29,9 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm @@ -42,10 +45,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 429bc8109b9f..ff552a9d8653 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -29,6 +29,9 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -40,10 +43,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 75f86bc134ee..1f0c0e912bee 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -30,6 +30,9 @@ from transformers import MixtralConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, QKVParallelLinear, @@ -40,10 +43,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index a39f94359a94..af4cdce29d08 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -7,6 +7,8 @@ import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -16,8 +18,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 611a48a9aad2..3513c7287910 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -44,6 +44,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -55,8 +56,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index c1ae1b2ae0f0..3a640850662c 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -24,6 +24,7 @@ from transformers import OPTConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -34,8 +35,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index ee910563b20d..c606ac027e9d 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -11,6 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -21,8 +22,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 40e068acaba7..e91624da9095 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -42,6 +42,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -52,8 +53,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a63b9c8d63d1..6213a2ded65a 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -11,6 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -22,8 +23,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 8c92cd773f6b..796e30e633e8 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -30,6 +30,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -41,8 +42,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 6b4a74198fd5..f920b4f5a40c 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -30,6 +30,9 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm @@ -43,10 +46,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index b83637fd50dc..651598b770f1 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -26,6 +26,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -36,8 +37,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 50d23e0a3b6e..76e8e4867341 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -25,6 +25,7 @@ from transformers import Starcoder2Config from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -35,8 +36,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 83d2ddb2bcf3..7e9ce9e5c8e1 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -28,6 +28,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -39,8 +40,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/parallel_utils/README.md b/vllm/model_executor/parallel_utils/README.md deleted file mode 100644 index b25e3afddad9..000000000000 --- a/vllm/model_executor/parallel_utils/README.md +++ /dev/null @@ -1 +0,0 @@ -The files in this folder are ported from [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core). We only keep the codes that are used in inference. \ No newline at end of file diff --git a/vllm/test_utils.py b/vllm/test_utils.py index bc220d3b8a43..0cf23e4bb7e7 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -1,7 +1,7 @@ import ray -from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized, init_distributed_environment) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) from vllm.utils import get_open_port diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 42f0828b826e..751384eb72af 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -7,13 +7,12 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) +from vllm.distributed import (broadcast_tensor_dict, + ensure_model_parallel_initialized, + init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict) -from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized, init_distributed_environment) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.model_runner import ModelRunner diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e7f20475ab1a..1de4748b7bcc 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,17 +9,15 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce +from vllm.distributed.device_communicators import (custom_all_reduce, + pynccl_utils) from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict) -from vllm.model_executor.parallel_utils.parallel_state import ( - with_pynccl_for_all_reduce) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 19de33089b2d..3f0b2fd83f3e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,14 +8,14 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.distributed import (broadcast_tensor_dict, + ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.distributed.device_communicators import pynccl_utils +from vllm.distributed.device_communicators.custom_all_reduce import ( + init_custom_ar) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.model_executor.parallel_utils import pynccl_utils -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict) -from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar -from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized, init_distributed_environment) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner From 67b4221a61ace91a79aff507df0a95a01978300e Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 11 Apr 2024 09:56:48 +0900 Subject: [PATCH 074/120] [Core][5/N] Fully working chunked prefill e2e (#3884) --- .buildkite/test-pipeline.yaml | 2 + benchmarks/benchmark_latency.py | 3 +- benchmarks/benchmark_throughput.py | 62 ++-- .../basic_correctness/test_chunked_prefill.py | 70 ++++ tests/core/test_chunked_prefill_scheduler.py | 16 +- .../test_basic_distributed_correctness.py | 7 +- .../test_chunked_prefill_distributed.py | 66 ++++ tests/entrypoints/test_openai_server.py | 2 +- tests/models/test_models.py | 2 +- tests/worker/test_model_runner.py | 189 ++++++++-- vllm/attention/__init__.py | 4 +- vllm/attention/backends/abstract.py | 42 ++- vllm/attention/backends/flash_attn.py | 85 +++-- vllm/attention/backends/rocm_flash_attn.py | 97 ++++-- vllm/attention/backends/torch_sdpa.py | 67 ++-- vllm/attention/backends/xformers.py | 138 ++++---- vllm/attention/layer.py | 5 +- vllm/attention/ops/paged_attn.py | 6 - vllm/config.py | 13 +- vllm/core/scheduler.py | 15 +- vllm/distributed/communication_op.py | 10 +- vllm/engine/arg_utils.py | 5 +- vllm/engine/llm_engine.py | 5 +- vllm/lora/layers.py | 5 +- vllm/sequence.py | 3 +- vllm/worker/model_runner.py | 323 +++++++++++++----- 26 files changed, 927 insertions(+), 315 deletions(-) create mode 100644 tests/basic_correctness/test_chunked_prefill.py create mode 100644 tests/distributed/test_chunked_prefill_distributed.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 27e44463a30a..695290ed74ab 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -29,6 +29,8 @@ steps: - pytest -v -s test_pynccl.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py - label: Engine Test command: pytest -v -s engine tokenization test_sequence.py test_config.py diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 91510dafc57a..aadbc441713f 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -177,8 +177,7 @@ def run_to_completion(profile_dir: Optional[str] = None): help='block size of key/value cache') parser.add_argument( '--enable-chunked-prefill', - type=bool, - default=False, + action='store_true', help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') parser.add_argument( diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e71338273d1e..6df1e1d628e6 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -74,25 +74,31 @@ def run_vllm( quantization_param_path: Optional[str], device: str, enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, gpu_memory_utilization: float = 0.9, download_dir: Optional[str] = None, ) -> float: from vllm import LLM, SamplingParams - llm = LLM(model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, - device=device, - enable_prefix_caching=enable_prefix_caching, - download_dir=download_dir) + llm = LLM( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + ) # Add the requests to the engine. for prompt, _, output_len in requests: @@ -213,15 +219,15 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - elapsed_time = run_vllm(requests, args.model, args.tokenizer, - args.quantization, args.tensor_parallel_size, - args.seed, args.n, args.use_beam_search, - args.trust_remote_code, args.dtype, - args.max_model_len, args.enforce_eager, - args.kv_cache_dtype, - args.quantization_param_path, args.device, - args.enable_prefix_caching, - args.gpu_memory_utilization, args.download_dir) + elapsed_time = run_vllm( + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.gpu_memory_utilization, + args.download_dir) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -335,6 +341,14 @@ def main(args: argparse.Namespace): "--enable-prefix-caching", action='store_true', help="enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') parser.add_argument('--download-dir', type=str, default=None, diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py new file mode 100644 index 000000000000..9ff07b3c0902 --- /dev/null +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -0,0 +1,70 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +It tests chunked prefill. Chunked prefill can be enabled by +enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens, +prefill requests are chunked. + +Run `pytest tests/models/test_chunked_prefill.py`. +""" +import pytest + +MODELS = [ + "facebook/opt-125m", + "meta-llama/Llama-2-7b-hf", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("enforce_eager", [False, True]) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, + enforce_eager: bool, + tensor_parallel_size: int, +) -> None: + if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16 + and not enforce_eager): + pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} " + "for high TP to save testing time.") + max_num_seqs = min(chunked_prefill_token_size, 256) + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + print(vllm_outputs[0]) + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 05e62ced5898..cce396bf4953 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -104,10 +104,10 @@ def test_chunk(): # One chunked prefill, and one decoding. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set(running) - # The first one is decoding. - assert seq_group_meta[0].token_chunk_size == 1 + # The first one is prefill. Scheduler guarantees ordering. + assert seq_group_meta[0].token_chunk_size == 56 # The second one is a chunked prefill. - assert seq_group_meta[1].token_chunk_size == 56 + assert seq_group_meta[1].token_chunk_size == 1 assert out.num_prefill_groups == 1 assert out.num_batched_tokens == 57 @@ -157,12 +157,12 @@ def test_complex(): # Decoding & chunked prefill & first chunk of 3rd request is scheduled. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert len(get_sequence_groups(out)) == 3 - # The first one is decoding. - assert seq_group_meta[0].token_chunk_size == 1 - # The second one is a chunked prefill. + # The first one is the first chunked prefill. + assert seq_group_meta[0].token_chunk_size == 7 + # The second one is the second new chunked prefill. assert seq_group_meta[1].token_chunk_size == 56 - # The third one is also chunked. - assert seq_group_meta[2].token_chunk_size == 7 + # The last one is decode. + assert seq_group_meta[2].token_chunk_size == 1 # Two of them are in chunked prefill. assert out.num_prefill_groups == 2 assert out.num_batched_tokens == 64 diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 1eba14d7a642..77aa90b12bf8 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -33,11 +33,16 @@ def test_models( dtype: str, max_tokens: int, ) -> None: + hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2) + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py new file mode 100644 index 000000000000..737b1f316951 --- /dev/null +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -0,0 +1,66 @@ +"""Compare the outputs of HF and distributed vLLM when using greedy sampling. +vLLM will allocate all the available memory, so we need to run the tests one +by one. The solution is to pass arguments (model name) by environment +variables. + +Run: +```sh +TEST_DIST_MODEL=facebook/opt-125m pytest \ + test_chunked_prefill_distributed.py +TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ + test_chunked_prefill_distributed.py +``` +""" +import os + +import pytest +import torch + +MODELS = [ + os.environ["TEST_DIST_MODEL"], +] + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, +) -> None: + # Add a chunked prefill config. + max_num_seqs = min(chunked_prefill_token_size, 256) + assert chunked_prefill_token_size != -1 + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + max_num_seqs=max_num_seqs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 442f8bdf3b4b..6f2086c4dd26 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -141,7 +141,7 @@ def server(zephyr_lora_files): "--max-cpu-loras", "2", "--max-num-seqs", - "128" + "128", ]) ray.get(server_runner.ready.remote()) yield server_runner diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 53a80d461964..cfe2539e3a05 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -12,7 +12,7 @@ "gpt2", "bigcode/tiny_starcoder_py", "EleutherAI/pythia-70m", - "bigscience/bloom-560m", + "bigscience/bloom-560m", # Testing alibi slopes. "microsoft/phi-2", "stabilityai/stablelm-3b-4e1t", # "allenai/OLMo-1B", # Broken diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 5b6f001f62fa..dcaae4af4a6f 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,14 +1,18 @@ import pytest import torch -from vllm.config import ModelConfig +from vllm.config import ModelConfig, SchedulerConfig from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_prompt(batch_size): - model_runner = ModelRunner(None, None, None, None, None) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=False) + model_runner = ModelRunner(None, None, scheduler_config, None, None) model_runner.set_block_size(16) prompt_lens = [] @@ -36,8 +40,10 @@ def test_prepare_prompt(batch_size): prompt_len - 1) selected_token_start_idx += prompt_len (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, - _, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) + _, _, + slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens + assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device @@ -45,8 +51,6 @@ def test_prepare_prompt(batch_size): assert torch.allclose(attn_metadata.prompt_lens_tensor, torch.tensor(prompt_lens, device=device)) assert attn_metadata.prompt_lens == prompt_lens - assert attn_metadata.num_prompt_tokens == sum(prompt_lens) - assert attn_metadata.num_generation_tokens == 0 assert attn_metadata.max_prompt_len == max(prompt_lens) # Test subquery start locs. @@ -83,23 +87,22 @@ def test_prepare_prompt(batch_size): assert torch.allclose(attn_metadata.block_tables, expected) # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is False - assert attn_metadata.kv_cache_dtype == "auto" - assert input_tokens.shape == (sum(prompt_lens), ) - assert input_positions.shape == (sum(prompt_lens), ) + assert len(input_tokens) == sum(prompt_lens) + assert len(input_positions) == sum(prompt_lens) torch.testing.assert_close(input_tokens, input_positions) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - assert input_tokens.shape == (sum(prompt_lens), ) - assert input_positions.shape == (sum(prompt_lens), ) + assert len(input_tokens) == sum(prompt_lens) + assert len(input_positions) == sum(prompt_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) - torch.testing.assert_close(input_tokens, input_positions) + assert input_tokens == input_positions actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, @@ -122,7 +125,12 @@ def test_prepare_decode_cuda_graph(batch_size): revision=None, enforce_eager=False, ) - model_runner = ModelRunner(model_config, None, None, None, None) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=False) + model_runner = ModelRunner(model_config, None, scheduler_config, None, + None) model_runner.set_block_size(16) prompt_lens = [] @@ -143,16 +151,15 @@ def test_prepare_decode_cuda_graph(batch_size): assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) - input_tokens, input_positions, attn_metadata, _, _, _ = ( + input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( model_runner._prepare_decode(seq_group_metadata_list)) + assert len(slot_mapping) == len(input_tokens) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is False assert attn_metadata.prompt_lens is None - assert attn_metadata.num_prompt_tokens == 0 - assert attn_metadata.num_generation_tokens == expected_bs assert attn_metadata.max_prompt_len is None assert attn_metadata.subquery_start_loc is None assert attn_metadata.seq_start_loc is None @@ -170,11 +177,10 @@ def test_prepare_decode_cuda_graph(batch_size): model_runner.get_max_block_per_batch()) # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is True - assert attn_metadata.kv_cache_dtype == "auto" - assert input_tokens.shape == (expected_bs, ) - assert input_positions.shape == (expected_bs, ) - torch.testing.assert_close(input_tokens, input_positions) + assert len(input_tokens) == expected_bs + assert len(input_positions) == expected_bs + assert input_tokens == input_positions # Verify Sampling expected_selected_token_indices = [] @@ -190,3 +196,148 @@ def test_prepare_decode_cuda_graph(batch_size): device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) + + +def test_empty_seq_group(): + """Verify prepare prompt and decode returns empty output.""" + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=False, + ) + model_runner = ModelRunner(model_config, None, None, None, None) + model_runner.set_block_size(16) + seq_group_metadata_list = [] + input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( + model_runner._prepare_decode(seq_group_metadata_list)) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + + (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, + _, _, + slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + assert len(return_prompt_lens) == 0 + + +@pytest.mark.parametrize("batch_size", list(range(2, 128))) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): + + def get_world_size(group=None): + return 1 + + def mock_get_process_group_ranks(group=None): + return [0] + + monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size) + monkeypatch.setattr(torch.distributed, "get_process_group_ranks", + mock_get_process_group_ranks) + + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=enforce_eager, + ) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=True) + model_runner = ModelRunner(model_config, + None, + scheduler_config, + None, + None, + is_driver_worker=True) + model_runner.set_block_size(16) + + # Add prefill requests. + prompt_lens = [] + seq_group_metadata_list = [] + prefill_metadata_list = [] + decode_metadata_list = [] + block_tables = {0: [1]} + prefill_batch_size = batch_size // 2 + decode_batch_size = batch_size - prefill_batch_size + for i in range(prefill_batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_lens.append(prompt_len) + seq_data = SequenceData(list(range(prompt_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + ) + assert seq_group_metadata.token_chunk_size == seq_data.get_len() + seq_group_metadata_list.append(seq_group_metadata) + prefill_metadata_list.append(seq_group_metadata) + + # Add decode requests + for i in range(prefill_batch_size, batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(prompt_len)) + seq_data = SequenceData(prompt_toks) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + ) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) + decode_metadata_list.append(seq_group_metadata) + + (input_tokens, input_positions, attn_metadata, _, _, _, + _) = model_runner.prepare_input_tensors(seq_group_metadata_list) + + prefill_meta_actual = attn_metadata.prefill_metadata + decode_meta_actual = attn_metadata.decode_metadata + + assert len(attn_metadata.slot_mapping) == len(input_tokens) + assert len(input_positions) == len(input_tokens) + assert attn_metadata.kv_cache_dtype == "auto" + assert attn_metadata.num_prefills == prefill_batch_size + if enforce_eager: + assert attn_metadata.num_decode_tokens == decode_batch_size + else: + assert attn_metadata.num_decode_tokens == _get_graph_batch_size( + decode_batch_size) + assert attn_metadata.num_prefill_tokens == sum(prompt_lens) + + # Verify attn metadata is consistent. We don't need to test individual + # values here because they are tested above. + prefill_meta = model_runner._prepare_prompt( + prefill_metadata_list).attn_metadata + decode_meta = model_runner._prepare_decode( + decode_metadata_list).attn_metadata + + for attr_expected, attr_actual in zip(vars(prefill_meta), + vars(prefill_meta_actual)): + assert attr_expected[1] == attr_actual[1] + for attr_expected, attr_actual in zip(vars(decode_meta), + vars(decode_meta_actual)): + assert attr_expected[1] == attr_actual[1] diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 9acb82c0df2c..7636b34a16fe 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,5 +1,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -8,4 +9,5 @@ "AttentionMetadata", "Attention", "get_attn_backend", + "AttentionMetadataPerStage", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index a03cf2dd7a6f..7a4ccecf702f 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar import torch @@ -47,7 +47,8 @@ def copy_blocks( @dataclass -class AttentionMetadata: +class AttentionMetadataPerStage: + """Attention metadata for a specific stage. I.e., prefill or decode.""" def asdict_zerocopy(self) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" @@ -59,6 +60,41 @@ def asdict_zerocopy(self) -> Dict[str, Any]: } +T = TypeVar("T", bound=AttentionMetadataPerStage) + + +@dataclass +class AttentionMetadata(Generic[T]): + """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # The attention metadata for prefill requests in a batch. + # None if there's no prefill requests in a batch. + prefill_metadata: Optional[T] + # The attention metadata for decode requests in a batch. + # None if there's no decode requests in a batch. + decode_metadata: Optional[T] + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + # The kv cache's data type. + kv_cache_dtype: str + + def __post_init__(self): + if self.num_prefill_tokens > 0: + assert self.num_prefills > 0 + assert self.prefill_metadata is not None + if self.num_decode_tokens > 0: + assert self.decode_metadata is not None + + class AttentionImpl(ABC): @abstractmethod @@ -80,7 +116,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + attn_metadata: AttentionMetadata[AttentionMetadataPerStage], kv_scale: float, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4e0d9d1418b3..12e8c4404b94 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -11,7 +11,8 @@ from flash_attn import flash_attn_varlen_func from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -53,7 +54,8 @@ def copy_blocks( @dataclass -class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -68,10 +70,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| @@ -107,18 +105,27 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): class FlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -155,7 +162,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, + attn_metadata: AttentionMetadata[FlashAttentionMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -188,52 +195,70 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - output = flash_attn_varlen_func( + out = flash_attn_varlen_func( q=query, k=key, v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prompt_len, + max_seqlen_k=prefill_meta.max_prompt_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. - output = PagedAttention.forward_prefix( + output[:num_prefill_tokens] = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 6019d917b449..e55435cd2c94 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -6,7 +6,8 @@ import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -51,7 +52,8 @@ def copy_blocks( @dataclass -class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): +class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -66,10 +68,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| @@ -117,6 +115,15 @@ class ROCmFlashAttentionImpl(AttentionImpl): The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| + |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -181,7 +188,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: ROCmFlashAttentionMetadata, + attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata], kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -218,9 +225,25 @@ def forward( kv_scale, ) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -230,63 +253,69 @@ def forward( key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) if self.use_naive_attn: - output = self.attn_fuc( + out = self.attn_fuc( query, key, value, - attn_metadata.prompt_lens, + prefill_meta.prompt_lens, self.scale, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: - output, _ = self.attn_func( + out, _ = self.attn_func( query, key, value, None, - attn_metadata.seq_start_loc, - attn_metadata.seq_start_loc, - attn_metadata.max_prompt_len, - attn_metadata.max_prompt_len, + prefill_meta.seq_start_loc, + prefill_meta.seq_start_loc, + prefill_meta.max_prompt_len, + prefill_meta.max_prompt_len, True, self.scale, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: - output = self.attn_func( + out = self.attn_func( q=query, k=key, v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prompt_len, + max_seqlen_k=prefill_meta.max_prompt_len, softmax_scale=self.scale, causal=True, ) - + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention - output = PagedAttention.forward_prefix( + output[:num_prefill_tokens] = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: + + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 9706e1910cb7..63904ea92987 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,7 +7,8 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -49,17 +50,14 @@ def copy_blocks( @dataclass -class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): +class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - slot_mapping: torch.Tensor prompt_lens: Optional[List[int]] prompt_lens_tensor: Optional[torch.Tensor] - num_prompt_tokens: int - num_generation_tokens: int max_subquery_len: Optional[int] = None max_prompt_len: Optional[int] = None @@ -113,7 +111,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: TorchSDPAMetadata, + attn_metadata: AttentionMetadata[TorchSDPAMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -142,36 +140,51 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: - if (kv_cache is None or attn_metadata.block_tables.numel() == 0): + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + if (kv_cache is None or prefill_meta.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - if attn_metadata.attn_bias is None: + if prefill_meta.attn_bias is None: if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.prompt_lens) # type: ignore + prefill_meta.prompt_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - attn_metadata.prompt_lens, self.sliding_window, + prefill_meta.prompt_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(attn_metadata.prompt_lens) - attn_metadata.attn_bias = att_masks + att_masks = [None] * len(prefill_meta.prompt_lens) + prefill_meta.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) start = 0 - output = torch.empty( - (num_tokens, self.num_heads, self.head_size), - dtype=query.dtype) - for prompt_len, mask in zip(attn_metadata.prompt_lens, - attn_metadata.attn_bias): + out = torch.empty((num_tokens, self.num_heads, self.head_size), + dtype=query.dtype) + for prompt_len, mask in zip(prefill_meta.prompt_lens, + prefill_meta.attn_bias): end = start + prompt_len sub_out = scaled_dot_product_attention( query[:, start:end, :], @@ -181,28 +194,32 @@ def forward( dropout_p=0.0, is_causal=not self.need_mask, scale=self.scale).movedim(query.dim() - 2, 0) - output[start:end, :, :] = sub_out + out[start:end, :, :] = sub_out start = end + assert out.shape == output[:num_prefill_tokens].shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention raise RuntimeError( "Torch SDPA backend doesn't support prefix decoding.") - else: + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + out = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, kv_scale, ) + assert out.shape == output[num_prefill_tokens:].shape + output[num_prefill_tokens:] # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 05b68bba5e6e..b745a04a143b 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -9,7 +9,8 @@ LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -54,7 +55,7 @@ def copy_blocks( @dataclass -class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): +class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): """Metadata for XFormersbackend. NOTE: Any python object stored here is not updated when it is @@ -65,19 +66,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor # (batch_size,). The prompt length per sequence. None if it is a decoding. prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| @@ -123,18 +115,27 @@ def __post_init__(self): class XFormersImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens --------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->| + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -170,7 +171,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: XFormersMetadata, + attn_metadata: AttentionMetadata[XFormersMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -202,59 +203,61 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention. # block tables are empty if the prompt does not have a cached # prefix. - if self.num_kv_heads != self.num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # TODO(woosuk): Use MQA/GQA kernels for higher performance. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - - output = self._run_memory_efficient_xformers_forward( - query, key, value, attn_metadata) + out = self._run_memory_efficient_xformers_forward( + query, key, value, prefill_meta) + assert out.shape == output[:num_prefill_tokens].shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. - output = PagedAttention.forward_prefix( + out = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: - # Decoding run. - output = PagedAttention.forward_decode( - query, + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + + if decode_meta := attn_metadata.decode_metadata: + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -275,13 +278,30 @@ def _run_memory_efficient_xformers_forward( """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. + See https://facebookresearch.github.io/xformers/components/ops.html + for API spec. + Args: - output: shape = [num_prompt_tokens, num_heads, head_size] - query: shape = [num_prompt_tokens, num_heads, head_size] - key: shape = [num_prompt_tokens, num_kv_heads, head_size] - value: shape = [num_prompt_tokens, num_kv_heads, head_size] + output: shape = [num_prefill_tokens, num_heads, head_size] + query: shape = [num_prefill_tokens, num_heads, head_size] + key: shape = [num_prefill_tokens, num_kv_heads, head_size] + value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ + original_query = query + if self.num_kv_heads != self.num_heads: + # GQA/MQA requires the shape [B, M, G, H, K]. + # Note that the output also has the same shape (which is different + # from a spec from the doc). + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. @@ -302,6 +322,7 @@ def _run_memory_efficient_xformers_forward( # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. if self.alibi_slopes is None: + # Add the batch dimension. query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) @@ -312,14 +333,13 @@ def _run_memory_efficient_xformers_forward( attn_bias=attn_metadata.attn_bias[0], p=0.0, scale=self.scale) - - return out.view_as(query) + return out.view_as(original_query) # Attention with alibi slopes. # FIXME(woosuk): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. - output = torch.empty_like(query) + output = torch.empty_like(original_query) start = 0 for i, prompt_len in enumerate(attn_metadata.prompt_lens): end = start + prompt_len @@ -331,7 +351,7 @@ def _run_memory_efficient_xformers_forward( p=0.0, scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out.squeeze(0)) + output[start:end].copy_(out.view_as(original_query[start:end])) start += prompt_len return output diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9856654fc5f9..fc65ae108dbb 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,7 +4,8 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.abstract import (AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.selector import get_attn_backend @@ -41,7 +42,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata, + attn_metadata: AttentionMetadata[AttentionMetadataPerStage], kv_scale: float = 1.0, ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 256bffdf032e..2d918491d657 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -13,11 +13,6 @@ @dataclass class PagedAttentionMetadata: """Metadata for PagedAttention.""" - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor # (batch_size,). The length of context (tokens stored in KV cache) per # sequence. WARNING: When it is a prefill request, it doesn't include new # tokens. When it is for decoding, it includes a new token. @@ -31,7 +26,6 @@ class PagedAttentionMetadata: # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. block_tables: Optional[torch.Tensor] - kv_cache_dtype: str class PagedAttention: diff --git a/vllm/config.py b/vllm/config.py index bca250e92228..4102edbe01d3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -565,9 +565,16 @@ def __init__( if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens else: - # If max_model_len is too short, use 2048 as the default value for - # higher throughput. - self.max_num_batched_tokens = max(max_model_len, 2048) + if enable_chunked_prefill: + # For chunked prefill, choose the well-tuned batch size. + self.max_num_batched_tokens = 768 + else: + # If max_model_len is too short, use 2048 as the default value + # for higher throughput. + self.max_num_batched_tokens = max(max_model_len, 2048) + if enable_chunked_prefill: + logger.info("Chunked prefill is enabled (EXPERIMENTAL).") + self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.use_v2_block_manager = use_v2_block_manager diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 0ae53f937496..2942eab735a9 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -140,7 +140,11 @@ def _sort_by_lora_ids(self) -> bool: @property def lora_requests(self) -> Set[LoRARequest]: - return {g.seq_group.lora_request for g in self.scheduled_seq_groups} + return { + g.seq_group.lora_request + for g in self.scheduled_seq_groups + if g.seq_group.lora_request is not None + } @dataclass @@ -826,13 +830,12 @@ def _schedule_chunked_prefill(self): # Update swapped requests. self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) - return SchedulerOutputs( scheduled_seq_groups=(prefills.seq_groups + - running_scheduled.decode_seq_groups + running_scheduled.prefill_seq_groups + - swapped_in.decode_seq_groups + - swapped_in.prefill_seq_groups), + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups), num_prefill_groups=(len(prefills.seq_groups) + len(swapped_in.prefill_seq_groups) + len(running_scheduled.prefill_seq_groups)), @@ -907,7 +910,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. - is_prompt = i < scheduler_outputs.num_prefill_groups + is_prompt = seq_group.is_prefill() seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=is_prompt, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index cf15db099b30..1004d626b6a4 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -173,10 +173,18 @@ def broadcast_tensor_dict( torch.distributed.broadcast_object_list([metadata_list], src=src, group=group) + async_handles = [] for key, value in metadata_list: if isinstance(value, TensorMetadata): tensor = tensor_dict[key] - torch.distributed.broadcast(tensor, src=src, group=group) + async_handles.append( + torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True)) + for async_handle in async_handles: + async_handle.wait() + else: recv_metadata_list = [None] torch.distributed.broadcast_object_list(recv_metadata_list, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4b573992c06..daefddc01b43 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -386,9 +386,8 @@ def add_cli_args( 'prompt latency) before scheduling next prompt.') parser.add_argument( '--enable-chunked-prefill', - type=bool, - default=False, - help='If True, the prefill requests can be chunked based on the ' + action='store_true', + help='If set, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') parser.add_argument( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1c639af69654..ddfdda898a5c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -633,7 +633,10 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - self._process_sequence_group_outputs(seq_group, outputs) + # If uncomputed tokens > 0, it means prefill is chunked. + # We don't need to process outputs in that case. + if seq_group.get_num_uncomputed_tokens() == 0: + self._process_sequence_group_outputs(seq_group, outputs) # Free the finished sequence groups. self.scheduler.free_finished_seq_groups() diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index dd33868f7630..84a94091486d 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -267,12 +267,13 @@ def set_mapping( def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = x > self.base_layer.org_vocab_size - 1 - indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) + embedding_len = self.indices_len[3] + indices = self.embeddings_indices[1][:embedding_len].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, self.lora_a_stacked_2d, ) - indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) + indices = self.embeddings_indices[0][:embedding_len].view_as(x) full_output = self.base_layer.forward( x.add_(indices * added_tokens_mask)) diff --git a/vllm/sequence.py b/vllm/sequence.py index 576bbe8c4f6c..77029908c221 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -500,7 +500,8 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int): def get_num_uncomputed_tokens(self) -> int: num_uncomputed_tokens = 0 for seq in self.get_seqs(): - num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() + if not seq.is_finished(): + num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() return num_uncomputed_tokens def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1de4748b7bcc..47ad8f0c9b78 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,12 +1,14 @@ import contextlib import time -from typing import Dict, List, Optional, Set, Tuple +from enum import IntEnum +from typing import Dict, List, NamedTuple, Optional, Set, Tuple import numpy as np import torch import torch.nn as nn -from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, + get_attn_backend) from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce @@ -37,6 +39,66 @@ ] +class PreparePromptMetadata(NamedTuple): + input_tokens: List[int] + input_positions: List[int] + attn_metadata: Optional[AttentionMetadataPerStage] + prompt_lens: List[int] + subquery_lens: List[int] + lora_index_mapping: List[int] + lora_prompt_mapping: List[int] + lora_requests: Set[LoRARequest] + multi_modal_input: Optional[torch.Tensor] + slot_mapping: List[int] + + @classmethod + def empty(cls): + return PreparePromptMetadata( + input_tokens=[], + input_positions=[], + attn_metadata=None, + prompt_lens=[], + subquery_lens=[], + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + multi_modal_input=None, + slot_mapping=[], + ) + + +class PrepareDecodeMetadata(NamedTuple): + input_tokens: List[int] + input_positions: List[int] + attn_metadata: Optional[AttentionMetadata] + lora_index_mapping: List[int] + lora_prompt_mapping: List[int] + lora_requests: Set[LoRARequest] + slot_mapping: List[int] + + @classmethod + def empty(cls): + return PrepareDecodeMetadata( + input_tokens=[], + input_positions=[], + attn_metadata=None, + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + slot_mapping=[], + ) + + +# How batches are constructed. +class BatchType(IntEnum): + # Every batch is prefill. + PREFILL = 0 + # Every batch is decode. + DECODE = 1 + # Batch is a mixture of prefill and decode. + MIXED = 2 + + class ModelRunner: def __init__( @@ -152,10 +214,7 @@ def get_max_block_per_batch(self) -> int: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], List[int], List[int], Set[LoRARequest], - torch.Tensor]: - assert len(seq_group_metadata_list) > 0 + ) -> PreparePromptMetadata: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -169,6 +228,9 @@ def _prepare_prompt( prefix_block_tables: List[List[int]] = [] multi_modal_input_list: List[torch.Tensor] = [] + if len(seq_group_metadata_list) == 0: + return PreparePromptMetadata.empty() + for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -178,7 +240,8 @@ def _prepare_prompt( computed_block_nums = seq_group_metadata.computed_block_nums if (self.scheduler_config is not None and self.scheduler_config.chunked_prefill_enabled - and computed_block_nums is not None): + and not (computed_block_nums is None + or computed_block_nums == [])): raise RuntimeError( "chunked prefill cannot be used with prefix caching " "now.") @@ -190,13 +253,8 @@ def _prepare_prompt( # it contains output tokens. prefill_end = min(seq_data.get_len(), computed_len + token_chunk_size) - # TODO(sang): Rename it after chunked prefill is introduced. prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] - prompt_len = len(prompt_tokens) - # Right now, the prefill_end is always same as the length of - # sequence. However, once chunked prefill is introduced, this - # assumption can be changed. - assert prefill_end == seq_data.get_len() + prompt_len = prefill_end prompt_lens.append(prompt_len) # NOTE: This only works for oooooooxxx style attention. @@ -206,6 +264,14 @@ def _prepare_prompt( computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) + elif self.scheduler_config.chunked_prefill_enabled: + if seq_group_metadata.block_tables is not None: + # Prefill has chunked before. + block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) + else: + # The first prefill. + prefix_block_tables.append([]) else: prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this @@ -267,20 +333,8 @@ def _prepare_prompt( max_subquery_len = max(subquery_lens) max_prompt_len = max(prompt_lens) - num_prompt_tokens = len(input_tokens) assert max_subquery_len > 0 - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - lora_index_mapping = lora_index_mapping - context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -332,11 +386,8 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - slot_mapping=slot_mapping, prompt_lens=prompt_lens, prompt_lens_tensor=prompt_lens_tensor, - num_prompt_tokens=num_prompt_tokens, - num_generation_tokens=0, max_subquery_len=max_subquery_len, max_context_len=None, max_prompt_len=max_prompt_len, @@ -345,18 +396,25 @@ def _prepare_prompt( context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, prompt_lens, - subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input) + + return PreparePromptMetadata( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + prompt_lens=prompt_lens, + subquery_lens=subquery_lens, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + multi_modal_input=multi_modal_input, + slot_mapping=slot_mapping, + ) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], Set[LoRARequest]]: - assert len(seq_group_metadata_list) > 0 + ) -> PrepareDecodeMetadata: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -366,6 +424,9 @@ def _prepare_decode( lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + if len(seq_group_metadata_list) == 0: + return PrepareDecodeMetadata.empty() + for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt assert seq_group_metadata.token_chunk_size == 1 @@ -424,15 +485,6 @@ def _prepare_decode( lora_index_mapping.append(0) batch_size = graph_batch_size - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -440,9 +492,9 @@ def _prepare_decode( if use_captured_graph: # When using cuda-graph all these tensors should be # padded. - assert context_lens.shape[0] == input_tokens.shape[0] - assert context_lens.shape[0] == input_positions.shape[0] - assert context_lens.shape[0] == slot_mapping.shape[0] + assert context_lens.shape[0] == len(input_tokens) + assert context_lens.shape[0] == len(input_positions) + assert context_lens.shape[0] == len(slot_mapping) # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -464,11 +516,8 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - slot_mapping=slot_mapping, prompt_lens=None, prompt_lens_tensor=None, - num_prompt_tokens=0, - num_generation_tokens=len(input_tokens), max_subquery_len=None, max_context_len=max_context_len, max_prompt_len=None, @@ -477,10 +526,16 @@ def _prepare_decode( context_lens=context_lens, block_tables=block_tables, use_cuda_graph=use_captured_graph, - kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, - lora_index_mapping, lora_prompt_mapping, lora_requests) + return PrepareDecodeMetadata( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + slot_mapping=slot_mapping, + ) def _prepare_sample( self, @@ -586,26 +641,66 @@ def prepare_input_tensors( ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[int], LoRAMapping, torch.Tensor]: if self.is_driver_worker: - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt + prefill_reqs = [] + decode_reqs = [] + for seq_group_meta in seq_group_metadata_list: + if seq_group_meta.is_prompt: + prefill_reqs.append(seq_group_meta) + else: + decode_reqs.append(seq_group_meta) + # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, prompt_lens, - subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, attn_metadata, - lora_index_mapping, lora_prompt_mapping, - lora_requests) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] - subquery_lens = None - multi_modal_input = None + ( + input_tokens, + input_positions, + prefill_attn_metadata, + prompt_lens, + subquery_lens, + lora_index_mapping, + lora_prompt_mapping, + lora_requests, + multi_modal_input, + slot_mapping, + ) = self._prepare_prompt(prefill_reqs) + ( + decode_input_tokens, + decode_input_positions, + decode_attn_metadata, + decode_lora_index_mapping, + decode_lora_prompt_mapping, + decode_lora_requests, + decode_slot_mapping, + ) = self._prepare_decode(decode_reqs) sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens) + if not self.scheduler_config.chunked_prefill_enabled: + assert (len(prefill_reqs) and len(decode_reqs)) == 0 + + num_prefills = len(prompt_lens) + num_prefill_tokens = len(input_tokens) + num_decode_tokens = len(decode_input_tokens) + + # Coalesce tensors. Note that attn_metadata is currently not + # coalesced for simplicity. + input_tokens.extend(decode_input_tokens) + input_positions.extend(decode_input_positions) + slot_mapping.extend(decode_slot_mapping) + lora_index_mapping.extend(decode_lora_index_mapping) + lora_prompt_mapping.extend(decode_lora_prompt_mapping) + lora_requests.update(decode_lora_requests) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + if self.lora_config: lora_mapping = LoRAMapping( lora_index_mapping, @@ -615,6 +710,16 @@ def prepare_input_tensors( lora_mapping = None # Broadcast the metadata. + # If batch contains both prefill and decode, it sends 2 broadcasts. + # If it only contains 1 type, it triggers a single broadcast. + if (prefill_attn_metadata is not None + and decode_attn_metadata is not None): + batch_type = BatchType.MIXED + elif prefill_attn_metadata is not None: + batch_type = BatchType.PREFILL + else: + batch_type = BatchType.DECODE + metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, @@ -623,19 +728,49 @@ def prepare_input_tensors( "lora_requests": lora_requests, "lora_mapping": lora_mapping, "multi_modal_input": multi_modal_input, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + "batch_type": batch_type, } - metadata_dict.update(attn_metadata.asdict_zerocopy()) + if prefill_attn_metadata is not None: + metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) + else: + metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) + + # Broadcast decode attn metadata for mixed batch type. + # The additional broadcast costs 300us overhead on 4 A10 GPUs. + # We can potentially reduce the overhead by coelescing tensors. + if batch_type == BatchType.MIXED: + assert decode_attn_metadata is not None + metadata_dict = decode_attn_metadata.asdict_zerocopy() + broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") + slot_mapping = metadata_dict.pop("slot_mapping") + num_prefills = metadata_dict.pop("num_prefills") selected_token_indices = metadata_dict.pop( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") - attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") + num_decode_tokens = metadata_dict.pop("num_decode_tokens") + batch_type = metadata_dict.pop("batch_type") + + # Create an attention metadata. + prefill_attn_metadata = None + decode_attn_metadata = None + if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: + prefill_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + else: + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -646,6 +781,23 @@ def prepare_input_tensors( perform_sampling=False, ) + # if it is a mixed batch, decode attn_metadata is broadcasted + # separately. + if batch_type == BatchType.MIXED: + metadata_dict = broadcast_tensor_dict(src=0) + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + + attn_metadata = AttentionMetadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + prefill_metadata=prefill_attn_metadata, + decode_metadata=decode_attn_metadata, + kv_cache_dtype=self.kv_cache_dtype, + ) + return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_input) @@ -663,8 +815,10 @@ def execute_model( if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) - # Execute the model. - if attn_metadata.use_cuda_graph: + # Currently cuda graph is only supported by the decode phase. + prefill_meta = attn_metadata.prefill_metadata + decode_meta = attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: @@ -842,13 +996,10 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): # Create dummy attn_metadata. - attn_metadata = self.attn_backend.make_metadata( + decode_metadata = self.attn_backend.make_metadata( is_prompt=False, - slot_mapping=slot_mapping[:batch_size], prompt_lens=None, prompt_lens_tensor=None, - num_prompt_tokens=0, - num_generation_tokens=batch_size, max_subquery_len=None, max_context_len=self.max_context_len_to_capture, max_prompt_len=None, @@ -857,6 +1008,14 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: context_lens=context_lens[:batch_size], block_tables=block_tables[:batch_size], use_cuda_graph=True, + ) + attn_metadata = AttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], + prefill_metadata=None, + decode_metadata=decode_metadata, kv_cache_dtype=self.kv_cache_dtype, ) @@ -950,8 +1109,8 @@ def capture( "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, - "context_lens": attn_metadata.context_lens, - "block_tables": attn_metadata.block_tables, + "context_lens": attn_metadata.decode_metadata.context_lens, + "block_tables": attn_metadata.decode_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} return @@ -972,10 +1131,10 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["context_lens"].copy_(attn_metadata.context_lens, - non_blocking=True) - self.input_buffers["block_tables"].copy_(attn_metadata.block_tables, - non_blocking=True) + self.input_buffers["context_lens"].copy_( + attn_metadata.decode_metadata.context_lens, non_blocking=True) + self.input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) # Run the graph. self.graph.replay() From caada5e50aa16cd5f59bd7889128a83588ca1f99 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 10 Apr 2024 18:48:26 -0700 Subject: [PATCH 075/120] [Core][Model] torch.compile for layernorm in commandr (#3985) [Core][Model] Use torch.compile to accelerate layernorm in commandr (#3985) --- vllm/model_executor/models/commandr.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index aa27f0a96c74..aa9b28b676e0 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -48,6 +48,18 @@ from vllm.sequence import SamplerOutput +@torch.compile +def layer_norm_func(hidden_states, weight, variance_epsilon): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + + variance_epsilon) + hidden_states = weight.to(torch.float32) * hidden_states + return hidden_states.to(input_dtype) + + class LayerNorm(nn.Module): def __init__(self, param_shape=None, eps=1e-5): @@ -57,14 +69,9 @@ def __init__(self, param_shape=None, eps=1e-5): set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) def forward(self, hidden_states, residuals=None): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - mean = hidden_states.mean(-1, keepdim=True) - variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) - hidden_states = (hidden_states - - mean) * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = self.weight.to(torch.float32) * hidden_states - return hidden_states.to(input_dtype), residuals + hidden_states = layer_norm_func(hidden_states, self.weight, + self.variance_epsilon) + return hidden_states, residuals def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() From e42df7227d18e2b96785f8ee52053663ade05b63 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 11 Apr 2024 12:09:50 +0900 Subject: [PATCH 076/120] [Test] Add xformer and flash attn tests (#3961) Co-authored-by: Simon Mo --- tests/basic_correctness/test_basic_correctness.py | 6 ++++++ vllm/attention/selector.py | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 97cff623c5e1..bd4c7ea3301b 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -4,6 +4,8 @@ """ import pytest +from vllm.attention.selector import VLLM_ATTENTION_BACKEND + MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -14,6 +16,7 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"]) def test_models( hf_runner, vllm_runner, @@ -22,7 +25,10 @@ def test_models( dtype: str, max_tokens: int, enforce_eager: bool, + attn_backend: str, + monkeypatch, ) -> None: + monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend) hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 4c699aed48d4..554e802cd551 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,4 +1,5 @@ import enum +import os from functools import lru_cache from typing import Type @@ -10,6 +11,8 @@ logger = init_logger(__name__) +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" + class _Backend(enum.Enum): FLASH_ATTN = enum.auto() @@ -75,4 +78,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: "Cannot use FlashAttention backend because the flash_attn package " "is not found. Please install it for better performance.") return _Backend.XFORMERS + + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + if backend_by_env_var is not None: + return _Backend[backend_by_env_var] + + # Default case. return _Backend.FLASH_ATTN From e9da5a40c63ce7f8a85438d3c7d919b46e7939f5 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 11 Apr 2024 03:26:07 +0000 Subject: [PATCH 077/120] [Misc] Add indirection layer for custom ops (#3913) --- .../kernels/benchmark_paged_attention.py | 2 +- tests/kernels/test_attention.py | 6 +- tests/kernels/test_cache.py | 25 ++- vllm/_custom_ops.py | 193 ++++++++++++++++++ vllm/attention/ops/paged_attn.py | 10 +- vllm/model_executor/layers/activation.py | 2 +- .../layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/layernorm.py | 2 +- .../model_executor/layers/quantization/awq.py | 2 +- .../layers/quantization/gptq.py | 2 +- .../layers/quantization/marlin.py | 2 +- .../layers/quantization/squeezellm.py | 2 +- .../model_executor/layers/rotary_embedding.py | 2 +- vllm/utils.py | 4 +- 14 files changed, 224 insertions(+), 32 deletions(-) create mode 100644 vllm/_custom_ops.py diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index f71d1fcaaef5..5c3650fa72d1 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -5,7 +5,7 @@ import torch -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random NUM_BLOCKS = 1024 diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 03ea72924921..9b1f3e30b6dc 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -7,7 +7,7 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask -from vllm._C import cache_ops, ops +from vllm import _custom_ops as ops from vllm.utils import get_max_shared_memory_bytes, is_hip FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 @@ -237,14 +237,14 @@ def test_paged_attention( dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(key_cache, dequantized_key_cache) + ops.convert_fp8(key_cache, dequantized_key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(value_cache, dequantized_value_cache) + ops.convert_fp8(value_cache, dequantized_value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 4141aacafd0b..d1051fd7e2f4 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -4,7 +4,7 @@ import pytest import torch -from vllm._C import cache_ops +from vllm import _custom_ops as ops from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] @@ -80,7 +80,7 @@ def test_copy_blocks( cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + ops.copy_blocks(key_caches, value_caches, block_mapping) # Run the reference implementation. for src, dsts in block_mapping.items(): @@ -145,9 +145,9 @@ def test_reshape_and_cache( # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - cache_ops.convert_fp8(key_cache, cloned_key_cache) + ops.convert_fp8(key_cache, cloned_key_cache) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - cache_ops.convert_fp8(value_cache, cloned_value_cache) + ops.convert_fp8(value_cache, cloned_value_cache) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() @@ -156,14 +156,14 @@ def test_reshape_and_cache( kv_scale = 1.0 # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, kv_scale) + ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, + kv_cache_dtype, kv_scale) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - cache_ops.convert_fp8(key_cache, result_key_cache) + ops.convert_fp8(key_cache, result_key_cache) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - cache_ops.convert_fp8(value_cache, result_value_cache) + ops.convert_fp8(value_cache, result_value_cache) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -251,9 +251,8 @@ def test_swap_blocks( src_value_caches_clone = src_value_caches[0].clone() # Call the swap_blocks kernel. - cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) - cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0], - block_mapping) + ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) + ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping) for src, dst in block_mapping.items(): assert torch.allclose(src_key_caches_clone[src].cpu(), @@ -291,9 +290,9 @@ def test_fp8_conversion( cache.uniform_(low, high) cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) - cache_ops.convert_fp8(cache, cache_fp8) + ops.convert_fp8(cache, cache_fp8) converted_cache = torch.empty_like(cache) - cache_ops.convert_fp8(cache_fp8, converted_cache) + ops.convert_fp8(cache_fp8, converted_cache) assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py new file mode 100644 index 000000000000..a0837a20875f --- /dev/null +++ b/vllm/_custom_ops.py @@ -0,0 +1,193 @@ +from typing import Dict, Optional + +import torch + +try: + from vllm._C import cache_ops as vllm_cache_ops + from vllm._C import ops as vllm_ops +except ImportError: + pass + + +# activation ops +def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.silu_and_mul(out, x) + + +def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_and_mul(out, x) + + +def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_tanh_and_mul(out, x) + + +def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_fast(out, x) + + +def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_new(out, x) + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, + context_lens, block_size, max_context_len, + alibi_slopes, kv_cache_dtype, kv_scale) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, + key_cache, value_cache, num_kv_heads, scale, + block_tables, context_lens, block_size, + max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale) + + +# pos encoding ops +def rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, + is_neox) + + +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + vllm_ops.batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) + + +# layer norm ops +def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: + vllm_ops.rms_norm(out, input, weight, epsilon) + + +def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) + + +# quantization ops +# awq +def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: int, thx: int, + thy: int) -> torch.Tensor: + return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, + thy) + + +def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, + scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: + return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) + + +# gptq +def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, use_exllama: bool, + bit: int) -> torch.Tensor: + return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, use_exllama, bit) + + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, + bit: int) -> None: + vllm_ops.gptq_shuffle(q_weight, q_perm, bit) + + +# squeezellm +def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, + lookup_table: torch.Tensor) -> None: + vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) + + +# marlin +def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: + return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, + size_n, size_k) + + +# moe +def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, kv_scale) + + +def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, + block_mapping: torch.Tensor) -> None: + vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: Dict[int, int]) -> None: + vllm_cache_ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None: + vllm_cache_ops.convert_fp8(output, input) + + +#TODO: cuda_utils, custom_ar diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 2d918491d657..cd0690a4ba95 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -3,7 +3,7 @@ import torch -from vllm._C import cache_ops, ops +from vllm import _custom_ops as ops from vllm.attention.ops.prefix_prefill import context_attention_fwd # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -69,7 +69,7 @@ def write_to_paged_cache( kv_cache_dtype: str, kv_scale: float, ) -> None: - cache_ops.reshape_and_cache( + ops.reshape_and_cache( key, value, key_cache, @@ -199,11 +199,11 @@ def swap_blocks( ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( @@ -212,4 +212,4 @@ def copy_blocks( ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 6786c48e0cab..baf1d4f26618 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.quantization import QuantizationConfig diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1ec09f0cd4c2..377b6588dbf4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,7 +8,7 @@ import triton import triton.language as tl -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.utils import is_hip diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index cb3cee2bad5a..a6619714b8aa 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm._C import ops +from vllm import _custom_ops as ops class RMSNorm(nn.Module): diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 2caef5f1ebf5..daea5ac73e42 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -3,7 +3,7 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 53baf710ed81..757ab1af8392 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -6,7 +6,7 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 784229878edf..a6482c059cc4 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -3,7 +3,7 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index ed25455e6ec1..bb295df2acc3 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -3,7 +3,7 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index d80e73bbe39e..eb8d5f6dfb2a 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn -from vllm._C import ops +from vllm import _custom_ops as ops def _rotate_neox(x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/utils.py b/vllm/utils.py index 8ba03333d3b6..8ab8927512cc 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -279,10 +279,10 @@ def _generate_random_fp8( #-----|-------------|------------------- # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} - from vllm._C import cache_ops + from vllm import _custom_ops as ops tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) - cache_ops.convert_fp8(tensor_tmp, tensor) + ops.convert_fp8(tensor_tmp, tensor) del tensor_tmp From f3d0bf7589d6e63a691dcbb9d1db538c184fde29 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 10 Apr 2024 20:33:02 -0700 Subject: [PATCH 078/120] [Doc][Installation] delete python setup.py develop (#3989) --- docs/source/getting_started/installation.rst | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 5dfb32080f97..e7826114ffa9 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -85,13 +85,3 @@ You can also build and install vLLM from source: $ nvcc --version # verify that nvcc is in your PATH $ ${CUDA_HOME}/bin/nvcc --version # verify that nvcc is in your CUDA_HOME - -.. note:: - If you are developing the C++ backend of vLLM, consider building vLLM with - - .. code-block:: console - - $ python setup.py develop - - since it will give you incremental builds. The downside is that this method - is `deprecated by setuptools `_. From c1dc547129f5faaa2ca5ba557145b8ec8838693c Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 11 Apr 2024 07:50:00 -0700 Subject: [PATCH 079/120] [Kernel] Fused MoE Config for Mixtral 8x22 (#4002) --- ...048,device_name=NVIDIA_A100-SXM4-80GB.json | 146 ++++++++++++++++++ ...048,device_name=NVIDIA_H100_80GB_HBM3.json | 146 ++++++++++++++++++ ...096,device_name=NVIDIA_A100-SXM4-80GB.json | 146 ++++++++++++++++++ ...096,device_name=NVIDIA_H100_80GB_HBM3.json | 146 ++++++++++++++++++ 4 files changed, 584 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 000000000000..0bb423b28f5a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000000..26bcbf26970c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 000000000000..dbc624731f5c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000000..32c0c9da471c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} From 08ccee1e830d39ecdb3c6cf382c843dbf5ae830e Mon Sep 17 00:00:00 2001 From: "fuchen.ljl" Date: Thu, 11 Apr 2024 23:59:26 +0800 Subject: [PATCH 080/120] punica fix-bgmv-kernel-640 (#4007) --- csrc/punica/bgmv/bgmv_config.h | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 2219d960ae62..1084a0f20df6 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -14,6 +14,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 128) \ f(in_T, out_T, W_T, narrow, 256) \ f(in_T, out_T, W_T, narrow, 512) \ + f(in_T, out_T, W_T, narrow, 640) \ f(in_T, out_T, W_T, narrow, 768) \ f(in_T, out_T, W_T, narrow, 1024) \ f(in_T, out_T, W_T, narrow, 1152) \ From 8afca50889bad6ad987c523c48c31fc52fcb72e4 Mon Sep 17 00:00:00 2001 From: bigPYJ1151 Date: Fri, 12 Apr 2024 02:56:49 +0800 Subject: [PATCH 081/120] [Hardware][Intel] Isolate CPUModelRunner and ModelRunner for better maintenance (#3824) --- vllm/attention/backends/torch_sdpa.py | 72 ++--- vllm/executor/cpu_executor.py | 10 + vllm/utils.py | 1 - vllm/worker/cpu_model_runner.py | 408 ++++++++++++++++++++++++++ vllm/worker/cpu_worker.py | 13 +- 5 files changed, 443 insertions(+), 61 deletions(-) create mode 100644 vllm/worker/cpu_model_runner.py diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 63904ea92987..d21b54b16db4 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -50,20 +50,15 @@ def copy_blocks( @dataclass -class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, + AttentionMetadataPerStage): """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool + slot_mapping: torch.Tensor prompt_lens: Optional[List[int]] - prompt_lens_tensor: Optional[torch.Tensor] - - max_subquery_len: Optional[int] = None - max_prompt_len: Optional[int] = None - subquery_start_loc: Optional[torch.Tensor] = None - seq_start_loc: Optional[torch.Tensor] = None - use_cuda_graph: bool = False def __post_init__(self): # Set during the execution of the first attention op. @@ -111,7 +106,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[TorchSDPAMetadata], + attn_metadata: TorchSDPAMetadata, kv_scale: float, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -140,51 +135,36 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - - output = torch.empty_like(query) - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - if (kv_cache is None or prefill_meta.block_tables.numel() == 0): + if attn_metadata.is_prompt: + if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - if prefill_meta.attn_bias is None: + if attn_metadata.attn_bias is None: if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - prefill_meta.prompt_lens) # type: ignore + attn_metadata.prompt_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - prefill_meta.prompt_lens, self.sliding_window, + attn_metadata.prompt_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(prefill_meta.prompt_lens) - prefill_meta.attn_bias = att_masks + att_masks = [None] * len(attn_metadata.prompt_lens) + attn_metadata.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) start = 0 - out = torch.empty((num_tokens, self.num_heads, self.head_size), - dtype=query.dtype) - for prompt_len, mask in zip(prefill_meta.prompt_lens, - prefill_meta.attn_bias): + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype) + for prompt_len, mask in zip(attn_metadata.prompt_lens, + attn_metadata.attn_bias): end = start + prompt_len sub_out = scaled_dot_product_attention( query[:, start:end, :], @@ -194,32 +174,28 @@ def forward( dropout_p=0.0, is_causal=not self.need_mask, scale=self.scale).movedim(query.dim() - 2, 0) - out[start:end, :, :] = sub_out + output[start:end, :, :] = sub_out start = end - assert out.shape == output[:num_prefill_tokens].shape - output[:num_prefill_tokens] = out else: # prefix-enabled attention raise RuntimeError( "Torch SDPA backend doesn't support prefix decoding.") - if decode_meta := attn_metadata.decode_metadata: + else: # Decoding run. - out = PagedAttention.forward_decode( - decode_query, + output = PagedAttention.forward_decode( + query, key_cache, value_cache, - decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, kv_scale, ) - assert out.shape == output[num_prefill_tokens:].shape - output[num_prefill_tokens:] # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) @@ -241,7 +217,7 @@ def _make_alibi_bias( bias = bias[None, :] - bias[:, None] num_heads = alibi_slopes.shape[0] - bias = bias[None, :].expand(num_heads, prompt_len, prompt_len) + bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]) inf_mask = torch.empty( (1, prompt_len, prompt_len), diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 2bf97338da0e..eda4e8989c16 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -25,6 +25,7 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, assert lora_config is None, "cpu backend doesn't support LoRA" model_config = _verify_and_get_model_config(model_config) cache_config = _verify_and_get_cache_config(cache_config) + scheduler_config = _verify_and_get_scheduler_config(scheduler_config) self.model_config = model_config self.cache_config = cache_config @@ -116,6 +117,15 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: return config +def _verify_and_get_scheduler_config( + config: SchedulerConfig) -> SchedulerConfig: + if config.chunked_prefill_enabled: + logger.warning("Chunked prefill is not supported on CPU, disable it.") + config.chunked_prefill_enabled = False + + return config + + def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: _GB = 1 << 30 if config.enable_prefix_caching: diff --git a/vllm/utils.py b/vllm/utils.py index 8ab8927512cc..fdb0a3768ab0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -372,7 +372,6 @@ def is_pin_memory_available() -> bool: print_warning_once("Pin memory is not supported on Neuron.") return False elif is_cpu(): - print_warning_once("Pin memory is not supported on CPU.") return False return True diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py new file mode 100644 index 000000000000..49e1ad5709f5 --- /dev/null +++ b/vllm/worker/cpu_model_runner.py @@ -0,0 +1,408 @@ +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.model_loader import get_model +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad, maybe_expand_dim + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 + + +class CPUModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.lora_config = lora_config + self.is_driver_worker = is_driver_worker + + # model_config can be None in tests/samplers/test_sampler.py. + # FIXME(woosuk): This is a hack to make the tests work. Refactor this. + self.sliding_window = (model_config.get_sliding_window() + if model_config is not None else None) + self.device_config = (device_config + if device_config is not None else DeviceConfig()) + self.device = self.device_config.device + + self.model = None + self.block_size = None # Set after initial profiling. + + self.kv_cache_dtype = kv_cache_dtype + + self.attn_backend = get_attn_backend( + self.model_config.dtype if model_config is not None else None) + + def load_model(self) -> None: + self.model = get_model(self.model_config, + self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int]]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + prompt_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + computed_len = seq_data.get_num_computed_tokens() + prompt_len = len(prompt_tokens) + + prompt_lens.append(prompt_len) # Prompt token num + input_tokens.extend(prompt_tokens) # Token ids + + # Token position ids + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.extend(list(range(computed_len, prompt_len))) + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, prompt_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, prompt_len - self.sliding_window) + + for i in range(computed_len, prompt_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // + self.block_size] # type: ignore + block_offset = i % self.block_size # type: ignore + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + num_prompt_tokens = len(input_tokens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) # type: ignore + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + prompt_lens=prompt_lens, + num_prefills=len(prompt_lens), + num_prefill_tokens=num_prompt_tokens, + num_decode_tokens=0, + prefill_metadata=None, + decode_metadata=None, + max_context_len=None, + context_lens=None, + block_tables=torch.tensor([]), + slot_mapping=slot_mapping, + kv_cache_dtype=self.kv_cache_dtype, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + prompt_lens, + ) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + context_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append(generation_token) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append(position) + + context_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + context_lens.append(context_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + max_context_len = max(context_lens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + context_lens = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + slot_mapping=slot_mapping, + prompt_lens=None, + num_prefill_tokens=0, + num_decode_tokens=len(input_tokens), + max_context_len=max_context_len, + num_prefills=0, + prefill_metadata=None, + decode_metadata=None, + context_lens=context_lens, + block_tables=block_tables, + kv_cache_dtype=self.kv_cache_dtype, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + ) + + def _prepare_sample( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> SamplingMetadata: + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + selected_token_indices: List[int] = [] + generators: List[torch.Generator] = [] + selected_token_start_idx = 0 + categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices_start_idx = 0 + categorized_sampled_token_indices_start_idx = 0 + + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + + if seq_group_metadata.is_prompt: + assert len(seq_ids) == 1 + subquery_len = prompt_lens[i] + if sampling_params.prompt_logprobs is not None: + # NOTE: prompt token positions do not need sample, skip + categorized_sample_indices_start_idx += subquery_len - 1 + + categorized_sample_indices[ + sampling_params.sampling_type].append([ + categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx + ]) + categorized_sample_indices_start_idx += 1 + categorized_sampled_token_indices_start_idx += 1 + + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + subquery_len - 1)) + selected_token_indices.append(selected_token_start_idx + + subquery_len - 1) + selected_token_start_idx += subquery_len + + if sampling_params.seed is not None: + seq_group_metadata.state.generator = torch.Generator( + device=self.device).manual_seed(sampling_params.seed) + else: + num_seqs = len(seq_ids) + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + num_seqs)) + selected_token_start_idx += num_seqs + + categorized_sample_indices[ + sampling_params.sampling_type].extend( + zip( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + + num_seqs), + range( + categorized_sampled_token_indices_start_idx, + categorized_sampled_token_indices_start_idx + + num_seqs))) + categorized_sample_indices_start_idx += num_seqs + categorized_sampled_token_indices_start_idx += num_seqs + + if sampling_params.seed is not None: + generators.append(seq_group_metadata.state.generator) + + selected_token_indices = torch.tensor(selected_token_indices, + dtype=torch.long) + + categorized_sample_indices = { + t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2) + for t, seq_ids in categorized_sample_indices.items() + } + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + sampling_metadata = SamplingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + generators=generators, + ) + return sampling_metadata + + def prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, + SamplingMetadata]: + if self.is_driver_worker: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, + prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode(seq_group_metadata_list) + prompt_lens = [] + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + prompt_lens) + # Broadcast the metadata. + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "selected_token_indices": + sampling_metadata.selected_token_indices, + } + metadata_dict.update(attn_metadata.asdict_zerocopy()) + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + selected_token_indices = metadata_dict.pop( + "selected_token_indices") + attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + sampling_metadata = SamplingMetadata( + seq_groups=None, + seq_data=None, + prompt_lens=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + generators=None, + perform_sampling=False, + ) + + return ( + input_tokens, + input_positions, + attn_metadata, + sampling_metadata, + ) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + kv_caches: List[torch.Tensor], + ) -> Optional[SamplerOutput]: + (input_tokens, input_positions, attn_metadata, sampling_metadata + ) = self.prepare_input_tensors(seq_group_metadata_list) + + model_executable = self.model + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Only perform sampling in the driver worker. + if not sampling_metadata.perform_sampling: + return None + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + return output diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 751384eb72af..3989207e8dd8 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -12,25 +12,14 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.model_executor.model_loader import get_model from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.worker.model_runner import ModelRunner +from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase logger = init_logger(__name__) -class CPUModelRunner(ModelRunner): - - def load_model(self) -> None: - self.model = get_model(self.model_config, - self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) - - class CPUCacheEngine: """Manages the KV cache for CPU backend. From a10d3056da644c31e4ebf95a2b6ad65a626a7350 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 11 Apr 2024 13:35:51 -0700 Subject: [PATCH 082/120] [Core] Set `linear_weights` directly on the layer (#3977) --- csrc/quantization/gptq/q_gemm.cu | 2 +- tests/kernels/test_moe.py | 2 +- vllm/lora/layers.py | 12 +-- vllm/model_executor/layers/linear.py | 77 ++++++++++--------- .../model_executor/layers/quantization/awq.py | 29 +++---- .../layers/quantization/gptq.py | 47 ++++++----- .../layers/quantization/marlin.py | 23 +++--- .../layers/quantization/squeezellm.py | 24 +++--- 8 files changed, 114 insertions(+), 102 deletions(-) diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 655158e38f55..cc56649917a8 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -2067,7 +2067,7 @@ void gptq_shuffle const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); vllm::gptq::shuffle_exllama_weight( (uint32_t*) q_weight.data_ptr(), - q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), + q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(), q_weight.size(0) * 32 / bit, q_weight.size(1), bit diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index affbbfb4aa94..046f11d957bd 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype): ).cuda() # Load the weights - vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data + vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 84a94091486d..a8ec4dcfd613 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -368,7 +368,7 @@ def set_mapping( def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora( x, self.lora_a_stacked, @@ -402,10 +402,6 @@ def forward(self, input_): if self.base_layer.skip_bias_add else None) return output, output_bias - @property - def linear_weights(self): - return self.base_layer.linear_weights - @classmethod def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, @@ -505,7 +501,7 @@ def set_lora( def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -746,7 +742,7 @@ def set_lora( def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -838,7 +834,7 @@ def set_mapping( def apply_weights(self, x: torch.Tensor) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x) + self.base_layer, x) _apply_lora( x, self.lora_a_stacked, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 8f42b3e8a4ab..3ca870742efc 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import List, Optional import torch import torch.nn.functional as F @@ -28,19 +28,24 @@ class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - """Create weights for a linear layer.""" + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """Create weights for a linear layer. + + The weights will be set as attributes of the layer.""" raise NotImplementedError @abstractmethod def apply_weights(self, - weights: Dict[str, torch.Tensor], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - """Apply the weights to the input tensor.""" + """Apply the weights in layer to the input tensor. + + Expects create_weights to have been called before on the layer.""" raise NotImplementedError @@ -55,22 +60,24 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, dtype=params_dtype), requires_grad=False) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - return {"weight": weight} + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, torch.Tensor], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - weight = weights["weight"] + weight = layer.weight if self.separate_bias_add: if bias is not None: return F.linear(x, weight) + bias @@ -111,12 +118,9 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) + self.linear_method.create_weights(self, self.input_size, + self.output_size, self.input_size, + self.output_size, self.params_dtype) if bias: self.bias = Parameter( torch.empty(self.output_size, dtype=self.params_dtype)) @@ -126,7 +130,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None - output = self.linear_method.apply_weights(self.linear_weights, x, bias) + output = self.linear_method.apply_weights(self, x, bias) output_bias = self.bias if self.skip_bias_add else None return output, output_bias @@ -177,13 +181,13 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + self.linear_method.create_weights(self, + self.input_size, + self.output_size_per_partition, + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -211,8 +215,7 @@ def forward(self, input_): bias = self.bias if not self.skip_bias_add else None # Matrix multiply. - output_parallel = self.linear_method.apply_weights( - self.linear_weights, input_, bias) + output_parallel = self.linear_method.apply_weights(self, input_, bias) if self.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) @@ -523,13 +526,13 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + self.linear_method.create_weights(self, + self.input_size_per_partition, + self.output_size, + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " @@ -569,7 +572,7 @@ def forward(self, input_): # Matrix multiply. output_parallel = self.linear_method.apply_weights( - self.linear_weights, input_parallel) + self, input_parallel) if self.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index daea5ac73e42..98651aed8be0 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -79,10 +79,11 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -136,19 +137,21 @@ def create_weights(self, input_size_per_partition: int, "input_dim": 0, "output_dim": 1, }) - return { - "qweight": qweight, - "qzeros": qzeros, - "scales": scales, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] - scales = weights["scales"] - qzeros = weights["qzeros"] + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros pack_factor = self.quant_config.pack_factor out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) @@ -163,5 +166,5 @@ def apply_weights(self, out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) if bias is not None: - out = out + bias + out.add_(bias) return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 757ab1af8392..f370b94a210e 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -89,12 +89,14 @@ def __init__(self, quant_config: GPTQConfig): def create_weights( self, + layer: torch.nn.Module, input_size_per_partition: int, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + **extra_weight_attrs, + ): del output_size # Unused. if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( @@ -179,37 +181,40 @@ def create_weights( "input_dim": scale_and_zero_input_dim, "output_dim": 1, }) - return { - "qweight": qweight, - "g_idx": g_idx, - "qzeros": qzeros, - "scales": scales, - "exllama_state": exllama_state, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("g_idx", g_idx) + set_weight_attrs(g_idx, extra_weight_attrs) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + layer.exllama_state = exllama_state def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] + qweight = layer.qweight out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass - if weights["exllama_state"] == ExllamaState.UNINITIALIZED: + if layer.exllama_state == ExllamaState.UNINITIALIZED: if self.quant_config.desc_act: - weights["g_idx"] = torch.argsort(weights["g_idx"]).to( - torch.int) + layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: - weights["g_idx"] = torch.empty((1, 1), device="meta") - weights["exllama_state"] = ExllamaState.READY - ops.gptq_shuffle(weights["qweight"], weights["g_idx"], + layer.g_idx.data = torch.empty((0, ), + device=layer.g_idx.device) + layer.exllama_state = ExllamaState.READY + ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) - output = ops.gptq_gemm(reshaped_x, weights["qweight"], - weights["qzeros"], weights["scales"], - weights["g_idx"], - weights["exllama_state"] == ExllamaState.READY, + output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, + layer.scales, layer.g_idx, + layer.exllama_state == ExllamaState.READY, self.quant_config.weight_bits) if bias is not None: - output = output + bias + output.add_(bias) return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index a6482c059cc4..bf0500f1155a 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -91,12 +91,14 @@ def __init__(self, quant_config: MarlinConfig): def create_weights( self, + layer: torch.nn.Module, input_size_per_partition: int, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + **extra_weight_attrs, + ): del output_size # Unused. if params_dtype != torch.float16: @@ -187,21 +189,22 @@ def create_weights( dtype=torch.int), requires_grad=False) - return { - "B": qweight, - "s": scales, - "workspace": workspace, - } + layer.register_parameter("B", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("s", scales) + set_weight_attrs(scales, extra_weight_attrs) + layer.register_parameter("workspace", workspace) + set_weight_attrs(workspace, extra_weight_attrs) def apply_weights( self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = weights["B"] - scales = weights["s"] - workspace = weights["workspace"] + qweight = layer.B + scales = layer.s + workspace = layer.workspace x_2d = x.view(-1, x.shape[-1]) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index bb295df2acc3..661ff9c55d0d 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -68,10 +68,11 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -103,17 +104,18 @@ def create_weights(self, input_size_per_partition: int, set_weight_attrs(lookup_table, { "output_dim": 0, }) - return { - "qweight": qweight, - "lookup_table": lookup_table, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("lookup_table", lookup_table) + set_weight_attrs(lookup_table, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] - lookup_table = weights["lookup_table"] + qweight = layer.qweight + lookup_table = layer.lookup_table out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) if is_hip(): @@ -126,5 +128,5 @@ def apply_weights(self, ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) if bias is not None: - out = out + bias + out.add_(bias) return out.reshape(out_shape) From 559eb852f83fe7867390dd2986b4f93a6572cf10 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 11 Apr 2024 14:00:48 -0700 Subject: [PATCH 083/120] [Core] init_distributed_environment align with init_process_group(#4014) [Core][Distributed] make init_distributed_environment compatible with init_process_group (#4014) --- vllm/distributed/parallel_state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 4bb77146295a..9fceffe7cb88 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -39,9 +39,9 @@ def init_distributed_environment( - world_size: int, - rank: int, - distributed_init_method: Optional[str] = None, + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", local_rank: int = -1, backend: str = "nccl", ): From 95e7d4a97cd64f8c6dc226ec0bbceebef6458701 Mon Sep 17 00:00:00 2001 From: Dylan Hawk <51147702+dylanwhawk@users.noreply.github.com> Date: Thu, 11 Apr 2024 15:15:50 -0700 Subject: [PATCH 084/120] Fix echo/logprob OpenAI completion bug (#3441) Co-authored-by: Dylan Hawk --- tests/entrypoints/test_openai_server.py | 31 ++++++++++++ vllm/entrypoints/openai/serving_chat.py | 9 ++-- vllm/entrypoints/openai/serving_completion.py | 15 ++++-- vllm/entrypoints/openai/serving_engine.py | 47 +++++++++++-------- 4 files changed, 73 insertions(+), 29 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 6f2086c4dd26..7940430b8b65 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -742,5 +742,36 @@ async def test_guided_grammar(server, client: openai.AsyncOpenAI): assert content.strip() == ground_truth +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, + model_name: str): + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # test using text and token IDs + for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=1) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, + list) else prompt + assert (completion.choices[0].text is not None + and re.search(r"^" + prompt_text, completion.choices[0].text)) + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + assert len(logprobs.tokens) > 5 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0980c3d3cb61..a03c5dc88108 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -63,8 +63,9 @@ async def create_chat_completion( request_id = f"cmpl-{random_uuid()}" try: - token_ids = self._validate_prompt_and_tokenize(request, - prompt=prompt) + # Tokenize/detokenize depending on prompt format (string/token list) + prompt_ids, prompt_text = self._validate_prompt_and_tokenize( + request, prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) guided_decode_logits_processor = ( @@ -78,8 +79,8 @@ async def create_chat_completion( except ValueError as e: return self.create_error_response(str(e)) - result_generator = self.engine.generate(prompt, sampling_params, - request_id, token_ids, + result_generator = self.engine.generate(prompt_text, sampling_params, + request_id, prompt_ids, lora_request) # Streaming response if request.stream: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 06e7a9225fef..c1f1744a118b 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -136,23 +136,24 @@ async def create_completion(self, request: CompletionRequest, for i, prompt in enumerate(prompts): if prompt_is_tokens: - input_ids = self._validate_prompt_and_tokenize( + prompt_formats = self._validate_prompt_and_tokenize( request, prompt_ids=prompt, truncate_prompt_tokens=sampling_params. truncate_prompt_tokens) else: - input_ids = self._validate_prompt_and_tokenize( + prompt_formats = self._validate_prompt_and_tokenize( request, prompt=prompt, truncate_prompt_tokens=sampling_params. truncate_prompt_tokens) + prompt_ids, prompt_text = prompt_formats generators.append( - self.engine.generate(prompt, + self.engine.generate(prompt_text, sampling_params, f"{request_id}-{i}", - prompt_token_ids=input_ids, + prompt_token_ids=prompt_ids, lora_request=lora_request)) except ValueError as e: # TODO: Use a vllm-specific Validation Error @@ -326,7 +327,8 @@ def request_output_to_completion_response( output_text = prompt_text elif request.echo and request.max_tokens > 0: token_ids = prompt_token_ids + output.token_ids - top_logprobs = prompt_logprobs + output.logprobs + top_logprobs = (prompt_logprobs + output.logprobs + if request.logprobs else None) output_text = prompt_text + output.text else: token_ids = output.token_ids @@ -334,6 +336,9 @@ def request_output_to_completion_response( output_text = output.text if request.logprobs is not None: + assert top_logprobs is not None, ( + "top_logprobs must be provided when logprobs " + "is requested") logprobs = self._create_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8f69388c0251..77a568b56403 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -2,7 +2,7 @@ import json from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union from pydantic import conint @@ -99,27 +99,32 @@ def _create_logprobs( last_token_len = 0 if num_output_top_logprobs: logprobs.top_logprobs = [] + for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] - if step_top_logprobs is not None: - token_logprob = step_top_logprobs[token_id].logprob + if step_top_logprobs is None: + token = self.tokenizer.decode(token_id) + logprobs.tokens.append(token) + logprobs.token_logprobs.append(None) + logprobs.top_logprobs.append(None) else: - token_logprob = None - token = step_top_logprobs[token_id].decoded_token - logprobs.tokens.append(token) - logprobs.token_logprobs.append(token_logprob) + token_logprob = step_top_logprobs[token_id].logprob + token = step_top_logprobs[token_id].decoded_token + logprobs.tokens.append(token) + logprobs.token_logprobs.append(token_logprob) + + if num_output_top_logprobs: + logprobs.top_logprobs.append({ + p.decoded_token: p.logprob + for i, p in step_top_logprobs.items() + } if step_top_logprobs else None) + if len(logprobs.text_offset) == 0: logprobs.text_offset.append(initial_text_offset) else: logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) last_token_len = len(token) - - if num_output_top_logprobs: - logprobs.top_logprobs.append({ - p.decoded_token: p.logprob - for i, p in step_top_logprobs.items() - } if step_top_logprobs else None) return logprobs def create_error_response( @@ -164,12 +169,12 @@ def _maybe_get_lora(self, request) -> Optional[LoRARequest]: raise ValueError("The model `{request.model}` does not exist.") def _validate_prompt_and_tokenize( - self, - request: Union[ChatCompletionRequest, CompletionRequest], - prompt: Optional[str] = None, - prompt_ids: Optional[List[int]] = None, - truncate_prompt_tokens: Optional[conint(ge=1)] = None - ) -> List[int]: + self, + request: Union[ChatCompletionRequest, CompletionRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None, + truncate_prompt_tokens: Optional[conint(ge=1)] = None + ) -> Tuple[List[int], str]: if not (prompt or prompt_ids): raise ValueError("Either prompt or prompt_ids should be provided.") if (prompt and prompt_ids): @@ -187,6 +192,8 @@ def _validate_prompt_and_tokenize( else: input_ids = prompt_ids + input_text = prompt if prompt is not None else self.tokenizer.decode( + prompt_ids) token_num = len(input_ids) if request.max_tokens is None: @@ -201,4 +208,4 @@ def _validate_prompt_and_tokenize( f"{request.max_tokens} in the completion). " f"Please reduce the length of the messages or completion.", ) else: - return input_ids + return input_ids, input_text From 1e96c3341a4e055ae392085fecc7a672295b71c2 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 11 Apr 2024 15:18:57 -0700 Subject: [PATCH 085/120] Add extra punica sizes to support bigger vocabs (#4015) --- csrc/punica/bgmv/bgmv_config.h | 12 +++++- csrc/punica/punica_ops.cc | 14 +++--- tests/lora/test_layers.py | 78 +++++++++++++++++++--------------- tests/lora/test_punica.py | 49 +++++++++++++++++++-- vllm/lora/layers.py | 4 +- 5 files changed, 109 insertions(+), 48 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 1084a0f20df6..9b76b98ab332 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -60,7 +60,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 33024) \ f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 49152) \ -// Keep above in sync with vllm/lora/layers::SamplerWithLoRA + f(in_T, out_T, W_T, narrow, 64000) \ + f(in_T, out_T, W_T, narrow, 64256) \ + f(in_T, out_T, W_T, narrow, 64512) \ + f(in_T, out_T, W_T, narrow, 102400) \ + f(in_T, out_T, W_T, narrow, 102656) \ + f(in_T, out_T, W_T, narrow, 102912) \ + f(in_T, out_T, W_T, narrow, 128000) \ + f(in_T, out_T, W_T, narrow, 128256) \ + f(in_T, out_T, W_T, narrow, 128512) \ +// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA +// and vllm/tests/lora/test_punica.py // Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc index 28739be14b86..7ebfd851c4fe 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cc @@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, } } -inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { - return (uint32_t(a) << 16) | uint32_t(b); +inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) { + return (uint64_t(a) << 32) | uint64_t(b); } #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") @@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { template inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, const int64_t *lora_indices, - uint16_t in_features, uint16_t out_features, + uint32_t in_features, uint32_t out_features, int64_t y_offset, int64_t full_y_size, int64_t batch_size, int64_t num_layers, int64_t layer_idx, float scale) { - switch (pack_u16(in_features, out_features)) { + switch (pack_u32(in_features, out_features)) { #define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ - case pack_u16(feat_in, feat_out): \ + case pack_u32(feat_in, feat_out): \ bgmv_kernel(Y, X, W, lora_indices, y_offset, \ full_y_size, batch_size, num_layers, \ layer_idx, scale); \ @@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, CHECK_EQ(y.size(0), x.size(0)); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); bool ok = false; - if (h_in < 65536 && h_out < 65536) { + if (h_in <= 128512 && h_out <= 128512) { // TODO: See if we can get rid of this massive nested switch switch (x.scalar_type()) { case at::ScalarType::Half: @@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, CHECK_EQ(y.size(0), x.size(0)); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); bool ok = false; - if (h_in < 65536 && h_out < 65536) { + if (h_in <= 128512 && h_out <= 128512) { // TODO: See if we can get rid of this massive nested switch switch (x.scalar_type()) { case at::ScalarType::Half: diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 71ce6f176483..e9e0c8554c1e 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -170,7 +170,8 @@ def create_random_inputs( @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_embeddings(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(512, 256) + embedding = VocabParallelEmbedding(vocab_size, 256) embedding.weight.data = torch.rand_like(embedding.weight.data) - embedding.weight.data[512:, :] = 0 + embedding.weight.data[vocab_size:, :] = 0 lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) lora_embedding.create_lora_weights(max_loras, lora_config) @@ -203,12 +204,13 @@ def create_random_embedding_layer(): active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info) lora_result = lora_embedding(torch.cat(inputs)) @@ -240,12 +242,13 @@ def create_random_embedding_layer(): active_lora_ids=[0], num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) lora_result = lora_embedding(torch.cat(inputs)) @@ -263,7 +266,9 @@ def create_random_embedding_layer(): # reason="Fails when loras are in any slot other than the first.") @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_embeddings_with_new_embeddings(dist_init, num_loras, device, + vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(512, 256) + embedding = VocabParallelEmbedding(vocab_size, 256) embedding_data = torch.rand_like(embedding.weight.data) embedding.weight.data = embedding_data - embedding.weight.data[512:, :] = 0 + embedding.weight.data[vocab_size:, :] = 0 expanded_embedding = VocabParallelEmbedding( - 512 + lora_config.lora_extra_vocab_size * max_loras, + vocab_size + lora_config.lora_extra_vocab_size * max_loras, 256, - org_num_embeddings=512) - expanded_embedding.weight.data[:512, :] = embedding_data + org_num_embeddings=vocab_size) + expanded_embedding.weight.data[:vocab_size, :] = embedding_data # We need to deepcopy the embedding as it will be modified # in place lora_embedding = VocabParallelEmbeddingWithLoRA( @@ -298,7 +303,7 @@ def create_random_embedding_layer(): id_to_index, layer=lora_embedding, layer_weights=torch.zeros( - (256, 512 + lora_config.lora_extra_vocab_size)), + (256, vocab_size + lora_config.lora_extra_vocab_size)), generate_embeddings_tensor=256, ) @@ -316,7 +321,7 @@ def create_random_embedding_layer(): active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -327,16 +332,18 @@ def create_random_embedding_layer(): for input_, original_input_, lora_id in zip(inputs, original_inputs, prompt_mapping): embedding_id = lora_id - 1 - input_[-1] = 512 + (embedding_id * embeddings_tensor_len) - original_input_[-1] = 512 - input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1) - original_input_[-2] = 512 + embeddings_tensor_len - 1 + input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) + original_input_[-1] = vocab_size + input_[-2] = vocab_size + ( + (embedding_id + 1) * embeddings_tensor_len - 1) + original_input_[-2] = vocab_size + embeddings_tensor_len - 1 mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) - expanded_embedding.weight[512:512 + + expanded_embedding.weight[vocab_size:vocab_size + (embeddings_tensor_len * max_loras)] = torch.cat(embeddings_tensors) @@ -370,14 +377,15 @@ def create_random_embedding_layer(): active_lora_ids=[0], num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) original_inputs = deepcopy(inputs) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) lora_result = lora_embedding(torch.cat(original_inputs)) @@ -393,7 +401,9 @@ def create_random_embedding_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_lm_head_logits_processor(dist_init, num_loras, device, + vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def _pretest(): - linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, - 1024, 32000) + linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, + 1024, vocab_size) linear.weight.data = torch.rand_like(linear.weight.data) - linear.weight.data[:, 32000:] = 0 + linear.weight.data[:, vocab_size:] = 0 logits_processor = LogitsProcessor( - 32000 + lora_config.lora_extra_vocab_size, 32000) + vocab_size + lora_config.lora_extra_vocab_size, vocab_size) lora_logits_processor = LogitsProcessorWithLoRA( logits_processor, 1024, linear.weight.dtype, linear.weight.device) lora_logits_processor.create_lora_weights(max_loras, lora_config) @@ -444,7 +454,7 @@ def _pretest(): lora_mapping, id_to_index, max_loras, - 32000, + vocab_size, lora_config.lora_extra_vocab_size, ) lora_logits_processor.set_mapping(*mapping_info, ) @@ -460,7 +470,7 @@ def _pretest(): org_vocab_size:logits_processor.org_vocab_size + embeddings_tensor_len] = embeddings_tensor - logits_processor.org_vocab_size = (32000 + + logits_processor.org_vocab_size = (vocab_size + lora_config.lora_extra_vocab_size) expected_results = [] for input_, lora_id in zip(inputs, prompt_mapping): @@ -468,11 +478,11 @@ def _pretest(): result = logits_processor._get_logits(hidden_states=input_, embedding=linear.weight, embedding_bias=None) - result[:, 32000 + embeddings_tensor_len:] = float("-inf") + result[:, vocab_size + embeddings_tensor_len:] = float("-inf") result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) - logits_processor.org_vocab_size = 32000 + logits_processor.org_vocab_size = vocab_size # Check that resetting the lora weights succeeds @@ -489,14 +499,14 @@ def _pretest(): lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 32000, + vocab_size, lora_config.lora_extra_vocab_size) lora_logits_processor.set_mapping(*mapping_info, ) lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), embedding=original_weight, - embedding_bias=None)[:, :32000] + embedding_bias=None)[:, :vocab_size] expected_result = logits_processor._get_logits( hidden_states=torch.cat(inputs), embedding=original_weight, diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 2736a1c7ade2..cab8b44ccd2d 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -43,10 +43,51 @@ def _lora_ref_impl( H1 = H2 = [ - 128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456, - 3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216, - 10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512, - 32768, 33024 + 128, + 256, + 512, + 1024, + 1152, + 1280, + 1536, + 2048, + 2304, + 2560, + 2752, + 3072, + 3456, + 3584, + 4096, + 4608, + 5120, + 5504, + 5632, + 6144, + 6848, + 6912, + 7168, + 8192, + 9216, + 10240, + 11008, + 13824, + 14336, + 22016, + 24576, + 27392, + 32000, + 32256, + 32512, + 32768, + 33024, + 36864, + 49152, + 64000, + 64256, + 102400, + 102656, + 128000, + 128256, ] SEED = [0xabcdabcd987] CUDA_DEVICES = [ diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index a8ec4dcfd613..5456b5613c47 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -935,9 +935,9 @@ def create_lora_weights( model_config: Optional[PretrainedConfig] = None, ) -> None: # Keep this in sync with csrc/punica/bgmv/bgmv_config.h - if 32000 < self.base_layer.vocab_size > 33024: + if 32000 < self.base_layer.vocab_size > 128512: raise ValueError("When using LoRA, vocab size must be " - "32000 >= vocab_size <= 33024") + "32000 >= vocab_size <= 128512") self.lora_a_stacked = torch.zeros( ( max_loras, From e46a60aa4c90cf3dfd9b90782f2eeabbda935eef Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 11 Apr 2024 23:34:12 +0100 Subject: [PATCH 086/120] [BugFix] Fix handling of stop strings and stop token ids (#3672) --- tests/conftest.py | 2 +- .../{samplers => engine}/test_stop_reason.py | 2 +- tests/engine/test_stop_strings.py | 111 ++++++++++++++++++ vllm/engine/llm_engine.py | 98 ++++++++++------ vllm/outputs.py | 4 +- vllm/sampling_params.py | 9 ++ vllm/sequence.py | 6 + vllm/transformers_utils/detokenizer.py | 7 +- 8 files changed, 202 insertions(+), 37 deletions(-) rename tests/{samplers => engine}/test_stop_reason.py (97%) create mode 100644 tests/engine/test_stop_strings.py diff --git a/tests/conftest.py b/tests/conftest.py index a7e8963af0ed..5c50fc2d1bab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -401,7 +401,7 @@ def __del__(self): cleanup() -@pytest.fixture +@pytest.fixture(scope="session") def vllm_runner(): return VllmRunner diff --git a/tests/samplers/test_stop_reason.py b/tests/engine/test_stop_reason.py similarity index 97% rename from tests/samplers/test_stop_reason.py rename to tests/engine/test_stop_reason.py index b242c405a4fb..b2f521a8ae4c 100644 --- a/tests/samplers/test_stop_reason.py +++ b/tests/engine/test_stop_reason.py @@ -3,7 +3,7 @@ 2. One of the provided stop tokens 3. The EOS token -Run `pytest tests/samplers/test_stop_reason.py`. +Run `pytest tests/engine/test_stop_reason.py`. """ import pytest diff --git a/tests/engine/test_stop_strings.py b/tests/engine/test_stop_strings.py new file mode 100644 index 000000000000..6b747beb4b54 --- /dev/null +++ b/tests/engine/test_stop_strings.py @@ -0,0 +1,111 @@ +from typing import Any, List, Optional + +import pytest + +from vllm import CompletionOutput, LLMEngine, SamplingParams + +MODEL = "meta-llama/llama-2-7b-hf" +MAX_TOKENS = 200 + + +@pytest.fixture(scope="session") +def vllm_model(vllm_runner): + return vllm_runner(MODEL) + + +@pytest.mark.skip_global_cleanup +def test_stop_basic(vllm_model): + _test_stopping(vllm_model.model.llm_engine, + stop=["."], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=".") + + _test_stopping(vllm_model.model.llm_engine, + stop=["."], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization.", + expected_reason=".") + + +@pytest.mark.skip_global_cleanup +def test_stop_multi_tokens(vllm_model): + _test_stopping( + vllm_model.model.llm_engine, + stop=["group of peo", "short"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization. We are a ", + expected_reason="group of peo") + + _test_stopping( + vllm_model.model.llm_engine, + stop=["group of peo", "short"], + include_in_output=True, + expected_output= + "VLLM is a 100% volunteer organization. We are a group of peo", + expected_reason="group of peo") + + +@pytest.mark.skip_global_cleanup +def test_stop_partial_token(vllm_model): + _test_stopping(vllm_model.model.llm_engine, + stop=["gani"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer or", + expected_reason="gani") + + _test_stopping(vllm_model.model.llm_engine, + stop=["gani"], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organi", + expected_reason="gani") + + +@pytest.mark.skip_global_cleanup +def test_stop_token_id(vllm_model): + # token id 13013 => " organization" + + _test_stopping(vllm_model.model.llm_engine, + stop_token_ids=[13013], + include_in_output=False, + expected_output="VLLM is a 100% volunteer", + expected_reason=13013) + + _test_stopping(vllm_model.model.llm_engine, + stop_token_ids=[13013], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=13013) + + +def _test_stopping(llm_engine: LLMEngine, + expected_output: str, + expected_reason: Any, + stop: Optional[List[str]] = None, + stop_token_ids: Optional[List[int]] = None, + include_in_output: bool = False) -> None: + llm_engine.add_request( + "id", "A story about vLLM:\n", + SamplingParams( + temperature=0.0, + max_tokens=MAX_TOKENS, + stop=stop, + stop_token_ids=stop_token_ids, + include_stop_str_in_output=include_in_output, + ), None) + + output: Optional[CompletionOutput] = None + output_text = "" + stop_reason = None + while llm_engine.has_unfinished_requests(): + (request_output, ) = llm_engine.step() + (output, ) = request_output.outputs + + # Ensure we don't backtrack + assert output.text.startswith(output_text) + output_text = output.text + stop_reason = output.stop_reason + + assert output is not None + assert output_text == expected_output + assert stop_reason == expected_reason diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ddfdda898a5c..a91629a63059 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -501,9 +501,11 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, for seq, _ in child_seqs: if seq_group.sampling_params.detokenize: - self.detokenizer.decode_sequence_inplace( + new_char_count = self.detokenizer.decode_sequence_inplace( seq, seq_group.sampling_params) - self._check_stop(seq, seq_group.sampling_params) + else: + new_char_count = 0 + self._check_stop(seq, new_char_count, seq_group.sampling_params) # Non-beam search case if not seq_group.sampling_params.use_beam_search: @@ -798,56 +800,86 @@ def _get_stats(self, time_e2e_requests=time_e2e_requests, ) - def _check_stop(self, seq: Sequence, + def _check_stop(self, seq: Sequence, new_char_count: int, sampling_params: SamplingParams) -> None: - """Stop the finished sequences.""" - # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return + """Stop the finished sequences. - # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return + new_char_count is the number of chars added to the + sequence's output text for the newly generated token + """ # Check if the minimum number of tokens has been generated yet; # skip the stop string/token checks if not if seq.get_output_len() < sampling_params.min_tokens: return - if sampling_params.detokenize: - for stop_str in sampling_params.stop: - if seq.output_text.endswith(stop_str): - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if a stop token was encountered. + # This assumes a single token produced per step. last_token_id = seq.get_last_token_id() if last_token_id in sampling_params.stop_token_ids: - stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( - last_token_id) - self._finalize_sequence(seq, sampling_params, stop_str) + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + # Remove last token + seq.output_text = seq.output_text[:-new_char_count] seq.status = SequenceStatus.FINISHED_STOPPED seq.stop_reason = last_token_id return - # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): + # Check if any stop strings are matched. + stop_str = self._check_stop_strings(seq, new_char_count, + sampling_params) + if stop_str is not None: seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str return - def _finalize_sequence(self, seq: Sequence, - sampling_params: SamplingParams, - stop_string: str) -> None: - if sampling_params.include_stop_str_in_output: + # Check if the sequence has reached max_model_len. + if seq.get_len() > self.scheduler_config.max_model_len: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return - if stop_string and seq.output_text.endswith(stop_string): - # Truncate the output text so that the stop string is - # not included in the output. - seq.output_text = seq.output_text[:-len(stop_string)] + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + @staticmethod + def _check_stop_strings(seq: Sequence, new_char_count: int, + sampling_params: SamplingParams) -> Optional[str]: + """Check if any stop strings are matched and truncate sequence + output text accordingly. + + Returns the stop string if matched or else None. + """ + if not new_char_count: + return None + + for stop_str in sampling_params.stop: + stop_string_len = len(stop_str) + # Avoid searching already-searched text. + stop_index = seq.output_text.find( + stop_str, -new_char_count - stop_string_len) + if stop_index == -1: + continue + + if sampling_params.include_stop_str_in_output: + # Truncate to end of stop string. + stop_index += stop_string_len + if stop_index >= len(seq.output_text): + # No truncation required. + return stop_str + + # Truncate the output text to either the beginning + # or end of the stop string. + seq.output_text = seq.output_text[:stop_index] + return stop_str + return None def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) diff --git a/vllm/outputs.py b/vllm/outputs.py index 61fe20bfc274..d01be0eb0efd 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -112,8 +112,10 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # always has the logprobs of the sampled tokens even if the # logprobs are not requested. include_logprobs = seq_group.sampling_params.logprobs is not None + text_buffer_length = seq_group.sampling_params.output_text_buffer_length outputs = [ - CompletionOutput(seqs.index(seq), seq.output_text, + CompletionOutput(seqs.index(seq), + seq.get_output_text_to_return(text_buffer_length), seq.get_output_token_ids(), seq.get_cumulative_logprob(), seq.output_logprobs if include_logprobs else None, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 4fdc3c6dedae..0b9787608798 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -166,6 +166,13 @@ def __init__( self.logits_processors = logits_processors self.include_stop_str_in_output = include_stop_str_in_output self.truncate_prompt_tokens = truncate_prompt_tokens + # Number of characters to hold back for stop string evaluation + # until sequence is finished. + if self.stop and not include_stop_str_in_output: + self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 + else: + self.output_text_buffer_length = 0 + self._verify_args() if self.use_beam_search: self._verify_beam_search() @@ -226,6 +233,8 @@ def _verify_args(self) -> None: and self.truncate_prompt_tokens < 1): raise ValueError(f"truncate_prompt_tokens must be >= 1, " f"got {self.truncate_prompt_tokens}") + if any(not stop_str for stop_str in self.stop): + raise ValueError("stop cannot contain an empty string.") if self.stop and not self.detokenize: raise ValueError( "stop strings are only supported when detokenize is True. " diff --git a/vllm/sequence.py b/vllm/sequence.py index 77029908c221..cdb6cce6f025 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -235,6 +235,12 @@ def __init__( def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + def get_output_text_to_return(self, buffer_length: int): + # We return the full output text if the sequence is finished. + truncate = buffer_length and not self.is_finished() + return self.output_text[:-buffer_length] if truncate else ( + self.output_text) + def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 486c1938e1e1..005932f1e3df 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -87,12 +87,15 @@ def decode_prompt_logprobs_inplace( prev_tokens.extend(next_iter_tokens) def decode_sequence_inplace(self, seq: Sequence, - prms: SamplingParams) -> None: + prms: SamplingParams) -> int: """Decodes the new token for a sequence. In-place operation. Args: seq: The sequence to decode. prms: The sampling parameters used to generate the sequence. + + Returns: + The number of characters added to the output text. """ all_input_ids = seq.get_token_ids() token_id_generated_this_iteration = all_input_ids[-1] @@ -151,6 +154,8 @@ def decode_sequence_inplace(self, seq: Sequence, seq.read_offset = read_offset seq.output_text += new_decoded_token_text + return len(new_decoded_token_text) + def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], From c2b4a1bce9a7707179cdfab2fb498c20b2b221e6 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Thu, 11 Apr 2024 17:17:21 -0700 Subject: [PATCH 087/120] [Doc] Add typing hints / mypy types cleanup (#3816) Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- benchmarks/backend_request_func.py | 62 ++++++++++--------- docs/source/conf.py | 3 +- setup.py | 5 +- vllm/core/block/interfaces.py | 31 ++++++---- vllm/engine/metrics.py | 10 ++- vllm/logger.py | 8 ++- .../model_executor/layers/rotary_embedding.py | 15 ++--- vllm/transformers_utils/config.py | 4 +- vllm/transformers_utils/configs/dbrx.py | 2 +- .../transformers_utils/tokenizers/baichuan.py | 10 +-- vllm/utils.py | 4 +- 11 files changed, 90 insertions(+), 64 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index ad428bd1c364..bab570252c92 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -27,8 +27,8 @@ class RequestFuncInput: class RequestFuncOutput: generated_text: str = "" success: bool = False - latency: float = 0 - ttft: float = 0 # Time to first token + latency: float = 0.0 + ttft: float = 0.0 # Time to first token itl: List[float] = field( default_factory=list) # List of inter-token latencies prompt_len: int = 0 @@ -58,23 +58,24 @@ async def async_request_tgi( output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data:") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data:") data = json.loads(chunk) timestamp = time.perf_counter() # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -119,23 +120,24 @@ async def async_request_trt_llm( output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data:") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data:") data = json.loads(chunk) timestamp = time.perf_counter() # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -151,7 +153,7 @@ async def async_request_trt_llm( output.success = True else: - output.error = response.reason + output.error = response.reason or "" output.success = False except Exception: output.success = False @@ -195,7 +197,7 @@ async def async_request_deepspeed_mii( output.generated_text = parsed_resp["text"][0] output.success = True else: - output.error = response.reason + output.error = response.reason or "" output.success = False except Exception: output.success = False @@ -234,19 +236,20 @@ async def async_request_openai_completions( output.prompt_len = request_func_input.prompt_len generated_text = "" - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data: ") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data: ") if chunk == "[DONE]": latency = time.perf_counter() - st else: @@ -255,7 +258,7 @@ async def async_request_openai_completions( if data["choices"][0]["text"]: timestamp = time.perf_counter() # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -315,19 +318,20 @@ async def async_request_openai_chat_completions( output.prompt_len = request_func_input.prompt_len generated_text = "" - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data: ") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data: ") if chunk == "[DONE]": latency = time.perf_counter() - st else: @@ -337,7 +341,7 @@ async def async_request_openai_chat_completions( delta = data["choices"][0]["delta"] if delta.get("content", None): # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -354,7 +358,7 @@ async def async_request_openai_chat_completions( output.success = True output.latency = latency else: - output.error = response.reason + output.error = response.reason or "" output.success = False except Exception: output.success = False diff --git a/docs/source/conf.py b/docs/source/conf.py index 44cda7c99cdd..7a8c365ffb3b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,6 +12,7 @@ import logging import sys +from typing import List from sphinx.ext import autodoc @@ -45,7 +46,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = [] +exclude_patterns: List[str] = [] # Exclude the prompt "$" when copying code copybutton_prompt_text = r"\$ " diff --git a/setup.py b/setup.py index 98c92f9196e7..9f0814e9f3bf 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ import subprocess import sys from shutil import which -from typing import List +from typing import Dict, List import torch from packaging.version import Version, parse @@ -52,7 +52,7 @@ def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None: class cmake_build_ext(build_ext): # A dict of extension directories that have been configured. - did_config = {} + did_config: Dict[str, bool] = {} # # Determine number of compilation jobs and optionally nvcc compile threads. @@ -261,6 +261,7 @@ def get_nvcc_cuda_version() -> Version: Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py """ + assert CUDA_HOME is not None, "CUDA_HOME is not set" nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True) output = nvcc_output.split() diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 9f466566f096..fbceacf0ec41 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -1,5 +1,5 @@ -from abc import ABC, abstractmethod, abstractproperty -from typing import Dict, List, Optional, Protocol +from abc import ABC, abstractmethod +from typing import Dict, FrozenSet, List, Optional, Protocol from vllm.utils import Device @@ -10,23 +10,28 @@ class Block(ABC): def append_token_ids(self, token_ids: List[int]) -> None: pass - @abstractproperty + @property + @abstractmethod def block_id(self) -> Optional[int]: pass - @abstractproperty + @property + @abstractmethod def token_ids(self) -> List[int]: pass - @abstractproperty + @property + @abstractmethod def num_empty_slots(self) -> int: pass - @abstractproperty + @property + @abstractmethod def is_full(self) -> bool: pass - @abstractproperty + @property + @abstractmethod def prev_block(self) -> Optional["Block"]: pass @@ -47,12 +52,13 @@ def __call__( class BlockAllocator(ABC): @abstractmethod - def allocate_mutable(self, prev_block: Optional[Block]) -> Block: + def allocate_mutable(self, prev_block: Optional[Block], + device: Device) -> Block: pass @abstractmethod def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int]) -> Block: + token_ids: List[int], device: Device) -> Block: pass @abstractmethod @@ -64,11 +70,12 @@ def fork(self, last_block: Block) -> List[Block]: pass @abstractmethod - def get_num_free_blocks(self) -> int: + def get_num_free_blocks(self, device: Device) -> int: pass - @abstractproperty - def all_block_ids(self) -> frozenset[int]: + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: pass @abstractmethod diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 905db52a1912..02560907a128 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -1,6 +1,6 @@ import time from dataclasses import dataclass -from typing import Dict, List +from typing import Dict, List, Protocol import numpy as np from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, @@ -119,6 +119,12 @@ class Stats: time_e2e_requests: List[float] +class SupportsMetricsInfo(Protocol): + + def metrics_info(self) -> Dict[str, str]: + ... + + class StatLogger: """StatLogger is used LLMEngine to log to Promethus and Stdout.""" @@ -135,7 +141,7 @@ def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: self.labels = labels self.metrics = Metrics(labelnames=list(labels.keys())) - def info(self, type: str, obj: object) -> None: + def info(self, type: str, obj: SupportsMetricsInfo) -> None: if type == "cache_config": self.metrics.info_cache_config.info(obj.metrics_info()) diff --git a/vllm/logger.py b/vllm/logger.py index e5e46f5cce3f..af9575085ef3 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -4,6 +4,7 @@ import logging import os import sys +from typing import Optional VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) @@ -26,7 +27,7 @@ def format(self, record): _root_logger = logging.getLogger("vllm") -_default_handler = None +_default_handler: Optional[logging.Handler] = None def _setup_logger(): @@ -55,7 +56,12 @@ def init_logger(name: str): # Use the same settings as above for root logger logger = logging.getLogger(name) logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG")) + if VLLM_CONFIGURE_LOGGING: + if _default_handler is None: + raise ValueError( + "_default_handler is not set up. This should never happen!" + " Please open an issue on Github.") logger.addHandler(_default_handler) logger.propagate = False return logger diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index eb8d5f6dfb2a..6519781c8a8e 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -247,11 +247,12 @@ def _yarn_find_correction_dim(num_rotations: int, # Find dim range bounds based on rotations -def _yarn_find_correction_range(low_rot: int, - high_rot: int, - dim: int, - base: float = 10000, - max_position_embeddings: int = 2048) -> int: +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> Tuple[int, int]: low = math.floor( _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = math.ceil( @@ -293,8 +294,8 @@ def __init__( *, extrapolation_factor: float = 1, attn_factor: float = 1, - beta_fast: float = 32, - beta_slow: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 8a6ba6c5b396..ce7a30dce72f 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,10 +1,10 @@ -from typing import Optional +from typing import Dict, Optional from transformers import AutoConfig, PretrainedConfig from vllm.transformers_utils.configs import * -_CONFIG_REGISTRY = { +_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, "dbrx": DbrxConfig, "mpt": MPTConfig, diff --git a/vllm/transformers_utils/configs/dbrx.py b/vllm/transformers_utils/configs/dbrx.py index 3a19af7129e7..1d2724f22abd 100644 --- a/vllm/transformers_utils/configs/dbrx.py +++ b/vllm/transformers_utils/configs/dbrx.py @@ -12,7 +12,7 @@ logger = logging.get_logger(__name__) -DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore class DbrxAttentionConfig(PretrainedConfig): diff --git a/vllm/transformers_utils/tokenizers/baichuan.py b/vllm/transformers_utils/tokenizers/baichuan.py index 02045bdcb2cc..79894035cb1f 100644 --- a/vllm/transformers_utils/tokenizers/baichuan.py +++ b/vllm/transformers_utils/tokenizers/baichuan.py @@ -16,11 +16,11 @@ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} -PRETRAINED_VOCAB_FILES_MAP = { +PRETRAINED_VOCAB_FILES_MAP = { # type: ignore "vocab_file": {}, "tokenizer_file": {}, } -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} # type: ignore class BaichuanTokenizer(PreTrainedTokenizer): @@ -148,9 +148,9 @@ def save_vocabulary(self, `Tuple(str)`: Paths to the files saved. """ if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) " - "should be a directory") - return + raise ValueError(f"Vocabulary path ({save_directory}) " + "should be a directory") + out_vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + diff --git a/vllm/utils.py b/vllm/utils.py index fdb0a3768ab0..669b65891d0d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -294,7 +294,7 @@ def create_kv_caches_with_random( head_size: int, cache_dtype: Optional[Union[str, torch.dtype]], model_dtype: Optional[Union[str, torch.dtype]] = None, - seed: Optional[int] = 0, + seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) @@ -400,7 +400,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): gc.collect() -def str_to_int_tuple(s: str) -> Tuple[int]: +def str_to_int_tuple(s: str) -> Tuple[int, ...]: """Convert a string to a tuple of integers.""" try: return tuple(map(int, s.split(","))) From 1096717ae9e0b414ad625c1a12354dd1d949ffb1 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Fri, 12 Apr 2024 12:02:44 +0800 Subject: [PATCH 088/120] [Core] Support LoRA on quantized models (#4012) --- tests/lora/conftest.py | 5 + tests/lora/test_quant_model.py | 179 +++++++++++++++++++++++++++++++++ vllm/config.py | 9 +- vllm/lora/layers.py | 67 +++++++----- 4 files changed, 234 insertions(+), 26 deletions(-) create mode 100644 tests/lora/test_quant_model.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 207c635e2dc8..1127cc33183c 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -143,6 +143,11 @@ def baichuan_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider") +@pytest.fixture(scope="session") +def tinyllama_lora_files(): + return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") + + @pytest.fixture def llama_2_7b_engine_extra_embeddings() -> nn.Module: cleanup() diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py new file mode 100644 index 000000000000..3d86a4366aa5 --- /dev/null +++ b/tests/lora/test_quant_model.py @@ -0,0 +1,179 @@ +# Adapted from +# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py +from dataclasses import dataclass +from typing import List + +import pytest + +import vllm +from vllm.lora.request import LoRARequest + +from .conftest import cleanup + + +@dataclass +class ModelWithQuantization: + model_path: str + quantization: str + + +MODELS: List[ModelWithQuantization] = [ + ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", + quantization="AWQ"), + ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + quantization="GPTQ"), +] + + +def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256): + raw_prompts = [ + "Give me an orange-ish brown color", + "Give me a neon pink color", + ] + + def format_prompt_tuples(prompt): + return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + + prompts = [format_prompt_tuples(p) for p in raw_prompts] + + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=max_tokens, + stop=["<|im_end|>"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tp_size", [1]) +def test_quant_model_lora(tinyllama_lora_files, model, tp_size): + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < tp_size: + # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + llm = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_model_len=400, + tensor_parallel_size=tp_size, + quantization=model.quantization, + trust_remote_code=True) + + if model.quantization is None: + expected_no_lora_output = [ + "Here are some examples of orange-brown colors", + "I'm sorry, I don't have" + ] + expected_lora_output = [ + "#ff8050", + "#ff8080", + ] + elif model.quantization == "AWQ": + expected_no_lora_output = [ + "I'm sorry, I don't understand", + "I'm sorry, I don't understand", + ] + expected_lora_output = [ + "#f07700: A v", + "#f00000: A v", + ] + elif model.quantization == "GPTQ": + expected_no_lora_output = [ + "I'm sorry, I don't have", + "I'm sorry, I don't have", + ] + expected_lora_output = [ + "#f08800: This is", + "#f07788 \n#", + ] + + def expect_match(output, expected_output): + # HACK: GPTQ lora outputs are just incredibly unstable. + # Assert that the outputs changed. + if (model.quantization == "GPTQ" + and expected_output is expected_lora_output): + assert output != expected_no_lora_output + for i, o in enumerate(output): + assert o.startswith( + '#'), f"Expected example {i} to start with # but got {o}" + return + assert output == expected_output + + max_tokens = 10 + + print("lora adapter created") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=0, + max_tokens=max_tokens) + expect_match(output, expected_no_lora_output) + + print("lora 1") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=1, + max_tokens=max_tokens) + expect_match(output, expected_lora_output) + + print("no lora") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=0, + max_tokens=max_tokens) + expect_match(output, expected_no_lora_output) + + print("lora 2") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=2, + max_tokens=max_tokens) + expect_match(output, expected_lora_output) + + print("removing lora") + + del llm + cleanup() + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.skip("Requires multiple GPUs") +def test_quant_model_tp_equality(tinyllama_lora_files, model): + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < 2: + # pytest.skip(f"Not enough GPUs for tensor parallelism {2}") + + llm_tp1 = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=1, + quantization=model.quantization, + trust_remote_code=True) + output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) + + del llm_tp1 + cleanup() + + llm_tp2 = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=2, + quantization=model.quantization) + output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) + + del llm_tp2 + cleanup() + + assert output_tp1 == output_tp2 diff --git a/vllm/config.py b/vllm/config.py index 4102edbe01d3..da7eb2810ff0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -822,9 +822,12 @@ def verify_with_model_config(self, model_config: ModelConfig): self.lora_dtype = model_config.dtype elif isinstance(self.lora_dtype, str): self.lora_dtype = getattr(torch, self.lora_dtype) - if model_config.quantization is not None: - raise ValueError( - "LoRA is not supported with quantized models yet.") + if model_config.quantization and model_config.quantization not in [ + "awq", "gptq" + ]: + # TODO support marlin and squeezellm + logger.warning(f"{model_config.quantization} quantization is not " + "tested with LoRA yet.") def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): if scheduler_config.max_num_batched_tokens > 65528: diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 5456b5613c47..4b9653de73a8 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -29,6 +29,19 @@ pass +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + if hasattr(base_layer, "weight"): + return base_layer.weight.device + if hasattr(base_layer, "linear_weights") and isinstance( + base_layer.linear_weights, dict): + values = list(base_layer.linear_weights.values()) + if len(values) and isinstance(values[0], torch.Tensor): + return values[0].device + raise ValueError(f"Unsupported base layer: {base_layer}") + + def _apply_lora( x: torch.Tensor, lora_a_stacked: torch.Tensor, @@ -302,6 +315,9 @@ def __init__(self, base_layer: ColumnParallelLinear) -> None: super().__init__() self.base_layer = base_layer self.tp_size = get_tensor_model_parallel_world_size() + self.input_size = self.base_layer.input_size + self.output_size = self.base_layer.output_size_per_partition + self.device = _get_lora_device(self.base_layer) def create_lora_weights( self, @@ -312,17 +328,17 @@ def create_lora_weights( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.lora_b_stacked = torch.zeros( max_loras, 1, - self.base_layer.weight.shape[0], + self.output_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.indices: Optional[torch.Tensor] = None @@ -442,18 +458,18 @@ def create_lora_weights( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) for _ in range(n_slices)) self.lora_b_stacked = tuple( torch.zeros( max_loras, 1, - self.base_layer.weight.shape[0] // 2, + self.output_size // 2, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) for _ in range(n_slices)) self.indices: Optional[torch.Tensor] = None @@ -619,25 +635,25 @@ def create_lora_weights( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), ) self.lora_b_stacked = ( @@ -647,7 +663,7 @@ def create_lora_weights( self.q_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, @@ -655,7 +671,7 @@ def create_lora_weights( self.kv_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, @@ -663,7 +679,7 @@ def create_lora_weights( self.kv_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), ) @@ -766,6 +782,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: RowParallelLinear) -> None: super().__init__() self.base_layer = base_layer + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + self.device = _get_lora_device(self.base_layer) def create_lora_weights( self, @@ -777,20 +796,20 @@ def create_lora_weights( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, ), dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.lora_b_stacked = torch.zeros( ( max_loras, 1, - self.base_layer.weight.shape[0], + self.output_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.indices: Optional[torch.Tensor] = None self.indices_len: Optional[List[int]] = None @@ -809,7 +828,7 @@ def set_lora( self.reset_lora(index) if self.base_layer.tp_size > 1: tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.base_layer.weight.shape[1] + shard_size = self.input_size start_idx = tensor_model_parallel_rank * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size lora_a = lora_a[start_idx:end_idx, :] @@ -884,7 +903,9 @@ def forward(self, input_): @property def weight(self): - return self.base_layer.weight + + return self.base_layer.weight if hasattr( + self.base_layer, "weight") else self.base_layer.qweight @classmethod def can_replace_layer(cls, source_layer: nn.Module, From 7fd3949a0b1c6cd0dcd7066aca48d9d589f2f68e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 12 Apr 2024 13:30:54 +0800 Subject: [PATCH 089/120] [Frontend][Core] Move `merge_async_iterators` to utils (#4026) --- vllm/entrypoints/openai/serving_completion.py | 38 +----------------- vllm/utils.py | 40 ++++++++++++++++++- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index c1f1744a118b..e24aa2489a80 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -1,4 +1,3 @@ -import asyncio import time from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional, Tuple) @@ -17,7 +16,7 @@ from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput -from vllm.utils import random_uuid +from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) @@ -50,41 +49,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]: return prompt_is_tokens, prompts -def merge_async_iterators(*iterators): - """Merge multiple asynchronous iterators into a single iterator. - - This method handle the case where some iterators finish before others. - When it yields, it yields a tuple (i, item) where i is the index of the - iterator that yields the item. - """ - queue = asyncio.Queue() - - finished = [False] * len(iterators) - - async def producer(i, iterator): - try: - async for item in iterator: - await queue.put((i, item)) - except Exception as e: - await queue.put(e) - finished[i] = True - - _tasks = [ - asyncio.create_task(producer(i, iterator)) - for i, iterator in enumerate(iterators) - ] - - async def consumer(): - while not all(finished) or not queue.empty(): - item = await queue.get() - if isinstance(item, Exception): - raise item - yield item - await asyncio.gather(*_tasks) - - return consumer() - - class OpenAIServingCompletion(OpenAIServing): def __init__(self, diff --git a/vllm/utils.py b/vllm/utils.py index 669b65891d0d..0967dfc969c8 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -9,8 +9,8 @@ from collections import OrderedDict, defaultdict from functools import lru_cache, partial from platform import uname -from typing import (Any, Awaitable, Callable, Dict, Generic, Hashable, List, - Optional, Tuple, TypeVar, Union) +from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, + Hashable, List, Optional, Tuple, TypeVar, Union) import psutil import torch @@ -181,6 +181,42 @@ def _async_wrapper(*args, **kwargs) -> asyncio.Future: return _async_wrapper +def merge_async_iterators( + *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]: + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + """ + queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue() + + finished = [False] * len(iterators) + + async def producer(i: int, iterator: AsyncIterator[T]): + try: + async for item in iterator: + await queue.put((i, item)) + except Exception as e: + await queue.put(e) + finished[i] = True + + _tasks = [ + asyncio.create_task(producer(i, iterator)) + for i, iterator in enumerate(iterators) + ] + + async def consumer(): + while not all(finished) or not queue.empty(): + item = await queue.get() + if isinstance(item, Exception): + raise item + yield item + await asyncio.gather(*_tasks) + + return consumer() + + def get_ip() -> str: host_ip = os.environ.get("HOST_IP") if host_ip: From 36729bac1303b655b816b77f45b17237bfafd692 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sat, 13 Apr 2024 01:56:57 +0900 Subject: [PATCH 090/120] [Test] Test multiple attn backend for chunked prefill. (#4023) --- .buildkite/test-pipeline.yaml | 8 +++++++- .../test_basic_correctness.py | 6 ------ .../basic_correctness/test_chunked_prefill.py | 4 ---- vllm/attention/backends/rocm_flash_attn.py | 18 ++++++------------ 4 files changed, 13 insertions(+), 23 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 695290ed74ab..8d7d6304cf12 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -12,7 +12,13 @@ steps: command: pytest -v -s async_engine - label: Basic Correctness Test - command: pytest -v -s basic_correctness + commands: + - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py - label: Core Test command: pytest -v -s core diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index bd4c7ea3301b..97cff623c5e1 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -4,8 +4,6 @@ """ import pytest -from vllm.attention.selector import VLLM_ATTENTION_BACKEND - MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -16,7 +14,6 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False, True]) -@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"]) def test_models( hf_runner, vllm_runner, @@ -25,10 +22,7 @@ def test_models( dtype: str, max_tokens: int, enforce_eager: bool, - attn_backend: str, - monkeypatch, ) -> None: - monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend) hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 9ff07b3c0902..d83416eb51b4 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -33,10 +33,6 @@ def test_models( enforce_eager: bool, tensor_parallel_size: int, ) -> None: - if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16 - and not enforce_eager): - pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} " - "for high TP to save testing time.") max_num_seqs = min(chunked_prefill_token_size, 256) enable_chunked_prefill = False max_num_batched_tokens = None diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e55435cd2c94..c42660fb8f74 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -162,7 +162,7 @@ def __init__( # AMD Radeon 7900 series (gfx1100) currently does not support # xFormers nor FlashAttention. As a temporary workaround, we use # naive PyTorch implementation of attention. - self.attn_fuc = _naive_attention() + self.attn_fuc = _naive_attention logger.debug("Using naive attention in ROCmBackend") elif self.use_triton_flash_attn: from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 @@ -334,26 +334,21 @@ def _naive_attention( prompt_lens: List[int], scale: float, ) -> torch.Tensor: - num_tokens = query.shape[0] output = torch.empty_like(query) start = 0 for _, prompt_len in enumerate(prompt_lens): end = start + prompt_len out = _naive_masked_attention( - query[None, start:end], - key[None, start:end], - value[None, start:end], + query[start:end], + key[start:end], + value[start:end], scale, ) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out) start += prompt_len - # Using view got RuntimeError: view size is not compatible - # with input tensor's size and stride (at least one - # dimension spans across two contiguous subspaces). - # Use reshape instead. - return output.reshape(num_tokens, -1) + return output def _naive_masked_attention( @@ -362,14 +357,13 @@ def _naive_masked_attention( value: torch.Tensor, scale: float, ) -> torch.Tensor: - seq_len, _, _ = query.shape + seq_len, head_size, head_dim = query.shape attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=query.dtype, device=query.device), diagonal=1) attn_mask = attn_mask * torch.finfo(query.dtype).min - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() attn_weights = attn_weights + attn_mask.float() attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) From 96b6a6d790115d04bb87d410f3bdd5d7d85b43f1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 12 Apr 2024 12:35:44 -0700 Subject: [PATCH 091/120] [Bugfix] fix type hint for py 3.8 (#4036) --- vllm/executor/executor_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index c18edd75d7a4..55bccfa8e3ca 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -31,7 +31,7 @@ def __init__( raise NotImplementedError @abstractmethod - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available blocks for the GPU KV cache and swappable CPU KV cache. From d4ec9ffb9574988132d927fd1615180522877262 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 12 Apr 2024 13:56:04 -0700 Subject: [PATCH 092/120] [Misc] Fix typo in scheduler.py (#4022) --- vllm/core/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 2942eab735a9..e44f983e1537 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -674,7 +674,7 @@ def _schedule_prefills( def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. - The current policy is designed to opimimize the throughput. First, + The current policy is designed to optimize the throughput. First, it batches as many prefill requests as possible. And it schedules decodes. If there's a pressure on GPU memory, decode requests can be swapped or preempted. From 09473ee41c0a22c4d18936ea7eb2328071c19308 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sat, 13 Apr 2024 06:35:50 +0900 Subject: [PATCH 093/120] [mypy] Add mypy type annotation part 1 (#4006) --- .github/workflows/mypy.yaml | 50 ++++++++++++++++++++++++++ format.sh | 22 +++++++++--- pyproject.toml | 5 ++- requirements-common.txt | 3 +- requirements-dev.txt | 2 +- vllm/config.py | 9 +++-- vllm/core/block_manager_v1.py | 12 ++++--- vllm/core/block_manager_v2.py | 4 ++- vllm/core/interfaces.py | 4 ++- vllm/core/scheduler.py | 25 +++++++------ vllm/distributed/communication_op.py | 10 +++--- vllm/engine/ray_utils.py | 18 ++++++---- vllm/entrypoints/api_server.py | 1 + vllm/entrypoints/llm.py | 8 +++-- vllm/executor/cpu_executor.py | 4 +-- vllm/executor/gpu_executor.py | 4 +-- vllm/executor/neuron_executor.py | 4 +-- vllm/executor/ray_gpu_executor.py | 11 +++--- vllm/sampling_params.py | 5 +-- vllm/sequence.py | 8 ++--- vllm/transformers_utils/config.py | 3 +- vllm/transformers_utils/detokenizer.py | 7 ++-- vllm/transformers_utils/tokenizer.py | 4 +-- vllm/usage/usage_lib.py | 8 ++--- vllm/utils.py | 12 ++++--- 25 files changed, 171 insertions(+), 72 deletions(-) create mode 100644 .github/workflows/mypy.yaml diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml new file mode 100644 index 000000000000..fbe0f816fd4a --- /dev/null +++ b/.github/workflows/mypy.yaml @@ -0,0 +1,50 @@ +name: mypy + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + ruff: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install mypy==1.9.0 + pip install types-setuptools + pip install types-PyYAML + pip install types-requests + pip install types-setuptools + - name: Mypy + run: | + mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml + + # TODO(sang): Follow up + # mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml + diff --git a/format.sh b/format.sh index deb57b2b049d..1c195b899c74 100755 --- a/format.sh +++ b/format.sh @@ -93,9 +93,23 @@ fi echo 'vLLM yapf: Done' # Run mypy -# TODO(zhuohan): Enable mypy -# echo 'vLLM mypy:' -# mypy +echo 'vLLM mypy:' +mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml + +# TODO(sang): Follow up +# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml + CODESPELL_EXCLUDES=( '--skip' '*docs/source/_build/**' @@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then exit 1 fi - - diff --git a/pyproject.toml b/pyproject.toml index 2a00d6796ee0..b870a4b85897 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,10 +46,13 @@ ignore = [ python_version = "3.8" ignore_missing_imports = true + check_untyped_defs = true files = "vllm" # TODO(woosuk): Include the code from Megatron and HuggingFace. -exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/" +exclude = [ + "vllm/model_executor/parallel_utils/|vllm/model_executor/models/", +] [tool.codespell] diff --git a/requirements-common.txt b/requirements-common.txt index ff053388a23e..c96f9c9937fb 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -11,4 +11,5 @@ uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 tiktoken == 0.6.0 # Required for DBRX tokenizer -outlines == 0.0.34 # Requires torch >= 2.1.0 \ No newline at end of file +outlines == 0.0.34 # Requires torch >= 2.1.0 +typing_extensions \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 75d22bbdb2a1..96dfda6faf00 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,7 +7,7 @@ codespell==2.2.6 isort==5.13.2 # type checking -mypy==0.991 +mypy==1.9.0 types-PyYAML types-requests types-setuptools diff --git a/vllm/config.py b/vllm/config.py index da7eb2810ff0..bbda4ecf3cc5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2,7 +2,7 @@ import json import os from dataclasses import dataclass, fields -from typing import TYPE_CHECKING, ClassVar, Optional, Union +from typing import TYPE_CHECKING, ClassVar, List, Optional, Union import torch from packaging.version import Version @@ -141,7 +141,7 @@ def _verify_load_format(self) -> None: supported_load_format = [ "auto", "pt", "safetensors", "npcache", "dummy" ] - rocm_not_supported_load_format = [] + rocm_not_supported_load_format: List[str] = [] if load_format not in supported_load_format: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " @@ -679,6 +679,9 @@ def maybe_create_spec_config( "num_speculative_tokens to be provided, but found " f"{speculative_model=} and {num_speculative_tokens=}.") + assert (speculative_model is not None + and num_speculative_tokens is not None) + # TODO: The user should be able to specify revision/quantization/max # model len for the draft model. It is not currently supported. draft_revision = None @@ -993,7 +996,7 @@ def _get_and_verify_max_len( derived_max_model_len *= scaling_factor if max_model_len is None: - max_model_len = derived_max_model_len + max_model_len = int(derived_max_model_len) elif max_model_len > derived_max_model_len: # Some models might have a separate key for specifying model_max_length # that will be bigger than derived_max_model_len. We compare user input diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index e7e3b4dc1e9b..e391a3b1e5a3 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -1,5 +1,6 @@ """A block manager that manages token blocks.""" from abc import ABC, abstractmethod +from collections.abc import Sequence as GenericSequence from itertools import count, takewhile from os.path import commonprefix from typing import Dict, List, Optional, Set @@ -231,10 +232,10 @@ def __init__( if self.enable_caching: logger.info("Automatic prefix caching is enabled.") - self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size, - num_gpu_blocks) - self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size, - num_cpu_blocks) + self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator( + Device.GPU, block_size, num_gpu_blocks) + self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator( + Device.CPU, block_size, num_cpu_blocks) else: self.gpu_allocator = UncachedBlockAllocator( Device.GPU, block_size, num_gpu_blocks) @@ -588,7 +589,8 @@ def get_all_computed_blocks(self, seq: Sequence) -> List[int]: for b in takewhile(lambda b: b.computed, block_table[:-1]) ] - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: """Return the block ids that are common for a given sequence group. Used in prefill (can skip prefill of some blocks). diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 813e71ad883b..19f0cf415eb3 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -1,4 +1,5 @@ """A block manager that manages token blocks.""" +from collections.abc import Sequence as GenericSequence from typing import Dict, List, Optional from vllm.core.block.block_table import BlockTable @@ -205,7 +206,8 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup): # as computed. self.block_allocator.mark_blocks_as_computed() - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: """Determine which blocks for which we skip prefill. With prefix caching we can skip prefill for previously-generated blocks. diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 711536bcc97b..c1f68a2e891b 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -1,5 +1,6 @@ import enum from abc import ABC, abstractmethod +from collections.abc import Sequence as GenericSequence from typing import Dict, List from vllm.sequence import Sequence, SequenceGroup @@ -103,7 +104,8 @@ def access_all_blocks_in_seq( pass @abstractmethod - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index e44f983e1537..18ddcd1d6d46 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -42,8 +42,8 @@ class SchedulingBudget: """ token_budget: int max_num_seqs: int - _requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set) - _requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set) + _requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set) + _requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set) _num_batched_tokens: int = 0 _num_curr_seqs: int = 0 @@ -133,7 +133,7 @@ def is_empty(self) -> bool: return (not self.scheduled_seq_groups and not self.blocks_to_swap_in and not self.blocks_to_swap_out and not self.blocks_to_copy) - def _sort_by_lora_ids(self) -> bool: + def _sort_by_lora_ids(self): self.scheduled_seq_groups = sorted( self.scheduled_seq_groups, key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) @@ -337,7 +337,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: self.free_seq(seq) def has_unfinished_seqs(self) -> bool: - return self.waiting or self.running or self.swapped + return len(self.waiting) != 0 or len(self.running) != 0 or len( + self.swapped) != 0 def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) @@ -404,7 +405,7 @@ def _schedule_running( budget.subtract_num_seqs(seq_group.request_id, num_running_seqs) if curr_loras is not None and seq_group.lora_int_id > 0: - curr_loras.pop(seq_group.lora_int_id) + curr_loras.remove(seq_group.lora_int_id) if running_queue: # Preempt the lowest-priority sequence groups. @@ -496,7 +497,7 @@ def _schedule_swapped( now = time.time() swapped_queue = policy.sort_by_priority(now, swapped_queue) - leftover_swapped = deque() + leftover_swapped: Deque[SequenceGroup] = deque() while swapped_queue: seq_group = swapped_queue[0] @@ -507,7 +508,9 @@ def _schedule_swapped( lora_int_id = 0 if self.lora_enabled: lora_int_id = seq_group.lora_int_id - if (lora_int_id > 0 and lora_int_id not in curr_loras + assert curr_loras is not None + assert self.lora_config is not None + if (lora_int_id > 0 and (lora_int_id not in curr_loras) and len(curr_loras) >= self.lora_config.max_loras): # We don't have a space for another LoRA, so # we ignore this request for now. @@ -593,7 +596,7 @@ def _schedule_prefills( # Copy the queue so that the input queue is not modified. waiting_queue = deque([s for s in waiting_queue]) - leftover_waiting_sequences = deque() + leftover_waiting_sequences: Deque[SequenceGroup] = deque() while self._passed_delay(time.time()) and waiting_queue: seq_group = waiting_queue[0] @@ -635,6 +638,8 @@ def _schedule_prefills( lora_int_id = 0 if self.lora_enabled: lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None if (self.lora_enabled and lora_int_id > 0 and lora_int_id not in curr_loras and len(curr_loras) >= self.lora_config.max_loras): @@ -780,7 +785,7 @@ def _schedule_chunked_prefill(self): token_budget=self.scheduler_config.max_num_batched_tokens, max_num_seqs=self.scheduler_config.max_num_seqs, ) - curr_loras = set() + curr_loras: Set[int] = set() remaining_waiting, prefills = (self.waiting, SchedulerPrefillOutputs.create_empty()) @@ -1087,7 +1092,7 @@ def _get_num_lookahead_slots(self, is_prefill: bool) -> int: def _get_num_new_tokens(self, seq_group: SequenceGroup, status: SequenceStatus, enable_chunking: bool, - budget: SchedulingBudget) -> Tuple[int, bool]: + budget: SchedulingBudget) -> int: """Get the next new tokens to compute for a given sequence group that's in a given `status`. diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 1004d626b6a4..a3e93691a1e8 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,5 +1,5 @@ from collections import namedtuple -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch.distributed import ProcessGroup @@ -144,7 +144,7 @@ def broadcast_tensor_dict( tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, -) -> Dict[Any, Union[torch.Tensor, Any]]: +) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary.""" group = group or torch.distributed.group.WORLD ranks = torch.distributed.get_process_group_ranks(group) @@ -157,10 +157,10 @@ def broadcast_tensor_dict( rank = torch.distributed.get_rank() if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] assert isinstance( tensor_dict, dict), (f"Expecting a dictionary, got {type(tensor_dict)}") - metadata_list = [] for key, value in tensor_dict.items(): if isinstance(value, torch.Tensor): assert value.is_cuda, ( @@ -190,10 +190,10 @@ def broadcast_tensor_dict( torch.distributed.broadcast_object_list(recv_metadata_list, src=src, group=group) - metadata_list = recv_metadata_list[0] + assert recv_metadata_list[0] is not None tensor_dict = {} async_handles = [] - for key, value in metadata_list: + for key, value in recv_metadata_list[0]: if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, dtype=value.dtype, diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index 70d5c9b1fae0..04d4ed83976d 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -1,9 +1,10 @@ import pickle -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple from vllm.config import ParallelConfig from vllm.logger import init_logger from vllm.utils import get_ip, is_hip, set_cuda_visible_devices +from vllm.worker.worker import Worker logger = init_logger(__name__) @@ -18,15 +19,20 @@ def __init__(self, init_cached_hf_modules=False) -> None: if init_cached_hf_modules: from transformers.dynamic_module_utils import init_hf_modules init_hf_modules() - self.worker = None + self._worker: Optional[Worker] = None # Since the compiled DAG runs a main execution # in a different thread that calls cuda.set_device. # The flag indicates is set_device is called on # that thread. self.compiled_dag_cuda_device_set = False - def init_worker(self, worker_init_fn): - self.worker = worker_init_fn() + def init_worker(self, worker_init_fn: Callable[[], Worker]): + self._worker = worker_init_fn() + + @property + def worker(self) -> Worker: + assert self._worker is not None + return self._worker def __getattr__(self, name): return getattr(self.worker, name) @@ -70,8 +76,8 @@ def execute_model_compiled_dag_remote(self, ignored): logger.warning(f"Failed to import Ray with {e!r}. " "For distributed inference, please install Ray with " "`pip install ray`.") - ray = None - RayWorkerVllm = None + ray = None # type: ignore + RayWorkerVllm = None # type: ignore def initialize_ray_cluster( diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 2a47eae112c1..587142adb9c6 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -47,6 +47,7 @@ async def generate(request: Request) -> Response: sampling_params = SamplingParams(**request_dict) request_id = random_uuid() + assert engine is not None results_generator = engine.generate(prompt, sampling_params, request_id) # Streaming case diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 5777e8179a1c..63ff0b30da55 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -170,8 +170,12 @@ def generate( multi_modal_data.data = multi_modal_data.data.to(torch.float16) # Add requests to the engine. - num_requests = len(prompts) if prompts is not None else len( - prompt_token_ids) + if prompts is not None: + num_requests = len(prompts) + else: + assert prompt_token_ids is not None + num_requests = len(prompt_token_ids) + for i in range(num_requests): prompt = prompts[i] if prompts is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index eda4e8989c16..33e67d8b3eec 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch @@ -61,7 +61,7 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 80ca5cb7367c..f20221a0b941 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -66,7 +66,7 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 57436a85cfa2..ee8e87432fa6 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -47,7 +47,7 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6c0ccd7e64c9..b937693c9225 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,7 +3,7 @@ import os import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -197,7 +197,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_parallel_loading_workers, ) - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. This invokes `determine_num_available_blocks` on each worker and takes @@ -205,7 +205,7 @@ def determine_num_available_blocks(self) -> tuple[int, int]: compatible with all workers. Returns: - - tuple[num_gpu_blocks, num_cpu_blocks] + - Tuple[num_gpu_blocks, num_cpu_blocks] """ # Get the maximum number of blocks that can be allocated on GPU and CPU. num_blocks = self._run_workers("determine_num_available_blocks", ) @@ -276,7 +276,7 @@ def _run_workers( self, method: str, *args, - driver_args: Optional[List[Any]] = None, + driver_args: Optional[Tuple[Any, ...]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, @@ -291,6 +291,7 @@ def _run_workers( if use_ray_compiled_dag: # Right now, compiled DAG can only accept a single # input. TODO(sang): Fix it. + assert self.forward_dag is not None output_channels = self.forward_dag.execute(1) else: # Start the ray workers first. @@ -369,7 +370,7 @@ async def _run_workers_async( self, method: str, *args, - driver_args: Optional[List[Any]] = None, + driver_args: Optional[Tuple[Any, ...]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> Any: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 0b9787608798..53a38b25bfda 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -5,7 +5,8 @@ from typing import Callable, List, Optional, Union import torch -from pydantic import conint +from pydantic import Field +from typing_extensions import Annotated _SAMPLING_EPS = 1e-5 @@ -127,7 +128,7 @@ def __init__( skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, logits_processors: Optional[List[LogitsProcessor]] = None, - truncate_prompt_tokens: Optional[conint(ge=1)] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n diff --git a/vllm/sequence.py b/vllm/sequence.py index cdb6cce6f025..dcde81df1992 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -171,10 +171,10 @@ def get_last_token_id(self) -> int: return self.prompt_token_ids[-1] return self.output_token_ids[-1] - def get_prompt_token_ids(self) -> int: + def get_prompt_token_ids(self) -> List[int]: return self.prompt_token_ids - def get_output_token_ids(self) -> int: + def get_output_token_ids(self) -> List[int]: return self.output_token_ids @property @@ -370,7 +370,7 @@ class SequenceGroupState: """Mutable state tied to a specific sequence group""" # torch.Generator used in seeded sampling - generator: Optional = None + generator: Optional = None # type: ignore class MultiModalData: @@ -599,7 +599,7 @@ def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @property - def token_chunk_size(self) -> int: + def token_chunk_size(self) -> Optional[int]: """Return the number of tokens to be processed (chunk size).""" return self._token_chunk_size diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index ce7a30dce72f..1756c91a612f 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -2,7 +2,8 @@ from transformers import AutoConfig, PretrainedConfig -from vllm.transformers_utils.configs import * +from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, + JAISConfig, MPTConfig, RWConfig) _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 005932f1e3df..f064c26c3f40 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -168,8 +168,8 @@ def _convert_tokens_to_string_with_added_encoders( # NOTE(woosuk): The following code is slow because it runs a for loop over # the output_tokens. In Python, running a for loop over a list can be slow # even when the loop body is very simple. - sub_texts = [] - current_sub_text = [] + sub_texts: List[str] = [] + current_sub_text: List[str] = [] all_special_tokens = set(tokenizer.all_special_tokens) for token in output_tokens: if skip_special_tokens and token in all_special_tokens: @@ -263,6 +263,7 @@ def detokenize_incrementally( tokenizer, all_input_ids[:-1], skip_special_tokens=skip_special_tokens) + assert prev_tokens is not None # If the new token id is out of bounds, return an empty string. if new_token_id >= len(tokenizer): @@ -271,6 +272,8 @@ def detokenize_incrementally( # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( [new_token_id], skip_special_tokens=skip_special_tokens) + if isinstance(new_tokens, str): + new_tokens = [new_tokens] output_tokens = prev_tokens + new_tokens # If this is the first iteration, return all tokens. diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index e216a99af91f..5d3d5801c960 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -5,7 +5,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizers import * +from vllm.transformers_utils.tokenizers import BaichuanTokenizer from vllm.utils import make_async logger = init_logger(__name__) @@ -28,7 +28,7 @@ def get_cached_tokenizer( tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) tokenizer_len = len(tokenizer) - class CachedTokenizer(tokenizer.__class__): + class CachedTokenizer(tokenizer.__class__): # type: ignore @property def all_special_ids(self): diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 658fe5c98f5e..b2672f7f1da6 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -7,7 +7,7 @@ from enum import Enum from pathlib import Path from threading import Thread -from typing import Dict, Optional +from typing import Any, Dict, Optional from uuid import uuid4 import cpuinfo @@ -124,7 +124,7 @@ def __init__(self) -> None: def report_usage(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any] = None) -> None: + extra_kvs: Optional[Dict[str, Any]] = None) -> None: t = Thread(target=self._report_usage_worker, args=(model_architecture, usage_context, extra_kvs or {}), daemon=True) @@ -132,13 +132,13 @@ def report_usage(self, def _report_usage_worker(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any]) -> None: + extra_kvs: Dict[str, Any]) -> None: self._report_usage_once(model_architecture, usage_context, extra_kvs) self._report_continous_usage() def _report_usage_once(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any]) -> None: + extra_kvs: Dict[str, Any]) -> None: # Platform information if torch.cuda.is_available(): device_property = torch.cuda.get_device_properties(0) diff --git a/vllm/utils.py b/vllm/utils.py index 0967dfc969c8..4c0dc9ca729a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -60,7 +60,7 @@ def __contains__(self, key: Hashable) -> bool: def __len__(self) -> int: return len(self.cache) - def __getitem__(self, key: Hashable) -> T: + def __getitem__(self, key: Hashable) -> Optional[T]: return self.get(key) def __setitem__(self, key: Hashable, value: T) -> None: @@ -76,7 +76,7 @@ def get(self, key: Hashable, default_value: Optional[T] = None) -> Optional[T]: if key in self.cache: - value = self.cache[key] + value: Optional[T] = self.cache[key] self.cache.move_to_end(key) else: value = default_value @@ -87,7 +87,7 @@ def put(self, key: Hashable, value: T) -> None: self.cache.move_to_end(key) self._remove_old_if_needed() - def _on_remove(self, key: Hashable, value: T): + def _on_remove(self, key: Hashable, value: Optional[T]): pass def remove_oldest(self): @@ -100,9 +100,11 @@ def _remove_old_if_needed(self) -> None: while len(self.cache) > self.capacity: self.remove_oldest() - def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T: + def pop(self, + key: Hashable, + default_value: Optional[T] = None) -> Optional[T]: run_on_remove = key in self.cache - value = self.cache.pop(key, default_value) + value: Optional[T] = self.cache.pop(key, default_value) if run_on_remove: self._on_remove(key, value) return value From fbb9d9eef48a29e0ea821bbf399e4bf9a08d6ac1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 12 Apr 2024 16:40:39 -0700 Subject: [PATCH 094/120] [Core] fix custom allreduce default value (#4040) --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 63ff0b30da55..9e08c253dc53 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -86,7 +86,7 @@ def __init__( swap_space: int = 4, enforce_eager: bool = False, max_context_len_to_capture: int = 8192, - disable_custom_all_reduce: bool = True, + disable_custom_all_reduce: bool = False, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: From d04973ad5446fe05c06035f6b2d99402fc3ac7bf Mon Sep 17 00:00:00 2001 From: Bellk17 Date: Fri, 12 Apr 2024 16:41:26 -0700 Subject: [PATCH 095/120] Fix triton compilation issue (#3984) Co-authored-by: Woosuk Kwon --- vllm/attention/ops/triton_flash_attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 87cf30cbef79..e160411859f0 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -415,7 +415,11 @@ def attn_fwd( return is_mqa = hq != hk - off_h_k = off_h_q % hk if is_mqa else off_h_q + if is_mqa: # noqa: SIM108 + off_h_k = off_h_q % hk + else: + off_h_k = off_h_q + n_extra_tokens = 0 if seqlen_k < BLOCK_N: n_extra_tokens = BLOCK_N - seqlen_k From b8aacac31a4e2e03381fdaef6f1e4bbb895f3b64 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Sat, 13 Apr 2024 07:56:37 +0800 Subject: [PATCH 096/120] [Bugfix] Fix LoRA bug (#4032) --- vllm/lora/layers.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 4b9653de73a8..aac86351b15e 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -32,14 +32,17 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear if hasattr(base_layer, "weight"): return base_layer.weight.device - if hasattr(base_layer, "linear_weights") and isinstance( - base_layer.linear_weights, dict): - values = list(base_layer.linear_weights.values()) - if len(values) and isinstance(values[0], torch.Tensor): - return values[0].device - raise ValueError(f"Unsupported base layer: {base_layer}") + # GPTQ/AWQ/SqueezeLLM + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # marlin + elif hasattr(base_layer, "B"): + return base_layer.B.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") def _apply_lora( From 546e7211684a28bbe53088961b4cf5123e235760 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 12 Apr 2024 18:43:37 -0700 Subject: [PATCH 097/120] [CI/Test] expand ruff and yapf for all supported python version (#4037) --- .github/workflows/mypy.yaml | 2 +- .github/workflows/ruff.yml | 2 +- .github/workflows/yapf.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index fbe0f816fd4a..6db0bb7645ec 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index e8060e369a88..e71033f82800 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml index b163c960db55..04f307bcf8b0 100644 --- a/.github/workflows/yapf.yml +++ b/.github/workflows/yapf.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} From 5c2e66e4871917c5d59cc4a8b89ef53e690e9bd9 Mon Sep 17 00:00:00 2001 From: Dylan Hawk <51147702+dylanwhawk@users.noreply.github.com> Date: Fri, 12 Apr 2024 21:07:04 -0700 Subject: [PATCH 098/120] [Bugfix] More type hint fixes for py 3.8 (#4039) --- vllm/executor/executor_base.py | 2 +- vllm/worker/cpu_worker.py | 4 ++-- vllm/worker/neuron_worker.py | 4 ++-- vllm/worker/worker_base.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 55bccfa8e3ca..bbfbfc689c99 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -39,7 +39,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: ExecutorBase may require modification of the result, e.g. to ensure the selected cache sizes are compatible with all workers. - Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks are blocks that are "active" on the device and can be appended to. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be appended to. diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 3989207e8dd8..41341b063bed 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,5 +1,5 @@ """A CPU worker class.""" -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch import torch.distributed @@ -157,7 +157,7 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of blocks available for the KV cache. This determines how many KV blocks can fit into the configured CPU diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 6136d50d0c06..2f22f82c045d 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -1,5 +1,5 @@ """A Neuron worker class.""" -from typing import List, Optional +from typing import List, Optional, Tuple import torch import torch.distributed @@ -40,7 +40,7 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. Swapping is not yet supported, so always return num_cpu_blocks=0. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e3027c406ffe..d8c9febb1158 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List +from typing import Dict, List, Tuple from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -18,14 +18,14 @@ def init_device(self) -> None: raise NotImplementedError @abstractmethod - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available blocks for the GPU KV cache and swappable CPU KV cache. The implementation may run profiling or other heuristics to determine the size of caches. - Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks are blocks that are "active" on the device and can be appended to. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be appended to. From 98afde19fc273b1e6a695990b93ec07157b856f1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 13 Apr 2024 07:12:53 -0700 Subject: [PATCH 099/120] [Core][Distributed] improve logging for init dist (#4042) --- vllm/distributed/parallel_state.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 9fceffe7cb88..1258bf58cb45 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -8,6 +8,10 @@ import torch +from vllm.logger import init_logger + +logger = init_logger(__name__) + # Tensor model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None # Pipeline model parallel group that the current rank belongs to. @@ -45,6 +49,8 @@ def init_distributed_environment( local_rank: int = -1, backend: str = "nccl", ): + logger.debug(f"{world_size=} {rank=} {local_rank=} " + f"{distributed_init_method=} {backend=}") if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " From ec8e3c695f2dce080bde569746180300e91084a3 Mon Sep 17 00:00:00 2001 From: zspo Date: Sat, 13 Apr 2024 22:52:36 +0800 Subject: [PATCH 100/120] [Bugfix] fix_log_time_in_metrics (#4050) --- vllm/engine/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 02560907a128..04e27e69ce0f 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -130,7 +130,7 @@ class StatLogger: def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: # Metadata for logging locally. - self.last_local_log = time.monotonic() + self.last_local_log = time.time() self.local_interval = local_interval # Tracked stats over current local logging interval. From 0a430b4ae2763c2f161e3bfb1529acf4685f7caa Mon Sep 17 00:00:00 2001 From: zspo Date: Sat, 13 Apr 2024 22:54:03 +0800 Subject: [PATCH 101/120] [Bugfix] fix_small_bug_in_neuron_executor (#4051) --- vllm/executor/neuron_executor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index ee8e87432fa6..d45f18e46625 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -25,6 +25,7 @@ def __init__( speculative_config: Optional[SpeculativeConfig], ) -> None: self.model_config = model_config + self.cache_config = cache_config assert lora_config is None, "LoRA is not supported for Neuron backend." self.parallel_config = parallel_config self.scheduler_config = scheduler_config @@ -43,6 +44,7 @@ def _init_worker(self): self.parallel_config, self.scheduler_config, self.device_config, + self.cache_config, ) self.driver_worker.init_device() self.driver_worker.load_model() From 989ae2538df211ca3a31f77ac8e106c5c97c6e53 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Sat, 13 Apr 2024 22:55:05 +0800 Subject: [PATCH 102/120] [Kernel] Add punica dimension for Baichuan-13B (#4053) --- csrc/punica/bgmv/bgmv_config.h | 1 + tests/lora/test_baichuan.py | 2 +- tests/lora/test_punica.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 9b76b98ab332..d2906914f927 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -47,6 +47,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 13696) \ f(in_T, out_T, W_T, narrow, 13824) \ f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 15360) \ f(in_T, out_T, W_T, narrow, 16384) \ f(in_T, out_T, W_T, narrow, 20480) \ f(in_T, out_T, W_T, narrow, 22016) \ diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index 2178266d2e0c..5ab863eea94b 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -62,7 +62,7 @@ def test_baichuan_lora(baichuan_lora_files): @pytest.mark.skip("Requires multiple GPUs") -def test_llama_tensor_parallel_equality(baichuan_lora_files): +def test_baichuan_tensor_parallel_equality(baichuan_lora_files): # Cannot use as it will initialize torch.cuda too early... # if torch.cuda.device_count() < 4: # pytest.skip(f"Not enough GPUs for tensor parallelism {4}") diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index cab8b44ccd2d..8b174f01d87d 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -72,6 +72,7 @@ def _lora_ref_impl( 11008, 13824, 14336, + 15360, 22016, 24576, 27392, From 711a000255eac3e034f0b73aa5cc62b45201a571 Mon Sep 17 00:00:00 2001 From: Sanger Steel Date: Sat, 13 Apr 2024 20:13:01 -0400 Subject: [PATCH 103/120] [Frontend] [Core] feat: Add model loading using `tensorizer` (#3476) --- .buildkite/test-pipeline.yaml | 3 + docs/source/conf.py | 1 + docs/source/models/engine_args.rst | 3 +- examples/tensorize_vllm_model.py | 254 ++++++++++++++ requirements-cpu.txt | 2 +- requirements-dev.txt | 1 + setup.py | 3 + tests/tensorizer/__init__.py | 0 .../tensorize_vllm_model_for_testing.py | 245 ++++++++++++++ tests/tensorizer/test_tensorizer.py | 302 +++++++++++++++++ vllm/config.py | 74 +++- vllm/engine/arg_utils.py | 45 ++- vllm/engine/llm_engine.py | 8 +- vllm/executor/gpu_executor.py | 23 +- vllm/executor/ray_gpu_executor.py | 6 +- vllm/model_executor/model_loader.py | 61 +++- vllm/model_executor/tensorizer_loader.py | 319 ++++++++++++++++++ vllm/model_executor/weight_utils.py | 34 +- vllm/worker/model_runner.py | 9 +- vllm/worker/worker.py | 9 +- 20 files changed, 1351 insertions(+), 51 deletions(-) create mode 100644 examples/tensorize_vllm_model.py create mode 100644 tests/tensorizer/__init__.py create mode 100644 tests/tensorizer/tensorize_vllm_model_for_testing.py create mode 100644 tests/tensorizer/test_tensorizer.py create mode 100644 vllm/model_executor/tensorizer_loader.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 8d7d6304cf12..aa4582bbda0c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -91,6 +91,9 @@ steps: command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 +- label: Tensorizer Test + command: apt-get install curl libsodium23 && pytest -v -s tensorizer + - label: Metrics Test command: pytest -v -s metrics diff --git a/docs/source/conf.py b/docs/source/conf.py index 7a8c365ffb3b..19cc8557a754 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -83,6 +83,7 @@ "vllm._C", "numpy", "tqdm", + "tensorizer", ] for mock_target in autodoc_mock_imports: diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index d8a7ac72e017..886a806934c0 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -36,7 +36,7 @@ Below, you can find an explanation of every engine argument for vLLM: Directory to download and load the weights, default to the default cache dir of huggingface. -.. option:: --load-format {auto,pt,safetensors,npcache,dummy} +.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer} The format of the model weights to load. @@ -45,6 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM: * "safetensors" will load the weights in the safetensors format. * "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. * "dummy" will initialize the weights with random values, mainly for profiling. + * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. `_. See `tensorized_vllm_model.py` in the examples folder to serialize a vLLM model, and for more information. Tensorizer support for vLLM can be installed with `pip install vllm[tensorizer]`. .. option:: --dtype {auto,half,float16,bfloat16,float,float32} diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py new file mode 100644 index 000000000000..3c20a38c7f72 --- /dev/null +++ b/examples/tensorize_vllm_model.py @@ -0,0 +1,254 @@ +import argparse +import dataclasses +import os +import time +import uuid +from functools import partial +from typing import Type + +import torch +import torch.nn as nn +from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, + TensorSerializer, stream_io) +from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor +from transformers import AutoConfig, PretrainedConfig + +from vllm.distributed import initialize_model_parallel +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.tensorizer_loader import TensorizerArgs + +# yapf conflicts with isort for this docstring +# yapf: disable +""" +tensorize_vllm_model.py is a script that can be used to serialize and +deserialize vLLM models. These models can be loaded using tensorizer directly +to the GPU extremely quickly. Tensor encryption and decryption is also +supported, although libsodium must be installed to use it. Install +vllm with tensorizer support using `pip install vllm[tensorizer]`. + +To serialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + serialize \ + --serialized-directory s3://my-bucket/ \ + --suffix vllm + +Which downloads the model from HuggingFace, loads it into vLLM, serializes it, +and saves it to your S3 bucket. A local directory can also be used. + +You can also encrypt the model weights with a randomly-generated key by +providing a `--keyfile` argument. + +To deserialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + deserialize \ + --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors + +Which downloads the model tensors from your S3 bucket and deserializes them. +To provide S3 credentials, you can provide `--s3-access-key-id` and +`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, +the OpenAI entrypoint, as arguments for LLM(), or as environment variables +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. + + +You can also provide a `--keyfile` argument to decrypt the model weights if +they were serialized with encryption. + +For more information on the available arguments, run +`python tensorize_vllm_model.py --help`. +""" + + +def parse_args(): + parser = argparse.ArgumentParser( + description="An example script that can be used to serialize and " + "deserialize vLLM models. These models " + "can be loaded using tensorizer directly to the GPU " + "extremely quickly. Tensor encryption and decryption is " + "also supported, although libsodium must be installed to " + "use it.") + parser = EngineArgs.add_cli_args(parser) + subparsers = parser.add_subparsers(dest='command') + + serialize_parser = subparsers.add_parser( + 'serialize', help="Serialize a model to `--serialized-directory`") + + serialize_parser.add_argument( + "--suffix", + type=str, + required=False, + help=( + "The suffix to append to the serialized model directory, which is " + "used to construct the location of the serialized model tensors, " + "e.g. if `--serialized-directory` is `s3://my-bucket/` and " + "`--suffix` is `v1`, the serialized model tensors will be " + "saved to " + "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " + "If none is provided, a random UUID will be used.")) + serialize_parser.add_argument( + "--serialized-directory", + type=str, + required=True, + help="The directory to serialize the model to. " + "This can be a local directory or S3 URI. The path to where the " + "tensors are saved is a combination of the supplied `dir` and model " + "reference ID. For instance, if `dir` is the serialized directory, " + "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will " + "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, " + "where `suffix` is given by `--suffix` or a random UUID if not " + "provided.") + + serialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Encrypt the model weights with a randomly-generated binary key," + " and save the key at this path")) + + deserialize_parser = subparsers.add_parser( + 'deserialize', + help=("Deserialize a model from `--path-to-tensors`" + " to verify it can be loaded and used.")) + + deserialize_parser.add_argument( + "--path-to-tensors", + type=str, + required=True, + help="The local path or S3 URI to the model tensors to deserialize. ") + + deserialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Path to a binary key to use to decrypt the model weights," + " if the model was serialized with encryption")) + + return parser.parse_args() + + +def make_model_contiguous(model): + # Ensure tensors are saved in memory contiguously + for param in model.parameters(): + param.data = param.data.contiguous() + + +def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: + architectures = getattr(config, "architectures", []) + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return model_cls + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def serialize(): + + eng_args_dict = {f.name: getattr(args, f.name) for f in + dataclasses.fields(EngineArgs)} + engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) + engine = LLMEngine.from_engine_args(engine_args) + + model = (engine.model_executor.driver_worker. + model_runner.model) + + encryption_params = EncryptionParams.random() if keyfile else None + if keyfile: + with _write_stream(keyfile) as stream: + stream.write(encryption_params.key) + + with _write_stream(model_path) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + + print("Serialization complete. Model tensors saved to", model_path) + if keyfile: + print("Key saved to", keyfile) + + +def deserialize(): + config = AutoConfig.from_pretrained(model_ref) + + with no_init_or_tensor(): + model_class = _get_vllm_model_architecture(config) + model = model_class(config) + + before_mem = get_mem_usage() + start = time.time() + + if keyfile: + with _read_stream(keyfile) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + tensorizer_args.deserializer_params['encryption'] = \ + decryption_params + + with (_read_stream(model_path)) as stream, TensorDeserializer( + stream, **tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(model) + end = time.time() + + # Brag about how fast we are. + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + print( + f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" + ) + print(f"Memory usage before: {before_mem}") + print(f"Memory usage after: {after_mem}") + + return model + + +args = parse_args() + +s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") + or None) +s3_secret_access_key = (args.s3_secret_access_key + or os.environ.get("S3_SECRET_ACCESS_KEY") or None) + +s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) + +_read_stream, _write_stream = (partial( + stream_io.open_stream, + mode=mode, + s3_access_key_id=s3_access_key_id, + s3_secret_access_key=s3_secret_access_key, + s3_endpoint=s3_endpoint, +) for mode in ("rb", "wb+")) + +model_ref = args.model + +model_name = model_ref.split("/")[1] + +os.environ["MASTER_ADDR"] = "127.0.0.1" +os.environ["MASTER_PORT"] = "8080" + +torch.distributed.init_process_group(world_size=1, rank=0) +initialize_model_parallel() + +keyfile = args.keyfile if args.keyfile else None + +if args.command == "serialize": + input_dir = args.serialized_directory.rstrip('/') + suffix = args.suffix if args.suffix else uuid.uuid4().hex + base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" + model_path = f"{base_path}/model.tensors" + serialize() +elif args.command == "deserialize": + tensorizer_args = TensorizerArgs.from_cli_args(args) + model_path = args.path_to_tensors + deserialize() +else: + raise ValueError("Either serialize or deserialize must be specified.") diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 36d20bc9473e..5779b38b24e6 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -3,4 +3,4 @@ # Dependencies for x86_64 CPUs torch == 2.2.1+cpu -triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error. +triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 96dfda6faf00..1317e51b2dd1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,6 +14,7 @@ types-setuptools # testing pytest +tensorizer==2.9.0a0 pytest-forked pytest-asyncio pytest-rerunfailures diff --git a/setup.py b/setup.py index 9f0814e9f3bf..813321efe796 100644 --- a/setup.py +++ b/setup.py @@ -405,6 +405,9 @@ def _read_requirements(filename: str) -> List[str]: python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, + extras_require={ + "optional": ["tensorizer==2.9.0a1"], + }, cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, ) diff --git a/tests/tensorizer/__init__.py b/tests/tensorizer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/tensorizer/tensorize_vllm_model_for_testing.py b/tests/tensorizer/tensorize_vllm_model_for_testing.py new file mode 100644 index 000000000000..d0be08329fd6 --- /dev/null +++ b/tests/tensorizer/tensorize_vllm_model_for_testing.py @@ -0,0 +1,245 @@ +import argparse +import dataclasses +import os +import time +import uuid +from functools import partial +from typing import Type + +import torch +import torch.nn as nn +from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, + TensorSerializer, stream_io) +from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor +from transformers import AutoConfig, PretrainedConfig + +from vllm.distributed import initialize_model_parallel +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.tensorizer_loader import TensorizerArgs + +# yapf conflicts with isort for this docstring +# yapf: disable +""" +tensorize_vllm_model.py is a script that can be used to serialize and +deserialize vLLM models. These models can be loaded using tensorizer directly +to the GPU extremely quickly. Tensor encryption and decryption is also +supported, although libsodium must be installed to use it. Install +vllm with tensorizer support using `pip install vllm[tensorizer]`. + +To serialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + serialize \ + --serialized-directory s3://my-bucket/ \ + --suffix vllm + +Which downloads the model from HuggingFace, loads it into vLLM, serializes it, +and saves it to your S3 bucket. A local directory can also be used. + +You can also encrypt the model weights with a randomly-generated key by +providing a `--keyfile` argument. + +To deserialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + deserialize \ + --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors + +Which downloads the model tensors from your S3 bucket and deserializes them. +To provide S3 credentials, you can provide `--s3-access-key-id` and +`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, +the OpenAI entrypoint, as arguments for LLM(), or as environment variables +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. + + +You can also provide a `--keyfile` argument to decrypt the model weights if +they were serialized with encryption. + +For more information on the available arguments, run +`python tensorize_vllm_model.py --help`. +""" + + +def parse_args(): + parser = argparse.ArgumentParser( + description="An example script that can be used to serialize and " + "deserialize vLLM models. These models " + "can be loaded using tensorizer directly to the GPU " + "extremely quickly. Tensor encryption and decryption is " + "also supported, although libsodium must be installed to " + "use it.") + parser = EngineArgs.add_cli_args(parser) + subparsers = parser.add_subparsers(dest='command') + + serialize_parser = subparsers.add_parser( + 'serialize', help="Serialize a model to `--serialized-directory`") + + serialize_parser.add_argument( + "--suffix", + type=str, + required=False, + help=( + "The suffix to append to the serialized model directory, which is " + "used to construct the location of the serialized model tensors, " + "e.g. if `--serialized-directory` is `s3://my-bucket/` and " + "`--suffix` is `v1`, the serialized model tensors will be " + "saved to " + "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " + "If none is provided, a random UUID will be used.")) + serialize_parser.add_argument( + "--serialized-directory", + type=str, + required=True) + + serialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Encrypt the model weights with a randomly-generated binary key," + " and save the key at this path")) + + deserialize_parser = subparsers.add_parser( + 'deserialize', + help=("Deserialize a model from `--path-to-tensors`" + " to verify it can be loaded and used.")) + + deserialize_parser.add_argument( + "--path-to-tensors", + type=str, + required=True, + help="The local path or S3 URI to the model tensors to deserialize. ") + + deserialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Path to a binary key to use to decrypt the model weights," + " if the model was serialized with encryption")) + + return parser.parse_args() + + +def make_model_contiguous(model): + # Ensure tensors are saved in memory contiguously + for param in model.parameters(): + param.data = param.data.contiguous() + + +def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: + architectures = getattr(config, "architectures", []) + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return model_cls + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def serialize(): + eng_args_dict = {f.name: getattr(args, f.name) for f in + dataclasses.fields(EngineArgs)} + engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) + engine = LLMEngine.from_engine_args(engine_args) + + model = (engine.model_executor.driver_worker. + model_runner.model) + + encryption_params = EncryptionParams.random() if keyfile else None + if keyfile: + with _write_stream(keyfile) as stream: + stream.write(encryption_params.key) + + with _write_stream(model_path) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + + print("Serialization complete. Model tensors saved to", model_path) + if keyfile: + print("Key saved to", keyfile) + + +def deserialize(): + config = AutoConfig.from_pretrained(model_ref) + + with no_init_or_tensor(): + model_class = _get_vllm_model_architecture(config) + model = model_class(config) + + before_mem = get_mem_usage() + start = time.time() + + if keyfile: + with _read_stream(keyfile) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + tensorizer_args.deserializer_params['encryption'] = \ + decryption_params + + with (_read_stream(model_path)) as stream, TensorDeserializer( + stream, **tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(model) + end = time.time() + + # Brag about how fast we are. + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + print( + f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" + ) + print(f"Memory usage before: {before_mem}") + print(f"Memory usage after: {after_mem}") + + return model + + +args = parse_args() + +s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") + or None) +s3_secret_access_key = (args.s3_secret_access_key + or os.environ.get("S3_SECRET_ACCESS_KEY") or None) + +s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) + +_read_stream, _write_stream = (partial( + stream_io.open_stream, + mode=mode, + s3_access_key_id=s3_access_key_id, + s3_secret_access_key=s3_secret_access_key, + s3_endpoint=s3_endpoint, +) for mode in ("rb", "wb+")) + +model_ref = args.model + +model_name = model_ref.split("/")[1] + +os.environ["MASTER_ADDR"] = "127.0.0.1" +os.environ["MASTER_PORT"] = "8080" + +torch.distributed.init_process_group(world_size=1, rank=0) +initialize_model_parallel() + +keyfile = args.keyfile if args.keyfile else None + +if args.command == "serialize": + input_dir = args.serialized_directory.rstrip('/') + suffix = args.suffix if args.suffix else uuid.uuid4().hex + base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" + model_path = f"{base_path}/model.tensors" + serialize() +elif args.command == "deserialize": + tensorizer_args = TensorizerArgs.from_cli_args(args) + model_path = args.path_to_tensors + deserialize() +else: + raise ValueError("Either serialize or deserialize must be specified.") diff --git a/tests/tensorizer/test_tensorizer.py b/tests/tensorizer/test_tensorizer.py new file mode 100644 index 000000000000..2ab893e95da9 --- /dev/null +++ b/tests/tensorizer/test_tensorizer.py @@ -0,0 +1,302 @@ +import gc +import subprocess +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from tests.entrypoints.test_openai_server import ServerRunner +from vllm import SamplingParams +from vllm.config import TensorizerConfig +from vllm.model_executor.tensorizer_loader import ( + EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer, + load_with_tensorizer, open_stream) + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) + +model_ref = "facebook/opt-125m" + + +def is_curl_installed(): + try: + subprocess.check_call(['curl', '--version']) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + return False + + +@pytest.fixture(autouse=True) +def tensorizer_config(): + config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True) + return config + + +@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent') +def test_load_with_tensorizer(mock_agent, tensorizer_config): + mock_linear_method = MagicMock() + mock_agent_instance = mock_agent.return_value + mock_agent_instance.deserialize.return_value = MagicMock() + + result = load_with_tensorizer(tensorizer_config, + linear_method=mock_linear_method) + + mock_agent.assert_called_once_with(tensorizer_config, + linear_method=mock_linear_method) + mock_agent_instance.deserialize.assert_called_once() + assert result == mock_agent_instance.deserialize.return_value + + +def test_is_vllm_model_with_vllm_in_uri(tensorizer_config): + tensorizer_config.vllm_tensorized = True + + result = is_vllm_serialized_tensorizer(tensorizer_config) + + assert result is True + + +def test_is_vllm_model_without_vllm_in_uri(tensorizer_config): + tensorizer_config.vllm_tensorized = False + + result = is_vllm_serialized_tensorizer(tensorizer_config) + + assert result is False + + +def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path): + vllm_model = vllm_runner(model_ref) + model_path = tmp_path / (model_ref + ".tensors") + outputs = vllm_model.generate(prompts, sampling_params) + model = (vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream) + serializer.write_module(model) + del vllm_model, model + gc.collect() + torch.cuda.empty_cache() + loaded_vllm_model = vllm_runner(model_ref, + load_format="tensorizer", + tensorizer_uri=model_path, + num_readers=1, + vllm_tensorized=True) + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + + # Assumes SamplingParams being seeded ensures the outputs are deterministic + assert outputs == deserialized_outputs + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_can_deserialize_s3(vllm_runner): + model_ref = "EleutherAI/pythia-1.4b" + tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" + + loaded_hf_model = vllm_runner( + model_ref, + tensorizer_uri=tensorized_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=False, + s3_endpoint="object.ord1.coreweave.com", + ) + + deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) + + assert deserialized_outputs + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_deserialized_encrypted_vllm_model_has_same_outputs( + vllm_runner, tmp_path): + vllm_model = vllm_runner(model_ref) + model_path = tmp_path / (model_ref + ".tensors") + key_path = tmp_path / (model_ref + ".key") + outputs = vllm_model.generate(prompts, sampling_params) + model = (vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + + encryption_params = EncryptionParams.random() + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + with open_stream(key_path, "wb+") as stream: + stream.write(encryption_params.key) + del vllm_model, model + gc.collect() + torch.cuda.empty_cache() + loaded_vllm_model = vllm_runner(model_ref, + tensorizer_uri=model_path, + load_format="tensorizer", + encryption_keyfile=key_path, + num_readers=1, + vllm_tensorized=True) + + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + + # Assumes SamplingParams being seeded ensures the outputs are deterministic + assert outputs == deserialized_outputs + + +def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, + tmp_path): + hf_model = hf_runner(model_ref) + model_path = tmp_path / (model_ref + ".tensors") + max_tokens = 50 + outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens) + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream) + serializer.write_module(hf_model.model) + del hf_model + gc.collect() + torch.cuda.empty_cache() + loaded_hf_model = vllm_runner(model_ref, + tensorizer_uri=model_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=False) + + deserialized_outputs = loaded_hf_model.generate_greedy( + prompts, max_tokens=max_tokens) + + assert outputs == deserialized_outputs + + +def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): + from huggingface_hub import snapshot_download + + from examples.multilora_inference import (create_test_prompts, + process_requests) + + model_ref = "meta-llama/Llama-2-7b-hf" + lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + test_prompts = create_test_prompts(lora_path) + + # Serialize model before deserializing and binding LoRA adapters + vllm_model = vllm_runner(model_ref, ) + model_path = tmp_path / (model_ref + ".tensors") + model = (vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream) + serializer.write_module(model) + del vllm_model, model + gc.collect() + torch.cuda.empty_cache() + loaded_vllm_model = vllm_runner( + model_ref, + tensorizer_uri=model_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=True, + enable_lora=True, + max_loras=1, + max_lora_rank=8, + max_cpu_loras=2, + max_num_seqs=50, + max_model_len=1000, + ) + process_requests(loaded_vllm_model.model.llm_engine, test_prompts) + + assert loaded_vllm_model + + +def test_load_without_tensorizer_load_format(vllm_runner): + with pytest.raises(ValueError): + vllm_runner(model_ref, tensorizer_uri="test") + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_tensorize_vllm_model(tmp_path): + # Test serialize command + serialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--dtype", "float16", "serialize", "--serialized-directory", + tmp_path, "--suffix", "tests" + ] + result = subprocess.run(serialize_args, capture_output=True, text=True) + print(result.stdout) # Print the output of the serialize command + + assert result.returncode == 0, (f"Serialize command failed with output:" + f"\n{result.stdout}\n{result.stderr}") + + path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" + + # Test deserialize command + deserialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors", + path_to_tensors + ] + result = subprocess.run(deserialize_args, capture_output=True, text=True) + assert result.returncode == 0, (f"Deserialize command failed with output:" + f"\n{result.stdout}\n{result.stderr}") + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_openai_apiserver_with_tensorizer(tmp_path): + ## Serialize model + serialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--dtype", "float16", "serialize", "--serialized-directory", + tmp_path, "--suffix", "tests" + ] + result = subprocess.run(serialize_args, capture_output=True, text=True) + print(result.stdout) # Print the output of the serialize command + + assert result.returncode == 0, (f"Serialize command failed with output:" + f"\n{result.stdout}\n{result.stderr}") + + path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" + + ## Start OpenAI API server + openai_args = [ + "--model", model_ref, "--dtype", "float16", "--load-format", + "tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized", + "--port", "8000" + ] + + server = ServerRunner.remote(openai_args) + + print("Server ready.") + assert server.ready.remote() + + +def test_raise_value_error_on_invalid_load_format(vllm_runner): + with pytest.raises(ValueError): + vllm_runner(model_ref, + load_format="safetensors", + tensorizer_uri="test") + + +def test_tensorizer_with_tp(vllm_runner): + with pytest.raises(ValueError): + model_ref = "EleutherAI/pythia-1.4b" + tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" + + vllm_runner( + model_ref, + tensorizer_uri=tensorized_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=False, + s3_endpoint="object.ord1.coreweave.com", + tensor_parallel_size=2, + ) + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_tensorizer_warn_quant(tmp_path): + model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" + serialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--quantization", "gptq", "--tensorizer-uri", "test", + "serialize", "--serialized-directory", tmp_path, "--suffix", "tests" + ] + result = subprocess.run(serialize_args, capture_output=True, text=True) + assert 'PerformanceWarning' in result.stderr diff --git a/vllm/config.py b/vllm/config.py index bbda4ecf3cc5..dce2944b2ee8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,6 +1,8 @@ import enum +import io import json import os +import typing from dataclasses import dataclass, fields from typing import TYPE_CHECKING, ClassVar, List, Optional, Union @@ -16,6 +18,8 @@ if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup + from vllm.model_executor.tensorizer_loader import TensorizerArgs + logger = init_logger(__name__) _GB = 1 << 30 @@ -139,13 +143,14 @@ def __init__( def _verify_load_format(self) -> None: load_format = self.load_format.lower() supported_load_format = [ - "auto", "pt", "safetensors", "npcache", "dummy" + "auto", "pt", "safetensors", "npcache", "dummy", "tensorizer" ] rocm_not_supported_load_format: List[str] = [] if load_format not in supported_load_format: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " - "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") + "'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or " + "'dummy'.") if is_hip() and load_format in rocm_not_supported_load_format: rocm_supported_load_format = [ f for f in supported_load_format @@ -882,6 +887,65 @@ def get_image_input_enum_type( f"{[x.name for x in cls.ImageInputType]}.") from e +@dataclass +class TensorizerConfig: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, + str, bytes, os.PathLike, int] + vllm_tensorized: bool + verify_hash: Optional[bool] = False + num_readers: Optional[int] = 1 + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + model_class: Optional[torch.nn.Module] = None + hf_config: Optional[PretrainedConfig] = None + dtype: Union[str, torch.dtype] = None + + def _construct_tensorizer_args(self) -> "TensorizerArgs": + from vllm.model_executor.tensorizer_loader import TensorizerArgs + tensorizer_args = { + "tensorizer_uri": self.tensorizer_uri, + "vllm_tensorized": self.vllm_tensorized, + "verify_hash": self.verify_hash, + "num_readers": self.num_readers, + "encryption_keyfile": self.encryption_keyfile, + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + return TensorizerArgs(**tensorizer_args) + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + if (parallel_config.tensor_parallel_size > 1 + and self.tensorizer_uri is not None): + raise ValueError( + "Loading to multiple GPUs is not currently supported with " + "vLLM-serialized models. Please set tensor_parallel_size=1." + " or use a non-vLLM-serialized model, such as a " + "serialized Hugging Face `PretrainedModel`.") + + def verify_with_model_config(self, model_config) -> None: + if (model_config.quantization is not None + and self.tensorizer_uri is not None): + from vllm.model_executor.tensorizer_loader import ( + tensorizer_warning) + tensorizer_warning( + "Loading a model using Tensorizer with quantization on vLLM" + " is unstable and may lead to errors.") + + if (model_config.load_format != "tensorizer" + and self.tensorizer_uri is not None): + raise ValueError( + "A tensorizer uri was passed for tensorizer loading, but the " + f"load format was set to {model_config.load_format}. " + "Please set the load format to 'tensorizer' to use " + f"tensorizer args.") + + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, @@ -1029,6 +1093,7 @@ class EngineConfig: lora_config: Optional[LoRAConfig] vision_language_config: Optional[VisionLanguageConfig] speculative_config: Optional[SpeculativeConfig] + tensorizer_config: Optional[TensorizerConfig] def __post_init__(self): """Verify configs are valid & consistent with each other. @@ -1036,6 +1101,11 @@ def __post_init__(self): self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.tensorizer_config: + self.tensorizer_config.verify_with_parallel_config( + self.parallel_config) + self.tensorizer_config.verify_with_model_config(self.model_config) + if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index daefddc01b43..831a03be65f6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,12 +1,15 @@ import argparse import dataclasses +import io +import os from dataclasses import dataclass -from typing import Optional +from typing import BinaryIO, Optional, Union from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, TokenizerPoolConfig, - VisionLanguageConfig) + SpeculativeConfig, TensorizerConfig, + TokenizerPoolConfig, VisionLanguageConfig) +from vllm.model_executor.tensorizer_loader import TensorizerArgs from vllm.utils import str_to_int_tuple @@ -58,12 +61,22 @@ class EngineArgs: num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 + # Tensorizer configuration parameters + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, + bytes, os.PathLike, int] = None + vllm_tensorized: bool = False + verify_hash: Optional[bool] = False + num_readers: Optional[int] = 1 + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + # Related to Vision-language models such as llava image_input_type: Optional[str] = None image_token_id: Optional[int] = None image_input_shape: Optional[str] = None image_feature_size: Optional[int] = None - scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False @@ -135,7 +148,9 @@ def add_cli_args( '--load-format', type=str, default=EngineArgs.load_format, - choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], + choices=[ + 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer' + ], help='The format of the model weights to load. ' '"auto" will try to load the weights in the safetensors format ' 'and fall back to the pytorch bin format if safetensors format ' @@ -145,7 +160,10 @@ def add_cli_args( '"npcache" will load the weights in pytorch format and store ' 'a numpy cache to speed up the loading. ' '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.') + 'which is mainly for profiling.' + '"tensorizer" will load the weights using tensorizer from CoreWeave' + 'which assumes tensorizer_uri is set to the location of the ' + 'serialized weights.') parser.add_argument( '--dtype', type=str, @@ -403,6 +421,7 @@ def add_cli_args( default=None, help='The number of speculative tokens to sample from ' 'the draft model in speculative decoding') + parser = TensorizerArgs.add_cli_args(parser) return parser @classmethod @@ -465,6 +484,17 @@ def create_engine_config(self, ) -> EngineConfig: max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None + tensorizer_config = TensorizerConfig( + tensorizer_uri=self.tensorizer_uri, + vllm_tensorized=self.vllm_tensorized, + verify_hash=self.verify_hash, + num_readers=self.num_readers, + encryption_keyfile=self.encryption_keyfile, + s3_access_key_id=self.s3_access_key_id, + s3_secret_access_key=self.s3_secret_access_key, + s3_endpoint=self.s3_endpoint, + ) + if self.image_input_type: if (not self.image_token_id or not self.image_input_shape or not self.image_feature_size): @@ -488,7 +518,8 @@ def create_engine_config(self, ) -> EngineConfig: device_config=device_config, lora_config=lora_config, vision_language_config=vision_language_config, - speculative_config=speculative_config) + speculative_config=speculative_config, + tensorizer_config=tensorizer_config) @dataclass diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a91629a63059..8c37c5a9d6ee 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,7 +6,7 @@ import vllm from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats @@ -74,6 +74,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -110,6 +111,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.speculative_config = speculative_config + self.tensorizer_config = tensorizer_config self.log_stats = log_stats self._init_tokenizer() @@ -125,6 +127,7 @@ def __init__( lora_config=lora_config, vision_language_config=vision_language_config, speculative_config=speculative_config, + tensorizer_config=tensorizer_config, ) self._initialize_kv_caches() @@ -264,6 +267,9 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.tensorizer_config: + self.tensorizer_config.verify_with_parallel_config( + self.parallel_config) if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index f20221a0b941..30577ecf62fa 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -2,7 +2,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -15,17 +15,14 @@ class GPUExecutor(ExecutorBase): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - ) -> None: + def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig]) -> None: self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config @@ -33,6 +30,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.tensorizer_config = tensorizer_config assert (not speculative_config ), "Speculative decoding not yet supported for GPU backend" @@ -61,6 +59,7 @@ def _init_worker(self): distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, + tensorizer_config=self.tensorizer_config, is_driver_worker=True, ) self.driver_worker.init_device() diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index b937693c9225..28dc3e0db312 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -7,7 +7,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -42,6 +42,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -50,6 +51,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.tensorizer_config = tensorizer_config assert (not speculative_config ), "Speculative decoding not yet supported for RayGPU backend." @@ -171,6 +173,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method=distributed_init_method, lora_config=lora_config, vision_language_config=vision_language_config, + tensorizer_config=self.tensorizer_config, )) # Initialize the driver worker with the Worker class. @@ -187,6 +190,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, + tensorizer_config=self.tensorizer_config, is_driver_worker=True, ) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 2745dbd89ab0..c70ca48bca70 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -3,11 +3,14 @@ from typing import Tuple, Type import torch -import torch.nn as nn +from torch import nn from vllm.config import DeviceConfig, ModelConfig from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.llava import LlavaForConditionalGeneration +from vllm.model_executor.tensorizer_loader import ( + ParameterizedLoadFormat, is_vllm_serialized_tensorizer, + load_with_tensorizer) from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) @@ -51,6 +54,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, **kwargs) -> nn.Module: lora_config = kwargs.get("lora_config", None) vision_language_config = kwargs.get("vision_language_config", None) + tensorizer_config = kwargs.get("tensorizer_config", None) model_class = _get_model_architecture(model_config)[0] # Get the (maybe quantized) linear method. @@ -71,33 +75,54 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, f"{model_config.dtype} is not supported for quantization " f"method {model_config.quantization}. Supported dtypes: " f"{supported_dtypes}") + linear_method = quant_config.get_linear_method() with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. + extra_kwargs = {} + if hasattr(model_class, "supported_lora_modules"): + extra_kwargs["lora_config"] = lora_config + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + elif model_class in _VISION_MODEL_CLASSES: + extra_kwargs["vision_language_config"] = vision_language_config + with torch.device(device_config.device): - if hasattr(model_class, "supported_lora_modules"): - model = model_class(model_config.hf_config, linear_method, - lora_config) - elif lora_config: - raise ValueError( - f"Model {model_class.__name__} does not support LoRA, " - "but LoRA is enabled. Support for this model may " - "be added in the future. If this is important to you, " - "please open an issue on github.") - else: - if model_class not in _VISION_MODEL_CLASSES: - model = model_class(model_config.hf_config, linear_method) - else: - model = model_class(model_config.hf_config, - vision_language_config, linear_method) + if (model_config.load_format == "tensorizer" + and is_vllm_serialized_tensorizer(tensorizer_config)): + extra_kwargs["linear_method"] = linear_method + tensorizer_config.model_class = model_class + tensorizer_config.hf_config = model_config.hf_config + tensorizer_config.dtype = model_config.dtype + model = load_with_tensorizer(tensorizer_config, **extra_kwargs) + return model.eval() + model = model_class(config=model_config.hf_config, + linear_method=linear_method, + **extra_kwargs) if model_config.load_format == "dummy": # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) else: # Load the weights from the cached or downloaded files. - model.load_weights(model_config.model, model_config.download_dir, - model_config.load_format, model_config.revision) + if model_config.load_format == "tensorizer": + # Provide a dynamic load format for `model.load_weights` + # to retain tensorizer args from CLI. + model_config.load_format = ParameterizedLoadFormat( + model_config.load_format) + model_config.load_format.params = ( + tensorizer_config._construct_tensorizer_args()) + + model.load_weights( + model_config.model, + model_config.download_dir, + model_config.load_format, + model_config.revision, + ) return model.eval() diff --git a/vllm/model_executor/tensorizer_loader.py b/vllm/model_executor/tensorizer_loader.py new file mode 100644 index 000000000000..ed3ad9e2ffa1 --- /dev/null +++ b/vllm/model_executor/tensorizer_loader.py @@ -0,0 +1,319 @@ +import argparse +import dataclasses +import io +import os +import time +import typing +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from vllm.config import TensorizerConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) + +tensorizer_load_fail = False + +try: + from tensorizer import (DecryptionParams, EncryptionParams, + TensorDeserializer, TensorSerializer) + from tensorizer.stream_io import open_stream + from tensorizer.utils import (convert_bytes, get_mem_usage, + no_init_or_tensor) +except ImportError: + tensorizer_load_fail = True + +__all__ = [ + 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', + 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', + 'no_init_or_tensor' +] + +logger = init_logger(__name__) + + +def load_with_tensorizer(tensorizer_config: TensorizerConfig, + **extra_kwargs) -> nn.Module: + tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs) + return tensorizer.deserialize() + + +def tensorizer_warning(message: str): + return warnings.warn(message, category=PerformanceWarning, stacklevel=2) + + +def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool: + if tensorizer_config is None: + return False + return tensorizer_config.vllm_tensorized + + +class ParameterizedLoadFormat(str): + __slots__ = "params" + + +class PerformanceWarning(UserWarning): + + def __str__(self): + return (f"{super().__str__()}" + " (set the VLLM_SILENCE_PERFORMANCE_WARNINGS" + " environment variable to hide this)") + + +if (os.getenv("VLLM_SILENCE_PERFORMANCE_WARNINGS", "").lower() + not in ("", "0", "n", "no", "off", "disable")): + warnings.simplefilter("ignore", category=PerformanceWarning) + + +@dataclass +class TensorizerArgs: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, + str, bytes, os.PathLike, int] + vllm_tensorized: bool + verify_hash: Optional[bool] = False + num_readers: Optional[int] = 1 + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + """ + Args for the TensorizerAgent class. These are used to configure the behavior + of the TensorDeserializer when loading tensors from a serialized model. + + Args: + tensorizer_uri: Path to serialized model tensors. Can be a local file + path or a S3 URI. + vllm_tensorized: If True, indicates that the serialized model is a + vLLM model. This is used to determine the behavior of the + TensorDeserializer when loading tensors from a serialized model. + It is far faster to deserialize a vLLM model as it utilizes + tensorizer's optimized GPU loading. + verify_hash: If True, the hashes of each tensor will be verified against + the hashes stored in the metadata. A `HashMismatchError` will be + raised if any of the hashes do not match. + num_readers: Controls how many threads are allowed to read concurrently + from the source file. Default is 1. This greatly increases + performance. + encryption_keyfile: File path to a binary file containing a + binary key to use for decryption. `None` (the default) means + no decryption. See the example script in + examples/tensorize_vllm_model.py. + s3_access_key_id: The access key for the S3 bucket. Can also be set via + the S3_ACCESS_KEY_ID environment variable. + s3_secret_access_key: The secret access key for the S3 bucket. Can also + be set via the S3_SECRET_ACCESS_KEY environment variable. + s3_endpoint: The endpoint for the S3 bucket. Can also be set via the + S3_ENDPOINT_URL environment variable. + """ + + def __post_init__(self): + self.file_obj = self.tensorizer_uri + self.s3_access_key_id = (self.s3_access_key_id + or os.environ.get("S3_ACCESS_KEY_ID")) or None + self.s3_secret_access_key = ( + self.s3_secret_access_key + or os.environ.get("S3_SECRET_ACCESS_KEY")) or None + self.s3_endpoint = (self.s3_endpoint + or os.environ.get("S3_ENDPOINT_URL")) or None + self.stream_params = { + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + + # Omitting self.dtype and self.device as this behaves weirdly + self.deserializer_params = { + "verify_hash": self.verify_hash, + "encryption": self.encryption_keyfile, + "num_readers": self.num_readers + } + if self.encryption_keyfile: + with open_stream( + self.encryption_keyfile, + **self.stream_params, + ) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + self.deserializer_params['encryption'] = decryption_params + + def add_cli_args( + parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Tensorizer CLI arguments""" + + # Create the argument group + group = parser.add_argument_group( + 'tensorizer options', + description=('Options for configuring the behavior of the' + ' tensorizer deserializer when ' + '--load-format=tensorizer')) + + group.add_argument( + "--tensorizer-uri", + help="Path to serialized model tensors. Can be a local file path," + " or an HTTP(S) or S3 URI.", + ) + group.add_argument( + "--verify-hash", + action="store_true", + help="If enabled, the hashes of each tensor will be verified" + " against the hashes stored in the file metadata. An exception" + " will be raised if any of the hashes do not match.", + ) + group.add_argument( + "--encryption-keyfile", + default=None, + help="The file path to a binary file containing a binary key to " + "use for decryption. Can be a file path or S3 network URI.") + group.add_argument( + "--num-readers", + default=1, + type=int, + help="Controls how many threads are allowed to read concurrently " + "from the source file.") + group.add_argument( + "--s3-access-key-id", + default=None, + help="The access key for the S3 bucket. Can also be set via the " + "S3_ACCESS_KEY_ID environment variable.", + ) + group.add_argument( + "--s3-secret-access-key", + default=None, + help="The secret access key for the S3 bucket. Can also be set via " + "the S3_SECRET_ACCESS_KEY environment variable.", + ) + group.add_argument( + "--s3-endpoint", + default=None, + help="The endpoint for the S3 bucket. Can also be set via the " + "S3_ENDPOINT_URL environment variable.", + ) + group.add_argument( + "--vllm-tensorized", + action="store_true", + help="If enabled, indicates that the serialized model is a vLLM " + "model. This is used to determine the behavior of the " + "TensorDeserializer when loading tensors from a " + "serialized model.") + + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs": + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + tensorizer_args = cls(**{ + attr: getattr(args, attr) + for attr in attrs if hasattr(args, attr) + }) + return tensorizer_args + + +class TensorizerAgent: + """ + A class for performing tensorizer deserializations specifically for + vLLM models using plaid_mode. Uses TensorizerArgs to configure the + behavior of the TensorDeserializer when loading tensors from a serialized + model. For deserializations of HuggingFace models, TensorDeserializer is + instead used as an iterator directly in the func hf_model_weights_iterator + in vllm/model_executor/weight_utils.py + """ + + def __init__(self, tensorizer_config: TensorizerConfig, + linear_method: LinearMethodBase, **extra_kwargs): + self.tensorizer_config = tensorizer_config + self.tensorizer_args = ( + self.tensorizer_config._construct_tensorizer_args()) + self.extra_kwargs = extra_kwargs + if extra_kwargs.get("linear_method", None) is not None: + self.linear_method = extra_kwargs["linear_method"] + else: + self.linear_method = linear_method + self.model = self._init_model() + + if tensorizer_load_fail: + raise ImportError( + "Tensorizer is not installed. Please install tensorizer " + "to use this feature with `pip install vllm[tensorizer]`.") + + def _init_model(self): + model_args = self.tensorizer_config.hf_config + model_args.torch_dtype = self.tensorizer_config.dtype + with no_init_or_tensor(): + return self.tensorizer_config.model_class( + config=model_args, + linear_method=self.linear_method, + **self.extra_kwargs) + + def _resize_lora_embeddings(self): + """Modify LoRA embedding layers to use bigger tensors + to allow for adapter added tokens.""" + for child in self.model.modules(): + if (isinstance(child, VocabParallelEmbedding) + and child.weight.shape[0] < + child.num_embeddings_per_partition): + new_weight = torch.empty(child.num_embeddings_per_partition, + child.embedding_dim, + dtype=child.weight.dtype, + device=child.weight.device) + new_weight[:child.weight.shape[0]].copy_(child.weight.data) + new_weight[child.weight.shape[0]:].fill_(0) + child.weight.data = new_weight + + def _check_tensors_on_meta_device(self): + for tensor in self.model.state_dict().values(): + if tensor.device.type == 'meta': + raise ValueError( + "The serialized model contains tensors on the meta device," + " indicating that some tensors were not loaded properly." + " Please check that the parameters of the model being" + " specified match that of the serialized model, such as" + " its quantization.") + + def deserialize(self): + """ + Deserialize the model using the TensorDeserializer. This method is + specifically for vLLM models using tensorizer's plaid_mode. + + The deserializer makes use of tensorizer_args.stream_params + to configure the behavior of the stream when loading tensors from a + serialized model. The deserializer_params are used to configure the + behavior of the TensorDeserializer when loading tensors themselves. + Documentation on these params can be found in TensorizerArgs + + Returns: + nn.Module: The deserialized model. + """ + before_mem = get_mem_usage() + # Lazy load the tensors from S3 into the model. + start = time.perf_counter() + with open_stream( + self.tensorizer_args.tensorizer_uri, + mode="rb", + **self.tensorizer_args.stream_params, + ) as stream, TensorDeserializer( + stream, + dtype=self.tensorizer_config.dtype, + **self.tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(self.model) + end = time.perf_counter() + + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + deserializer.close() + logger.info(f"Deserialized {total_bytes_str} in " + f"{end - start:0.2f}s, {per_second}/s") + logger.info(f"Memory usage before: {before_mem}") + logger.info(f"Memory usage after: {after_mem}") + + self._check_tensors_on_meta_device() + self._resize_lora_embeddings() + return self.model.eval() diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 0961478930d7..08425604f051 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -5,7 +5,7 @@ import json import os from collections import defaultdict -from typing import Any, Iterable, Iterator, List, Optional, Tuple +from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union import filelock import huggingface_hub.constants @@ -161,7 +161,8 @@ def prepare_hf_model_weights( revision: Optional[str] = None, ) -> Tuple[str, List[str], bool]: # Download model weights from huggingface. - is_local = os.path.isdir(model_name_or_path) + is_local = os.path.isdir(model_name_or_path) \ + and load_format != "tensorizer" use_safetensors = False # Some quantized models use .pt files for storing the weights. if load_format == "auto": @@ -173,13 +174,15 @@ def prepare_hf_model_weights( allow_patterns = ["*.pt"] elif load_format == "npcache": allow_patterns = ["*.bin"] + elif load_format == "tensorizer": + allow_patterns = ["*.tensors"] else: raise ValueError(f"Unknown load_format: {load_format}") if fall_back_to_pt: allow_patterns += ["*.pt"] - if not is_local: + if not is_local and load_format != "tensorizer": # Before we download we look at that is available: fs = HfFileSystem() file_list = fs.ls(model_name_or_path, detail=False, revision=revision) @@ -224,6 +227,9 @@ def prepare_hf_model_weights( if not any(f.endswith(x) for x in blacklist) ] + if load_format == "tensorizer": + return hf_folder, hf_weights_files, use_safetensors + if len(hf_weights_files) == 0: raise RuntimeError( f"Cannot find any model weights with `{model_name_or_path}`") @@ -234,7 +240,7 @@ def prepare_hf_model_weights( def hf_model_weights_iterator( model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto", + load_format: Union[Tuple, str] = "auto", revision: Optional[str] = None, fall_back_to_pt: Optional[bool] = True, ) -> Iterator[Tuple[str, torch.Tensor]]: @@ -277,6 +283,26 @@ def hf_model_weights_iterator( with open(param_path, "rb") as f: param = np.load(f) yield name, torch.from_numpy(param) + elif load_format == "tensorizer": + from vllm.model_executor.tensorizer_loader import (TensorDeserializer, + open_stream, + tensorizer_warning) + tensorizer_args = load_format.params + tensorizer_warning( + "Deserializing HuggingFace models is not optimized for " + "loading on vLLM, as tensorizer is forced to load to CPU. " + "Consider deserializing a vLLM model instead for faster " + "load times. See the examples/tensorize_vllm_model.py example " + "script for serializing vLLM models.") + + deserializer_args = tensorizer_args.deserializer_params + stream_params = tensorizer_args.stream_params + stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) + with TensorDeserializer(stream, **deserializer_args, + device="cpu") as state: + for name, param in state.items(): + yield name, param + del state elif use_safetensors: for st_file in hf_weights_files: with safe_open(st_file, framework="pt") as f: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 47ad8f0c9b78..7dbe14ead097 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,7 +10,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, VisionLanguageConfig) + SchedulerConfig, TensorizerConfig, + VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed.device_communicators import (custom_all_reduce, pynccl_utils) @@ -111,11 +112,13 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + tensorizer_config: Optional[TensorizerConfig] = None, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.lora_config = lora_config + self.tensorizer_config = tensorizer_config self.is_driver_worker = is_driver_worker # model_config can be None in tests/samplers/test_sampler.py. @@ -158,7 +161,9 @@ def load_model(self) -> None: lora_config=self.lora_config, vision_language_config=self.vision_language_config, parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + scheduler_config=self.scheduler_config, + tensorizer_config=self.tensorizer_config, + ) self.model_memory_usage = m.consumed_memory logger.info(f"Loading model weights took " diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 3f0b2fd83f3e..82491c6df661 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,7 +7,8 @@ import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) + ParallelConfig, SchedulerConfig, TensorizerConfig, + VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) @@ -42,6 +43,7 @@ def __init__( distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, + tensorizer_config: Optional[TensorizerConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -53,6 +55,7 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.tensorizer_config = tensorizer_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -70,7 +73,9 @@ def __init__( lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, - vision_language_config=vision_language_config) + vision_language_config=vision_language_config, + tensorizer_config=tensorizer_config, + ) # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine = None From 2cd6b4f3625466eb5849bcfd7a6fb316735adab8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 13 Apr 2024 23:40:21 -0700 Subject: [PATCH 104/120] [Core] avoid too many cuda context by caching p2p test (#4021) --- .../device_communicators/custom_all_reduce.py | 53 +++++------ vllm/distributed/parallel_state.py | 9 ++ vllm/distributed/utils.py | 87 ++++++++++++++++++- 3 files changed, 116 insertions(+), 33 deletions(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 84238d2e4607..f83caef879da 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -42,12 +42,17 @@ def init_custom_ar() -> None: " disable_custom_all_reduce=True explicitly.", world_size, str(_SUPPORTED_WORLD_SIZES)) return - if not _can_p2p(rank, world_size): + num_dev = torch.cuda.device_count() + # note: num dev can be larger than world_size if we're only using + # first few GPUs + if num_dev < world_size: logger.warn( - "Custom allreduce is disabled because your platform lacks GPU P2P" - " capability or P2P test failed. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.") - return + "Cannot test GPU P2P because not all GPUs are visible to the " + "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" + " is set.") + return False + # test nvlink first, this will filter out most of the cases + # where custom allreduce is not supported full_nvlink = _is_full_nvlink(rank, world_size) if world_size > 2 and not full_nvlink: logger.warn( @@ -55,6 +60,15 @@ def init_custom_ar() -> None: " than two PCIe-only GPUs. To silence this warning, specify" " disable_custom_all_reduce=True explicitly.") return + # test P2P capability + # this is expensive to compute at the first time + # then we cache the result + if not _can_p2p(rank, world_size): + logger.warn( + "Custom allreduce is disabled because your platform lacks GPU P2P" + " capability or P2P test failed. To silence this warning, specify" + " disable_custom_all_reduce=True explicitly.") + return _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink) @@ -143,40 +157,15 @@ def _is_full_nvlink(rank, world_size): def _can_p2p(rank: int, world_size: int) -> bool: - num_dev = torch.cuda.device_count() - # note: num dev can be larger than world_size if we're only using - # first few GPUs - if num_dev < world_size: - logger.warn( - "Cannot test GPU P2P because not all GPUs are visible to the " - "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" - " is set.") - return False + from vllm.distributed.utils import gpu_p2p_access_check for i in range(world_size): if i == rank: continue - if not torch.cuda.can_device_access_peer(rank, i): - return False - # on some platforms, P2P support might be buggy and we need - # additional checks. See also: - # https://github.com/vllm-project/vllm/issues/2728 - if not _can_actually_p2p(rank, i): + if not gpu_p2p_access_check(rank, i): return False return True -# code partly borrowed from -# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10 -# License: MIT -def _can_actually_p2p(idx_a, idx_b): - dev_i = f"cuda:{idx_a}" - dev_j = f"cuda:{idx_b}" - a = torch.randn(5, device=dev_i) + 123.0 - b = a.to(dev_j) - c = b.to(dev_i) - return torch.all(a == c) - - class CustomAllreduce: # max_size: max supported allreduce size diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 1258bf58cb45..e2473736375e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -41,6 +41,13 @@ # source rank when broadcasting from the first or last pipeline stage. _PIPELINE_GLOBAL_RANKS = None +_LOCAL_RANK = -1 + + +def get_local_rank(): + global _LOCAL_RANK + return _LOCAL_RANK + def init_distributed_environment( world_size: int = -1, @@ -66,6 +73,8 @@ def init_distributed_environment( ranks = list(range(torch.distributed.get_world_size())) _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks, backend="gloo") + global _LOCAL_RANK + _LOCAL_RANK = local_rank def initialize_model_parallel( diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 0cd420c8e11b..e0a871ebe175 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -2,9 +2,18 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -from typing import Sequence +import json +import os +from typing import Dict, Optional, Sequence import torch +import torch.distributed as dist + +from vllm.logger import init_logger + +from .parallel_state import get_cpu_world_group, get_local_rank + +logger = init_logger(__name__) def ensure_divisibility(numerator, denominator): @@ -46,3 +55,79 @@ def split_tensor_along_last_dim( return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list + + +# code partly borrowed from +# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10 +# License: MIT +def _can_actually_p2p(idx_a, idx_b): + dev_i = f"cuda:{idx_a}" + dev_j = f"cuda:{idx_b}" + a = torch.randn(5, device=dev_i) + 123.0 + b = a.to(dev_j) + c = b.to(dev_i) + return torch.all(a == c).cpu().item() + + +# why do we need this cache? +# 1. we can have runtime checks for P2P access, where every process checks +# P2P access to all other GPUs. Unfortunately, the test might cost many +# (world_size * world_size) cuda context, and reduce the memory available +# for the model. see https://github.com/vllm-project/vllm/issues/3821 +# 2. alternatively, we can have a p2p map that is generated by the master +# process and broadcasted to all other processes. This still requires +# #world_size of cuda context, belonging to the master process, on each GPU. +# 3. we can have a cache file, that records the p2p access status. The first +# time the master process checks the p2p access, it will generate the cache +# file, at the cost of #world_size of cuda context. Later on, all processes +# can read the cache file to check the p2p access status without any cost of +# additional cuda context. +# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we +# can have different cache files for different CUDA_VISIBLE_DEVICES settings, +# e.g. used by different vllm engines. The device id in the cache file is a +# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number +# of visible devices in the vllm engine. +_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None + + +def gpu_p2p_access_check(i: int, j: int) -> bool: + """Check if GPU i can access GPU j.""" + + # if the cache variable is already calculated, + # read from the cache instead of checking it again + global _gpu_p2p_access_cache + if _gpu_p2p_access_cache is not None: + return _gpu_p2p_access_cache[f"{i}->{j}"] + + is_distributed = dist.is_initialized() + + num_dev = torch.cuda.device_count() + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if cuda_visible_devices is None: + cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + path = os.path.expanduser( + f"~/.config/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + if (not is_distributed or get_local_rank() == 0) \ + and (not os.path.exists(path)): + # only the local master process (with local_rank == 0) can + # enter this block to calculate the cache + logger.info(f"generating GPU P2P access cache for in {path}") + cache = {} + for _i in range(num_dev): + for _j in range(num_dev): + # on some platforms, P2P support might be buggy and we need + # additional checks. See also: + # https://github.com/vllm-project/vllm/issues/2728 + cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer( + _i, _j) and _can_actually_p2p(_i, _j) + with open(path, "w") as f: + json.dump(cache, f, indent=4) + if is_distributed: + cpu_world_group = get_cpu_world_group() + dist.barrier(cpu_world_group) + logger.info(f"reading GPU P2P access cache from {path}") + with open(path, "r") as f: + cache = json.load(f) + _gpu_p2p_access_cache = cache + return _gpu_p2p_access_cache[f"{i}->{j}"] From 563c54f760f870ae44c7662c8a9ec3a223a3c4c4 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 14 Apr 2024 22:12:42 +0100 Subject: [PATCH 105/120] [BugFix] Fix tensorizer extra in setup.py (#4072) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 813321efe796..19a9150ad2e6 100644 --- a/setup.py +++ b/setup.py @@ -406,7 +406,7 @@ def _read_requirements(filename: str) -> List[str]: install_requires=get_requirements(), ext_modules=ext_modules, extras_require={ - "optional": ["tensorizer==2.9.0a1"], + "tensorizer": ["tensorizer==2.9.0a1"], }, cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, From aceb17cf2d629175a484c3d9df355f44bd334cb3 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Sun, 14 Apr 2024 14:35:55 -0700 Subject: [PATCH 106/120] [Docs] document that mixtral 8x22b is supported (#4073) --- README.md | 2 +- docs/source/models/supported_models.rst | 36 ++++++++++++------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index d53227b82d87..8434c1188334 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) -- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.) +- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) - OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c09b0ff25043..5e5ce871f61d 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -30,23 +30,23 @@ Alongside each architecture, we include some popular models that use it. * - :code:`CohereForCausalLM` - Command-R - :code:`CohereForAI/c4ai-command-r-v01`, etc. - - + - * - :code:`DbrxForCausalLM` - DBRX - :code:`databricks/dbrx-base`, :code:`databricks/dbrx-instruct`, etc. - - + - * - :code:`DeciLMForCausalLM` - DeciLM - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. - - + - * - :code:`BloomForCausalLM` - BLOOM, BLOOMZ, BLOOMChat - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. - - + - * - :code:`FalconForCausalLM` - Falcon - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. - - + - * - :code:`GemmaForCausalLM` - Gemma - :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc. @@ -54,19 +54,19 @@ Alongside each architecture, we include some popular models that use it. * - :code:`GPT2LMHeadModel` - GPT-2 - :code:`gpt2`, :code:`gpt2-xl`, etc. - - + - * - :code:`GPTBigCodeForCausalLM` - StarCoder, SantaCoder, WizardCoder - :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc. - - + - * - :code:`GPTJForCausalLM` - GPT-J - :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc. - - + - * - :code:`GPTNeoXForCausalLM` - GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM - :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc. - - + - * - :code:`InternLMForCausalLM` - InternLM - :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc. @@ -93,32 +93,32 @@ Alongside each architecture, we include some popular models that use it. - ✅︎ * - :code:`MixtralForCausalLM` - Mixtral-8x7B, Mixtral-8x7B-Instruct - - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc. + - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, :code:`mistral-community/Mixtral-8x22B-v0.1`, etc. - ✅︎ * - :code:`MPTForCausalLM` - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. - - + - * - :code:`OLMoForCausalLM` - OLMo - :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc. - - + - * - :code:`OPTForCausalLM` - OPT, OPT-IML - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc. - - + - * - :code:`OrionForCausalLM` - Orion - :code:`OrionStarAI/Orion-14B-Base`, :code:`OrionStarAI/Orion-14B-Chat`, etc. - - + - * - :code:`PhiForCausalLM` - Phi - :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc. - - + - * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. - - + - * - :code:`Qwen2ForCausalLM` - Qwen2 - :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc. @@ -126,11 +126,11 @@ Alongside each architecture, we include some popular models that use it. * - :code:`Qwen2MoeForCausalLM` - Qwen2MoE - :code:`Qwen/Qwen1.5-MoE-A2.7B`, :code:`Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. - - + - * - :code:`StableLmForCausalLM` - StableLM - :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc. - - + - If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. Otherwise, please refer to :ref:`Adding a New Model ` for instructions on how to implement support for your model. From 8db1bf32f8924403c6a845b5ce71ba0f14beb038 Mon Sep 17 00:00:00 2001 From: Roy Date: Mon, 15 Apr 2024 08:43:54 +0800 Subject: [PATCH 107/120] [Misc] Upgrade triton to 2.2.0 (#4061) --- requirements-cpu.txt | 2 +- requirements-cuda.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 5779b38b24e6..e911ad03295f 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -3,4 +3,4 @@ # Dependencies for x86_64 CPUs torch == 2.2.1+cpu -triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file +triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 6ee75e8139c0..c6d2cd46aee5 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,4 +7,3 @@ pynvml == 11.5.0 vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.2.1 xformers == 0.0.25 # Requires PyTorch 2.2.1 -triton >= 2.1.0 From e11e2007368b22fce05b9ecdf00dd48eda471f9e Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 14 Apr 2024 21:50:08 -0700 Subject: [PATCH 108/120] [Bugfix] Fix filelock version requirement (#4075) --- requirements-common.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index c96f9c9937fb..90a3bc8abc1d 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -12,4 +12,5 @@ pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 tiktoken == 0.6.0 # Required for DBRX tokenizer outlines == 0.0.34 # Requires torch >= 2.1.0 -typing_extensions \ No newline at end of file +typing_extensions +filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 From 0003e9154bf1091d0de7ca7a6c7f1253df1eca5b Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Mon, 15 Apr 2024 23:35:55 +0800 Subject: [PATCH 109/120] [Misc][Minor] Fix CPU block num log in CPUExecutor. (#4088) --- vllm/executor/cpu_executor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 33e67d8b3eec..e63a88be7868 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -74,7 +74,10 @@ def initialize_cache(self, num_gpu_blocks: int, # NOTE: We log here to avoid multiple logs when number of workers is # greater than one. We could log in the engine, but not all executors # have GPUs. - logger.info(f"# CPU blocks: {num_cpu_blocks}") + # NOTE: `cpu block` for CPU backend is located on CPU memory but is + # referred as `gpu block`. Because we want to reuse the existing block + # management procedure. + logger.info(f"# CPU blocks: {num_gpu_blocks}") self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model(self, From eb46fbfda25348422918c4a876e17aef05fc5e34 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 15 Apr 2024 13:05:09 -0700 Subject: [PATCH 110/120] [Core] Simplifications to executor classes (#4071) --- vllm/executor/cpu_executor.py | 31 +++++++++------------------- vllm/executor/executor_base.py | 27 +++++++++++++++++------- vllm/executor/gpu_executor.py | 32 ++++------------------------- vllm/executor/neuron_executor.py | 29 ++++++-------------------- vllm/executor/ray_gpu_executor.py | 34 ++++--------------------------- 5 files changed, 44 insertions(+), 109 deletions(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index e63a88be7868..f562e4e0ae3d 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -1,10 +1,9 @@ import os -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Set, Tuple import torch -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -16,23 +15,13 @@ class CPUExecutor(ExecutorBase): - def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], *args, **kwargs) -> None: - assert device_config.device_type == "cpu" - assert lora_config is None, "cpu backend doesn't support LoRA" - model_config = _verify_and_get_model_config(model_config) - cache_config = _verify_and_get_cache_config(cache_config) - scheduler_config = _verify_and_get_scheduler_config(scheduler_config) - - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config + def _init_executor(self) -> None: + assert self.device_config.device_type == "cpu" + assert self.lora_config is None, "cpu backend doesn't support LoRA" + self.model_config = _verify_and_get_model_config(self.model_config) + self.cache_config = _verify_and_get_cache_config(self.cache_config) + self.scheduler_config = _verify_and_get_scheduler_config( + self.scheduler_config) # Instantiate the worker and load the model to CPU. self._init_worker() @@ -99,7 +88,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index bbfbfc689c99..bbb6ec80f7b7 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -16,7 +16,6 @@ class ExecutorBase(ABC): that can execute the model on multiple devices. """ - @abstractmethod def __init__( self, model_config: ModelConfig, @@ -27,8 +26,23 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig], ) -> None: - raise NotImplementedError + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.vision_language_config = vision_language_config + self.speculative_config = speculative_config + self.tensorizer_config = tensorizer_config + + self._init_executor() + + @abstractmethod + def _init_executor(self) -> None: + pass @abstractmethod def determine_num_available_blocks(self) -> Tuple[int, int]: @@ -71,7 +85,7 @@ def remove_lora(self, lora_id: int) -> bool: raise NotImplementedError @abstractmethod - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: raise NotImplementedError @abstractmethod @@ -94,8 +108,7 @@ async def execute_model_async( """Executes one model step on the given sequences.""" raise NotImplementedError - @abstractmethod async def check_health_async(self) -> None: """Checks if the executor is healthy. If not, it should raise an exception.""" - raise NotImplementedError + self.check_health() diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 30577ecf62fa..bae509f48025 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,8 +1,5 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TensorizerConfig, VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -15,24 +12,8 @@ class GPUExecutor(ExecutorBase): - def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - tensorizer_config: Optional[TensorizerConfig]) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.vision_language_config = vision_language_config - self.tensorizer_config = tensorizer_config - - assert (not speculative_config + def _init_executor(self) -> None: + assert (not self.speculative_config ), "Speculative decoding not yet supported for GPU backend" # Instantiate the worker and load the model to GPU. @@ -103,7 +84,7 @@ def remove_lora(self, lora_id: int) -> bool: assert lora_id > 0, "lora_id must be greater than 0." return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: @@ -127,8 +108,3 @@ async def execute_model_async( blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy) return output - - async def check_health_async(self) -> None: - # GPUExecutor will always be healthy as long as - # it's running. - return diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index d45f18e46625..273b17a927ef 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,8 +1,5 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -13,24 +10,10 @@ class NeuronExecutor(ExecutorBase): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - assert lora_config is None, "LoRA is not supported for Neuron backend." - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - assert (not speculative_config + def _init_executor(self) -> None: + assert (self.lora_config is + None), "LoRA is not supported for Neuron backend." + assert (not self.speculative_config ), "Speculative decoding not yet supported for Neuron backend." # Instantiate the worker and load the model to the device. @@ -80,7 +63,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 28dc3e0db312..5db2f3f65253 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,11 +3,8 @@ import os import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TensorizerConfig, VisionLanguageConfig) from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -32,27 +29,8 @@ class RayGPUExecutor(ExecutorBase): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - tensorizer_config: Optional[TensorizerConfig], - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.vision_language_config = vision_language_config - self.tensorizer_config = tensorizer_config - assert (not speculative_config + def _init_executor(self) -> None: + assert (not self.speculative_config ), "Speculative decoding not yet supported for RayGPU backend." assert self.parallel_config.worker_use_ray @@ -273,7 +251,7 @@ def remove_lora(self, lora_id: int) -> bool: lora_id=lora_id, ) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self._run_workers("list_loras") def _run_workers( @@ -416,7 +394,3 @@ async def execute_model_async( # Only the driver worker returns the sampling results. output = all_outputs[0] return output - - async def check_health_async(self) -> None: - """Raises an error if engine is unhealthy.""" - self._check_if_any_actor_is_dead() From d619ae2d19c41d9aa8f68fa0e5e32cc410dc2522 Mon Sep 17 00:00:00 2001 From: Sanger Steel Date: Mon, 15 Apr 2024 16:28:25 -0400 Subject: [PATCH 111/120] [Doc] Add better clarity for tensorizer usage (#4090) Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- docs/source/models/engine_args.rst | 2 +- examples/tensorize_vllm_model.py | 60 +++++++++++++++++------- vllm/model_executor/tensorizer_loader.py | 6 +-- 3 files changed, 46 insertions(+), 22 deletions(-) diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index 886a806934c0..235cb4e128c9 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -45,7 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM: * "safetensors" will load the weights in the safetensors format. * "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. * "dummy" will initialize the weights with random values, mainly for profiling. - * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. `_. See `tensorized_vllm_model.py` in the examples folder to serialize a vLLM model, and for more information. Tensorizer support for vLLM can be installed with `pip install vllm[tensorizer]`. + * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. `_ See `examples/tensorize_vllm_model.py `_ to serialize a vLLM model, and for more information. .. option:: --dtype {auto,half,float16,bfloat16,float,float32} diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py index 3c20a38c7f72..8cf8be09d0b9 100644 --- a/examples/tensorize_vllm_model.py +++ b/examples/tensorize_vllm_model.py @@ -23,14 +23,16 @@ # yapf: disable """ tensorize_vllm_model.py is a script that can be used to serialize and -deserialize vLLM models. These models can be loaded using tensorizer directly -to the GPU extremely quickly. Tensor encryption and decryption is also -supported, although libsodium must be installed to use it. Install -vllm with tensorizer support using `pip install vllm[tensorizer]`. +deserialize vLLM models. These models can be loaded using tensorizer +to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint, +or locally. Tensor encryption and decryption is also supported, although +libsodium must be installed to use it. Install vllm with tensorizer support +using `pip install vllm[tensorizer]`. -To serialize a model, you can run something like this: +To serialize a model, install vLLM from source, then run something +like this from the root level of this repository: -python tensorize_vllm_model.py \ +python -m examples.tensorize_vllm_model \ --model EleutherAI/gpt-j-6B \ --dtype float16 \ serialize \ @@ -38,31 +40,57 @@ --suffix vllm Which downloads the model from HuggingFace, loads it into vLLM, serializes it, -and saves it to your S3 bucket. A local directory can also be used. +and saves it to your S3 bucket. A local directory can also be used. This +assumes your S3 credentials are specified as environment variables +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. +To provide S3 credentials directly, you can provide `--s3-access-key-id` and +`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this +script. You can also encrypt the model weights with a randomly-generated key by providing a `--keyfile` argument. -To deserialize a model, you can run something like this: +To deserialize a model, you can run something like this from the root +level of this repository: -python tensorize_vllm_model.py \ +python -m examples.tensorize_vllm_model \ --model EleutherAI/gpt-j-6B \ --dtype float16 \ deserialize \ --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors Which downloads the model tensors from your S3 bucket and deserializes them. -To provide S3 credentials, you can provide `--s3-access-key-id` and -`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, -the OpenAI entrypoint, as arguments for LLM(), or as environment variables -in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. - You can also provide a `--keyfile` argument to decrypt the model weights if they were serialized with encryption. -For more information on the available arguments, run -`python tensorize_vllm_model.py --help`. +For more information on the available arguments for serializing, run +`python -m examples.tensorize_vllm_model serialize --help`. + +Or for deserializing: + +`python -m examples.tensorize_vllm_model deserialize --help`. + +Once a model is serialized, it can be used to load the model when running the +OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing +the `--tensorizer-uri` CLI argument that is functionally the same as the +`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to +signify that the model to be deserialized is a vLLM model, rather than a +HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer +in the same inference server, albeit without the speed optimizations. To +deserialize an encrypted file, the `--encryption-keyfile` argument can be used +to provide the path to the keyfile used to encrypt the model weights. For +information on all the arguments that can be used to configure tensorizer's +deserialization, check out the tensorizer options argument group in the +`vllm/entrypoints/openai/api_server.py` script with `--help`. + +Tensorizer can also be invoked with the `LLM` class directly to load models: + + llm = LLM(model="facebook/opt-125m", + load_format="tensorizer", + tensorizer_uri=path_to_opt_tensors, + num_readers=3, + vllm_tensorized=True) """ diff --git a/vllm/model_executor/tensorizer_loader.py b/vllm/model_executor/tensorizer_loader.py index ed3ad9e2ffa1..8550cc97aefe 100644 --- a/vllm/model_executor/tensorizer_loader.py +++ b/vllm/model_executor/tensorizer_loader.py @@ -126,7 +126,6 @@ def __post_init__(self): "s3_endpoint": self.s3_endpoint, } - # Omitting self.dtype and self.device as this behaves weirdly self.deserializer_params = { "verify_hash": self.verify_hash, "encryption": self.encryption_keyfile, @@ -145,7 +144,7 @@ def add_cli_args( parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Tensorizer CLI arguments""" - # Create the argument group + # Tensorizer options arg group group = parser.add_argument_group( 'tensorizer options', description=('Options for configuring the behavior of the' @@ -205,9 +204,7 @@ def add_cli_args( @classmethod def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs": - # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] - # Set the attributes from the parsed arguments. tensorizer_args = cls(**{ attr: getattr(args, attr) for attr in attrs if hasattr(args, attr) @@ -291,7 +288,6 @@ def deserialize(self): nn.Module: The deserialized model. """ before_mem = get_mem_usage() - # Lazy load the tensors from S3 into the model. start = time.perf_counter() with open_stream( self.tensorizer_args.tensorizer_uri, From 4695397dcfef693a0a10f1eb8bf77ea905c54829 Mon Sep 17 00:00:00 2001 From: Ricky Xu Date: Mon, 15 Apr 2024 14:24:45 -0700 Subject: [PATCH 112/120] [Bugfix] Fix ray workers profiling with nsight (#4095) --- vllm/executor/ray_gpu_executor.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 5db2f3f65253..7aca5e36107a 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -48,6 +48,21 @@ def _init_executor(self) -> None: if USE_RAY_COMPILED_DAG: self.forward_dag = self._compiled_ray_dag() + def _configure_ray_workers_use_nsight(self, + ray_remote_kwargs) -> Dict[str, Any]: + # If nsight profiling is enabled, we need to set the profiling + # configuration for the ray workers as runtime env. + runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) + runtime_env.update({ + "nsight": { + "t": "cuda,cudnn,cublas", + "o": "'worker_process_%p'", + "cuda-graph-trace": "node", + } + }) + + return ray_remote_kwargs + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): if self.parallel_config.tensor_parallel_size == 1: @@ -63,6 +78,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # The remaining workers are the actual ray actors. self.workers: List[RayWorkerVllm] = [] + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs = self._configure_ray_workers_use_nsight( + ray_remote_kwargs) + # Create the workers. driver_ip = get_ip() for bundle_id, bundle in enumerate(placement_group.bundle_specs): From 37e84a403d6d11b670a42e84153204cd8b76b849 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 Apr 2024 06:47:31 +0900 Subject: [PATCH 113/120] [Typing] Fix Sequence type GenericAlias only available after Python 3.9. (#4092) --- vllm/core/block_manager_v1.py | 5 +++-- vllm/core/block_manager_v2.py | 2 +- vllm/core/interfaces.py | 2 +- vllm/utils.py | 7 ++++--- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index e391a3b1e5a3..be093922b84f 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -1,9 +1,10 @@ """A block manager that manages token blocks.""" from abc import ABC, abstractmethod -from collections.abc import Sequence as GenericSequence from itertools import count, takewhile from os.path import commonprefix -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional +from typing import Sequence as GenericSequence +from typing import Set from vllm.block import BlockTable, PhysicalTokenBlock from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 19f0cf415eb3..6339a6baf416 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -1,6 +1,6 @@ """A block manager that manages token blocks.""" -from collections.abc import Sequence as GenericSequence from typing import Dict, List, Optional +from typing import Sequence as GenericSequence from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index c1f68a2e891b..56c2c5995c38 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -1,7 +1,7 @@ import enum from abc import ABC, abstractmethod -from collections.abc import Sequence as GenericSequence from typing import Dict, List +from typing import Sequence as GenericSequence from vllm.sequence import Sequence, SequenceGroup diff --git a/vllm/utils.py b/vllm/utils.py index 4c0dc9ca729a..aad62516ad1b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -6,11 +6,12 @@ import subprocess import uuid import warnings -from collections import OrderedDict, defaultdict +from collections import defaultdict from functools import lru_cache, partial from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, - Hashable, List, Optional, Tuple, TypeVar, Union) + Hashable, List, Optional, OrderedDict, Tuple, TypeVar, + Union) import psutil import torch @@ -51,7 +52,7 @@ def reset(self) -> None: class LRUCache(Generic[T]): def __init__(self, capacity: int): - self.cache = OrderedDict[Hashable, T]() + self.cache: OrderedDict[Hashable, T] = OrderedDict() self.capacity = capacity def __contains__(self, key: Hashable) -> bool: From 4e7ee664e201442e24e2298a36a5264b98691626 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 Apr 2024 14:24:53 +0900 Subject: [PATCH 114/120] [Core] Fix engine-use-ray broken (#4105) --- tests/async_engine/test_api_server.py | 17 +++++++++++++---- vllm/engine/async_llm_engine.py | 7 +++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index 248bfbc8ab5c..7f57d5cf9b18 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -25,21 +25,30 @@ def _query_server_long(prompt: str) -> dict: @pytest.fixture -def api_server(tokenizer_pool_size: int): +def api_server(tokenizer_pool_size: int, engine_use_ray: bool, + worker_use_ray: bool): script_path = Path(__file__).parent.joinpath( "api_server_async_engine.py").absolute() - uvicorn_process = subprocess.Popen([ + commands = [ sys.executable, "-u", str(script_path), "--model", "facebook/opt-125m", "--host", "127.0.0.1", "--tokenizer-pool-size", str(tokenizer_pool_size) - ]) + ] + if engine_use_ray: + commands.append("--engine-use-ray") + if worker_use_ray: + commands.append("--worker-use-ray") + uvicorn_process = subprocess.Popen(commands) yield uvicorn_process.terminate() @pytest.mark.parametrize("tokenizer_pool_size", [0, 2]) -def test_api_server(api_server, tokenizer_pool_size: int): +@pytest.mark.parametrize("worker_use_ray", [False, True]) +@pytest.mark.parametrize("engine_use_ray", [False, True]) +def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool, + engine_use_ray: bool): """ Run the API server and test it. diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f61049513512..1dbf58904541 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -333,8 +333,7 @@ def from_engine_args( if engine_config.device_config.device_type == "neuron": raise NotImplementedError("Neuron is not supported for " "async engine yet.") - elif (engine_config.parallel_config.worker_use_ray - or engine_args.engine_use_ray): + elif engine_config.parallel_config.worker_use_ray: initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync @@ -410,8 +409,8 @@ def _init_engine(self, *args, else: # FIXME(woosuk): This is a bit hacky. Be careful when changing the # order of the arguments. - cache_config = args[1] - parallel_config = args[2] + cache_config = kwargs["cache_config"] + parallel_config = kwargs["parallel_config"] if parallel_config.tensor_parallel_size == 1: num_gpus = cache_config.gpu_memory_utilization else: From 05434764cd99990035779cf9a4ed86623b528825 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Tue, 16 Apr 2024 08:54:57 +0300 Subject: [PATCH 115/120] LM Format Enforcer Guided Decoding Support (#3868) Co-authored-by: Simon Mo --- requirements-common.txt | 1 + tests/entrypoints/test_guided_processors.py | 42 +++++++- tests/entrypoints/test_openai_server.py | 69 ++++++++---- vllm/config.py | 26 ++++- vllm/engine/arg_utils.py | 18 +++- vllm/engine/llm_engine.py | 10 +- vllm/entrypoints/openai/protocol.py | 12 +++ vllm/entrypoints/openai/serving_chat.py | 6 +- vllm/entrypoints/openai/serving_completion.py | 6 +- .../guided_decoding/__init__.py | 25 +++++ .../lm_format_enforcer_decoding.py | 69 ++++++++++++ .../outlines_decoding.py} | 7 +- .../outlines_logits_processors.py} | 100 +++++++++--------- 13 files changed, 304 insertions(+), 87 deletions(-) create mode 100644 vllm/model_executor/guided_decoding/__init__.py create mode 100644 vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py rename vllm/model_executor/{guided_decoding.py => guided_decoding/outlines_decoding.py} (93%) rename vllm/model_executor/{guided_logits_processors.py => guided_decoding/outlines_logits_processors.py} (70%) diff --git a/requirements-common.txt b/requirements-common.txt index 90a3bc8abc1d..c1614d2537b2 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -11,6 +11,7 @@ uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 tiktoken == 0.6.0 # Required for DBRX tokenizer +lm-format-enforcer == 0.9.3 outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 5622744566bc..30f0ad5d8272 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -1,11 +1,14 @@ # This unit test should be moved to a new # tests/test_guided_decoding directory. - +import pytest import torch from transformers import AutoTokenizer -from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor, - RegexLogitsProcessor) +from vllm.entrypoints.openai.protocol import CompletionRequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.outlines_logits_processors import ( + JSONLogitsProcessor, RegexLogitsProcessor) TEST_SCHEMA = { "type": "object", @@ -73,3 +76,36 @@ def test_guided_logits_processors(): json_LP(token_ids, tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"]) +async def test_guided_logits_processor_black_box(backend: str): + tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') + token_ids = tokenizer.encode( + f"Give an example IPv4 address with this regex: {TEST_REGEX}") + regex_request = CompletionRequest(model='test', + prompt=token_ids, + guided_regex=TEST_REGEX) + regex_lp = await get_guided_decoding_logits_processor( + backend, regex_request, tokenizer) + assert regex_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = regex_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) + + token_ids = tokenizer.encode( + f"Give an employee profile that fits this schema: {TEST_SCHEMA}") + json_request = CompletionRequest(model='test', + prompt=token_ids, + guided_json=TEST_SCHEMA) + json_lp = await get_guided_decoding_logits_processor( + backend, json_request, tokenizer) + assert json_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = json_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 7940430b8b65..14e6ee0ffe9d 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -506,7 +506,10 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): assert first_response != completion.choices[0].text -async def test_guided_json_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt=f"Give an example JSON for an employee profile " @@ -514,7 +517,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI): n=3, temperature=1.0, max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 3 @@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI): jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) -async def test_guided_json_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, - max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + max_tokens=1000, + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) message = chat_completion.choices[0].message assert message.content is not None json1 = json.loads(message.content) @@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, - max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + max_tokens=1000, + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) message = chat_completion.choices[0].message assert message.content is not None json2 = json.loads(message.content) @@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): assert json1["age"] != json2["age"] -async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}", n=3, temperature=1.0, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 3 @@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None -async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) ip1 = chat_completion.choices[0].message.content assert ip1 is not None assert re.fullmatch(TEST_REGEX, ip1) is not None @@ -606,21 +623,26 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) ip2 = chat_completion.choices[0].message.content assert ip2 is not None assert re.fullmatch(TEST_REGEX, ip2) is not None assert ip1 != ip2 -async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt="The best language for type-safe systems programming is ", n=2, temperature=1.0, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 2 @@ -628,7 +650,10 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): assert completion.choices[i].text in TEST_CHOICE -async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) choice1 = chat_completion.choices[0].message.content assert choice1 in TEST_CHOICE @@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) choice2 = chat_completion.choices[0].message.content assert choice2 in TEST_CHOICE assert choice1 != choice2 -async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): with pytest.raises(openai.BadRequestError): _ = await client.completions.create( model=MODEL_NAME, prompt="Give an example JSON that fits this schema: 42", - extra_body=dict(guided_json=42)) + extra_body=dict(guided_json=42, + guided_decoding_backend=guided_decoding_backend)) messages = [{ "role": "system", diff --git a/vllm/config.py b/vllm/config.py index dce2944b2ee8..bf31b03b7c6c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -66,8 +66,8 @@ class ModelConfig: weights. If None, we assume the model weights are not quantized. quantization_param_path: Path to JSON file containing scaling factors. Used to load KV cache scaling factors into the model when KV cache - type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also - be used to load activation and weight scaling factors when the + type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also + be used to load activation and weight scaling factors when the model dtype is FP8_E4M3 on ROCm. enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. @@ -422,7 +422,7 @@ def verify_with_parallel_config( @dataclass class TokenizerPoolConfig: """Configuration for the tokenizer pool. - + Args: pool_size: Number of tokenizer workers in the pool. pool_type: Type of the pool. @@ -446,9 +446,9 @@ def create_config( tokenizer_pool_extra_config: Optional[Union[str, dict]] ) -> Optional["TokenizerPoolConfig"]: """Create a TokenizerPoolConfig from the given parameters. - + If tokenizer_pool_size is 0, return None. - + Args: tokenizer_pool_size: Number of tokenizer workers in the pool. tokenizer_pool_type: Type of the pool. @@ -1079,6 +1079,21 @@ def _get_and_verify_max_len( return int(max_model_len) +@dataclass +class DecodingConfig: + """Dataclass which contains the decoding strategy of the engine""" + + # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer' + guided_decoding_backend: str = 'outlines' + + def __post_init__(self): + valid_guided_backends = ['outlines', 'lm-format-enforcer'] + backend = self.guided_decoding_backend + if backend not in valid_guided_backends: + raise ValueError(f"Invalid guided_decoding_backend '{backend}," + f"must be one of {valid_guided_backends}") + + @dataclass(frozen=True) class EngineConfig: """Dataclass which contains all engine-related configuration. This @@ -1093,6 +1108,7 @@ class EngineConfig: lora_config: Optional[LoRAConfig] vision_language_config: Optional[VisionLanguageConfig] speculative_config: Optional[SpeculativeConfig] + decoding_config: Optional[DecodingConfig] tensorizer_config: Optional[TensorizerConfig] def __post_init__(self): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 831a03be65f6..3de74b0ac28b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -5,9 +5,9 @@ from dataclasses import dataclass from typing import BinaryIO, Optional, Union -from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, TensorizerConfig, +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, + EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, SpeculativeConfig, TensorizerConfig, TokenizerPoolConfig, VisionLanguageConfig) from vllm.model_executor.tensorizer_loader import TensorizerArgs from vllm.utils import str_to_int_tuple @@ -80,6 +80,7 @@ class EngineArgs: scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False + guided_decoding_backend: str = 'outlines' # Speculative decoding configuration. speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None @@ -200,6 +201,13 @@ def add_cli_args( default=EngineArgs.max_model_len, help='model context length. If unspecified, ' 'will be automatically derived from the model.') + parser.add_argument( + '--guided-decoding-backend', + type=str, + default='outlines', + choices=['outlines', 'lm-format-enforcer'], + help='Which engine will be used for guided decoding' + ' (JSON schema / regex etc)') # Parallel arguments parser.add_argument('--worker-use-ray', action='store_true', @@ -511,6 +519,9 @@ def create_engine_config(self, ) -> EngineConfig: else: vision_language_config = None + decoding_config = DecodingConfig( + guided_decoding_backend=self.guided_decoding_backend) + return EngineConfig(model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, @@ -519,6 +530,7 @@ def create_engine_config(self, ) -> EngineConfig: lora_config=lora_config, vision_language_config=vision_language_config, speculative_config=speculative_config, + decoding_config=decoding_config, tensorizer_config=tensorizer_config) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8c37c5a9d6ee..f06c1d18ace4 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -4,9 +4,10 @@ from transformers import PreTrainedTokenizer import vllm -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TensorizerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, TensorizerConfig, + VisionLanguageConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats @@ -74,6 +75,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + decoding_config: Optional[DecodingConfig], tensorizer_config: Optional[TensorizerConfig], executor_class: Type[ExecutorBase], log_stats: bool, @@ -100,6 +102,7 @@ def __init__( f"kv_cache_dtype={cache_config.cache_dtype}, " f"quantization_param_path={model_config.quantization_param_path}, " f"device_config={device_config.device}, " + f"decoding_config={decoding_config!r}, " f"seed={model_config.seed})") # TODO(woosuk): Print more configs in debug mode. @@ -111,6 +114,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.speculative_config = speculative_config + self.decoding_config = decoding_config or DecodingConfig() self.tensorizer_config = tensorizer_config self.log_stats = log_stats diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f94d22d279cc..cf779d44c816 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -133,6 +133,12 @@ class ChatCompletionRequest(BaseModel): description=( "If specified, the output will follow the context free grammar."), ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'")) # doc: end-chat-completion-extra-params @@ -265,6 +271,12 @@ class CompletionRequest(BaseModel): description=( "If specified, the output will follow the context free grammar."), ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be one of " + "'outlines' / 'lm-format-enforcer'")) # doc: end-completion-extra-params diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a03c5dc88108..c9ed4a9de20f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -68,9 +68,13 @@ async def create_chat_completion( request, prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + decoding_config = self.engine.engine.decoding_config + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( await get_guided_decoding_logits_processor( - request, await self.engine.get_tokenizer())) + guided_decoding_backend, request, await + self.engine.get_tokenizer())) if guided_decode_logits_processor: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index e24aa2489a80..a71f2d6a4426 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -88,9 +88,13 @@ async def create_completion(self, request: CompletionRequest, try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + decoding_config = self.engine.engine.decoding_config + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend guided_decode_logit_processor = ( await get_guided_decoding_logits_processor( - request, await self.engine.get_tokenizer())) + guided_decoding_backend, request, await + self.engine.get_tokenizer())) if guided_decode_logit_processor is not None: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py new file mode 100644 index 000000000000..0558d6c95d97 --- /dev/null +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -0,0 +1,25 @@ +from typing import Optional, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest) +from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( + get_lm_format_enforcer_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_outlines_guided_decoding_logits_processor) +from vllm.sampling_params import LogitsProcessor + + +async def get_guided_decoding_logits_processor( + guided_decoding_backend: str, request: Union[CompletionRequest, + ChatCompletionRequest], + tokenizer) -> Optional[LogitsProcessor]: + if guided_decoding_backend == 'outlines': + return await get_outlines_guided_decoding_logits_processor( + request, tokenizer) + if guided_decoding_backend == 'lm-format-enforcer': + return await get_lm_format_enforcer_guided_decoding_logits_processor( + request, tokenizer) + + raise ValueError( + f"Unknown guided decoding backend '{guided_decoding_backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer'") diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py new file mode 100644 index 000000000000..0d74a5f8e81f --- /dev/null +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -0,0 +1,69 @@ +from functools import lru_cache +from json import loads as json_loads +from typing import Optional, Union + +from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser, + RegexParser, StringParser, + TokenEnforcerTokenizerData, UnionParser) +from lmformatenforcer.integrations.vllm import ( + build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data) +from pydantic import BaseModel +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest) +from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_outlines_guided_decoding_logits_processor) +from vllm.sampling_params import LogitsProcessor + + +async def get_lm_format_enforcer_guided_decoding_logits_processor( + request: Union[CompletionRequest, ChatCompletionRequest], + tokenizer) -> Optional[LogitsProcessor]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + + tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer) + character_level_parser: CharacterLevelParser + if request.guided_json: + schema = _normalize_json_schema_object(request.guided_json) + character_level_parser = JsonSchemaParser(schema) + elif request.guided_choice: + character_level_parser = UnionParser( + [StringParser(choice) for choice in request.guided_choice]) + elif request.guided_regex: + character_level_parser = RegexParser(request.guided_regex) + elif request.guided_grammar: + # CFG grammar not supported by LMFE, revert to outlines + return await get_outlines_guided_decoding_logits_processor( + request, tokenizer) + elif (request.response_format is not None + and request.response_format.type == "json_object"): + character_level_parser = JsonSchemaParser( + None) # None means any json object + else: + return None + + logits_processor = build_vllm_logits_processor(tokenizer_data, + character_level_parser) + return logits_processor + + +def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: + if isinstance(schema, str): + return json_loads(schema) + if isinstance(schema, dict): + return schema + if isinstance(schema, BaseModel): + return schema.model_json_schema() + + +@lru_cache +def _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData: + return build_vllm_token_enforcer_tokenizer_data(tokenizer) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py similarity index 93% rename from vllm/model_executor/guided_decoding.py rename to vllm/model_executor/guided_decoding/outlines_decoding.py index 8e710f1ac2b5..bd4564a36e1e 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -12,9 +12,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) -from vllm.model_executor.guided_logits_processors import (CFGLogitsProcessor, - JSONLogitsProcessor, - RegexLogitsProcessor) +from vllm.model_executor.guided_decoding.outlines_logits_processors import ( + CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) class GuidedDecodingMode(Enum): @@ -54,7 +53,7 @@ class GuidedDecodingMode(Enum): global_thread_pool = None # used for generating logits processor fsm -async def get_guided_decoding_logits_processor( +async def get_outlines_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: """ diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py similarity index 70% rename from vllm/model_executor/guided_logits_processors.py rename to vllm/model_executor/guided_decoding/outlines_logits_processors.py index 035fe0003732..28041695546d 100644 --- a/vllm/model_executor/guided_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -13,9 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import json import math from collections import defaultdict +from functools import lru_cache from typing import Callable, DefaultDict, Dict, List, Optional, Union import torch @@ -27,50 +29,6 @@ class BaseLogitsProcessor: - def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase): - """Adapt vLLM's tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. The decoder of outlines, returns a list whereas - the decode of vLLM returns an str. To sync the vLLM decoder with - outlines internal api, the decoder should be adapted. In addition - we need to handle the missing spaces to Llama's tokenizer to be - able to compile FSMs for this model. - - """ - if getattr(tokenizer, "_outlines_adapted", False): - return tokenizer - - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) - - def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = tokenizer.convert_tokens_to_string([token]) - - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - def change_decoder( - decoder: Callable[[List[int]], str] - ) -> Callable[[List[int]], List[str]]: - """Sync vLLM's decoder with the outlines by returning list.""" - - def new_decoder(inp_tokens: List[int]) -> List[str]: - return [decoder(inp_tokens)] - - return new_decoder - - tokenizer.convert_token_to_string = convert_token_to_string - tokenizer.decode = change_decoder(tokenizer.decode) - setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 - - return tokenizer - def init_state(self): """Initialize the FSM states.""" self.fsm_state: DefaultDict[int, int] = defaultdict(int) @@ -78,7 +36,6 @@ def init_state(self): def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" - seq_id = hash(tuple(input_ids)) if len(input_ids) == 0: @@ -96,7 +53,6 @@ def __call__(self, input_ids: List[int], device=scores.device) mask[allowed_tokens] = 0 scores.add_(mask) - return scores @@ -113,7 +69,7 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - tokenizer = self.adapt_tokenizer(tokenizer) + tokenizer = _adapt_tokenizer(tokenizer) fsm = RegexFSM(regex_string, tokenizer) self.fsm = fsm @@ -167,6 +123,54 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - tokenizer = self.adapt_tokenizer(tokenizer) + tokenizer = _adapt_tokenizer(tokenizer) fsm = CFGFSM(cfg, tokenizer) self.fsm = fsm + + +@lru_cache +def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. The decoder of outlines, returns a list whereas + the decode of vLLM returns an str. To sync the vLLM decoder with + outlines internal api, the decoder should be adapted. In addition + we need to handle the missing spaces to Llama's tokenizer to be + able to compile FSMs for this model. + + """ + if getattr(tokenizer, "_outlines_adapted", False): + return tokenizer + + tokenizer = copy.deepcopy(tokenizer) + + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + def change_decoder( + decoder: Callable[[List[int]], + str]) -> Callable[[List[int]], List[str]]: + """Sync vLLM's decoder with the outlines by returning list.""" + + def new_decoder(inp_tokens: List[int]) -> List[str]: + return [decoder(inp_tokens)] + + return new_decoder + + tokenizer.convert_token_to_string = convert_token_to_string + tokenizer.decode = change_decoder(tokenizer.decode) + setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 + + return tokenizer From 2a19f5e58f36efb090434adb57e55a411144669b Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 16 Apr 2024 01:39:25 -0700 Subject: [PATCH 116/120] allow append empty tokens in block table --- tests/core/block/e2e/test_correctness.py | 65 ++++++++++++++++++++++++ vllm/core/block/block_table.py | 1 - 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 94b65401e1dd..a403d442d7af 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -229,6 +229,71 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, assert baseline_token_ids == test_token_ids +@pytest.mark.parametrize("common_llm_kwargs", [ + { + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + + "enable_chunked_prefill": True, + "max_num_batched_tokens": 2, + "max_num_seqs": 2, + }, +]) +@pytest.mark.parametrize("per_test_common_llm_kwargs",[{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [ + { + "use_v2_block_manager": False, + }, +]) +@pytest.mark.parametrize( + "test_llm_kwargs", [ + { + "use_v2_block_manager": True, + "num_lookahead_slots": 0, + }, + { + "use_v2_block_manager": True, + "num_lookahead_slots": 5, + }, +]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_chunked_prefill_block_manager_v2(baseline_llm_generator, test_llm_generator, batch_size): + output_len = 32 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids with BlockManagerV1') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids with BlockManagerV2') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): for llm in llm_generator: diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index ba061bbc4fbc..560267e55ea3 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -104,7 +104,6 @@ def append_token_ids(self, token_ids (List[int]): The sequence of token IDs to be appended. """ assert self._is_allocated - assert token_ids, "can't append empty token ids" self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + num_lookahead_slots) From b6e9e826604123654224a5d598fd140c1cfedde5 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 16 Apr 2024 02:58:43 -0700 Subject: [PATCH 117/120] rebase on stop string fixes --- vllm/engine/output_processor/multi_step.py | 15 +++-- vllm/engine/output_processor/stop_checker.py | 63 +------------------- 2 files changed, 13 insertions(+), 65 deletions(-) diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 6b01a94f59e4..bae903acda66 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -101,17 +101,24 @@ def _process_seq_outputs(self, seq: Sequence, valid_samples = valid_samples[:i + 1] break + # Incrementally append tokens to the sequence, as if we had only one new + # token. for output_token_id in output_token_ids: seq.append_token_id( token_id=output_token_id, # TODO emit logprobs in multi-step decoding. logprobs={output_token_id: Logprob(0.0)}, ) - self.detokenizer.decode_sequence_inplace(seq, sampling_params) - self.stop_checker.maybe_stop_sequence(seq, - sampling_params, - new_token_ids=output_token_ids) + new_char_count = 0 + if sampling_params.detokenize: + new_char_count = self.detokenizer.decode_sequence_inplace(seq, sampling_params) + + self.stop_checker.maybe_stop_sequence(seq, + new_char_count=new_char_count, + sampling_params=sampling_params) + if seq.is_finished(): + break if seq.is_finished(): self.scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index f259b818748e..93e2fe6ac17c 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -1,4 +1,4 @@ -from typing import Callable, List +from typing import Callable, List, Optional from transformers import PreTrainedTokenizer @@ -61,7 +61,7 @@ def maybe_stop_sequence(self, seq: Sequence, return # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: + if seq.get_len() > self.max_model_len: seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return @@ -101,62 +101,3 @@ def _check_stop_strings(seq: Sequence, new_char_count: int, seq.output_text = seq.output_text[:stop_index] return stop_str return None - # TODO spec decode - ## - # """Check if the sequences should be stopped. If so, mark it as finished. - # """ - - # # Check if the sequence has reached max_model_len. - # if seq.get_len() > self.max_model_len: - # seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - # return - - # # Check if the sequence has reached max_tokens. - # if seq.get_output_len() == sampling_params.max_tokens: - # seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - # return - - # # Check if the minimum number of tokens has been generated yet; - # # skip the stop string/token checks if not - # if seq.get_output_len() < sampling_params.min_tokens: - # return - - # if sampling_params.detokenize: - # for stop_str in sampling_params.stop: - # # TODO(cade) Fix this for speculative decoding. - # if seq.output_text.endswith(stop_str): - # self._finalize_sequence(seq, sampling_params, stop_str) - # seq.status = SequenceStatus.FINISHED_STOPPED - # seq.stop_reason = stop_str - # return - - # # Determine if any stop_token_ids are in new_token_ids. - # intersection = set(new_token_ids).intersection( - # sampling_params.stop_token_ids) - # if intersection: - # # Get arbitrary token id that caused the stop. - # stop_token_id = next(iter(intersection)) - - # stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( - # stop_token_id) - # self._finalize_sequence(seq, sampling_params, stop_str) - # seq.status = SequenceStatus.FINISHED_STOPPED - # seq.stop_reason = stop_token_id - # return - - # # Check if the sequence has generated the EOS token. - # if ((not sampling_params.ignore_eos) - # and seq.eos_token_id in new_token_ids): - # seq.status = SequenceStatus.FINISHED_STOPPED - # return - - #def _finalize_sequence(self, seq: Sequence, - # sampling_params: SamplingParams, - # stop_string: str) -> None: - # if sampling_params.include_stop_str_in_output: - # return - - # if stop_string and seq.output_text.endswith(stop_string): - # # Truncate the output text so that the stop string is - # # not included in the output. - # seq.output_text = seq.output_text[:-len(stop_string)] From bf0c37cbbd2f0f034edbd77a6292d9ba3509bf19 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 16 Apr 2024 03:00:13 -0700 Subject: [PATCH 118/120] test spec --- vllm/executor/gpu_executor.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 4fd9735669fd..9268b646a18a 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -13,13 +13,6 @@ class GPUExecutor(ExecutorBase): def _init_executor(self) -> None: - assert (not self.speculative_config - ), "Speculative decoding not yet supported for GPU backend" - - # Instantiate the worker and load the model to GPU. - self._init_worker() - - def _init_worker(self): if self.speculative_config is None: self._init_non_spec_worker() else: From a158256acb08f0c954feaf953590b0668d6f8904 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 16 Apr 2024 03:07:16 -0700 Subject: [PATCH 119/120] lint & mypy --- tests/core/block/e2e/test_correctness.py | 34 +++++++++++--------- vllm/engine/output_processor/multi_step.py | 10 +++--- vllm/engine/output_processor/single_step.py | 3 +- vllm/engine/output_processor/stop_checker.py | 6 ++-- vllm/executor/gpu_executor.py | 2 ++ vllm/executor/neuron_executor.py | 5 ++- 6 files changed, 34 insertions(+), 26 deletions(-) diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index a403d442d7af..1015892b67a4 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -229,27 +229,28 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, assert baseline_token_ids == test_token_ids -@pytest.mark.parametrize("common_llm_kwargs", [ - { - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 2, - "max_num_seqs": 2, - }, -]) -@pytest.mark.parametrize("per_test_common_llm_kwargs",[{}]) +@pytest.mark.parametrize( + "common_llm_kwargs", + [ + { + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 2, + "max_num_seqs": 2, + }, + ]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [ { "use_v2_block_manager": False, }, ]) -@pytest.mark.parametrize( - "test_llm_kwargs", [ +@pytest.mark.parametrize("test_llm_kwargs", [ { "use_v2_block_manager": True, "num_lookahead_slots": 0, @@ -261,7 +262,8 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, ]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) -def test_chunked_prefill_block_manager_v2(baseline_llm_generator, test_llm_generator, batch_size): +def test_chunked_prefill_block_manager_v2(baseline_llm_generator, + test_llm_generator, batch_size): output_len = 32 temperature = 0.0 diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index bae903acda66..50da0d35fcec 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -112,11 +112,13 @@ def _process_seq_outputs(self, seq: Sequence, new_char_count = 0 if sampling_params.detokenize: - new_char_count = self.detokenizer.decode_sequence_inplace(seq, sampling_params) + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) - self.stop_checker.maybe_stop_sequence(seq, - new_char_count=new_char_count, - sampling_params=sampling_params) + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count=new_char_count, + sampling_params=sampling_params) if seq.is_finished(): break diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 3ded72db3092..1b7eb014f802 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -110,7 +110,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, seq, seq_group.sampling_params) else: new_char_count = 0 - self.stop_checker.maybe_stop_sequence(seq, new_char_count, seq_group.sampling_params) + self.stop_checker.maybe_stop_sequence(seq, new_char_count, + seq_group.sampling_params) # Non-beam search case if not seq_group.sampling_params.use_beam_search: diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 93e2fe6ac17c..66deb9b59174 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional +from typing import Callable, Optional from transformers import PreTrainedTokenizer @@ -19,10 +19,8 @@ def __init__(self, max_model_len: int, self.max_model_len = max_model_len self.get_tokenizer_for_seq = get_tokenizer_for_seq - def maybe_stop_sequence(self, seq: Sequence, - new_char_count: int, + def maybe_stop_sequence(self, seq: Sequence, new_char_count: int, sampling_params: SamplingParams) -> None: - """Stop the finished sequences. new_char_count is the number of chars added to the diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 9268b646a18a..b7ab9481eb9f 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -48,6 +48,8 @@ def _init_non_spec_worker(self): def _init_spec_worker(self): """Initialize a SpecDecodeWorker, using a draft model for proposals. """ + assert self.speculative_config is not None + from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.worker.worker import Worker diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 273b17a927ef..7cc187e297c9 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -48,10 +48,13 @@ def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int) -> List[SamplerOutput]: assert (blocks_to_swap_in == {} and blocks_to_swap_out == {} and blocks_to_copy == {}), ( "Cache operations are not supported for Neuron backend.") + assert num_lookahead_slots == 0, ( + "lookahead not supported for Neuron backend.") output = self.driver_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list) From 5a69f6c25ad51515fcc9d1e5ecc9d43fea3af89c Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 16 Apr 2024 03:15:31 -0700 Subject: [PATCH 120/120] doc --- tests/core/block/e2e/test_correctness.py | 3 +++ vllm/executor/gpu_executor.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 1015892b67a4..0ee78a9b0a8e 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -264,6 +264,9 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, @pytest.mark.parametrize("seed", [1]) def test_chunked_prefill_block_manager_v2(baseline_llm_generator, test_llm_generator, batch_size): + """Verify that chunked prefill works with BlockManagerV2, with and without + lookahead scheduling. + """ output_len = 32 temperature = 0.0 diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index b7ab9481eb9f..962cac585bb2 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -13,6 +13,11 @@ class GPUExecutor(ExecutorBase): def _init_executor(self) -> None: + """Initialize the worker and load the model. + + If speculative decoding is enabled, we instead create the speculative + worker. + """ if self.speculative_config is None: self._init_non_spec_worker() else: