Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize("temperature", [1.0])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
Expand Down
8 changes: 8 additions & 0 deletions tests/spec_decode/test_batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ def test_create_single_target_seq_group_metadata(k: int):
)

assert output.request_id == input_seq_group_metadata.request_id
assert output.sampling_params.repetition_penalty == \
input_seq_group_metadata.sampling_params.repetition_penalty
assert output.sampling_params.temperature == \
input_seq_group_metadata.sampling_params.temperature
assert output.sampling_params.top_p == \
input_seq_group_metadata.sampling_params.top_p
assert output.sampling_params.top_k == \
input_seq_group_metadata.sampling_params.top_k
assert len(output.seq_data) == 1
assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple(
prompt_tokens)
Expand Down
14 changes: 1 addition & 13 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,28 +307,16 @@ def _create_target_seq_group_metadata(
token_ids_to_score = self._get_token_ids_to_score(
proposal_token_ids[batch_index])

# Use simpler sampling parameters apart from for final token
# (in particular don't do seeded sampling) since those sampled tokens
# aren't used.
# We don't replace the sampling_params in the greedy case because
# this also controls whether the probs get modified in the sampler
# (see use of _modify_greedy_probs_inplace there).
sampling_params = input_seq_group_metadata.sampling_params
non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \
if sampling_params.temperature else sampling_params

target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
last_index = len(token_ids_to_score) - 1
for i, token_ids in enumerate(token_ids_to_score):
target_sampling_params = sampling_params if i == last_index \
else non_bonus_sampling_params
target_seq_group_metadata_list.append(
self._create_single_target_seq_group_metadata(
input_seq_group_metadata,
input_seq_id,
next(target_seq_ids_iter),
token_ids,
sampling_params=target_sampling_params,
sampling_params=sampling_params,
))

return target_seq_group_metadata_list
Expand Down