Skip to content

Commit b052507

Browse files
committed
improve gbnf guide performance
Signed-off-by: cj <[email protected]>
1 parent cfbb8c9 commit b052507

File tree

7 files changed

+131
-37
lines changed

7 files changed

+131
-37
lines changed

vllm/model_executor/guided_decoding/xgrammar_decoding.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,18 @@ def _ensure_ctx(self):
325325
raise ValueError(
326326
"Invalid configuration for xgrammar logits processor")
327327

328+
def accept(self, token_ids: int) -> bool:
329+
if self.ctx is None:
330+
self._ensure_ctx()
331+
if len(self.matchers) == 0:
332+
self.matchers = [
333+
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
334+
]
335+
self.token_bitmask = xgr.allocate_token_bitmask(
336+
self.batch_size, self.tokenizer_info.vocab_size)
337+
return self.matchers[0].accept_token(
338+
token_ids) or self.matchers[0].is_terminated()
339+
328340
def __call__(self, input_ids: list[int],
329341
scores: torch.Tensor) -> torch.Tensor:
330342

@@ -345,15 +357,6 @@ def __call__(self, input_ids: list[int],
345357
self.token_bitmask = xgr.allocate_token_bitmask(
346358
self.batch_size, self.tokenizer_info.vocab_size)
347359

348-
if not self.prefilled:
349-
# Have not sampled a token yet
350-
self.prefilled = True
351-
else:
352-
for i, matcher in enumerate(self.matchers):
353-
if not matcher.is_terminated():
354-
sampled_token = input_ids[-1]
355-
assert self.matchers[i].accept_token(sampled_token)
356-
357360
for i, matcher in enumerate(self.matchers):
358361
if not matcher.is_terminated():
359362
# @ubospica: ideally, fill_next_token_bitmask should be
@@ -402,5 +405,4 @@ def clone(self) -> XGrammarLogitsProcessor:
402405
new_processor.batch_size = self.batch_size
403406
# Reset prefilled state for new sequence
404407
new_processor.prefilled = False
405-
406408
return new_processor

vllm/model_executor/layers/logits_processor.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,27 @@
2121
envs.VLLM_LOGITS_PROCESSOR_THREADS)
2222

2323

