Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
Expand Down Expand Up @@ -304,8 +304,6 @@ def forward(
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]

assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
Expand All @@ -319,8 +317,8 @@ def forward(
# prompt, and they have the same length.
out = flash_attn_varlen_func(
q=query,
k=key,
v=value,
k=key[:num_prefill_tokens],
v=value[:num_prefill_tokens],
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
Expand Down Expand Up @@ -353,15 +351,15 @@ def forward(
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[num_prefill_tokens:] = flash_attn_with_kvcache(
decode_query.unsqueeze(1),
decode_query.unflatten(0, (-1, decode_meta.max_query_len)),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
).squeeze(1)
).flatten(0, 1)

# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_TARGET_DEVICE: str = "cuda"
VLLM_USE_MULTI_QUERY_SCORER: bool = False
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
VLLM_USE_PRECOMPILED: bool = False
Expand Down Expand Up @@ -251,6 +252,10 @@
lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"),
"VLLM_FUSED_MOE_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")),

# Use multi-query scorer for speculative decoding
"VLLM_USE_MULTI_QUERY_SCORER":
lambda: bool(int(os.getenv("VLLM_USE_MULTI_QUERY_SCORER", "0"))),
}

# end-env-vars-definition
Expand Down
33 changes: 23 additions & 10 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,15 @@ def _greedy_sample(

seq_ids = seq_group.seq_ids
num_parent_seqs = len(seq_ids)
query_len = len(seq_group.sample_indices)
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs))
next_token_ids = [samples_lst[sample_idx]]
parent_ids = [
i for i in range(num_parent_seqs) for _ in range(query_len)
]
next_token_ids = samples_lst[sample_idx:sample_idx + query_len]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
sample_idx += num_parent_seqs * query_len
return results


Expand Down Expand Up @@ -333,18 +336,25 @@ def _random_sample(
sampling_params = seq_group.sampling_params
is_prompt = seq_group.is_prompt
num_parent_seqs = len(seq_ids)
query_len = len(seq_group.sample_indices)
if is_prompt:
# Prompt phase.
parent_ids = [0] * sampling_params.best_of
parent_ids = [0] * (query_len * sampling_params.best_of)
next_token_ids = random_samples[
sample_idx, :sampling_params.best_of].tolist()
sample_idx:sample_idx +
query_len, :sampling_params.best_of].flatten().tolist()
sample_idx += query_len
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
parent_ids = [
seq_id for seq_id in range(num_parent_seqs)
for _ in range(query_len)
]
next_token_ids = random_samples[sample_idx:sample_idx +
num_parent_seqs, 0].tolist()
num_parent_seqs * query_len,
0].tolist()
sample_idx += num_parent_seqs * query_len
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
return results


Expand Down Expand Up @@ -752,8 +762,11 @@ def _get_logprobs(
# The current step may have different number of seq_ids, and
# we can obtain it from `sample_result[1]`.
query_idx = seq_group.sample_indices[0]
query_indices.extend(
[query_idx + parent_id for parent_id in parent_seq_ids])
if seq_group.is_prompt and len(seq_group.sample_indices) > 1:
query_indices.extend(seq_group.sample_indices)
else:
query_indices.extend(
[query_idx + parent_id for parent_id in parent_seq_ids])
next_token_ids.extend(token_ids)

if sampling_params.logprobs is not None:
Expand Down
27 changes: 16 additions & 11 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,17 @@ def _prepare_seq_groups(
assert num_prefill_sample == 1
assert query_lens is not None and seq_lens is not None
query_len, seq_len = query_lens[i], seq_lens[i]

# If we need sampling, exclude num_prefill_sample tokens from
# prompt logprob.
prompt_logprob_len = (query_len - num_prefill_sample
if do_sample else query_len)
sample_len = num_prefill_sample if do_sample else 0
else:
# Decode
assert query_lens is not None
prompt_logprob_len = 0
sample_len = len(seq_ids) if do_sample else 0
sample_len = query_lens[i] if do_sample else 0

# Update indices to select from the model output.
"""
Expand Down Expand Up @@ -389,14 +391,14 @@ def from_sampling_metadata(

if seq_group.do_sample:
sample_lens = len(seq_group.sample_indices)
assert sample_lens == len(seq_ids)
temperatures += [temperature] * len(seq_ids)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
min_ps += [min_p] * len(seq_ids)
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
assert sample_lens >= len(seq_ids)
temperatures += [temperature] * sample_lens
top_ps += [top_p] * sample_lens
top_ks += [top_k] * sample_lens
min_ps += [min_p] * sample_lens
presence_penalties += [p] * sample_lens
frequency_penalties += [f] * sample_lens
repetition_penalties += [r] * sample_lens

if is_prompt:
prompt_best_of.append(sampling_params.best_of)
Expand Down Expand Up @@ -425,10 +427,13 @@ def from_sampling_metadata(
prompt_tokens.extend([] for _ in range(prefill_len))
output_tokens.extend([] for _ in range(prefill_len))
if seq_group.do_sample:
sample_lens = len(seq_group.sample_indices)
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
prompt_tokens.append(list(seq_data.prompt_token_ids))
output_tokens.append(list(seq_data.output_token_ids))
for k in range(sample_lens, 0, -1):
prompt_tokens.append(seq_data._prompt_token_ids)
output_tokens.append(
seq_data._output_token_ids[:-k])

sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties,
Expand Down
18 changes: 16 additions & 2 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,18 @@ def token_chunk_size(self) -> int:
assert self._token_chunk_size is not None
return self._token_chunk_size

@token_chunk_size.setter
def token_chunk_size(self, value: int) -> None:
self._token_chunk_size = value

def __repr__(self) -> str:
return (f"SequenceGroupMetadata(request_id={self.request_id}, "
f"is_prompt={self.is_prompt}, "
f"seq_data={self.seq_data}, "
f"sampling_params={self.sampling_params}, "
f"block_tables={self.block_tables}, "
f"token_chunk_size={self.token_chunk_size})")


class SequenceOutput:
"""The model output associated with a sequence.
Expand Down Expand Up @@ -890,14 +902,16 @@ class HiddenStates:

def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
hidden_states: torch.Tensor):
assert len(seq_group_metadata_list) == len(hidden_states)
assert seq_group_metadata_list[0].is_prompt or len(
seq_group_metadata_list) == len(hidden_states)
self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
self.hidden_states: torch.Tensor = hidden_states

def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
hidden_states: torch.Tensor) -> None:
"""Update hidden states from target model invocation."""
assert len(seq_group_metadata_list) == len(hidden_states)
assert seq_group_metadata_list[0].is_prompt or len(
seq_group_metadata_list) == len(hidden_states)
self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
self.hidden_states = torch.cat([self.hidden_states, hidden_states])

