-
Notifications
You must be signed in to change notification settings - Fork 658
[Optimization] xgrammar async compile, multi thread, speed up #4835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 2 commits
ef1a837
01ddeb8
45ada78
c565792
bee4a84
4ad4236
037157c
e5d0b28
7d60aa4
5ec4dcd
52f44f1
bb78460
06c48da
8d7beba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,9 +14,10 @@ | |||||||||||
| # limitations under the License. | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| import multiprocessing | ||||||||||||
| import os | ||||||||||||
| import traceback | ||||||||||||
| from concurrent.futures import ThreadPoolExecutor | ||||||||||||
| from concurrent.futures import Future, ThreadPoolExecutor | ||||||||||||
|
|
||||||||||||
| from fastdeploy.config import ErnieArchitectures, FDConfig | ||||||||||||
| from fastdeploy.engine.request import Request | ||||||||||||
|
|
@@ -135,9 +136,9 @@ class BackendBase: | |||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| def __init__(self, fd_config: FDConfig): | ||||||||||||
| self.cache = {} | ||||||||||||
| self.fd_config = fd_config | ||||||||||||
| self.executor = ThreadPoolExecutor() | ||||||||||||
| max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) | ||||||||||||
| self.executor = ThreadPoolExecutor(max_workers=max_workers) | ||||||||||||
| self.max_cache_size = 2048 | ||||||||||||
| self.reasoning_parser = None | ||||||||||||
|
|
||||||||||||
|
|
@@ -263,7 +264,7 @@ def get_logits_processor( | |||||||||||
| self, | ||||||||||||
| schemata_key: tuple[str, str], | ||||||||||||
| enable_thinking: bool = False, | ||||||||||||
| ) -> tuple[LogitsProcessorBase, bool]: | ||||||||||||
| ) -> Future[LogitsProcessorBase]: | ||||||||||||
| """ | ||||||||||||
| get logits processor by key from cache or create new one. | ||||||||||||
|
|
@@ -275,13 +276,8 @@ def get_logits_processor( | |||||||||||
| - LogitsProcessorBase: The logits processor instance | ||||||||||||
| - bool: True if processor was from cache, False if newly created | ||||||||||||
| """ | ||||||||||||
| value = self.cache.get(schemata_key, None) | ||||||||||||
| if value: | ||||||||||||
| value_copy = value.copy() | ||||||||||||
| value_copy.enable_reasoning = enable_thinking | ||||||||||||
| return value_copy, True | ||||||||||||
| value = self.executor.submit(self._init_logits_processor, schemata_key, enable_thinking) | ||||||||||||
| return value, False | ||||||||||||
| return value | ||||||||||||
|
|
||||||||||||
| def _get_tokenizer_hf(self): | ||||||||||||
| """ | ||||||||||||
|
|
@@ -303,7 +299,7 @@ def _get_tokenizer_hf(self): | |||||||||||
|
|
||||||||||||
|
||||||||||||
| # NOTE: Setting use_fast=True switches to the fast (Rust) tokenizer, which may produce different tokenization results | |
| # compared to the slow (Python) tokenizer. This can affect model outputs and downstream processing. | |
| # Please ensure this change is compatible with your use case. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,7 @@ | |
| BaseChecker, | ||
| LogitsProcessorBase, | ||
| ) | ||
| from fastdeploy.platforms import current_platform | ||
| from fastdeploy.utils import llm_logger | ||
|
|
||
| try: | ||
|
|
@@ -86,6 +87,8 @@ def __init__( | |
| terminate_without_stop_token=terminate_without_stop_token, | ||
| override_stop_tokens=override_stop_tokens, | ||
| ) | ||
| # when matcher accept eos_token_id, is_terminated = True | ||
| self.is_terminated: bool = False | ||
|
|
||
| def allocate_token_bitmask(self) -> torch.Tensor: | ||
| """ | ||
|
|
@@ -109,40 +112,6 @@ def fill_token_bitmask(self, token_bitmask: torch.Tensor, idx: int) -> None: | |
| """ | ||
| self.matcher.fill_next_token_bitmask(token_bitmask, idx) | ||
|
|
||
| def apply_token_mask( | ||
| self, | ||
| logits: paddle.Tensor, | ||
| token_bitmask: torch.Tensor, | ||
| indices: Optional[List[int]] = None, | ||
| ) -> paddle.Tensor: | ||
| """ | ||
| Apply the token mask to the logits, modifying probabilities of invalid tokens. | ||
|
|
||
| Args: | ||
| logits (paddle.Tensor): The logits tensor to modify | ||
| token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens | ||
| indices (Optional[List[int]]): Optional list of batch indices to apply mask to | ||
|
|
||
| Returns: | ||
| paddle.Tensor: The modified logits tensor | ||
| """ | ||
| origin_place = logits.place | ||
| origin_dtype = logits.dtype | ||
| logits = torch.from_numpy(logits.numpy()) | ||
|
|
||
| logits = logits.float() # cpu | ||
| apply_token_bitmask_inplace( | ||
| logits=logits, | ||
| bitmask=token_bitmask.to(logits.device, non_blocking=True), | ||
| indices=indices, | ||
| ) | ||
|
|
||
| return paddle.to_tensor( | ||
| logits.numpy(), | ||
| dtype=origin_dtype, | ||
| place=origin_place, | ||
| ) | ||
|
|
||
| def reset(self) -> None: | ||
| """ | ||
| Reset the grammar matcher state to initial conditions. | ||
|
|
@@ -155,23 +124,21 @@ def reset(self) -> None: | |
| def accept_token(self, token: int) -> None: | ||
| """ | ||
| Validate and accept a generated token against the grammar constraints. | ||
| when accept eos_token, is_terminated = True | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里在哪里判断的eos_token啊?输出超长的场景怎么处理的?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. eos accept 之后,matcher 的状态就是is_terminated,下面就会被重置掉了。后面输出的 token 不会再限制格式。开 ignore_eos 之后也可以继续生成。 |
||
|
|
||
| Args: | ||
| token (int): The token ID to validate | ||
|
|
||
| Raises: | ||
| AssertionError: If token is not allowed by the grammar | ||
| """ | ||
| assert self.matcher.accept_token(token), f"Failed to accept token {token}" | ||
|
|
||
| def is_terminated(self) -> bool: | ||
| """ | ||
| Check if the grammar matching process has terminated. | ||
|
|
||
| Returns: | ||
| bool: True if matching has terminated, False otherwise | ||
| """ | ||
| return self.matcher.is_terminated() | ||
| if self.is_terminated or self.matcher.is_terminated(): | ||
| self.is_terminated = True | ||
| return False | ||
| if not self.matcher.accept_token(token): | ||
| self.matcher.reset() | ||
| return False | ||
| if self.matcher.is_terminated(): | ||
| self.is_terminated = True | ||
| return True | ||
|
Comment on lines
124
to
+141
|
||
|
|
||
| def copy(self) -> "XGrammarProcessor": | ||
| """ | ||
|
|
@@ -216,7 +183,13 @@ def __init__( | |
|
|
||
| try: | ||
| tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size) | ||
| self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info) | ||
| llm_logger.info(f"xgrammar_backend.py tokenzer_info={tokenizer_info.dump_metadata()}") | ||
ST-XX marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.grammar_compiler = GrammarCompiler( | ||
| tokenizer_info=tokenizer_info, | ||
| max_threads=8, | ||
| cache_enabled=True, | ||
| cache_limit_bytes=4 * 1024 * 1024, | ||
| ) # TODO cfg | ||
ST-XX marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| except Exception as e: | ||
| raise Exception(f"Failed to load XGrammar tokenizer: {e}") | ||
|
|
||
|
|
@@ -467,3 +440,49 @@ def schema_format(self, request: Request): | |
| else: | ||
| # regex is not format | ||
| return request, None | ||
|
|
||
|
|
||
| def apply_token_mask( | ||
| logits: paddle.Tensor, | ||
| token_bitmask: torch.Tensor, | ||
| indices: Optional[List[int]] = None, | ||
| ) -> paddle.Tensor: | ||
| """ | ||
| Apply the token mask to the logits, modifying probabilities of invalid tokens. | ||
|
|
||
| Args: | ||
| logits (paddle.Tensor): The logits tensor to modify | ||
| token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens | ||
| indices (Optional[List[int]]): Optional list of batch indices to apply mask to | ||
|
|
||
| Returns: | ||
| paddle.Tensor: The modified logits tensor | ||
| """ | ||
|
|
||
| if current_platform.is_cuda(): | ||
| dlpack = paddle.utils.dlpack.to_dlpack(logits) | ||
| t_logits = torch.from_dlpack(dlpack) | ||
| apply_token_bitmask_inplace( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个算子是支持paddle.tensor 的吧,为什么还要转torch.tensor 呢
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里还是原生的 xgr. apply_token_bitmask_inplace 接口,只支持 tensor.Tensor |
||
| logits=t_logits, | ||
| bitmask=token_bitmask.to(t_logits.device, non_blocking=True), | ||
| indices=indices, | ||
| ) | ||
| dlpack2 = torch.utils.dlpack.to_dlpack(t_logits) | ||
| return paddle.utils.dlpack.from_dlpack(dlpack2) | ||
|
Comment on lines
+468
to
+476
|
||
| else: | ||
| origin_place = logits.place | ||
| origin_dtype = logits.dtype | ||
| logits = torch.from_numpy(logits.numpy()) | ||
|
|
||
| logits = logits.float() # cpu | ||
| apply_token_bitmask_inplace( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个算子在多硬件上好像没有验证过?不确定能不能用
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是纯 cpu 操作。 bitmask=token_bitmask.to(logits.device, non_blocking=True), |
||
| logits=logits, | ||
| bitmask=token_bitmask.to(logits.device, non_blocking=True), | ||
| indices=indices, | ||
| ) | ||
|
|
||
| return paddle.to_tensor( | ||
| logits.numpy(), | ||
| dtype=origin_dtype, | ||
| place=origin_place, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for determining
max_workersusing(multiprocessing.cpu_count() + 1) // 2seems arbitrary without documentation. Consider adding a comment explaining why half the CPU count plus one is chosen, or make this configurable through the FDConfig.