24+
def accept_grammar(token_ids: torch.Tensor,
25+
sampling_metadata: SamplingMetadata,
26+
mask: torch.Tensor = None) -> torch.Tensor:
27+
if mask is None:
28+
mask = torch.ones_like(token_ids, dtype=torch.bool)
29+
30+
accept = torch.ones_like(token_ids, dtype=torch.bool)
31+
for seq_group in sampling_metadata.seq_groups:
32+
logits_processors = seq_group.sampling_params.logits_processors
33+
if logits_processors:
34+
for row_idx, logits_processor in zip(seq_group.sample_indices,
35+
logits_processors):
36+
tkid = token_ids[row_idx].item()
37+
tkmask = mask[row_idx].item()
38+
# only when mask =1 , fsm accept the token
39+
if tkmask:
40+
accept[row_idx] = accept[row_idx].item(
41+
) and logits_processor.accept(tkid)
42+
return accept
43+
44+
2445
class LogitsProcessor(nn.Module):
2546
"""Process logits and apply logits processors from sampling metadata.
2647
@@ -52,13 +73,12 @@ def __init__(self,
5273
# Whether to use gather or all-gather to gather the logits.
5374
self.use_all_gather = current_platform.use_all_gather()
5475

55-
def forward(
56-
self,
57-
lm_head: VocabParallelEmbedding,
58-
hidden_states: torch.Tensor,
59-
sampling_metadata: Optional[SamplingMetadata] = None,
60-
embedding_bias: Optional[torch.Tensor] = None,
61-
) -> Optional[torch.Tensor]:
76+
def forward(self,
77+
lm_head: VocabParallelEmbedding,
78+
hidden_states: torch.Tensor,
79+
sampling_metadata: Optional[SamplingMetadata] = None,
80+
embedding_bias: Optional[torch.Tensor] = None,
81+
skip_grammar: bool = False) -> Optional[torch.Tensor]:
6282
if self.logits_as_input:
6383
logits = hidden_states
6484
else:
@@ -77,6 +97,9 @@ def forward(
7797
if self.scale != 1.0:
7898
logits *= self.scale
7999

100+
if skip_grammar:
101+
return logits
102+
80103
# Apply logits processors (if any).
81104
if sampling_metadata is not None and \
82105
sampling_metadata.seq_groups is not None:
@@ -138,12 +161,15 @@ def _prune_hidden_states(
138161
return hidden_states
139162

140163

141-
def _apply_logits_processors(
142-
logits: torch.Tensor,
143-
sampling_metadata: SamplingMetadata,
144-
) -> torch.Tensor:
164+
def _apply_logits_processors(logits: torch.Tensor,
165+
sampling_metadata: SamplingMetadata,
166+
mask: torch.Tensor = None,
167+
accept_last: bool = True) -> torch.Tensor:
145168
found_logits_processors = False
146169
logits_processed = 0
170+
171+
if mask is None:
172+
mask = torch.ones(size=(logits.shape[0], ), dtype=torch.bool)
147173
logits_row_ids_and_logits_row_futures = []
148174
for seq_group in sampling_metadata.seq_groups:
149175
seq_ids = seq_group.seq_ids
@@ -154,6 +180,9 @@ def _apply_logits_processors(
154180

155181
for seq_id, logits_row_idx in zip(seq_ids,
156182
seq_group.sample_indices):
183+
if not mask[logits_row_idx].item():
184+
continue
185+
157186
logits_row = logits[logits_row_idx]
158187
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
159188
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
@@ -164,12 +193,12 @@ def _apply_logits_processors(
164193
_logits_processor_threadpool.submit(
165194
_apply_logits_processors_single_seq, logits_row,
166195
logits_processors, past_tokens_ids,
167-
prompt_tokens_ids)))
196+
prompt_tokens_ids, accept_last)))
168197
else:
169198
logits[logits_row_idx] = \
170199
_apply_logits_processors_single_seq(
171200
logits_row, logits_processors, past_tokens_ids,
172-
prompt_tokens_ids)
201+
prompt_tokens_ids, accept_last)
173202

174203
logits_processed += len(seq_group.sample_indices) + len(
175204
seq_group.prompt_logprob_indices)
@@ -184,13 +213,15 @@ def _apply_logits_processors(
184213

185214

186215
def _apply_logits_processors_single_seq(logits_row, logits_processors,
187-
past_tokens_ids,
188-
prompt_tokens_ids) -> torch.Tensor:
216+
past_tokens_ids, prompt_tokens_ids,
217+
accept_last) -> torch.Tensor:
189218
for logits_processor in logits_processors:
190219
parameters = inspect.signature(logits_processor).parameters
191220
if len(parameters) == 3:
192221
logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids,
193222
logits_row)
194223
else:
224+
if accept_last and len(past_tokens_ids) > 0:
225+
logits_processor.accept(past_tokens_ids[-1])
195226
logits_row = logits_processor(past_tokens_ids, logits_row)
196227
return logits_row

vllm/model_executor/models/interfaces.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
QuantizationConfig)
1313
from vllm.utils import supports_kw
1414

15+
from .. import SamplingMetadata
16+
from ..layers.logits_processor import _apply_logits_processors, accept_grammar
17+
from ..layers.sampler import SamplerOutput
1518
from .interfaces_base import is_pooling_model
1619

1720
if TYPE_CHECKING:
@@ -221,6 +224,55 @@ def forward(
221224
...
222225

223226

227+
class SupportsSampleV2:
228+
229+
def compute_logits_v2(
230+
self,
231+
hidden_states: torch.Tensor,
232+
sampling_metadata: SamplingMetadata,
233+
) -> Optional[torch.Tensor]:
234+
logits = self.logits_processor(self.lm_head,
235+
hidden_states,
236+
sampling_metadata,
237+
skip_grammar=True)
238+
return logits
239+
240+
def samplev2(
241+
self,
242+
logits: torch.Tensor,
243+
sampling_metadata: SamplingMetadata,
244+
) -> Optional[SamplerOutput]:
245+
# compute logits
246+
next_tokens: SamplerOutput = self.sampler(logits, sampling_metadata)
247+
248+
# check if the sampled tokens fit the grammars
249+
tks = torch.tensor(
250+
[o.samples[0].output_token for o in next_tokens.outputs])
251+
accepted = accept_grammar(tks, sampling_metadata)
252+
need_resample = torch.logical_not(accepted)
253+
if accepted.all():
254+
return next_tokens
255+
# resample
256+
# if the token is not valid, sample again.
257+
# but first apply the grammar bitmask
258+
# only apply logits processor when need_resample
259+
logits = _apply_logits_processors(logits, sampling_metadata,
260+
need_resample, False)
261+
new_next_tokens: SamplerOutput = self.sampler(logits,
262+
sampling_metadata)
263+
264+
for i, replace in enumerate(need_resample.tolist()):
265+
if replace:
266+
next_tokens.outputs[i] = new_next_tokens.outputs[i]
267+
268+
tks = torch.tensor(
269+
[o.samples[0].output_token for o in next_tokens.outputs])
270+
# matcher only accept next token when first round is not accepted.
271+
accepted = accept_grammar(tks, sampling_metadata, need_resample)
272+
assert accepted.all()
273+
return next_tokens
274+
275+
224276
# We can't use runtime_checkable with ClassVar for issubclass checks
225277
# so we need to treat the class as an instance and use isinstance instead
226278
@runtime_checkable

vllm/model_executor/models/llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from vllm.model_executor.sampling_metadata import SamplingMetadata
4949
from vllm.sequence import IntermediateTensors
5050

51-
from .interfaces import SupportsLoRA, SupportsPP
51+
from .interfaces import SupportsLoRA, SupportsPP, SupportsSampleV2
5252
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
5353
is_pp_missing_parameter,
5454
make_empty_intermediate_tensors_factory, make_layers,
@@ -433,7 +433,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
433433
return loaded_params
434434

435435

436-
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
436+
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsSampleV2):
437437
packed_modules_mapping = {
438438
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
439439
"gate_up_proj": ["gate_proj", "up_proj"]

vllm/model_executor/models/qwen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from vllm.model_executor.sampling_metadata import SamplingMetadata
3232
from vllm.sequence import IntermediateTensors
3333

34-
from .interfaces import SupportsLoRA, SupportsPP
34+
from .interfaces import SupportsLoRA, SupportsPP, SupportsSampleV2
3535
from .utils import (is_pp_missing_parameter,
3636
make_empty_intermediate_tensors_factory, make_layers,
3737
maybe_prefix)
@@ -335,7 +335,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
335335
return loaded_params
336336

337337

338-
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA):
338+
class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA,
339+
SupportsSampleV2):
339340
packed_modules_mapping = {
340341
"c_attn": ["c_attn"],
341342
"gate_up_proj": [

vllm/model_executor/models/qwen2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from vllm.model_executor.sampling_metadata import SamplingMetadata
5353
from vllm.sequence import IntermediateTensors, PoolerOutput
5454

55-
from .interfaces import SupportsLoRA, SupportsPP
55+
from .interfaces import SupportsLoRA, SupportsPP, SupportsSampleV2
5656
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
5757
is_pp_missing_parameter,
5858
make_empty_intermediate_tensors_factory, make_layers,
@@ -405,7 +405,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
405405
return loaded_params
406406

407407

408-
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
408+
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsSampleV2):
409409
packed_modules_mapping = {
410410
"qkv_proj": [
411411
"q_proj",

vllm/worker/model_runner.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from vllm.model_executor.model_loader import get_model
3939
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
4040
from vllm.model_executor.models import supports_lora, supports_multimodal
41+
from vllm.model_executor.models.interfaces import SupportsSampleV2
4142
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
4243
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
4344
MultiModalKwargs, MultiModalPlaceholderMap,
@@ -1785,8 +1786,12 @@ def execute_model(
17851786
torch.tensor(model_forward_time + orig_model_forward_time))
17861787
return hidden_or_intermediate_states
17871788

1788-
logits = self.model.compute_logits(hidden_or_intermediate_states,
1789-
model_input.sampling_metadata)
1789+
if isinstance(self.model, SupportsSampleV2):
1790+
logits = self.model.compute_logits_v2(
1791+
hidden_or_intermediate_states, model_input.sampling_metadata)
1792+
else:
1793+
logits = self.model.compute_logits(hidden_or_intermediate_states,
1794+
model_input.sampling_metadata)
17901795

17911796
if not self.is_driver_worker:
17921797
return []
@@ -1795,10 +1800,13 @@ def execute_model(
17951800
model_input.async_callback()
17961801

17971802
# Sample the next token.
1798-
output: SamplerOutput = self.model.sample(
1799-
logits=logits,
1800-
sampling_metadata=model_input.sampling_metadata,
1801-
)
1803+
if isinstance(self.model, SupportsSampleV2):
1804+
output: SamplerOutput = self.model.samplev2(
1805+
logits=logits, sampling_metadata=model_input.sampling_metadata)
1806+
else:
1807+
output = self.model.sample(
1808+
logits=logits, sampling_metadata=model_input.sampling_metadata)
1809+
18021810
if (self.observability_config is not None
18031811
and self.observability_config.collect_model_forward_time
18041812
and output is not None):

0 commit comments

Comments
 (0)