Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 21 additions & 74 deletions cpp/grammar_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,81 +19,15 @@
#include "fsm.h"
#include "grammar_functor.h"
#include "grammar_impl.h"
#include "grammar_matcher_for_cache.h"
#include "support/logging.h"
#include "support/thread_pool.h"
#include "support/thread_safe_cache.h"
#include "support/utils.h"

namespace xgrammar {

/************** AdaptiveTokenMaskCache Generator **************/

/*! \brief The concrete implementation of GrammarMatcherNode. */
class GrammarMatcherForTokenMaskCache : public EarleyParser {
public:
GrammarMatcherForTokenMaskCache(
const Grammar& grammar, const ParserState& init_state, const bool& need_expand = true
)
: EarleyParser(grammar, init_state),
init_rule_id(init_state.rule_id),
initial_state(init_state) {}
/*!
* \brief Get the adaptive token mask for the given ParserState.
* \param is_root_rule Whether to consider the parent rule. If false, there will be
* no uncertain tokens. Useful for the root rule.
*/
AdaptiveTokenMask GetAdaptiveTokenMask(
size_t vocab_size,
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
const std::vector<int32_t>& subtree_nodes_range,
bool is_root_rule
);

/*!
* \brief Get the token mask for the given ParserState.
* \param sorted_decoded_vocab The sorted decoded vocabulary.
* \param first_char_mask The first character mask.
* \param is_root_rule Whether to consider the parent rule. If false, there will be
* no uncertain tokens. Useful for the root rule.
* \returns True if the rejected indices are filled as usual, False otherwise.
* It's used to determine which construction function will be used.
*/
bool GetTokenMaskWithFirstCharacterCheck(
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
const std::bitset<256>& first_char_mask,
const std::vector<int>& subtree_nodes_range,
bool is_root_rule
);

private:
/*! \brief Check if a token can pass the lookahead assertion. */
std::pair</*acceptable*/ bool, /*can reach end*/ bool> IsTokenPassLookaheadAssertion(
const std::string& token, const std::vector<bool>& can_reach_end_stack
);

/*!
* \brief Check if speculative calculation will be applied.
* \return first: whether speculative calculation is applicable.
* \return second: part of the first character mask,
* which can be used in speculative calculation.
*/
std::pair<bool, std::bitset<256>> GetSpeculativeCalculation(
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab
);

// The id of the initial rule.
int32_t init_rule_id;

// The initial state of the parser.
ParserState initial_state;

// Temporary data for GetAdaptiveTokenMask.
std::vector<int32_t> tmp_accepted_indices_;
std::vector<int32_t> tmp_rejected_indices_;
std::vector<int32_t> tmp_uncertain_indices_;
std::vector<bool> tmp_can_reach_end_stack_;
std::vector<bool> tmp_can_reach_end_prefix_or_stack_;
};
/************** Use GrammarMatcher to generate the AdaptiveTokenMaskCache **************/

std::pair<bool, bool> GrammarMatcherForTokenMaskCache::IsTokenPassLookaheadAssertion(
const std::string& token, const std::vector<bool>& can_reach_end_stack
Expand Down Expand Up @@ -543,8 +477,8 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
*/
class GrammarCompilerNoCache {
public:
GrammarCompilerNoCache(const TokenizerInfo& tokenizer_info, int max_threads)
: tokenizer_info_(tokenizer_info), max_threads_(max_threads) {}
GrammarCompilerNoCache(const TokenizerInfo& tokenizer_info, int max_threads, bool is_jit)
: tokenizer_info_(tokenizer_info), max_threads_(max_threads), is_jit_(is_jit) {}

CompiledGrammar CompileBuiltinJSONGrammar();

Expand Down Expand Up @@ -573,6 +507,8 @@ class GrammarCompilerNoCache {
const TokenizerInfo tokenizer_info_;
/*! \brief The maximum number of threads to use. */
const int max_threads_;
/*! \brief Whether the jit mode is enabled.*/
const bool is_jit_;
};

CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar grammar) {
Expand All @@ -588,6 +524,9 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
if (tokenizer_info_.GetVocabSize() == 0) {
return CompiledGrammar(compiled_grammar_impl);
}
if (is_jit_) {
return CompiledGrammar(compiled_grammar_impl);
}
// Step 3. Compute the adaptive token mask cache
// The token mask cache is computed for these positions in the grammar:
// 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3)
Expand Down Expand Up @@ -827,9 +766,10 @@ class GrammarCompiler::Impl {
const TokenizerInfo& tokenizer_info,
int max_threads,
bool cache_enabled,
int64_t max_memory_bytes
int64_t max_memory_bytes,
bool is_jit
)
: no_cache_compiler_(tokenizer_info, max_threads),
: no_cache_compiler_(tokenizer_info, max_threads, is_jit),
cache_enabled_(cache_enabled),
compile_cache_(static_cast<std::size_t>(max_memory_bytes), Computer(*this)) {
if (max_memory_bytes < -1) {
Expand Down Expand Up @@ -997,9 +937,16 @@ GrammarCompiler::GrammarCompiler(
const TokenizerInfo& tokenizer_info,
int max_threads,
bool cache_enabled,
int64_t max_memory_bytes
int64_t max_memory_bytes,
bool is_jit
)
: pimpl_(std::make_shared<Impl>(tokenizer_info, max_threads, cache_enabled, max_memory_bytes)) {
: pimpl_(std::make_shared<Impl>(
tokenizer_info, max_threads, cache_enabled, max_memory_bytes, is_jit
)) {
if (max_memory_bytes < -1) {
XGRAMMAR_LOG(FATAL) << "Invalid max_memory_bytes: " << max_memory_bytes << ". "
<< "It should be -1 (unlimited) or a non-negative integer.";
}
}

CompiledGrammar GrammarCompiler::CompileJSONSchema(
Expand Down
26 changes: 25 additions & 1 deletion cpp/grammar_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "compiled_grammar_impl.h"
#include "earley_parser.h"
#include "grammar_impl.h"
#include "grammar_matcher_for_cache.h"
#include "support/dynamic_bitset.h"
#include "support/encoding.h"
#include "support/int_set.h"
Expand Down Expand Up @@ -509,7 +510,7 @@ bool GrammarMatcher::Impl::FillNextTokenBitmask(
CheckAndGetBitmaskPtr(*next_token_bitmask, tokenizer_info_.GetVocabSize(), index);
const auto& sorted_decoded_vocab = tokenizer_info_.GetSortedDecodedVocab();
const auto& subtree_range = tokenizer_info_.GetTrieSubtreeNodesRange();
const auto& adaptive_token_mask_cache = compiled_grammar_->adaptive_token_mask_cache;
auto& adaptive_token_mask_cache = compiled_grammar_->adaptive_token_mask_cache;
// We need to have a copy, because scanable_state_history_ will be modified during the
// FillNextTokenBitmask process, which can lead to undefined behavior.
auto latest_states = GetLatestScanableStates();
Expand All @@ -532,8 +533,31 @@ bool GrammarMatcher::Impl::FillNextTokenBitmask(
std::vector<std::pair<ParserState, decltype(adaptive_token_mask_cache.cbegin())>>
latest_states_with_masks;

auto add_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
auto grammar_matcher = GrammarMatcherForTokenMaskCache(grammar_, state, false);
auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask(
tokenizer_info_.GetVocabSize(),
tokenizer_info_.GetSortedDecodedVocab(),
tokenizer_info_.GetTrieSubtreeNodesRange(),
is_root_rule
);
return adaptive_token_mask_cache.emplace(state, std::move(cur_adaptive_token_mask_cache)).first;
};

for (const auto& state : latest_states) {
auto adaptive_token_mask_it = adaptive_token_mask_cache.find(state);
if (adaptive_token_mask_it == adaptive_token_mask_cache.end()) {
// It means that the grammar is jit.
bool is_root_rule = state.rule_id == grammar_->GetRootRuleId();
ParserState state_to_check = ParserState{
state.rule_id,
state.sequence_id,
state.element_id,
ParserState::kNoPrevInputPos,
state.sub_element_id
};
adaptive_token_mask_it = add_adaptive_token_mask(state_to_check, is_root_rule);
}
XGRAMMAR_CHECK(adaptive_token_mask_it != adaptive_token_mask_cache.end()) << state;
const auto& adaptive_token_mask = adaptive_token_mask_it->second;
latest_states_with_masks.push_back(std::make_pair(state, adaptive_token_mask_it));
Expand Down
84 changes: 84 additions & 0 deletions cpp/grammar_matcher_for_cache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*!
* Copyright (c) 2025 by Contributors
* \file xgrammar/grammar_matcher_for_cache.h
* \brief The header for the grammar matcher for the cache.
*/

#ifndef XGRAMMAR_GRAMMAR_MATCHER_FOR_CACHE_H_
#define XGRAMMAR_GRAMMAR_MATCHER_FOR_CACHE_H_

#include <bitset>

#include "compiled_grammar_impl.h"
#include "earley_parser.h"

namespace xgrammar {
/*! \brief The concrete implementation of GrammarMatcherNode. */
class GrammarMatcherForTokenMaskCache : public EarleyParser {
public:
GrammarMatcherForTokenMaskCache(
const Grammar& grammar, const ParserState& init_state, const bool& need_expand = true
)
: EarleyParser(grammar, init_state),
init_rule_id(init_state.rule_id),
initial_state(init_state) {}
/*!
* \brief Get the adaptive token mask for the given ParserState.
* \param is_root_rule Whether to consider the parent rule. If false, there will be
* no uncertain tokens. Useful for the root rule.
*/
AdaptiveTokenMask GetAdaptiveTokenMask(
size_t vocab_size,
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
const std::vector<int32_t>& subtree_nodes_range,
bool is_root_rule
);

/*!
* \brief Get the token mask for the given ParserState.
* \param sorted_decoded_vocab The sorted decoded vocabulary.
* \param first_char_mask The first character mask.
* \param is_root_rule Whether to consider the parent rule. If false, there will be
* no uncertain tokens. Useful for the root rule.
* \returns True if the rejected indices are filled as usual, False otherwise.
* It's used to determine which construction function will be used.
*/
bool GetTokenMaskWithFirstCharacterCheck(
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
const std::bitset<256>& first_char_mask,
const std::vector<int>& subtree_nodes_range,
bool is_root_rule
);

private:
/*! \brief Check if a token can pass the lookahead assertion. */
std::pair</*acceptable*/ bool, /*can reach end*/ bool> IsTokenPassLookaheadAssertion(
const std::string& token, const std::vector<bool>& can_reach_end_stack
);

/*!
* \brief Check if speculative calculation will be applied.
* \return first: whether speculative calculation is applicable.
* \return second: part of the first character mask,
* which can be used in speculative calculation.
*/
std::pair<bool, std::bitset<256>> GetSpeculativeCalculation(
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab
);

// The id of the initial rule.
int32_t init_rule_id;

// The initial state of the parser.
ParserState initial_state;

// Temporary data for GetAdaptiveTokenMask.
std::vector<int32_t> tmp_accepted_indices_;
std::vector<int32_t> tmp_rejected_indices_;
std::vector<int32_t> tmp_uncertain_indices_;
std::vector<bool> tmp_can_reach_end_stack_;
std::vector<bool> tmp_can_reach_end_prefix_or_stack_;
};
} // namespace xgrammar

#endif // XGRAMMAR_GRAMMAR_MATCHER_FOR_CACHE_H_
2 changes: 1 addition & 1 deletion cpp/nanobind/nanobind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ NB_MODULE(xgrammar_bindings, m) {
.def_static("deserialize_json", &CompiledGrammar_DeserializeJSON);

auto pyGrammarCompiler = nb::class_<GrammarCompiler>(m, "GrammarCompiler");
pyGrammarCompiler.def(nb::init<const TokenizerInfo&, int, bool, int64_t>())
pyGrammarCompiler.def(nb::init<const TokenizerInfo&, int, bool, int64_t, bool>())
.def(
"compile_json_schema",
&GrammarCompiler::CompileJSONSchema,
Expand Down
3 changes: 2 additions & 1 deletion include/xgrammar/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ class GrammarCompiler {
const TokenizerInfo& tokenizer_info,
int max_threads = 8,
bool cache_enabled = true,
int64_t max_memory_bytes = -1 // unlimited
int64_t max_memory_bytes = -1, // unlimited
bool is_jit = false
);

/*! \brief Get the compiled grammar for a JSON schema string. */
Expand Down
6 changes: 5 additions & 1 deletion python/xgrammar/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
max_threads: int = 8,
cache_enabled: bool = True,
cache_limit_bytes: int = -1,
is_jit: bool = False,
):
"""Construct the compiler.

Expand All @@ -128,6 +129,9 @@ def __init__(
cache_limit_bytes : int, default: -1
The maximum memory usage for the cache in the specified unit.
Note that the actual memory usage may slightly exceed this value.

is_jit : bool, default: False
Whether to enable Just-In-Time (JIT) compilation.
"""
if not isinstance(tokenizer_info, TokenizerInfo):
raise ValueError(
Expand All @@ -137,7 +141,7 @@ def __init__(

self._init_handle(
_core.GrammarCompiler(
tokenizer_info._handle, max_threads, cache_enabled, cache_limit_bytes
tokenizer_info._handle, max_threads, cache_enabled, cache_limit_bytes, is_jit
)
)

Expand Down
13 changes: 9 additions & 4 deletions python/xgrammar/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def _ebnf_to_grammar_no_normalization(ebnf_string: str, root_rule_name: str = "r
)


def _get_matcher_from_grammar(grammar: Union[Grammar, str], **kwargs) -> GrammarMatcher:
def _get_matcher_from_grammar(
grammar: Union[Grammar, str], is_jit: bool = False, **kwargs
) -> GrammarMatcher:
"""Create a GrammarMatcher from a grammar. The tokenizer info will be set to an empty
TokenizerInfo. The result matcher can only accept strings, and cannot accept tokens.

Expand All @@ -127,7 +129,7 @@ def _get_matcher_from_grammar(grammar: Union[Grammar, str], **kwargs) -> Grammar
The created grammar matcher.
"""
tokenizer_info = TokenizerInfo([])
grammar_compiler = GrammarCompiler(tokenizer_info, cache_enabled=False)
grammar_compiler = GrammarCompiler(tokenizer_info, cache_enabled=False, is_jit=is_jit)
compiled_grammar = grammar_compiler.compile_grammar(grammar)
return GrammarMatcher(compiled_grammar, terminate_without_stop_token=True, **kwargs)

Expand Down Expand Up @@ -299,7 +301,10 @@ def bitmask_to_bool_mask(bit_mask: torch.Tensor, vocab_size: Optional[int] = Non


def _get_matcher_from_grammar_and_tokenizer_info(
grammar: Union[Grammar, str], tokenizer_info: Optional[TokenizerInfo] = None, **kwargs
grammar: Union[Grammar, str],
tokenizer_info: Optional[TokenizerInfo] = None,
is_jit: bool = False,
**kwargs,
) -> GrammarMatcher:
"""Create a GrammarMatcher from a grammar and tokenizer info.

Expand All @@ -321,7 +326,7 @@ def _get_matcher_from_grammar_and_tokenizer_info(
"""
if tokenizer_info is None:
tokenizer_info = TokenizerInfo([])
grammar_compiler = GrammarCompiler(tokenizer_info, cache_enabled=False)
grammar_compiler = GrammarCompiler(tokenizer_info, cache_enabled=False, is_jit=is_jit)
compiled_grammar = grammar_compiler.compile_grammar(grammar)
return GrammarMatcher(compiled_grammar, **kwargs)

Expand Down
11 changes: 11 additions & 0 deletions tests/python/test_grammar_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,5 +329,16 @@ def make_schema(name_str: str):
assert grammar_compiler.get_cache_size_bytes() == 0


@pytest.mark.hf_token_required
def test_grammar_compiler_jit():
grammar = xgr.Grammar.builtin_json_grammar()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
compiler = xgr.GrammarCompiler(xgr.TokenizerInfo.from_huggingface(tokenizer), is_jit=True)
time_start = time.monotonic_ns()
_ = compiler.compile_grammar(grammar)
time_end = time.monotonic_ns()
print(f"JIT compilation time: {(time_end - time_start) / 1e6} ms")


if __name__ == "__main__":
pytest.main(sys.argv)
Loading
Loading