Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/v1/sample/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
temperature=0.0,
all_greedy=True,
all_random=False,
rejection_sampling=True,
spec_token_ids=spec_tokens,
top_p=None,
top_k=None,
Expand Down
42 changes: 17 additions & 25 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,9 @@ def _create_default_sampling_metadata(
temperature=torch.full((batch_size, ), 0.0),
all_greedy=True,
all_random=False,
rejection_sampling=False,
top_p=torch.empty(batch_size, ),
top_k=torch.empty(batch_size, ),
no_top_p=True,
no_top_k=True,
min_p=torch.empty(batch_size, ),
no_min_p=True,
top_p=None,
top_k=None,
min_p=None,
generators={},
max_num_logprobs=0,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
Expand All @@ -94,8 +90,7 @@ def _create_default_sampling_metadata(
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
no_penalties=True,
min_tokens=[],
stop_token_ids=[],
min_tokens={},
logit_bias=[None] * batch_size,
)
return fake_sampling_metadata
Expand All @@ -104,33 +99,30 @@ def _create_default_sampling_metadata(
def _generate_min_token_penalties_and_stop_tokens(
num_output_tokens: int, batch_size: int, vocab_size: int,
batch_indices_for_min_token_penalty: List[int]
) -> Tuple[List[int], List[Set[int]]]:
) -> Dict[int, Tuple[int, Set[int]]]:
"""
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
Generates and returns a dict of minimum token penalties and
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
batch.

If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
stop_token_ids: List[Set[int]] = []
min_tokens: List[int] = []
min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
for index in range(batch_size):
if index in batch_indices_for_min_token_penalty:
min_tokens.append(
min_tokens[index] = (
np.random.randint(num_output_tokens + 1,
2 * num_output_tokens))
stop_token_ids.append(
2 * num_output_tokens),
set(
np.random.randint(0, vocab_size - 1)
for _ in range(np.random.randint(0, vocab_size))))

else:
min_tokens.append(np.random.randint(0, num_output_tokens))
stop_token_ids.append(set())
return (min_tokens, stop_token_ids)
min_tokens[index] = (np.random.randint(0,
num_output_tokens), set())
return min_tokens


def _create_weighted_output_token_list(
Expand Down Expand Up @@ -165,7 +157,7 @@ def _create_weighted_output_token_list(
output_token_ids_for_batch.extend(
[token_id for _ in range(index + 1)])
output_token_ids.append(output_token_ids_for_batch)
return (output_token_ids, sorted_token_ids_in_output)
return output_token_ids, sorted_token_ids_in_output


@pytest.mark.parametrize("device", CUDA_DEVICES)
Expand All @@ -182,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
batch_indices_for_min_token_penalty = np.random.randint(
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens(
min_tokens = _generate_min_token_penalties_and_stop_tokens(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
batch_indices_for_min_token_penalty)
sampling_metadata.min_tokens = min_tokens
sampling_metadata.stop_token_ids = stop_token_ids
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
for token_id in range(VOCAB_SIZE):
if token_id in stop_token_ids[batch_idx]:
_, stop_token_ids = min_tokens.get(batch_idx, (0, set()))
if token_id in stop_token_ids:
assert logits[batch_idx][token_id] == -float("inf")
else:
assert logits[batch_idx][token_id] != -float("inf")
Expand Down
43 changes: 19 additions & 24 deletions tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple

import numpy as np
import pytest
Expand Down Expand Up @@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata(
top_p = [0.0 for _ in range(num_reqs)]
min_p = [0.0 for _ in range(num_reqs)]
temperature = [0.0 for _ in range(num_reqs)]
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
min_tokens = [0 for _ in range(num_reqs)]
min_tokens = {}
logit_bias = [None] * num_reqs
for req in reqs:
if req.req_id not in req_ids_retained:
Expand All @@ -83,22 +82,21 @@ def _construct_expected_sampling_metadata(
top_p[index_in_input_batch] = req.sampling_params.top_p
min_p[index_in_input_batch] = req.sampling_params.min_p
temperature[index_in_input_batch] = req.sampling_params.temperature
stop_token_ids[
index_in_input_batch] = req.sampling_params.all_stop_token_ids
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens
min_tokens[index_in_input_batch] = (
req.sampling_params.min_tokens,
req.sampling_params.all_stop_token_ids)
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float,
device=device),
all_greedy=False,
all_random=True,
rejection_sampling=False,
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
top_k=torch.tensor(top_k, dtype=torch.int, device=device),
no_top_p=all(x == 1.0 for x in top_p),
no_top_k=all(x == 0 for x in top_k),
min_p=torch.tensor(min_p, dtype=torch.float, device=device),
no_min_p=all(x == 0.0 for x in min_p),
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
top_p, dtype=torch.float, device=device),
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
top_k, dtype=torch.int, device=device),
min_p=None if all(x == 0.0 for x in min_p) else torch.tensor(
min_p, dtype=torch.float, device=device),
generators={},
max_num_logprobs=0,
prompt_token_ids=make_tensor_with_pad(
Expand All @@ -119,7 +117,6 @@ def _construct_expected_sampling_metadata(
output_token_ids=output_token_ids,
spec_token_ids=[],
min_tokens=min_tokens,
stop_token_ids=stop_token_ids,
no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
Expand Down Expand Up @@ -206,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch.condense(req_indices_to_remove)

# Generate the sampling metadata
sampling_metadata = input_batch.make_sampling_metadata(
req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False)
sampling_metadata = input_batch._make_sampling_metadata()

# Create expected output.
expected_sampling_metadata = _construct_expected_sampling_metadata(
Expand All @@ -216,13 +212,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch.req_id_to_index,
device=torch.device(device))

def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
return (t1 is None
and t2 is None) or (t1 is not None and t2 is not None
and torch.allclose(t1, t2))

# Assert the actual and expected output.
assert torch.allclose(expected_sampling_metadata.temperature,
sampling_metadata.temperature)
assert torch.allclose(expected_sampling_metadata.top_p,
sampling_metadata.top_p)
assert torch.allclose(expected_sampling_metadata.top_k,
sampling_metadata.top_k)
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
assert torch.allclose(
expected_sampling_metadata.frequency_penalties,
sampling_metadata.frequency_penalties,
Expand All @@ -240,10 +239,6 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids)
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
assert expected_sampling_metadata.stop_token_ids == \
sampling_metadata.stop_token_ids
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p
assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
vocab_size, num_seqs)
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat(
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, vocab_size)
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits > 0]
logits[logits <= 0] *= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits <= 0]
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits
20 changes: 13 additions & 7 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import time
from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import (Deque, Dict, Iterable, List, Optional, Sequence, Set,
Tuple, Union)

from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
SpeculativeConfig)
Expand Down Expand Up @@ -117,10 +118,10 @@ def schedule(self) -> "SchedulerOutput":
num_scheduled_tokens: Dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
scheduled_encoder_inputs: Dict[str, List[int]] = {}
scheduled_encoder_inputs: Dict[str, Sequence[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: Dict[str, List[int]] = {}
scheduled_spec_decode_tokens: Dict[str, Sequence[int]] = {}

# For logging.
scheduled_timestamp = time.monotonic()
Expand Down Expand Up @@ -195,8 +196,13 @@ def schedule(self) -> "SchedulerOutput":
request.num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0:
if isinstance(request.spec_token_ids, list):
del request.spec_token_ids[num_scheduled_spec_tokens:]
else:
request.spec_token_ids = (
request.spec_token_ids[:num_scheduled_spec_tokens])
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids[:num_scheduled_spec_tokens])
request.spec_token_ids)

# Encoder-related.
if encoder_inputs_to_schedule:
Expand Down Expand Up @@ -404,7 +410,7 @@ def _try_schedule_encoder_inputs(
num_computed_tokens: int,
num_new_tokens: int,
encoder_budget: int,
) -> Tuple[List[int], int, int]:
) -> Tuple[Sequence[int], int, int]:
"""
Determine which encoder inputs need to be scheduled in the current step,
and update `num_new_tokens` and encoder token budget accordingly.
Expand All @@ -422,7 +428,7 @@ def _try_schedule_encoder_inputs(
decoder tokens up to just before the unschedulable encoder input.
"""
if not request.has_encoder_inputs():
return [], num_new_tokens, encoder_budget
return (), num_new_tokens, encoder_budget

encoder_inputs_to_schedule: List[int] = []
mm_positions = request.mm_positions
Expand Down Expand Up @@ -567,7 +573,7 @@ def update_from_output(
outputs.append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids or [],
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/core/scheduler_output.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple

if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -95,11 +95,11 @@ class SchedulerOutput:
# req_id -> spec_token_ids
# If a request does not have any spec decode tokens, it will not be
# included in the dictionary.
scheduled_spec_decode_tokens: Dict[str, List[int]]
scheduled_spec_decode_tokens: Dict[str, Sequence[int]]
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.
scheduled_encoder_inputs: Dict[str, List[int]]
scheduled_encoder_inputs: Dict[str, Sequence[int]]
# Number of common prefix blocks for all requests.
# This can be used for cascade attention.
num_common_prefix_blocks: int
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import Dict, List, NamedTuple, Optional
from typing import Dict, List, NamedTuple, Optional, Sequence

import torch

Expand Down Expand Up @@ -68,7 +68,7 @@ class ModelRunnerOutput:
sampled_token_ids: List[List[int]]

# num_reqs x num_spec_tokens
spec_token_ids: Optional[List[List[int]]]
spec_token_ids: Optional[List[Sequence[int]]]

# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import enum
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Sequence, Union

from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: List[int] = []
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.spec_token_ids: List[int] = []
self.spec_token_ids: Sequence[int] = []
self.num_computed_tokens = 0

# Multi-modal related
Expand Down
21 changes: 10 additions & 11 deletions vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Sequence, Set, Tuple

import torch

Expand All @@ -12,15 +12,13 @@ class SamplingMetadata:
temperature: torch.Tensor
all_greedy: bool
all_random: bool
rejection_sampling: bool
spec_token_ids: List[List[int]]

top_p: torch.Tensor
top_k: torch.Tensor
no_top_p: bool
no_top_k: bool
min_p: torch.Tensor
no_min_p: bool
# The list will empty if no requests have spec tokens.
spec_token_ids: List[Sequence[int]]

top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor]
min_p: Optional[torch.Tensor]

generators: Dict[int, torch.Generator]

Expand All @@ -34,7 +32,8 @@ class SamplingMetadata:
repetition_penalties: torch.Tensor

output_token_ids: List[List[int]]
min_tokens: List[int]
stop_token_ids: List[Set[int]]

# req_index -> (min_tokens, stop_token_ids)
min_tokens: Dict[int, Tuple[int, Set[int]]]

logit_bias: List[Optional[Dict[int, float]]]
Loading