Expand Down
4 changes: 2 additions & 2 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ def _create_single_target_seq_group_metadata(
input sequence.
"""
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_token_ids = seq_data.get_prompt_token_ids()
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
prompt_token_ids = seq_data._prompt_token_ids
new_output_token_ids = [*seq_data._output_token_ids, *token_ids]

new_seq_data_dict = {
target_seq_id:
Expand Down
139 changes: 139 additions & 0 deletions vllm/spec_decode/multi_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from typing import List, Tuple

import torch

from vllm.sequence import (ExecuteModelRequest, SamplerOutput)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import nvtx_range
from vllm.worker.worker_base import WorkerBase

SeqId = int
TargetSeqId = int
TokenId = int


class MultiQueryTop1Scorer(SpeculativeScorer):
"""Implements a speculative scorer that uses batch expansion to get
probabilities of speculative tokens according to the scoring model.
Batch expansion converts a list of sequences and multiple query positions
to a new batch of sequences, each with a single query position. This allows
for MQA-like scoring in speculative decoding without requiring an MQA
kernel.
It is strictly less efficient than MQA scoring.
It only supports scoring the top1 proposal tokens of the proposer, instead
of topk/tree.
"""

def __init__(self, scorer_worker: WorkerBase, device: str,
vocab_size: int):
self._scorer_worker = scorer_worker
self._device = device
self._vocab_size = vocab_size

@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""

# TODO(cade) perform this on GPU to remove blocking call.
proposal_token_ids_list = proposals.proposal_token_ids.tolist()

for seq_group_metadata, proposal_token_ids in zip(
execute_model_req.seq_group_metadata_list,
proposal_token_ids_list,
):
seq_id, seq_data = next(iter(seq_group_metadata.seq_data.items()))
if proposal_token_ids:
seq_data.update_num_computed_tokens(
(seq_data.get_len() - 1) -
seq_data.get_num_computed_tokens())
for token in proposal_token_ids:
seq_data._output_token_ids.append(token)
seq_data._cached_all_token_ids.append(token)
# use the prompt mode for multi-query sampling
seq_group_metadata.token_chunk_size += len(proposal_token_ids)

target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=execute_model_req.
seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]

for seq_group_metadata, proposal_token_ids in zip(
execute_model_req.seq_group_metadata_list,
proposal_token_ids_list,
):
seq_id, seq_data = next(iter(seq_group_metadata.seq_data.items()))
if proposal_token_ids:
for token in proposal_token_ids:
seq_data._output_token_ids.pop()
seq_data._cached_all_token_ids.pop()
seq_group_metadata.token_chunk_size -= len(proposal_token_ids)

all_tokens, all_probs, spec_logprobs = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
k=execute_model_req.num_lookahead_slots,
)

return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
logprobs=spec_logprobs,
hidden_states=target_sampler_output.hidden_states,
)

def _contract_batch(
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals,
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
target_token_ids = target_sampler_output.sampled_token_ids
target_probs = target_sampler_output.sampled_token_probs
target_logprobs = target_sampler_output.logprobs

all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))

seq_indices: List[int] = []
rank_indices: List[int] = []
for i, output in enumerate(target_sampler_output.outputs):
num_tokens = len(output.samples)
seq_indices.extend([i] * num_tokens)
rank_indices.extend(range(num_tokens))

seq_indices = torch.tensor(seq_indices, device=self._device)
rank_indices = torch.tensor(rank_indices, device=self._device)
all_tokens[seq_indices, rank_indices] = target_token_ids.flatten()
all_probs[seq_indices,
rank_indices] = target_probs.view(-1, self._vocab_size)
all_logprobs[seq_indices, rank_indices] = target_logprobs.view(
-1, self._vocab_size)

return all_tokens, all_probs, all_logprobs
Loading