Skip to content

Commit 22f4638

Browse files
Support allowed_token_ids in V1 sampler
Signed-off-by: Catherine Lee <[email protected]>
1 parent 3809458 commit 22f4638

File tree

3 files changed

+52
-0
lines changed

3 files changed

+52
-0
lines changed

tests/v1/sample/test_sampler.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,3 +412,32 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
412412
1e-2)
413413
else:
414414
assert logits_for_req[token_id] == pytest.approx(1e-2)
415+
416+
417+
@pytest.mark.parametrize("device", CUDA_DEVICES)
418+
@pytest.mark.parametrize("batch_size", [1, 2, 32])
419+
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
420+
def test_sampler_allowed_token_ids(device: str, batch_size: int,
421+
bias_value: float):
422+
"""
423+
Test to verify that when the repetition penalty is enabled, tokens
424+
are penalized based on their presence in the prompt or the existing
425+
output.
426+
"""
427+
torch.set_default_device(device)
428+
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
429+
sampling_metadata = _create_default_sampling_metadata(
430+
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
431+
allowed_token_ids = set([0, 1, 2])
432+
sampling_metadata.allowed_token_ids = list(allowed_token_ids)
433+
# https://github.com/vllm-project/vllm/blob/38094584566b89210a6f72a408eba1fae43c3d81/tests/entrypoints/openai/test_completion.py#L620
434+
sampler = Sampler()
435+
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
436+
logits = logits.cpu()
437+
for batch_idx in range(batch_size):
438+
logits_for_req = logits[batch_idx]
439+
for token_id in range(VOCAB_SIZE):
440+
if token_id not in allowed_token_ids:
441+
assert logits_for_req[token_id] == float("-inf")
442+
else:
443+
assert logits_for_req[token_id] > float("-inf")

vllm/v1/sample/metadata.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ class SamplingMetadata:
3838
stop_token_ids: List[Set[int]]
3939

4040
logit_bias: List[Optional[Dict[int, float]]]
41+
allowed_token_ids: Optional[List[int]] = None

vllm/v1/sample/sampler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def forward(
4949
logits = logits.to(torch.float32)
5050
# Apply logits bias.
5151
logits = self.apply_logits_bias(logits, sampling_metadata)
52+
# Apply allowed token ids.
53+
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
5254
# Apply penalties (e.g., min_tokens, freq_penalties).
5355
logits = self.apply_penalties(logits, sampling_metadata)
5456
# Sample the next token.
@@ -227,3 +229,23 @@ def apply_logits_bias(
227229
for token_id, bias in logit_bias.items():
228230
logits[i, token_id] += bias
229231
return logits
232+
233+
def apply_allowed_token_ids(
234+
self,
235+
logits: torch.Tensor,
236+
sampling_metadata: SamplingMetadata,
237+
) -> torch.Tensor:
238+
if not sampling_metadata.allowed_token_ids:
239+
return logits
240+
vocab_size = logits.size(dim=1)
241+
if not all(0 <= tid < vocab_size
242+
for tid in sampling_metadata.allowed_token_ids):
243+
raise ValueError("allowed_token_ids contains "
244+
"out-of-vocab token id")
245+
allowed_ids = list(sampling_metadata.allowed_token_ids)
246+
mask = torch.ones((logits.shape[-1], ),
247+
dtype=torch.bool,
248+
device=logits.device)
249+
mask[allowed_ids] = False
250+
logits.masked_fill_(mask, float("-inf"))
251+
return logits

0 commit comments

Comments
 (0)