-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[core] Sampling controller interface #6273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5877a7a
0190aef
3c6723c
80c5091
1273203
3744143
a70c68e
039db20
a4db333
e0eb2da
6fec2b0
59f2e5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.inputs import LLMInputs | ||
| from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
| from vllm.multimodal import MultiModalDataDict | ||
| from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics | ||
|
|
||
|
|
@@ -185,9 +186,14 @@ def get_num_computed_tokens(self) -> int: | |
|
|
||
| def update_num_computed_tokens(self, num_new_computed_tokens: int): | ||
| """Update number of tokens computed so far.""" | ||
| seq_len = self.get_len() | ||
| self._num_computed_tokens += num_new_computed_tokens | ||
| assert self._num_computed_tokens <= self.get_len(), ( | ||
| self._num_computed_tokens, self.get_len()) | ||
| # We can overflow by 1 if previous sampling was updated by | ||
| # SamplingController to generate an empty sequence of tokens. | ||
| if self._num_computed_tokens == seq_len + 1: | ||
| self._num_computed_tokens = seq_len | ||
| assert self._num_computed_tokens <= seq_len, ( | ||
| self._num_computed_tokens, seq_len) | ||
| # If all tokens are computed, it means it is in decoding phase. | ||
| if self.get_num_uncomputed_tokens() == 0: | ||
| self._stage = SequenceStage.DECODE | ||
|
|
@@ -468,8 +474,8 @@ def lora_int_id(self) -> int: | |
|
|
||
| def get_last_latency(self, now: float) -> Optional[float]: | ||
| """Sets the last token time for Request level timings.""" | ||
| # If still in prefill phase, raise Error. | ||
| if self.is_prefill(): | ||
| # If still in initial prefill phase, raise Error. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the case it is not "initial"?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This happens for fast-forward tokens - if there is more than one, we switch to prefill mode. |
||
| if self.is_prefill() and self.get_seqs()[0].get_output_len() == 0: | ||
| raise ValueError( | ||
| "seq_group.get_last_latency() should not be called " | ||
| "if the seq_group is in prefill phase.") | ||
|
|
@@ -701,6 +707,36 @@ def __init__( | |
| self.parent_seq_id = parent_seq_id | ||
| self.output_token = output_token | ||
| self.logprobs = logprobs | ||
| # If present, these tokens should appended to the output | ||
| # instead of output_token. | ||
| self.fast_forward_tokens: Optional[List[int]] = None | ||
|
|
||
| def append_to(self, seq: Sequence) -> None: | ||
mmoskal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Append the sampling output to the sequence. | ||
|
|
||
| If fast forward tokens is set, this appends them, generating appropriate | ||
| Logprobs, and switching the sequence to PREFILL if needed. | ||
| Otherwise, just the output token is appended. | ||
| """ | ||
| if self.fast_forward_tokens is not None: | ||
| logprobs = self.logprobs | ||
| for token in self.fast_forward_tokens: | ||
| # On first iteration, use the existing self.logprobs, provided | ||
| # they contain the token. | ||
| if token not in logprobs: | ||
| logprobs = { | ||
| token: Logprob(logprob=0.0, rank=1, decoded_token=None) | ||
| } | ||
| seq.append_token_id(token, logprobs) | ||
| # On subsequent iterations always use artificially created | ||
| # logprobs. | ||
| logprobs = {} | ||
| # If more than one token was appended, switch to prefill stage. | ||
| if seq.data.get_num_uncomputed_tokens() > 1: | ||
| seq.data._stage = SequenceStage.PREFILL | ||
| else: | ||
| seq.append_token_id(self.output_token, self.logprobs) | ||
|
|
||
| def __repr__(self) -> str: | ||
| return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " | ||
|
|
@@ -912,6 +948,53 @@ def prune(self, | |
| self.seq_ids = seq_ids | ||
|
|
||
|
|
||
| class SamplingController: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make it an abstract class?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also in the docstring, we should probably mention this class is a singleton and stateful. And prepare & transform logits need to be called 1 tot 1
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added docstrings. It seems python doesn't have abstract classes, only abstract methods. I can imagine use cases where you only override some of these methods, and none of them you have to override, so I don't think it makes sense to make any of them abstract. One thing we could do is to make the engine.sampling_controller field non-optional and use this base class as a no-op implementation. Not sure if that would be cleaner though? |
||
| """ | ||
| This is used to modify sampling process for a given LLMEngine. | ||
| There is only one instance of this class per LLMEngine. | ||
|
|
||
| In each generation step, one of the following things can happen: | ||
|
|
||
| There are no sequences to run, and empty_step() is called; | ||
| this can be used to run actions that normally run in sync with step, | ||
| when there are no sequences to run | ||
|
|
||
| Otherwise (normal case), the following methods are run in this exact order: | ||
| - prepare() causes the sampling controller to start logit bias prepreation | ||
| for the sequences that will be run; typically the logit indices from | ||
| sampling_metadata will have to be stored in the sampling controller | ||
| - forward pass is started | ||
| - transform_logits() is called after the forward pass has finished, to | ||
| modify the logits | ||
| - sampling happens on biased logits | ||
| - transform_sampler_output() is called to modify the sampler output | ||
|
|
||
| This class does nothing for each of these steps. Subclasses can override | ||
| any and each of these methods to modify the sampling process; they will | ||
| be stateful. | ||
|
|
||
| Currently, you just have to assign an instance of your subclass to | ||
| engine.sampling_controller to use it. | ||
| """ | ||
|
|
||
| def prepare(self, sampling_metadata: "SamplingMetadata"): | ||
| """Prepare the sampling controller for the next step.""" | ||
| pass | ||
|
|
||
| def empty_step(self): | ||
| """Called instead of prepare() when the scheduler found no sequences | ||
| to run.""" | ||
| pass | ||
|
|
||
| def transform_logits(self, logits: torch.Tensor) -> torch.Tensor: | ||
| """Apply the sampling controller to the logits.""" | ||
| return logits | ||
|
|
||
| def transform_sampler_output(self, output: SamplerOutput) -> SamplerOutput: | ||
| """Apply the sampling controller to the sampler output.""" | ||
| return output | ||
|
|
||
|
|
||
| @dataclass | ||
| class ExecuteModelRequest: | ||
| """The model execution request, containing CPU metadata only. The LLM | ||
|
|
@@ -936,6 +1019,8 @@ class ExecuteModelRequest: | |
| num_steps: int = 1 | ||
| # Finished request ids since last step. | ||
| finished_requests_ids: List[str] = field(default_factory=list) | ||
| # Sampling controller to use for this step. | ||
| sampling_controller: Optional[SamplingController] = None | ||
|
|
||
| def clone( | ||
| self, seq_group_metadata_list: List[SequenceGroupMetadata] | ||
|
|
@@ -951,4 +1036,5 @@ def clone( | |
| running_queue_size=self.running_queue_size, | ||
| previous_hidden_states=self.previous_hidden_states, | ||
| num_steps=self.num_steps, | ||
| finished_requests_ids=self.finished_requests_ids) | ||
| finished_requests_ids=self.finished_requests_ids, | ||
| sampling_controller=self.sampling_controller) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1240,6 +1240,11 @@ def execute_model( | |
| "finished_requests_ids": model_input.finished_requests_ids, | ||
| "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, | ||
| } if self.has_seqlen_agnostic else {} | ||
|
|
||
| if (ctrl := model_input.sampling_controller) is not None: | ||
| assert model_input.sampling_metadata is not None | ||
| ctrl.prepare(model_input.sampling_metadata) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has been changed from RFC (we pass seq group metadata, here, we are using sampling params). Is there any way to hook this with seq group metadata?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. yeah I guess what we have in sampling metdata is probably sufficient. let me get back to you soon about this. |
||
|
|
||
| hidden_or_intermediate_states = model_executable( | ||
| input_ids=model_input.input_tokens, | ||
| positions=model_input.input_positions, | ||
|
|
@@ -1259,12 +1264,18 @@ def execute_model( | |
| if not self.is_driver_worker: | ||
| return [] | ||
|
|
||
| if ctrl is not None: | ||
| logits = ctrl.transform_logits(logits) | ||
|
|
||
| # Sample the next token. | ||
| output: SamplerOutput = self.model.sample( | ||
| logits=logits, | ||
| sampling_metadata=model_input.sampling_metadata, | ||
| ) | ||
|
|
||
| if ctrl is not None: | ||
| output = ctrl.transform_sampler_output(output) | ||
|
|
||
| if self.return_hidden_states: | ||
| # we only need to pass hidden states of most recent token | ||
| assert model_input.sampling_metadata is not None | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: is this change related to this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this happens when the "fast forward" is 0-tokens long - see comment in PR description about
BlockTable.append_token_ids