Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams(temperature=0, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="facebook/opt-125m",
speculative_model="facebook/opt-125m",
num_speculative_tokens=5,
use_v2_block_manager=True,
enforce_eager=True)
# llm = LLM(model="facebook/opt-125m")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
Expand Down
23 changes: 17 additions & 6 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ def test_prepare_decode_cuda_graph(batch_size):
# decode has only 1 token for query.
start_idx += 1
start_loc.append(start_idx)
# start_loc are padded to expected_bs + 1
last_loc = start_loc[-1] + 1
for _ in range(expected_bs - (len(start_loc) - 1)):
start_loc.append(last_loc)
last_loc += 1
assert torch.allclose(
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
Expand All @@ -208,6 +213,10 @@ def test_prepare_decode_cuda_graph(batch_size):
for seq_len in seq_lens:
start_idx += seq_len
seq_start_loc.append(start_idx)
last_loc = seq_start_loc[-1] + 1
for _ in range(expected_bs - (len(start_loc) - 1)):
start_loc.append(last_loc)
last_loc += 1
assert torch.allclose(
attn_metadata.seq_start_loc,
torch.tensor(seq_start_loc, dtype=torch.int32, device=device))
Expand Down Expand Up @@ -374,9 +383,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
attn_metadata = model_runner._prepare_model_input_tensors(
seq_group_metadata_list).attn_metadata

for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
vars(prefill_meta_actual)):
assert attr_expected[1] == attr_actual[1]
for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
vars(decode_meta_actual)):
assert attr_expected[1] == attr_actual[1]
if attn_metadata.prefill_metadata:
for attr_expected, attr_actual in zip(
vars(attn_metadata.prefill_metadata),
vars(prefill_meta_actual)):
assert attr_expected[1] == attr_actual[1]
for attr_expected, attr_actual in zip(
vars(attn_metadata.decode_metadata), vars(decode_meta_actual)):
assert attr_expected[1] == attr_actual[1]
179 changes: 33 additions & 146 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from vllm_flash_attn import flash_attn_varlen_func

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand Down Expand Up @@ -73,10 +73,8 @@ class FlashAttentionMetadata(AttentionMetadata):
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# the computed tokens + new tokens.
seq_lens: List[int]

# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
Expand All @@ -88,12 +86,6 @@ class FlashAttentionMetadata(AttentionMetadata):

# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
Expand All @@ -102,9 +94,6 @@ class FlashAttentionMetadata(AttentionMetadata):
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]

# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
Expand All @@ -119,69 +108,12 @@ class FlashAttentionMetadata(AttentionMetadata):
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool

_cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["FlashAttentionMetadata"] = None

@property
def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
if self.num_prefills == 0:
return None

if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata

assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.block_tables is not None
assert self.seq_start_loc is not None

self._cached_prefill_metadata = FlashAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
return self._cached_prefill_metadata

@property
def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None

if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.block_tables is not None
assert self.seq_lens_tensor is not None

self._cached_decode_metadata = FlashAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
# Fields that are not used in flash attention backend,
# but used in other backends
context_lens_tensor: Optional[torch.Tensor] = None
seq_lens_tensor: Optional[torch.Tensor] = None
max_prefill_seq_len: Optional[int] = None
max_decode_seq_len: Optional[int] = None


class FlashAttentionImpl(AttentionImpl):
Expand Down Expand Up @@ -288,7 +220,6 @@ def forward(
if kv_cache is not None:
key_cache = kv_cache[0]
value_cache = kv_cache[1]

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
Expand All @@ -301,74 +232,30 @@ def forward(
self.kv_cache_dtype,
)

num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens

output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]

assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache is None or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
out = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
output[:num_prefill_tokens] = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_k=max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
)

if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[num_prefill_tokens:] = flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
).squeeze(1)
if kv_cache is None or (attn_metadata.block_tables is not None
and attn_metadata.block_tables.numel()) == 0:
k = key
v = value
block_tables = None
else:
k = kv_cache[0]
v = kv_cache[1]
block_tables = attn_metadata.block_tables

max_seq_len = max(attn_metadata.seq_lens)
output = flash_attn_varlen_func(
q=query,
k=k,
v=v,
cu_seqlens_q=attn_metadata.query_start_loc,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
max_seqlen_k=max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
block_table=block_tables)

# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
2 changes: 2 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,8 @@ def _process_model_outputs(
scheduled_seq_groups, output_by_sequence_group,
seq_group_metadata_list):
seq_group = scheduled_seq_group.seq_group
# I don't think we should update the number of computed tokens here
# We should update this field when processing the output
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
if self.model_config.embedding_mode:
Expand Down
13 changes: 8 additions & 5 deletions vllm/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,22 @@ def process_outputs(self, sequence_group: SequenceGroup,
]
assert valid_samples

self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
output_token_ids = self._process_seq_outputs(
seq, valid_samples, sequence_group.sampling_params)

# FIXME: -1 is incorrect, it's scheduled_seq_group.token_chunk_size
sequence_group.update_num_computed_tokens(len(output_token_ids) - 1)

def _process_seq_outputs(self, seq: Sequence,
valid_samples: List[SequenceOutput],
sampling_params: SamplingParams) -> None:
sampling_params: SamplingParams) -> List[int]:
output_token_ids = [sample.output_token for sample in valid_samples]
output_logprobs = [sample.logprobs for sample in valid_samples]

# Truncate to max_tokens if necessary.
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
len(output_token_ids))
if remaining_tokens < 0:
valid_samples = valid_samples[:remaining_tokens]
output_token_ids = output_token_ids[:remaining_tokens]

# Truncate any tokens after EOS. This is required as spec decode
Expand All @@ -114,7 +116,6 @@ def _process_seq_outputs(self, seq: Sequence,
for i in range(len(output_token_ids)):
if output_token_ids[i] == eos_token_id:
output_token_ids = output_token_ids[:i + 1]
valid_samples = valid_samples[:i + 1]
break

# Incrementally append tokens to the sequence, as if we had only one new
Expand Down Expand Up @@ -143,3 +144,5 @@ def _process_seq_outputs(self, seq: Sequence,
if seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)

return output_token_ids
27 changes: 18 additions & 9 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,13 @@ def prepare(
num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
device)
# print("seq_lens", seq_lens)
# print("query_lens", query_lens)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=device,
pin_memory=pin_memory)
# print("selected_token_indices", selected_token_indices)
categorized_sample_indices = {
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
Expand Down Expand Up @@ -220,9 +223,10 @@ def _prepare_seq_groups(
if do_sample else query_len)
sample_len = num_prefill_sample if do_sample else 0
else:
assert query_lens is not None
# Decode
prompt_logprob_len = 0
sample_len = len(seq_ids) if do_sample else 0
sample_len = query_lens[i] if do_sample else 0

# Update indices to select from the model output.
"""
Expand Down Expand Up @@ -389,14 +393,19 @@ def from_sampling_metadata(

if seq_group.do_sample:
sample_lens = len(seq_group.sample_indices)
assert sample_lens == len(seq_ids)
temperatures += [temperature] * len(seq_ids)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
min_ps += [min_p] * len(seq_ids)
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
assert sample_lens >= len(seq_ids)
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
if do_penalties:
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
temperatures += [temperature] * sample_lens
top_ps += [top_p] * sample_lens
top_ks += [top_k] * sample_lens
min_ps += [min_p] * sample_lens
presence_penalties += [p] * sample_lens
frequency_penalties += [f] * sample_lens
repetition_penalties += [r] * sample_lens

if is_prompt:
prompt_best_of.append(sampling_params.best_of)
Expand Down
Loading