diff --git a/cpp/json_schema_converter.cc b/cpp/json_schema_converter.cc index 9de124fc..4a0d78d3 100644 --- a/cpp/json_schema_converter.cc +++ b/cpp/json_schema_converter.cc @@ -339,7 +339,9 @@ class JSONSchemaConverter { std::string VisitConst(const picojson::object& schema, const std::string& rule_name); /*! \brief Visit an enum schema. */ - std::string VisitEnum(const picojson::object& schema, const std::string& rule_name); + std::string VisitEnum( + const picojson::object& schema, const std::string& rule_name, const JSONFormat json_format + ); /*! \brief Convert the JSON string to a printable string that can be shown in BNF. */ std::string JSONStrToPrintableStr(const std::string& json_str); @@ -873,7 +875,7 @@ std::string JSONSchemaConverter::VisitSchema( } else if (schema_obj.count("const")) { return VisitConst(schema_obj, rule_name); } else if (schema_obj.count("enum")) { - return VisitEnum(schema_obj, rule_name); + return VisitEnum(schema_obj, rule_name, json_format); } else if (schema_obj.count("anyOf") || schema_obj.count("oneOf")) { return VisitAnyOf(schema_obj, rule_name); } else if (schema_obj.count("allOf")) { @@ -979,7 +981,7 @@ std::string JSONSchemaConverter::VisitConst( } std::string JSONSchemaConverter::VisitEnum( - const picojson::object& schema, const std::string& rule_name + const picojson::object& schema, const std::string& rule_name, const JSONFormat json_format ) { XGRAMMAR_CHECK(schema.count("enum")); std::string result = ""; @@ -989,7 +991,17 @@ std::string JSONSchemaConverter::VisitEnum( result += " | "; } ++idx; - result += "(\"" + JSONStrToPrintableStr(value.serialize()) + "\")"; + if (json_format == JSONFormat::kJSON) { + result += "(\"" + JSONStrToPrintableStr(value.serialize()) + "\")"; + } else if (json_format == JSONFormat::kXML) { + auto inner = JSONStrToPrintableStr(value.serialize()); + // If the inner is a json style string, remove the quotation marks. + if (inner.size() >= 4 && inner.substr(0, 2) == "\\\"" && + inner.substr(inner.size() - 2, 2) == "\\\"") { + inner = inner.substr(2, inner.size() - 4); + } + result += "(\"" + inner + "\")"; + } } return result; } diff --git a/cpp/nanobind/nanobind.cc b/cpp/nanobind/nanobind.cc index d77a695b..82f57027 100644 --- a/cpp/nanobind/nanobind.cc +++ b/cpp/nanobind/nanobind.cc @@ -25,12 +25,61 @@ #include "../testing.h" #include "python_methods.h" #include "xgrammar/exception.h" +#include "xgrammar/grammar.h" #include "xgrammar/matcher.h" namespace nb = nanobind; namespace xgrammar { +Grammar Grammar_ApplyStructuralTagTemplate( + const std::string& structural_tag_template, const nb::kwargs& kwargs +) { + std::unordered_map>> values; + for (const auto& [key, value] : kwargs) { + nb::str key_str; + if (!nb::try_cast(key, key_str)) { + throw nb::type_error("Expected a string key for structural tag template values"); + } + nb::list value_list; + if (!nb::try_cast(value, value_list)) { + throw nb::type_error("Expected a list of dictionaries for structural tag template values"); + } + std::vector> value_vec; + value_vec.reserve(value_list.size()); + for (const auto& item : value_list) { + nb::dict item_dict; + if (!nb::try_cast(item, item_dict)) { + throw nb::type_error( + "Expected a dictionary for each item in the list of structural tag template values" + ); + } + std::unordered_map item_map; + for (const auto& [item_key, item_value] : item_dict) { + nb::str item_key_str, item_value_str; + if (!nb::try_cast(item_key, item_key_str)) { + throw nb::type_error( + "Expected a string key for each item in the structural tag template dictionary" + ); + } + if (!nb::try_cast(item_value, item_value_str)) { + throw nb::type_error( + "Expected a string for each value in the structural tag template dictionary" + ); + } + item_map[item_key_str.c_str()] = item_value_str.c_str(); + } + value_vec.push_back(std::move(item_map)); + } + values[key_str.c_str()] = std::move(value_vec); + } + auto result = ApplyStructuralTagTemplate(structural_tag_template, values).ToVariant(); + if (std::holds_alternative(result)) { + ThrowVariantError(std::get(result)); + } + return std::get(result); +} + std::vector CommonEncodedVocabType( const nb::typed> encoded_vocab ) { @@ -203,6 +252,7 @@ NB_MODULE(xgrammar_bindings, m) { &Grammar_FromStructuralTag, nb::call_guard() ) + .def_static("apply_structural_tag_template", &Grammar_ApplyStructuralTagTemplate) .def_static("builtin_json_grammar", &Grammar::BuiltinJSONGrammar) .def_static("union", &Grammar::Union, nb::call_guard()) .def_static("concat", &Grammar::Concat, nb::call_guard()) diff --git a/cpp/nanobind/python_methods.cc b/cpp/nanobind/python_methods.cc index 0fa137c7..80126a44 100644 --- a/cpp/nanobind/python_methods.cc +++ b/cpp/nanobind/python_methods.cc @@ -15,6 +15,7 @@ #include #include "../grammar_impl.h" +#include "../structural_tag.h" #include "../support/logging.h" #include "../support/utils.h" #include "xgrammar/exception.h" diff --git a/cpp/nanobind/python_methods.h b/cpp/nanobind/python_methods.h index 29dec070..073f6391 100644 --- a/cpp/nanobind/python_methods.h +++ b/cpp/nanobind/python_methods.h @@ -12,8 +12,10 @@ #include #include #include +#include #include +#include "../structural_tag.h" #include "xgrammar/tokenizer_info.h" namespace xgrammar { diff --git a/cpp/structural_tag.cc b/cpp/structural_tag.cc index 4d3debf7..667fca94 100644 --- a/cpp/structural_tag.cc +++ b/cpp/structural_tag.cc @@ -7,8 +7,16 @@ #include #include +#include +#include +#include +#include #include #include +#include +#include +#include +#include #include "grammar_functor.h" #include "grammar_impl.h" @@ -23,6 +31,8 @@ namespace xgrammar { // Short alias for the error type. using ISTError = InvalidStructuralTagError; +std::optional FullyTemplatePlaceholder(const std::string& str); + /************** StructuralTag Parser **************/ class StructuralTagParser { @@ -189,6 +199,10 @@ Result StructuralTagParser::ParseJSONSchemaFormat( auto json_schema_it = obj.find("json_schema"); if (json_schema_it == obj.end() || !(json_schema_it->second.is() || json_schema_it->second.is())) { + if (json_schema_it != obj.end() && json_schema_it->second.is() && + FullyTemplatePlaceholder(json_schema_it->second.to_str()).has_value()) { + return ResultOk(json_schema_it->second.to_str()); + } return ResultErr( "JSON schema format must have a json_schema field with a object or boolean value" ); @@ -204,6 +218,10 @@ Result StructuralTagParser::ParseQwenXmlParame auto json_schema_it = obj.find("json_schema"); if (json_schema_it == obj.end() || !(json_schema_it->second.is() || json_schema_it->second.is())) { + if (json_schema_it != obj.end() && json_schema_it->second.is() && + FullyTemplatePlaceholder(json_schema_it->second.to_str()).has_value()) { + return ResultOk(json_schema_it->second.to_str()); + } return ResultErr( "Qwen XML Parameter format must have a json_schema field with a object or boolean value" ); @@ -1057,6 +1075,1131 @@ Result StructuralTagGrammarConverter::VisitSub(const TagsWithSepa return ResultOk(rule_id); } +/************** StructuralTag Template Filler **************/ + +/*! + * \brief Detect all template placeholder names in the given string. + * \param str The string to detect. + * \return The detected template placeholder name. If no placeholder is found, return std::nullopt. + * \details A template placeholder is in the format of {function_name[].arg_name}. + */ +Result, StructuralTagError> DetectTemplatePlaceholderNames( + const std::string& str +) { + static const std::regex placeholder_regex( + R"(\{([a-zA-Z_][a-zA-Z0-9_]*)\[\]\.([a-zA-Z_][a-zA-Z0-9_]*)\})" + ); + + std::optional placeholder_name_opt = std::nullopt; + auto iter = std::sregex_iterator(str.begin(), str.end(), placeholder_regex); + for (; iter != std::sregex_iterator(); ++iter) { + const auto& match = *iter; + std::string function_name = match[1].str(); + if (placeholder_name_opt.has_value() && placeholder_name_opt.value() != function_name) { + return ResultErr(InvalidStructuralTagError( + "Multiple different placeholder names found in the same string: '" + str + "'" + )); + } else { + placeholder_name_opt = function_name; + } + } + return ResultOk>(placeholder_name_opt); +} + +std::optional FullyTemplatePlaceholder(const std::string& str) { + static const std::regex full_placeholder_regex( + R"(^\{([a-zA-Z_][a-zA-Z0-9_]*)\[\]\.([a-zA-Z_][a-zA-Z0-9_]*)\}$)" + ); + std::smatch match; + if (std::regex_match(str, match, full_placeholder_regex)) { + return match[1].str(); + } else { + return std::nullopt; + } +} + +const auto FormatToString = [](const Format& format) -> std::string { + return std::visit([&](auto&& arg) -> std::string { return arg.ToString(); }, format); +}; + +/*! + * \brief A structural tag template filler, used to fill the structral tags with the given values. + */ +class StructuralTagTemplateFiller { + public: + Result Apply( + const StructuralTag& template_structural_tag, + const std::unordered_map< + std::string, + std::vector>>& values + ); + + bool HasUnfilledPlaceholders(const StructuralTag& structural_tag); + + const std::regex placeholder_regex = + std::regex(R"(\{([a-zA-Z_][a-zA-Z0-9_]*)\[\]\.([a-zA-Z_][a-zA-Z0-9_]*)\})"); + + private: + static const int kDefaultExpansionMode = -1; + std::unordered_map> format_to_placeholder_names_; + const std::unordered_map>>* + values_ = nullptr; + + Result, StructuralTagError> Visit(const Format& format); + Result, StructuralTagError> VisitExpand( + const Format& format_template_to_expand, + const int index = kDefaultExpansionMode, + const std::string& current_placeholder_name = "" + ); + + Result, StructuralTagError> VisitSub(const ConstStringFormat& format); + Result, StructuralTagError> VisitSub(const JSONSchemaFormat& format); + Result, StructuralTagError> VisitSub(const QwenXmlParameterFormat& format + ); + Result, StructuralTagError> VisitSub(const AnyTextFormat& format); + Result, StructuralTagError> VisitSub(const GrammarFormat& format); + Result, StructuralTagError> VisitSub(const RegexFormat& format); + Result, StructuralTagError> VisitSub(const SequenceFormat& format); + Result, StructuralTagError> VisitSub(const OrFormat& format); + Result, StructuralTagError> VisitSub(const TagFormat& format); + Result, StructuralTagError> VisitSub(const TriggeredTagsFormat& format); + Result, StructuralTagError> VisitSub( + const TagsWithSeparatorFormat& format + ); + + Result, StructuralTagError> VisitExpandSub( + const ConstStringFormat& format_template_to_expand, + const int index = kDefaultExpansionMode, + const std::string& current_placeholder_name = "" + ); + Result, StructuralTagError> VisitExpandSub( + const JSONSchemaFormat& format_template_to_expand, + const int index = kDefaultExpansionMode, + const std::string& current_placeholder_name = "" + ); + Result, StructuralTagError> VisitExpandSub( + const QwenXmlParameterFormat& format_template_to_expand, + const int index = kDefaultExpansionMode, + const std::string& current_placeholder_name = "" + ); + Result, StructuralTagError> VisitExpandSub( + const AnyTextFormat& format_template_to_expand, + const int index = kDefaultExpansionMode, + const std::string& current_placeholder_name = "" + ); + Result, StructuralTagError> VisitExpandSub( + const GrammarFormat& format_template_to_expand, + const int index = kDefaultExpansionMode, + const std::string& current_placeholder_name = "" + ); + Result, StructuralTagError> VisitExpandSub( + const RegexFormat& format_template_to_expand, + const int index = kDefaultExpansionMode, + const std::string& current_placeholder_name = "" + ); + Result, StructuralTagError> VisitExpandSub( + const SequenceFormat& format_template_to_expand, + const int index = kDefaultExpansionMode, + const std::string& current_placeholder_name = "" + ); + Result, StructuralTagError> VisitExpandSub( + const OrFormat& format_template_to_expand, + const int index = kDefaultExpansionMode, + const std::string& current_placeholder_name = "" + ); + Result, StructuralTagError> VisitExpandSub( + const TagFormat& format_template_to_expand, + const int index = kDefaultExpansionMode, + const std::string& current_placeholder_name = "" + ); + Result, StructuralTagError> VisitExpandSub( + const TriggeredTagsFormat& format_template_to_expand, + const int index = kDefaultExpansionMode, + const std::string& current_placeholder_name = "" + ); + Result, StructuralTagError> VisitExpandSub( + const TagsWithSeparatorFormat& format_template_to_expand, + const int index = kDefaultExpansionMode, + const std::string& current_placeholder_name = "" + ); + + Result ReplacePlaceHolder( + const std::string& str, const std::string& placeholder_name, int index + ); +}; + +Result StructuralTagTemplateFiller::ReplacePlaceHolder( + const std::string& str, const std::string& placeholder_name, int index +) { + XGRAMMAR_DCHECK(values_->find(placeholder_name) != values_->end()); + const auto& placeholder_values_vector = values_->at(placeholder_name); + XGRAMMAR_DCHECK(static_cast(index) < placeholder_values_vector.size()); + const auto& placeholder_values = placeholder_values_vector[index]; + const std::regex placeholder_regex( + R"(\{)" + placeholder_name + R"(\[\]\.([a-zA-Z_][a-zA-Z0-9_]*)\})" + ); + std::string result_str = ""; + int last_match_pos = 0; + + // Replace each placeholder, and add the left part. + auto iter = std::sregex_iterator(str.begin(), str.end(), placeholder_regex); + for (; iter != std::sregex_iterator(); ++iter) { + result_str.append(str.begin() + last_match_pos, str.begin() + iter->position()); + const auto& arg_name = (*iter)[1]; + if (placeholder_values.find(arg_name) == placeholder_values.end()) { + return ResultErr(InvalidStructuralTagError( + "The value " + std::string(arg_name) + " is not defined in the array " + placeholder_name + )); + } + result_str.append(placeholder_values.at(arg_name)); + last_match_pos = iter->position() + iter->length(); + } + + // Add the last piece. + result_str.append(str.begin() + last_match_pos, str.end()); + return ResultOk(result_str); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::Visit( + const Format& format +) { + auto result = std::visit( + [&](auto&& arg) -> Result, StructuralTagError> { + return VisitSub(arg); + }, + format + ); + if (result.IsErr()) { + return result; + } + auto placeholder_names = std::move(result).Unwrap(); + auto serialized_format = + std::visit([&](auto&& arg) -> std::string { return arg.ToString(); }, format); + format_to_placeholder_names_[serialized_format] = placeholder_names; + return ResultOk>(placeholder_names); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitSub( + const ConstStringFormat& format +) { + auto result = DetectTemplatePlaceholderNames(format.value); + if (result.IsErr()) { + return ResultErr(std::move(result).UnwrapErr()); + } + auto placeholder_name_opt = std::move(result).Unwrap(); + if (placeholder_name_opt.has_value()) { + return ResultOk>({placeholder_name_opt.value()}); + } else { + return ResultOk>({}); + } +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitSub( + const JSONSchemaFormat& format +) { + auto placeholder_name_opt = FullyTemplatePlaceholder(format.json_schema); + if (placeholder_name_opt.has_value()) { + return ResultOk>({placeholder_name_opt.value()}); + } else { + return ResultOk>({}); + } +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitSub( + const QwenXmlParameterFormat& format +) { + auto placeholder_name_opt = FullyTemplatePlaceholder(format.xml_schema); + if (placeholder_name_opt.has_value()) { + return ResultOk>({placeholder_name_opt.value()}); + } else { + return ResultOk>({}); + } +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitSub( + const AnyTextFormat& format +) { + return ResultOk>({}); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitSub( + const GrammarFormat& format +) { + auto placeholder_name_opt = FullyTemplatePlaceholder(format.grammar); + if (placeholder_name_opt.has_value()) { + return ResultOk>({placeholder_name_opt.value()}); + } else { + return ResultOk>({}); + } +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitSub( + const RegexFormat& format +) { + auto placeholder_name_opt = FullyTemplatePlaceholder(format.pattern); + if (placeholder_name_opt.has_value()) { + return ResultOk>({placeholder_name_opt.value()}); + } else { + return ResultOk>({}); + } +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitSub( + const SequenceFormat& format +) { + std::unordered_map placeholder_names; + + for (const auto& element : format.elements) { + auto result = Visit(element); + if (result.IsErr()) { + return result; + } + auto sub_placeholder_names = std::move(result).Unwrap(); + for (const auto& sub_place_holder_name : sub_placeholder_names) { + if (placeholder_names.find(sub_place_holder_name) == placeholder_names.end()) { + placeholder_names[sub_place_holder_name] = 1; + } else { + placeholder_names[sub_place_holder_name]++; + } + } + } + if (placeholder_names.size() > 1) { + bool multiple_defined = false; + std::string multiple_defined_name = ""; + for (const auto& [placeholder_name, times] : placeholder_names) { + if (times > 1) { + if (multiple_defined) { + return ResultErr(InvalidStructuralTagError( + "Mingled template detected: " + multiple_defined_name + " and " + placeholder_name + )); + } else { + multiple_defined = true; + multiple_defined_name = placeholder_name; + } + } + } + } + + std::vector return_value; + return_value.reserve(placeholder_names.size()); + for (const auto& [placeholder_name, _] : placeholder_names) { + return_value.push_back(placeholder_name); + } + return ResultOk>(return_value); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitSub( + const OrFormat& format +) { + std::vector placeholder_names; + for (const auto& element : format.elements) { + auto result = Visit(element); + if (result.IsErr()) { + return result; + } + auto sub_placeholder_names = std::move(result).Unwrap(); + placeholder_names.insert( + placeholder_names.end(), sub_placeholder_names.begin(), sub_placeholder_names.end() + ); + } + return ResultOk>(placeholder_names); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitSub( + const TagFormat& format +) { + std::unordered_map placeholder_names; + + auto result = Visit(*format.content); + if (result.IsErr()) { + return result; + } + auto sub_placeholder_names = std::move(result).Unwrap(); + for (const auto& sub_place_holder_name : sub_placeholder_names) { + if (placeholder_names.find(sub_place_holder_name) == placeholder_names.end()) { + placeholder_names[sub_place_holder_name] = 1; + } else { + placeholder_names[sub_place_holder_name]++; + } + } + + auto begin_result = DetectTemplatePlaceholderNames(format.begin); + if (begin_result.IsErr()) { + return ResultErr(std::move(begin_result).UnwrapErr()); + } + auto begin_placeholder_names = std::move(begin_result).Unwrap(); + if (begin_placeholder_names.has_value()) { + placeholder_names[begin_placeholder_names.value()]++; + } + + auto end_result = DetectTemplatePlaceholderNames(format.end); + if (end_result.IsErr()) { + return ResultErr(std::move(end_result).UnwrapErr()); + } + auto end_placeholder_names = std::move(end_result).Unwrap(); + if (end_placeholder_names.has_value()) { + placeholder_names[end_placeholder_names.value()]++; + } + + if (begin_placeholder_names.has_value() && end_placeholder_names.has_value() && + begin_placeholder_names.value() != end_placeholder_names.value()) { + return ResultErr(InvalidStructuralTagError( + "Multiple different placeholder names found in the tag format: '" + + begin_placeholder_names.value() + "' and '" + end_placeholder_names.value() + "'" + )); + } + + if (placeholder_names.size() > 1) { + bool multiple_defined = false; + std::string multiple_defined_name = ""; + for (const auto& [placeholder_name, times] : placeholder_names) { + if (times > 1) { + if (multiple_defined) { + return ResultErr(InvalidStructuralTagError( + "Mingled template detected: " + multiple_defined_name + " and " + placeholder_name + )); + } else { + multiple_defined = true; + multiple_defined_name = placeholder_name; + } + } + } + } + std::vector return_value; + return_value.reserve(placeholder_names.size()); + for (const auto& [placeholder_name, _] : placeholder_names) { + return_value.push_back(placeholder_name); + } + return ResultOk>(return_value); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitSub( + const TriggeredTagsFormat& format +) { + std::vector placeholder_names; + for (const auto& tag : format.tags) { + auto result = Visit(tag); + if (result.IsErr()) { + return result; + } + auto sub_placeholder_names = std::move(result).Unwrap(); + placeholder_names.insert( + placeholder_names.end(), sub_placeholder_names.begin(), sub_placeholder_names.end() + ); + } + return ResultOk>(placeholder_names); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitSub( + const TagsWithSeparatorFormat& format +) { + std::vector placeholder_names; + for (const auto& tag : format.tags) { + auto result = Visit(tag); + if (result.IsErr()) { + return result; + } + auto sub_placeholder_names = std::move(result).Unwrap(); + placeholder_names.insert( + placeholder_names.end(), sub_placeholder_names.begin(), sub_placeholder_names.end() + ); + } + return ResultOk>(placeholder_names); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitExpand( + const Format& format_template_to_expand, + const int index, + const std::string& current_placeholder_name +) { + XGRAMMAR_DCHECK( + format_to_placeholder_names_.find(FormatToString(format_template_to_expand)) != + format_to_placeholder_names_.end() + ); + const auto& placeholder_names = + format_to_placeholder_names_[FormatToString(format_template_to_expand)]; + bool is_dummy_placeholder = + std::all_of(placeholder_names.begin(), placeholder_names.end(), [&](const std::string& name) { + return name != current_placeholder_name; + }); + if (is_dummy_placeholder) { + return std::visit( + [&](auto&& arg) -> Result, StructuralTagError> { + return VisitExpandSub(arg); + }, + format_template_to_expand + ); + } else { + return std::visit( + [&](auto&& arg) -> Result, StructuralTagError> { + return VisitExpandSub(arg, index, current_placeholder_name); + }, + format_template_to_expand + ); + } +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitExpandSub( + const ConstStringFormat& format_template_to_expand, + const int index, + const std::string& current_placeholder_name +) { + auto serialized_format = format_template_to_expand.ToString(); + XGRAMMAR_DCHECK( + format_to_placeholder_names_.find(serialized_format) != format_to_placeholder_names_.end() + ) << "Format not visited before expansion"; + const auto& placeholder_names = format_to_placeholder_names_[serialized_format]; + // If there are no placeholders, return the original format as the only expansion. + if (placeholder_names.empty()) { + return ResultOk>({format_template_to_expand}); + } + XGRAMMAR_DCHECK(placeholder_names.size() == 1) + << "Multiple different placeholders in ConstStringFormat is not supported"; + + const auto& placeholder_name = placeholder_names[0]; + std::vector expansions; + if (index == kDefaultExpansionMode) { + const auto& value = values_->at(placeholder_name); + expansions.reserve(value.size()); + for (int i = 0; i < static_cast(value.size()); ++i) { + auto result = ReplacePlaceHolder(format_template_to_expand.value, placeholder_name, i); + if (result.IsErr()) { + return ResultErr(std::move(result).UnwrapErr()); + } + expansions.push_back(ConstStringFormat(std::move(result).Unwrap())); + } + } else { + XGRAMMAR_DCHECK(current_placeholder_name == placeholder_name) + << "Index provided for a different placeholder name"; + auto result = ReplacePlaceHolder(format_template_to_expand.value, placeholder_name, index); + if (result.IsErr()) { + return ResultErr(std::move(result).UnwrapErr()); + } + expansions.push_back(ConstStringFormat(std::move(result).Unwrap())); + } + return ResultOk>(expansions); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitExpandSub( + const JSONSchemaFormat& format_template, + const int index, + const std::string& current_placeholder_name +) { + auto serialized_format = format_template.ToString(); + XGRAMMAR_DCHECK( + format_to_placeholder_names_.find(serialized_format) != format_to_placeholder_names_.end() + ) << "Format not visited before expansion"; + const auto& placeholder_names = format_to_placeholder_names_[serialized_format]; + // If there are no placeholders, return the original format as the only expansion. + if (placeholder_names.empty()) { + return ResultOk>({format_template}); + } + XGRAMMAR_DCHECK(placeholder_names.size() == 1) + << "Multiple different placeholders in JSONSchemaFormat is not supported"; + const auto& placeholder_name = placeholder_names[0]; + std::vector expansions; + if (index == kDefaultExpansionMode) { + const auto& value = values_->at(placeholder_name); + expansions.reserve(value.size()); + for (int i = 0; i < static_cast(value.size()); ++i) { + auto result = ReplacePlaceHolder(format_template.json_schema, placeholder_name, i); + if (result.IsErr()) { + return ResultErr(std::move(result).UnwrapErr()); + } + expansions.push_back(JSONSchemaFormat(std::move(result).Unwrap())); + } + } else { + XGRAMMAR_DCHECK(current_placeholder_name == placeholder_name) + << "Index provided for a different placeholder name"; + auto result = ReplacePlaceHolder(format_template.json_schema, placeholder_name, index); + if (result.IsErr()) { + return ResultErr(std::move(result).UnwrapErr()); + } + expansions.push_back(JSONSchemaFormat(std::move(result).Unwrap())); + } + return ResultOk>(expansions); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitExpandSub( + const QwenXmlParameterFormat& format_template, + const int index, + const std::string& current_placeholder_name +) { + auto serialized_format = format_template.ToString(); + XGRAMMAR_DCHECK( + format_to_placeholder_names_.find(serialized_format) != format_to_placeholder_names_.end() + ) << "Format not visited before expansion"; + const auto& placeholder_names = format_to_placeholder_names_[serialized_format]; + // If there are no placeholders, return the original format as the only expansion. + if (placeholder_names.empty()) { + return ResultOk>({format_template}); + } + XGRAMMAR_DCHECK(placeholder_names.size() == 1) + << "Multiple different placeholders in QwenXmlParameterFormat is not supported"; + + const auto& placeholder_name = placeholder_names[0]; + std::vector expansions; + if (index == kDefaultExpansionMode) { + const auto& value = values_->at(placeholder_name); + expansions.reserve(value.size()); + for (int i = 0; i < static_cast(value.size()); ++i) { + auto result = ReplacePlaceHolder(format_template.xml_schema, placeholder_name, i); + if (result.IsErr()) { + return ResultErr(std::move(result).UnwrapErr()); + } + expansions.push_back(QwenXmlParameterFormat(std::move(result).Unwrap())); + } + } else { + XGRAMMAR_DCHECK(current_placeholder_name == placeholder_name) + << "Index provided for a different placeholder name"; + auto result = ReplacePlaceHolder(format_template.xml_schema, placeholder_name, index); + if (result.IsErr()) { + return ResultErr(std::move(result).UnwrapErr()); + } + expansions.push_back(QwenXmlParameterFormat(std::move(result).Unwrap())); + } + return ResultOk>(expansions); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitExpandSub( + const AnyTextFormat& format_template, + const int /*index*/, + const std::string& /*current_placeholder_name*/ +) { + return ResultOk>({format_template}); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitExpandSub( + const GrammarFormat& format_template, + const int index, + const std::string& current_placeholder_name +) { + auto serialized_format = format_template.ToString(); + XGRAMMAR_DCHECK( + format_to_placeholder_names_.find(serialized_format) != format_to_placeholder_names_.end() + ) << "Format not visited before expansion"; + const auto& placeholder_names = format_to_placeholder_names_[serialized_format]; + // If there are no placeholders, return the original format as the only expansion. + if (placeholder_names.empty()) { + return ResultOk>({format_template}); + } + XGRAMMAR_DCHECK(placeholder_names.size() == 1) + << "Multiple different placeholders in GrammarFormat is not supported"; + + const auto& placeholder_name = placeholder_names[0]; + std::vector expansions; + if (index == kDefaultExpansionMode) { + const auto& value = values_->at(placeholder_name); + expansions.reserve(value.size()); + for (int i = 0; i < static_cast(value.size()); ++i) { + auto result = ReplacePlaceHolder(format_template.grammar, placeholder_name, i); + if (result.IsErr()) { + return ResultErr(std::move(result).UnwrapErr()); + } + expansions.push_back(GrammarFormat(std::move(result).Unwrap())); + } + } else { + XGRAMMAR_DCHECK(current_placeholder_name == placeholder_name) + << "Index provided for a different placeholder name"; + auto result = ReplacePlaceHolder(format_template.grammar, placeholder_name, index); + if (result.IsErr()) { + return ResultErr(std::move(result).UnwrapErr()); + } + expansions.push_back(GrammarFormat(std::move(result).Unwrap())); + } + return ResultOk>(expansions); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitExpandSub( + const RegexFormat& format_template, const int index, const std::string& current_placeholder_name +) { + auto serialized_format = format_template.ToString(); + XGRAMMAR_DCHECK( + format_to_placeholder_names_.find(serialized_format) != format_to_placeholder_names_.end() + ) << "Format not visited before expansion"; + const auto& placeholder_names = format_to_placeholder_names_[serialized_format]; + // If there are no placeholders, return the original format as the only expansion. + if (placeholder_names.empty()) { + return ResultOk>({format_template}); + } + XGRAMMAR_DCHECK(placeholder_names.size() == 1) + << "Multiple different placeholders in RegexFormat is not supported"; + + const auto& placeholder_name = placeholder_names[0]; + std::vector expansions; + if (index == kDefaultExpansionMode) { + const auto& value = values_->at(placeholder_name); + expansions.reserve(value.size()); + for (int i = 0; i < static_cast(value.size()); ++i) { + auto result = ReplacePlaceHolder(format_template.pattern, placeholder_name, i); + if (result.IsErr()) { + return ResultErr(std::move(result).UnwrapErr()); + } + expansions.push_back(RegexFormat(std::move(result).Unwrap())); + } + } else { + XGRAMMAR_DCHECK(current_placeholder_name == placeholder_name) + << "Index provided for a different placeholder name"; + auto result = ReplacePlaceHolder(format_template.pattern, placeholder_name, index); + if (result.IsErr()) { + return ResultErr(std::move(result).UnwrapErr()); + } + expansions.push_back(RegexFormat(std::move(result).Unwrap())); + } + return ResultOk>(expansions); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitExpandSub( + const SequenceFormat& format_template, + const int index, + const std::string& current_placeholder_name +) { + auto serialized_format = format_template.ToString(); + XGRAMMAR_DCHECK( + format_to_placeholder_names_.find(serialized_format) != format_to_placeholder_names_.end() + ) << "Format not visited before expansion"; + + // Check each subformat's placeholder names. + bool multiple_defined = false; + std::string multiple_defined_name = ""; + std::unordered_set placeholder_name_set; + std::vector element_iterators; + for (const auto& subformat : format_template.elements) { + auto serialized_subformat = FormatToString(subformat); + XGRAMMAR_DCHECK( + format_to_placeholder_names_.find(serialized_subformat) != + format_to_placeholder_names_.end() + ); + element_iterators.push_back(format_to_placeholder_names_.find(serialized_subformat)); + const auto& sub_placeholder_names = format_to_placeholder_names_[serialized_subformat]; + for (const auto& sub_placeholder_name : sub_placeholder_names) { + if (placeholder_name_set.find(sub_placeholder_name) != placeholder_name_set.end()) { + multiple_defined = true; + multiple_defined_name = sub_placeholder_name; + } else { + placeholder_name_set.insert(sub_placeholder_name); + } + } + } + + // Step 1. If not multiple defined, or multiple defined placeholder is expanded in the + // previous level, expand normally. + if ((!multiple_defined) || + (multiple_defined && multiple_defined_name == current_placeholder_name)) { + std::vector expansions; + for (int i = 0; i < static_cast(format_template.elements.size()); i++) { + if (element_iterators[i]->second.size() == 0) { + // Simple subformat. + expansions.push_back(format_template.elements[i]); + continue; + } + + // Expand subformat. + auto sub_expansions_result = + VisitExpand(format_template.elements[i], index, current_placeholder_name); + if (sub_expansions_result.IsErr()) { + return ResultErr(std::move(sub_expansions_result).UnwrapErr()); + } + auto sub_expansions = std::move(sub_expansions_result).Unwrap(); + if (sub_expansions.size() == 1) { + expansions.push_back(sub_expansions[0]); + } else if (sub_expansions.size() > 1) { + OrFormat or_format(std::move(sub_expansions)); + expansions.push_back(or_format); + } + } + if (expansions.empty()) { + return ResultOk>(); + } else { + return ResultOk>(std::vector{SequenceFormat(std::move(expansions)) + }); + } + } + + // Step 2. multiple defined, and the multiple defined placeholder is not expanded in the previous + // level. + + // Initialization for expansion. + std::vector all_expansions; + std::vector is_multiple_defined_in_subformat; + std::vector is_current_placeholder_in_subformat; + const auto& values = values_->at(multiple_defined_name); + for (const auto& element_iterator : element_iterators) { + bool is_multiple_defined = + std::all_of( + element_iterator->second.begin(), + element_iterator->second.end(), + [&](const std::string& name) { return name != multiple_defined_name; } + ) == false; + is_multiple_defined_in_subformat.push_back(is_multiple_defined); + bool is_current_placeholder = + std::all_of( + element_iterator->second.begin(), + element_iterator->second.end(), + [&](const std::string& name) { return name != current_placeholder_name; } + ) == false; + is_current_placeholder_in_subformat.push_back(is_current_placeholder); + if (is_current_placeholder && is_multiple_defined) { + return ResultErr(InvalidStructuralTagError( + "Mingled template detected when expanding multiple defined placeholder: " + + multiple_defined_name + " and " + current_placeholder_name + )); + } + } + + for (int value_index = 0; value_index < static_cast(values.size()); ++value_index) { + std::vector expansions; + for (int i = 0; i < static_cast(format_template.elements.size()); i++) { + if (element_iterators[i]->second.size() == 0) { + // Simple subformat. + expansions.push_back(format_template.elements[i]); + continue; + } + + // Expand subformat. + auto sub_expansions_result = + is_multiple_defined_in_subformat[i] + ? VisitExpand(format_template.elements[i], value_index, multiple_defined_name) + : VisitExpand(format_template.elements[i], index, current_placeholder_name); + + if (sub_expansions_result.IsErr()) { + return ResultErr(std::move(sub_expansions_result).UnwrapErr()); + } + auto sub_expansions = std::move(sub_expansions_result).Unwrap(); + if (sub_expansions.size() == 1) { + expansions.push_back(sub_expansions[0]); + } else { + OrFormat or_format(std::move(sub_expansions)); + expansions.push_back(or_format); + } + } + all_expansions.push_back(SequenceFormat(std::move(expansions))); + } + return ResultOk>(std::move(all_expansions)); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitExpandSub( + const OrFormat& format_template, const int index, const std::string& current_placeholder_name +) { + std::vector all_expansions; + for (const auto& element : format_template.elements) { + auto sub_expansions_result = VisitExpand(element, index, current_placeholder_name); + if (sub_expansions_result.IsErr()) { + return ResultErr(std::move(sub_expansions_result).UnwrapErr()); + } + auto sub_expansions = std::move(sub_expansions_result).Unwrap(); + all_expansions.insert(all_expansions.end(), sub_expansions.begin(), sub_expansions.end()); + } + return ResultOk>(all_expansions); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitExpandSub( + const TagFormat& format_template, const int index, const std::string& current_placeholder_name +) { + auto begin_result = DetectTemplatePlaceholderNames(format_template.begin); + auto end_result = DetectTemplatePlaceholderNames(format_template.end); + XGRAMMAR_DCHECK(begin_result.IsOk() && end_result.IsOk()); + auto begin_placeholder_name_opt = std::move(begin_result).Unwrap(); + auto end_placeholder_name_opt = std::move(end_result).Unwrap(); + + if (begin_placeholder_name_opt == std::nullopt && end_placeholder_name_opt == std::nullopt) { + // No placeholders in begin and end, only expand content. + auto content_expansions_result = + VisitExpand(*format_template.content, index, current_placeholder_name); + if (content_expansions_result.IsErr()) { + return ResultErr(std::move(content_expansions_result).UnwrapErr()); + } + auto content_expansions = std::move(content_expansions_result).Unwrap(); + std::vector expanded_formats; + for (const auto& content_expansion : content_expansions) { + TagFormat expanded_format{ + format_template.begin, + std::make_shared(content_expansion), + format_template.end, + }; + expanded_formats.push_back(expanded_format); + } + return ResultOk>(expanded_formats); + } + + // Otherwise, if the placeholder has been expanded in the previous level, expand normally. + std::string placeholder_name = begin_placeholder_name_opt.has_value() + ? begin_placeholder_name_opt.value() + : end_placeholder_name_opt.value(); + if (placeholder_name == current_placeholder_name) { + auto begin_result = ReplacePlaceHolder(format_template.begin, placeholder_name, index); + if (begin_result.IsErr()) { + return ResultErr(std::move(begin_result).UnwrapErr()); + } + std::string replaced_begin = std::move(begin_result).Unwrap(); + + auto end_result = ReplacePlaceHolder(format_template.end, placeholder_name, index); + if (end_result.IsErr()) { + return ResultErr(std::move(end_result).UnwrapErr()); + } + std::string replaced_end = std::move(end_result).Unwrap(); + + auto content_expansions_result = + VisitExpand(*format_template.content, index, current_placeholder_name); + if (content_expansions_result.IsErr()) { + return ResultErr(std::move(content_expansions_result).UnwrapErr()); + } + auto content_expansions = std::move(content_expansions_result).Unwrap(); + if (content_expansions.size() == 1) { + TagFormat expanded_format{ + replaced_begin, + std::make_shared(content_expansions[0]), + replaced_end, + }; + return ResultOk>({expanded_format}); + } else { + TagFormat expanded_format{ + replaced_begin, + std::make_shared(OrFormat{std::move(content_expansions)}), + replaced_end + }; + return ResultOk>({expanded_format}); + } + } + + // Otherwise, if it is expanding a different placeholder, raise an error. + if (index != kDefaultExpansionMode) { + return ResultErr(InvalidStructuralTagError( + "Mingled Template Expansion: " + placeholder_name + " and " + current_placeholder_name + )); + } + + // Expand for all values of the placeholder. + std::vector expanded_formats; + const auto& values = values_->at(placeholder_name); + for (int value_index = 0; value_index < static_cast(values.size()); ++value_index) { + auto begin_result = ReplacePlaceHolder(format_template.begin, placeholder_name, value_index); + if (begin_result.IsErr()) { + return ResultErr(std::move(begin_result).UnwrapErr()); + } + std::string replaced_begin = std::move(begin_result).Unwrap(); + + auto end_result = ReplacePlaceHolder(format_template.end, placeholder_name, value_index); + if (end_result.IsErr()) { + return ResultErr(std::move(end_result).UnwrapErr()); + } + std::string replaced_end = std::move(end_result).Unwrap(); + + auto content_expansions_result = + VisitExpand(*format_template.content, value_index, placeholder_name); + if (content_expansions_result.IsErr()) { + return ResultErr(std::move(content_expansions_result).UnwrapErr()); + } + auto content_expansions = std::move(content_expansions_result).Unwrap(); + if (content_expansions.size() == 1) { + TagFormat expanded_format{ + replaced_begin, + std::make_shared(content_expansions[0]), + replaced_end, + }; + expanded_formats.push_back(expanded_format); + } else { + TagFormat expanded_format{ + replaced_begin, + std::make_shared(OrFormat{std::move(content_expansions)}), + replaced_end + }; + expanded_formats.push_back(expanded_format); + } + } + return ResultOk>(std::move(expanded_formats)); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitExpandSub( + const TriggeredTagsFormat& format_template, + const int index, + const std::string& current_placeholder_name +) { + std::vector expanded_tags; + for (const auto& tag : format_template.tags) { + auto sub_expansions_result = VisitExpand(tag, index, current_placeholder_name); + if (sub_expansions_result.IsErr()) { + return ResultErr(std::move(sub_expansions_result).UnwrapErr()); + } + auto sub_expansions = std::move(sub_expansions_result).Unwrap(); + for (const auto& sub_expansion : sub_expansions) { + expanded_tags.push_back(std::get(sub_expansion)); + } + } + // If no tags are expanded, return an empty vector. + if (expanded_tags.empty()) { + XGRAMMAR_LOG(WARNING) << "No tags expanded in TriggeredTagsFormat, possibly due to no values " + "provided for the triggers."; + return ResultOk>({}); + } + TriggeredTagsFormat expanded_format{ + format_template.triggers, + std::move(expanded_tags), + format_template.at_least_one, + format_template.stop_after_first + }; + return ResultOk>({expanded_format}); +} + +Result, StructuralTagError> StructuralTagTemplateFiller::VisitExpandSub( + const TagsWithSeparatorFormat& format_template, + const int index, + const std::string& current_placeholder_name +) { + std::vector expanded_tags; + for (const auto& tag : format_template.tags) { + auto sub_expansions_result = VisitExpand(tag, index, current_placeholder_name); + if (sub_expansions_result.IsErr()) { + return ResultErr(std::move(sub_expansions_result).UnwrapErr()); + } + auto sub_expansions = std::move(sub_expansions_result).Unwrap(); + for (const auto& sub_expansion : sub_expansions) { + expanded_tags.push_back(std::get(sub_expansion)); + } + } + // If no tags are expanded, return an empty vector. + if (expanded_tags.empty()) { + XGRAMMAR_LOG(WARNING + ) << "No tags expanded in TagsWithSeparatorFormat, possibly due to no values " + "provided for the tags."; + return ResultOk>({}); + } + TagsWithSeparatorFormat expanded_format{ + std::move(expanded_tags), + format_template.separator, + format_template.at_least_one, + format_template.stop_after_first + }; + return ResultOk>({expanded_format}); +} + +Result StructuralTagTemplateFiller::Apply( + const StructuralTag& template_structural_tag, + const std::unordered_map< + std::string, + std::vector>>& values +) { + format_to_placeholder_names_.clear(); + values_ = &values; + + // Step 1. Visit the template structural tag to collect all placeholder names. + auto result = Visit(template_structural_tag.format); + if (result.IsErr()) { + return ResultErr(std::move(result).UnwrapErr()); + } + + // Step 2. Analyze if all placeholder names have corresponding values. + auto placeholder_names = std::move(result).Unwrap(); + for (const auto& placeholder_name : placeholder_names) { + if (values.find(placeholder_name) == values.end()) { + return ResultErr(InvalidStructuralTagError( + "No values provided for template placeholder: " + placeholder_name + )); + } + } + + // Step 3. Fill the template structural tag with the values. + auto expanded_formats_result = VisitExpand(template_structural_tag.format); + if (expanded_formats_result.IsErr()) { + return ResultErr(std::move(expanded_formats_result).UnwrapErr()); + } + auto expanded_formats = std::move(expanded_formats_result).Unwrap(); + if (expanded_formats.size() == 1) { + return ResultOk(StructuralTag{std::move(expanded_formats[0])}); + } else if (expanded_formats.size() > 1) { + OrFormat or_format(std::move(expanded_formats)); + return ResultOk(StructuralTag{std::move(or_format)}); + } + XGRAMMAR_LOG(WARNING + ) << "No formats expanded from the template structural tag, any text will be allowed."; + return ResultOk({StructuralTag{AnyTextFormat()}}); +} + +bool StructuralTagTemplateFiller::HasUnfilledPlaceholders(const StructuralTag& structural_tag) { + auto result = Visit(structural_tag.format); + if (result.IsErr()) { + return false; + } + auto placeholder_names = std::move(result).Unwrap(); + return !placeholder_names.empty(); +} + +Result FillTemplateWithValues( + const StructuralTag& template_structural_tag, + const std::unordered_map< + std::string, + std::vector>>& values +) { + return StructuralTagTemplateFiller().Apply(template_structural_tag, values); +} + +/************** StructuralTag To Strings **************/ + +std::string SequenceFormat::ToString() const { + std::string repr = "{type: sequence, elements: ["; + for (size_t i = 0; i < elements.size(); ++i) { + repr += std::visit([&](auto&& arg) -> std::string { return arg.ToString(); }, elements[i]); + if (i != elements.size() - 1) { + repr += ", "; + } + } + repr += "]}"; + return repr; +} + +std::string OrFormat::ToString() const { + std::string repr = "{type: or, elements: ["; + for (size_t i = 0; i < elements.size(); ++i) { + repr += std::visit([&](auto&& arg) -> std::string { return arg.ToString(); }, elements[i]); + if (i != elements.size() - 1) { + repr += ", "; + } + } + repr += "]}"; + return repr; +} + +std::string TagFormat::ToString() const { + std::string repr = "{type: tag, begin: " + begin + ", end: " + end + ", content: "; + repr += std::visit([&](auto&& arg) -> std::string { return arg.ToString(); }, *content); + repr += "}"; + return repr; +} + +std::string TriggeredTagsFormat::ToString() const { + std::string repr = "{type: triggered_tags, triggers: ["; + for (size_t i = 0; i < triggers.size(); ++i) { + repr += triggers[i]; + if (i != triggers.size() - 1) { + repr += ", "; + } + } + repr += "], tags: ["; + for (size_t i = 0; i < tags.size(); ++i) { + repr += tags[i].ToString(); + if (i != tags.size() - 1) { + repr += ", "; + } + } + repr += "], at_least_one: " + std::string(at_least_one ? "true" : "false"); + repr += ", stop_after_first: " + std::string(stop_after_first ? "true" : "false") + "}"; + return repr; +} + +std::string TagsWithSeparatorFormat::ToString() const { + std::string repr = "{type: tags_with_separator, tags: ["; + for (size_t i = 0; i < tags.size(); ++i) { + repr += tags[i].ToString(); + if (i != tags.size() - 1) { + repr += ", "; + } + } + repr += "], separator: " + separator; + repr += ", at_least_one: " + std::string(at_least_one ? "true" : "false"); + repr += ", stop_after_first: " + std::string(stop_after_first ? "true" : "false") + "}"; + return repr; +} + /************** StructuralTag Conversion Public API **************/ Result StructuralTagToGrammar(const std::string& structural_tag_json) { @@ -1076,4 +2219,33 @@ Result StructuralTagToGrammar(const std::string& st return ResultOk(GrammarNormalizer::Apply(std::move(result).Unwrap())); } +Result ApplyStructuralTagTemplate( + const std::string& structural_tag_template_json, + const std::unordered_map< + std::string, + std::vector>>& values +) { + // Step 1. Parse the template. + auto structural_tag_result_raw = StructuralTagParser::FromJSON(structural_tag_template_json); + if (structural_tag_result_raw.IsErr()) { + return ResultErr(std::move(structural_tag_result_raw).UnwrapErr()); + } + auto structural_tag_raw = std::move(structural_tag_result_raw).Unwrap(); + // Step 2. Replace the elements. + auto filled_stag_result = FillTemplateWithValues(structural_tag_raw, values); + if (filled_stag_result.IsErr()) { + return ResultErr(std::move(filled_stag_result).UnwrapErr()); + } + auto filled_stag = std::move(filled_stag_result).Unwrap(); + auto err = StructuralTagAnalyzer().Analyze(&filled_stag); + if (err.has_value()) { + return ResultErr(std::move(err).value()); + } + auto result = StructuralTagGrammarConverter().Convert(filled_stag); + if (result.IsErr()) { + return ResultErr(std::move(result).UnwrapErr()); + } + return ResultOk(GrammarNormalizer::Apply(std::move(result).Unwrap())); +} + } // namespace xgrammar diff --git a/cpp/structural_tag.h b/cpp/structural_tag.h index fe99edb8..e6713cc7 100644 --- a/cpp/structural_tag.h +++ b/cpp/structural_tag.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -55,35 +56,41 @@ struct ConstStringFormat { static constexpr const char* type = "const_string"; std::string value; ConstStringFormat(std::string value) : value(std::move(value)) {} + std::string ToString() const { return "{type: const_string, value: " + value + "}"; } }; struct JSONSchemaFormat { static constexpr const char* type = "json_schema"; std::string json_schema; JSONSchemaFormat(std::string json_schema) : json_schema(std::move(json_schema)) {} + std::string ToString() const { return "{type: json_schema, json_schema: " + json_schema + "}"; } }; struct QwenXmlParameterFormat { static constexpr const char* type = "qwen_xml"; std::string xml_schema; QwenXmlParameterFormat(std::string xml_schema) : xml_schema(std::move(xml_schema)) {} + std::string ToString() const { return "{type: qwen_xml, xml_schema: " + xml_schema + "}"; } }; struct GrammarFormat { static constexpr const char* type = "grammar"; std::string grammar; GrammarFormat(std::string grammar) : grammar(std::move(grammar)) {} + std::string ToString() const { return "{type: grammar, grammar: " + grammar + "}"; } }; struct RegexFormat { static constexpr const char* type = "regex"; std::string pattern; RegexFormat(std::string pattern) : pattern(std::move(pattern)) {} + std::string ToString() const { return "{type: regex, pattern: " + pattern + "}"; } }; struct AnyTextFormat { static constexpr const char* type = "any_text"; AnyTextFormat() {} + std::string ToString() const { return "{type: any_text}"; } private: // Detected in StructuralTagAnalyzer @@ -98,6 +105,7 @@ struct SequenceFormat { static constexpr const char* type = "sequence"; std::vector elements; SequenceFormat(std::vector elements) : elements(std::move(elements)) {} + std::string ToString() const; private: // Detected in StructuralTagAnalyzer @@ -110,6 +118,7 @@ struct OrFormat { static constexpr const char* type = "or"; std::vector elements; OrFormat(std::vector elements) : elements(std::move(elements)) {} + std::string ToString() const; private: // Detected in StructuralTagAnalyzer @@ -126,6 +135,7 @@ struct TagFormat { TagFormat(std::string begin, std::shared_ptr content, std::string end) : begin(std::move(begin)), content(std::move(content)), end(std::move(end)) {} + std::string ToString() const; }; struct TriggeredTagsFormat { @@ -145,6 +155,7 @@ struct TriggeredTagsFormat { tags(std::move(tags)), at_least_one(at_least_one), stop_after_first(stop_after_first) {} + std::string ToString() const; private: // Detected in StructuralTagAnalyzer @@ -167,6 +178,7 @@ struct TagsWithSeparatorFormat { separator(std::move(separator)), at_least_one(at_least_one), stop_after_first(stop_after_first) {} + std::string ToString() const; private: // Detected in StructuralTagAnalyzer @@ -193,6 +205,21 @@ struct StructuralTag { */ Result StructuralTagToGrammar(const std::string& structural_tag_json); +/*! + * \brief Apply a structural tag template with given values to generate a structural tag, and + * convert it to a grammar. + * \param structural_tag_template_json The JSON string of the structural tag template. + * \param values The values to apply to the template. + * \return A grammar if the application is successful, otherwise an error message in + * StructuralTagError. + */ +Result ApplyStructuralTagTemplate( + const std::string& structural_tag_template_json, + const std::unordered_map< + std::string, + std::vector>>& values +); + } // namespace xgrammar #endif // XGRAMMAR_STRUCTURAL_TAG_H_ diff --git a/docs/api/python/structural_tag.rst b/docs/api/python/structural_tag.rst index 59b634da..8fa5ab1c 100644 --- a/docs/api/python/structural_tag.rst +++ b/docs/api/python/structural_tag.rst @@ -72,3 +72,8 @@ Combinatorial Formats .. autoclass:: TagsWithSeparatorFormat :show-inheritance: :exclude-members: model_config + +Template Structural Tag +----------------------- + +.. autofunction:: get_builtin_template_structural_tag diff --git a/docs/index.rst b/docs/index.rst index 223364bb..48533f2a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -20,6 +20,7 @@ The mission of this project is to bring flexible zero-overhead structure generat tutorials/advanced_topics tutorials/structural_tag tutorials/advanced_structural_tag + tutorials/template_structural_tag tutorials/engine_integration tutorials/json_generation tutorials/ebnf_guided_generation diff --git a/docs/tutorials/template_structural_tag.md b/docs/tutorials/template_structural_tag.md new file mode 100644 index 00000000..5215cce1 --- /dev/null +++ b/docs/tutorials/template_structural_tag.md @@ -0,0 +1,133 @@ +# Template Structural Tags + +Based on the [structural tags](./structural_tag.md), template structural tags provides a more convineient method for users to generate grammars to describe the grammar to constrain the LLMs' output. +In general, Template structural tags support placeholders in the structural tags, which are in the form of `{lists[].arg}`. These values will be automatically expanded with the user's input values. Users can use `xgrammar.Grammar.apply_structural_tag_template(template_json_str: Union[str, Dict[str, Any]], **kwargs: List[Dict[str, Any]])` to generate a `xgrammar.Grammar`. Here, `template_json_str` is the template structural tags, and kwargs are a series of values. For example, the template structural tag `stag` contains `{A[].name}` and `{A[].age}`, then the users can use `xgrammar.Grammar.apply_structural_tag_template(stag, A=A)` to generate the grammar, where `A=[{"name":..., "age":...},.{"name":...,"age":...}, ...]`. This function will replace the placeholders automatically. + +## Template Placeholders in Formats + +1. `const_string` + +Each `const_string` format can contain multiple placeholders, but they must be from the same value mapping. for example, `A=[{"begin": "It is", "end":"."}, {"begin": "Is it", "end": "?"}]`, and the template is: + +```json +{ + "type": "const_string", + "value": "{A[].begin} a dog{A[].end}" +} +``` + +Is allowed. And the output is constrained to `It is a dog.` and `Is it a dog?`. However, if the provided values are `Begin=[{"word": "It is"}, {"word": "Is it"}], End=[{"word": "."}, {"word": "?"}]`, and the template is + +```json +{ + "type": "const_string", + "value": "{Begin[].word} a dog{End[].word}" +} +``` + +cannot be compiled. Because the meaning is ambiguous, and we call these templates **mingled**, we cannot compile them. This template format will be expanded into a `const_string` format or an `or` format, or return nothing. + +2. `grammar`, `json_schema`, `qwen_xml_parameter`, `regex` + +If the template placeholder is in these formats' value, **only if** the value is exactly the placeholder. For example, + +```json +{ + "type": "json_schemas", + "json_schema": "{schemas[].schema}" +} +``` + +can be automatically replaced with the given `schemas`. However, this format will not be replaced: + +```json +{ + "type": "json_schemas", + "json_schema": { + "type": "{schemas[].schema}" + } +} +``` + +The same rule holds for the four formats. This template format will be expanded into a `json_schemas` format, or an `or` format, or return nothing. + +3. `tag` + +For a tag, it is allowed to contain a placeholder in the `begin` and `end` fields. It is also okay if the `content` field is also a template format. However, as the same as `const_string`, `begin` and `end` can contain multiple placeholders, but they must be from the same value mapping. Otherwise, the template is mingled and cannot be compiled. For example, this is a valid tag template: + +```json +{ + "type": "tag", + "begin": "" +} +``` + +This format will be expanded into a `tag` format or an `or` format, or a series of `tag` formats in `triggered_tags`, `tags_with_separator`, or return nothing. + +4. `triggered_tags`, `tags_with_separator` + +Template placeholders are not allowed in `triggers` and `separators`. + +5. `sequence`, `or`, `any_text` + +Template placeholders cannot be directly contained in these formats. +## Valid Template Structural Tags +Not all the template structural tags are valid. For example, the mingled formats mentioned above are not valid template structural tags. Besides, there are some other situations where the template structural tag cannot be compiled. For example: + +```json +{ + "type": "sequence", + "elements": [ + { + "type": "const_string", + "value": "{A[].value1}" + }, + { + "type": "const_string", + "value": "{B[].value1}" + }, + { + "type": "const_string", + "value": "{A[].value2}" + }, + { + "type": "const_string", + "value": "{B[].value2}" + }, + ] +} +``` + +cannot be compiled because we cannot analyze the meaning of the sequence. However, + +```json +{ + "type": "sequence", + "elements": [ + { + "type": "const_string", + "value": "{A[].value1}" + }, + { + "type": "const_string", + "value": "{B[].value}" + }, + { + "type": "const_string", + "value": "{A[].value2}" + }, + ] +} +``` + +can be compiled. It will be interpreted as a pair of `A.value1`, `A.value2`, and an arbitrary `B.value`. Basically, all the invalid template structural tags are in similar situations: we cannot divide each template placeholder's scope properly. In the former invalid one, both `A` and `B`'s scopes are the `sequence` format. In the latter valid one, `A`'s scope is the `sequence` format, while `B`'s scope is the `const_string` format. + +## Builtin Template Structural Tags + +There are several builtin template structural tags, designed for different LLMs' tool-calling formats. Users can get the builtin template structural tags with `xgrammar.structural_tag.get_builtin_template_structural_tag(format_type: str)`. Currently, these `format_type`s are supported: `Llama`, `Kimi`, `Deepseek`, `Qwen_Coder`, `Qwen`, `Harmony`. +For `Llama`, `Kimi`, `Deepseek`, `Qwen_Coder`, `Qwen`, users need to provide a value `tools=[{"name":..., "parameters":...}, ...]` for the template structural tags. For `Harmony` format, users must provide both `tools=[{"name":..., "parameters":...}]` and `builtin_tools=[{"name":..., "parameters":...}, ...]` for the template. These templates will force the LLMs to output the correct function-calling formats, with other natural language outputs. diff --git a/python/xgrammar/__init__.py b/python/xgrammar/__init__.py index cf81929f..58fdbcc7 100644 --- a/python/xgrammar/__init__.py +++ b/python/xgrammar/__init__.py @@ -23,5 +23,5 @@ get_bitmask_shape, reset_token_bitmask, ) -from .structural_tag import StructuralTag +from .structural_tag import StructuralTag, get_builtin_template_structural_tag from .tokenizer_info import TokenizerInfo, VocabType diff --git a/python/xgrammar/grammar.py b/python/xgrammar/grammar.py index 93f8e348..1e560cf3 100644 --- a/python/xgrammar/grammar.py +++ b/python/xgrammar/grammar.py @@ -1,6 +1,7 @@ """This module provides classes representing grammars.""" import json +import sys from typing import Any, Dict, List, Optional, Tuple, Type, Union, overload from pydantic import BaseModel @@ -431,3 +432,65 @@ def deserialize_json(json_string: str) -> "Grammar": When the __VERSION__ field in the JSON string is not the same as the current version. """ return Grammar._create_from_handle(_core.Grammar.deserialize_json(json_string)) + + @staticmethod + def apply_structural_tag_template( + template_json_str: Union[str, Dict[str, Any]], **kwargs: List[Dict[str, Any]] + ) -> "Grammar": + """Apply a structural tag template to create a grammar. The template is a JSON string + representing a structural tag, with placeholders for the tag items, or a dictionary. + The placeholders are specified as keyword arguments, where the key is the placeholder + name, and the value is a list of tag items to replace the placeholder. + + Parameters + ---------- + template_json_str : Union[str, Dict[str, Any]] + The structural tag template as a JSON string or a dictionary. + **kwargs : List[Dict[str, Any]] + The placeholders and their corresponding tag items. + + Returns + ------- + grammar : Grammar + The constructed grammar from the structural tag template. + + Raises + ------ + ValueError + When the template_json_str is not a string or a dictionary. + TypeError + When the values in kwargs are not lists of dictionaries. + InvalidJSONError + When the template_json_str is not a valid JSON string. + InvalidStructuralTagError + When the structural tag template is not valid, or the values for + the placeholders are not found in kwargs. + """ + + if isinstance(template_json_str, dict): + template_json_str = json.dumps(template_json_str) + if not isinstance(template_json_str, str): + raise ValueError("template_json_str must be a string or a dictionary") + + for key, values in kwargs.items(): + if not isinstance(values, list): + raise TypeError(f"Value for {key} must be a list, got {type(values)}") + for item in values: + if not isinstance(item, dict): + raise TypeError(f"Items in {key} must be dictionaries, got {type(item)}") + for item_key, value in item.items(): + if isinstance(value, str): + continue + if isinstance(value, dict): + item[item_key] = json.dumps(value) + else: + item[item_key] = str(value) + # warning + print( + f"Warning: {item_key} value {value} is not a string or dict, converted to string", + file=sys.stderr, + ) + + return Grammar._create_from_handle( + _core.Grammar.apply_structural_tag_template(template_json_str, **kwargs) + ) diff --git a/python/xgrammar/structural_tag.py b/python/xgrammar/structural_tag.py index bcde2bc3..05862fea 100644 --- a/python/xgrammar/structural_tag.py +++ b/python/xgrammar/structural_tag.py @@ -340,3 +340,259 @@ def from_json(json_str: Union[str, Dict[str, Any]]) -> "StructuralTag": "StructuralTagItem", "StructuralTag", ] + +# ---------- Template Structural Tag ---------- + +_structural_tag_registry = {} + + +def _register_template_structural_tag_format(name: str): + """Register a structural tag format.""" + + def decorator(func): + _structural_tag_registry[name] = func + return func + + return decorator + + +def get_builtin_template_structural_tag(format_type: str) -> Dict[str, Any]: + """Get builtin template structural tag format by format type. + In all the template structural tag formats, users should provide + a list of tools, each tool should have a "name" and "parameters" field. + to use the template structural tag format. Besides, for the OpenAI Harmony Response Format, + users should also provide a list of builtin tools, each builtin tool should have a "name" + and "parameters" field. + + Examples + -------- + + .. code-block:: python + + from xgrammar import structural_tag, Grammar + tools = [ + {"name": "tool1", "parameters": {"param1": {"type": "string"}}}, + {"name": "tool2", "parameters": {"param2": {"type": "integer"}}}, + ] + builtin_tools = [ + {"name": "builtin_tool1", "parameters": {"param1": {"type": "string"}}}, + {"name": "builtin_tool2", "parameters": {"param2": {"type": "integer"}}}, + ] + template_structural_tag = structural_tag.get_builtin_template_structural_tag("Harmony") + grammar = Grammar.apply_template_structural_tag(template_structural_tag, tools=tools, builtin_tools=builtin_tools) + + The above grammar can be used to construct a grammar that matches the function calling + format of the specified model. + + + + Parameters + ---------- + format_type : str + The format type, must be one of the registered format types: + "Llama", "Kimi", "Deepseek", "Qwen_Coder", "Qwen", "Harmony". + + Returns + ------- + Dict[str, Any] + A template structural tag format dictionary. + + Raises + ------ + ValueError + If the format type is unknown. + + """ + func = _structural_tag_registry.get(format_type) + if func is None: + support_types = list(_structural_tag_registry.keys()) + raise ValueError(f"Unknown format type: {format_type}, support types: {support_types}") + return func() + + +@_register_template_structural_tag_format("Llama") +def get_llama_style_template_structural_tag() -> Dict[str, Any]: + """Get Llama style structural tag format. + + Returns + ------- + Dict[str, Any] + A template structural tag format dictionary. + This format is used by Llama 3 and other models that follow the same style. + """ + return { + "type": "structural_tag", + "format": { + "type": "triggered_tags", + "triggers": ['{"name": '], + "tags": [ + { + "begin": '{"name": "{tools[].name}", "parameters": ', + "content": {"type": "json_schema", "json_schema": "{tools[].parameters}"}, + "end": "}", + } + ], + }, + } + + +@_register_template_structural_tag_format("Kimi") +def get_kimi_style_template_structural_tag() -> Dict[str, Any]: + """Get Kimi style structural tag format. + + Returns + ------- + Dict[str, Any] + A template structural tag format dictionary. + This format is used by Kimi-v2 and other models that follow the same style. + """ + return { + "type": "structural_tag", + "format": { + "type": "triggered_tags", + "triggers": ["<|tool_call_begin|>"], + "tags": [ + { + "begin": "<|tool_call_begin|>{tools[].name}<|tool_call_argument_begin|>", + "content": {"type": "json_schema", "json_schema": "{tools[].parameters}"}, + "end": "<|tool_call_end|>", + } + ], + }, + } + + +@_register_template_structural_tag_format("Deepseek") +def get_deepseek_style_template_structural_tag() -> Dict[str, Any]: + """Get Deepseek style structural tag format. + + Returns + ------- + Dict[str, Any] + A template structural tag format dictionary. + This format is used by Deepseek-v3.1 and other models that follow the same style. + """ + return { + "type": "structural_tag", + "format": { + "type": "triggered_tags", + "triggers": ["<|tool▁calls▁begin|><|tool▁call▁begin|>"], + "tags": [ + { + "begin": "<|tool▁calls▁begin|><|tool▁call▁begin|>{tools[].name}<|tool▁sep|>", + "content": {"type": "json_schema", "json_schema": "{tools[].parameters}"}, + "end": "<|tool▁call▁end|>", + } + ], + }, + } + + +@_register_template_structural_tag_format("Qwen_Coder") +def get_qwen_coder_style_template_structural_tag() -> Dict: + """Get Qwen Coder style structural tag format. + + Returns + ------- + Dict[str, Any] + A template structural tag format dictionary. + This format is used by Qwen3 Coder and other models that follow the same style. + """ + return { + "type": "structural_tag", + "format": { + "type": "triggered_tags", + "triggers": ["", + "content": { + "type": "qwen_xml_parameter", + "json_schema": "{tools[].parameters}", + }, + "end": "", + } + ], + }, + } + + +@_register_template_structural_tag_format("Qwen") +def get_qwen_style_template_structural_tag() -> Dict: + """Get Qwen style structural tag format. + + Returns + ------- + Dict[str, Any] + A template structural tag format dictionary. + This format is used by Qwen3 and other models that follow the same style. + """ + return { + "type": "structural_tag", + "format": { + "type": "triggered_tags", + "triggers": [""], + "tags": [ + { + "begin": '{"name": "{tools[].name}", "arguments": ', + "content": {"type": "json_schema", "json_schema": "{tools[].parameters}"}, + "end": "}", + } + ], + }, + } + + +@_register_template_structural_tag_format("Harmony") +def get_harmony_style_template_structural_tag() -> Dict: + """Get harmony style structural tag format. + + Returns + ------- + Dict[str, Any] + A template structural tag format dictionary. + This format is in OpenAI Harmony Response Format, which is used by GPT-oss + and other models that follow the same style. + """ + return { + "type": "structural_tag", + "format": { + "type": "triggered_tags", + "triggers": ["<|start|>"], + "tags": [ + { + "type": "tag", + "begin": "<|start|>assistant<|channel|>analysis<|message|>", + "content": {"type": "any_text"}, + "end": "<|end|>", + }, + { + "type": "tag", + "begin": "<|start|>assistant<|channel|>final<|message|>", + "content": {"type": "any_text"}, + "end": "<|return|>", + }, + { + "type": "tag", + "begin": "<|start|>assistant<|channel|>final<|message|>", + "content": {"type": "any_text"}, + "end": "<|call|>", + }, + { + "type": "tag", + "begin": "<|start|>assistant<|channel|>commentary to={tools[].name}<|constrain|>json<|message|>", + "content": {"type": "json_schema", "json_schema": "{tools[].parameters}"}, + "end": "<|end|>", + }, + { + "type": "tag", + "begin": "<|start|>assistant<|channel|>analysis to={builtin_tools[].name}<|message|>", + "content": { + "type": "json_schema", + "json_schema": "{builtin_tools[].parameters}", + }, + "end": "<|end|>", + }, + ], + }, + } diff --git a/tests/python/test_structural_tag_converter.py b/tests/python/test_structural_tag_converter.py index 1b0ba3d2..d68437b5 100644 --- a/tests/python/test_structural_tag_converter.py +++ b/tests/python/test_structural_tag_converter.py @@ -90,6 +90,39 @@ def check_stag_with_instance( profiler.profile_stag(structural_tag_format, instance) +def check_template_stag_with_grammar( + structural_tag_format: Dict[str, Any], + expected_grammar_ebnf: str, + **kwargs: List[Dict[str, str]], +): + if structural_tag_format["type"] == "structural_tag": + grammar = xgr.Grammar.apply_structural_tag_template(structural_tag_format, **kwargs) + else: + structural_tag = {"type": "structural_tag", "format": structural_tag_format} + grammar = xgr.Grammar.apply_structural_tag_template(structural_tag, **kwargs) + assert str(grammar) == expected_grammar_ebnf + + +def check_template_stag_with_instance( + structural_tag_format: Union[Dict[str, Any], StructuralTag], + instance: str, + is_accepted: bool = True, + debug_print: bool = False, + **kwargs: List[Dict[str, str]], +): + if isinstance(structural_tag_format, StructuralTag) or ( + "type" in structural_tag_format and structural_tag_format["type"] == "structural_tag" + ): + stag_grammar = xgr.Grammar.apply_structural_tag_template(structural_tag_format, **kwargs) + else: + structural_tag = {"type": "structural_tag", "format": structural_tag_format} + stag_grammar = xgr.Grammar.apply_structural_tag_template(structural_tag, **kwargs) + accepted = _is_grammar_accept_string(stag_grammar, instance, debug_print=debug_print) + assert accepted == is_accepted + if PROFILER_ON: + profiler.profile_stag(structural_tag_format, instance) + + const_string_stag_grammar = [ ( {"type": "const_string", "value": "Hello!"}, @@ -1979,5 +2012,1617 @@ def test_from_structural_tag_with_structural_tag_instance( check_stag_with_instance(stag, instance, is_accepted) +const_string_template_values_stag_grammar_instance_accepted = [ + ( + {"type": "const_string", "value": "The value is: {strings[].value}."}, + [{"value": "a"}, {"value": "b"}, {"value": "c"}], + r"""const_string ::= (("The value is: a.")) +const_string_1 ::= (("The value is: b.")) +const_string_2 ::= (("The value is: c.")) +or ::= ((const_string) | (const_string_1) | (const_string_2)) +root ::= ((or)) +""", + [ + ("The value is: a.", True), + ("The value is: b.", True), + ("The value is: c.", True), + ("The value is: d.", False), + ], + ), + ( + {"type": "const_string", "value": "{strings[].value}"}, + [{"value": "a"}, {"value": "b"}, {"value": "c"}], + r"""const_string ::= (("a")) +const_string_1 ::= (("b")) +const_string_2 ::= (("c")) +or ::= ((const_string) | (const_string_1) | (const_string_2)) +root ::= ((or)) +""", + [("a", True), ("b", True), ("c", True), ("d", False)], + ), + ( + {"type": "const_string", "value": "The value is: {strings[].value}"}, + [{"value": "a"}, {"value": "b"}, {"value": "c"}], + r"""const_string ::= (("The value is: a")) +const_string_1 ::= (("The value is: b")) +const_string_2 ::= (("The value is: c")) +or ::= ((const_string) | (const_string_1) | (const_string_2)) +root ::= ((or)) +""", + [ + ("The value is: a", True), + ("The value is: b", True), + ("The value is: c", True), + ("The value is: d", False), + ], + ), + ( + {"type": "const_string", "value": "{strings[].value}是"}, + [{"value": "a"}, {"value": "b"}, {"value": "c"}], + r"""const_string ::= (("a\u662f")) +const_string_1 ::= (("b\u662f")) +const_string_2 ::= (("c\u662f")) +or ::= ((const_string) | (const_string_1) | (const_string_2)) +root ::= ((or)) +""", + [("a是", True), ("b是", True), ("c是", True), ("d是", False)], + ), +] + + +@pytest.mark.parametrize( + "template_stag_format, template_values, expected_grammar, instance_is_accepted_tuples", + const_string_template_values_stag_grammar_instance_accepted, +) +def test_const_string_template_values( + template_stag_format: Dict[str, Any], + template_values: List[Dict[str, Any]], + expected_grammar: str, + instance_is_accepted_tuples: List[Tuple[str, bool]], +): + """Test const_string format with template values""" + check_template_stag_with_grammar( + template_stag_format, expected_grammar, strings=template_values + ) + + for instance, is_accepted in instance_is_accepted_tuples: + check_template_stag_with_instance( + template_stag_format, instance, is_accepted, strings=template_values + ) + + +def test_const_string_template_values_with_mingled_templates(): + mingled_format = { + "type": "const_string", + "value": "{string_a[].value} and {string_b[].value} are mingled!", + } + structural_tag = {"type": "structural_tag", "format": mingled_format} + with pytest.raises(Exception) as exc_info: + xgr.Grammar.apply_structural_tag_template( + structural_tag, string_a=[{"value": "1"}], string_b=[{"value": "2"}] + ) + expected_info = ( + "Invalid structural tag error: Multiple different placeholder names " + "found in the same string: '{string_a[].value} and {string_b[].value} " + "are mingled!'" + ) + assert str(exc_info.value) == expected_info + + +json_schema_template_values_stag_grammar_instance_accepted = [ + ( + {"type": "json_schema", "json_schema": "{schemas[].value}"}, + [ + { + "value": r"""{"type":"object", "properties": {"arg": {"type": "string"}}, "required": ["arg"]}""" + }, + {"value": r"""{"type":"string"}"""}, + ], + r"""basic_escape ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub) | ("\\" basic_escape basic_string_sub)) (=([ \n\t]* [,}\]:])) +basic_any ::= ((basic_number) | (basic_string) | (basic_boolean) | (basic_null) | (basic_array) | (basic_object)) +basic_integer ::= (("0") | (basic_integer_1 [1-9] [0-9]*)) +basic_number ::= ((basic_number_1 basic_number_7 basic_number_3 basic_number_6)) +basic_string ::= (("\"" basic_string_sub)) +basic_boolean ::= (("true") | ("false")) +basic_null ::= (("null")) +basic_array ::= (("[" [ \n\t]* basic_any basic_array_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object ::= (("{" [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root ::= (("{" [ \n\t]* "\"arg\"" [ \n\t]* ":" [ \n\t]* basic_string [ \n\t]* "}")) +basic_integer_1 ::= ("" | ("-")) +basic_number_1 ::= ("" | ("-")) +basic_number_2 ::= (([0-9] basic_number_2) | ([0-9])) +basic_number_3 ::= ("" | ("." basic_number_2)) +basic_number_4 ::= ("" | ([+\-])) +basic_number_5 ::= (([0-9] basic_number_5) | ([0-9])) +basic_number_6 ::= ("" | ([eE] basic_number_4 basic_number_5)) +basic_array_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any basic_array_1)) +basic_object_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1)) +basic_number_7 ::= (("0") | ([1-9] [0-9]*)) +basic_escape_1 ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub_1 ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub_1) | ("\\" basic_escape_1 basic_string_sub_1)) (=([ \n\t]* [,}\]:])) +basic_any_1 ::= ((basic_number_8) | (basic_string_1) | (basic_boolean_1) | (basic_null_1) | (basic_array_2) | (basic_object_2)) +basic_integer_2 ::= (("0") | (basic_integer_1_1 [1-9] [0-9]*)) +basic_number_8 ::= ((basic_number_1_1 basic_number_7_1 basic_number_3_1 basic_number_6_1)) +basic_string_1 ::= (("\"" basic_string_sub_1)) +basic_boolean_1 ::= (("true") | ("false")) +basic_null_1 ::= (("null")) +basic_array_2 ::= (("[" [ \n\t]* basic_any_1 basic_array_1_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object_2 ::= (("{" [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_1 ::= ((basic_string_1)) +basic_integer_1_1 ::= ("" | ("-")) +basic_number_1_1 ::= ("" | ("-")) +basic_number_2_1 ::= (([0-9] basic_number_2_1) | ([0-9])) +basic_number_3_1 ::= ("" | ("." basic_number_2_1)) +basic_number_4_1 ::= ("" | ([+\-])) +basic_number_5_1 ::= (([0-9] basic_number_5_1) | ([0-9])) +basic_number_6_1 ::= ("" | ([eE] basic_number_4_1 basic_number_5_1)) +basic_array_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any_1 basic_array_1_1)) +basic_object_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1)) +basic_number_7_1 ::= (("0") | ([1-9] [0-9]*)) +or ::= ((root) | (root_1)) +root_2 ::= ((or)) +""", + [ + ('{"arg": "value"}', True), + ('{"arg": "another value"}', True), + ('{"arg": 123}', False), + ('{"arg": "value", "extra": "field"}', False), + ('{"arg": "value"', False), + ('"just a string"', True), + ], + ) +] + + +@pytest.mark.parametrize( + "template_stag_format, template_values, expected_grammar, instance_is_accepted_tuples", + json_schema_template_values_stag_grammar_instance_accepted, +) +def test_json_schema_template_values( + template_stag_format: Dict[str, Any], + template_values: List[Dict[str, Any]], + expected_grammar: str, + instance_is_accepted_tuples: List[Tuple[str, bool]], +): + """Test json_schema format with template values""" + check_template_stag_with_grammar( + template_stag_format, expected_grammar, schemas=template_values + ) + + for instance, is_accepted in instance_is_accepted_tuples: + check_template_stag_with_instance( + template_stag_format, instance, is_accepted, schemas=template_values + ) + + +def test_part_json_schema_template_failure(): + template_format = {"type": "json_schema", "json_schema": r"""{"type": types[].value}"""} + structural_tag = {"type": "structural_tag", "format": template_format} + types = [{"value": "object"}, {"value": "string"}, {"value": "integer"}] + with pytest.raises(Exception) as exc_info: + xgr.Grammar.apply_structural_tag_template(structural_tag, types=types) + expected_info = ( + "Invalid structural tag error: JSON schema format must have a json_schema field with " + "a object or boolean value" + ) + assert str(exc_info.value) == expected_info + + +qwen_template_values_stag_grammar_instance_accepted = [ + ( + {"type": "qwen_xml_parameter", "json_schema": "{schemas[].value}"}, + [ + { + "value": r"""{"type":"object", "properties": {"name": {"type": "string"}}, "required": ["name"]}""" + }, + { + "value": r"""{"type":"object", "properties": {"age": {"type": "integer"}}, "required": ["age"]}""" + }, + ], + r"""basic_escape ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub) | ("\\" basic_escape basic_string_sub)) (=([ \n\t]* [,}\]:])) +xml_escape ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +xml_entity ::= (("<") | (">") | ("&") | (""") | ("'")) +xml_string ::= ("" | ([^<>&\0-\x1f\\\r\n] xml_string) | ("\\" xml_escape xml_string) | (xml_entity xml_string)) (=([ \n\t]*)) +xml_variable_name ::= (([a-zA-Z_] [a-zA-Z0-9_]*)) +xml_string_0 ::= ((xml_string)) +xml_any ::= ((basic_number) | (xml_string) | (basic_boolean) | (basic_null) | (basic_array) | (basic_object)) +basic_any ::= ((basic_number) | (basic_string) | (basic_boolean) | (basic_null) | (basic_array) | (basic_object)) +basic_integer ::= (("0") | (basic_integer_1 [1-9] [0-9]*)) +basic_number ::= ((basic_number_1 basic_number_7 basic_number_3 basic_number_6)) +basic_string ::= (("\"" basic_string_sub)) +basic_boolean ::= (("true") | ("false")) +basic_null ::= (("null")) +basic_array ::= (("[" [ \n\t]* basic_any basic_array_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object ::= (("{" [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root ::= (([ \n\t]* "" [ \n\t]* xml_string_0 [ \n\t]* "")) +basic_integer_1 ::= ("" | ("-")) +basic_number_1 ::= ("" | ("-")) +basic_number_2 ::= (([0-9] basic_number_2) | ([0-9])) +basic_number_3 ::= ("" | ("." basic_number_2)) +basic_number_4 ::= ("" | ([+\-])) +basic_number_5 ::= (([0-9] basic_number_5) | ([0-9])) +basic_number_6 ::= ("" | ([eE] basic_number_4 basic_number_5)) +basic_array_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any basic_array_1)) +basic_object_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1)) +basic_number_7 ::= (("0") | ([1-9] [0-9]*)) +basic_escape_1 ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub_1 ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub_1) | ("\\" basic_escape_1 basic_string_sub_1)) (=([ \n\t]* [,}\]:])) +xml_escape_1 ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +xml_entity_1 ::= (("<") | (">") | ("&") | (""") | ("'")) +xml_string_1 ::= ("" | ([^<>&\0-\x1f\\\r\n] xml_string_1) | ("\\" xml_escape_1 xml_string_1) | (xml_entity_1 xml_string_1)) (=([ \n\t]*)) +xml_variable_name_1 ::= (([a-zA-Z_] [a-zA-Z0-9_]*)) +xml_string_0_1 ::= ((xml_string_1)) +xml_any_1 ::= ((basic_number_8) | (xml_string_1) | (basic_boolean_1) | (basic_null_1) | (basic_array_2) | (basic_object_2)) +basic_any_1 ::= ((basic_number_8) | (basic_string_1) | (basic_boolean_1) | (basic_null_1) | (basic_array_2) | (basic_object_2)) +basic_integer_2 ::= (("0") | (basic_integer_1_1 [1-9] [0-9]*)) +basic_number_8 ::= ((basic_number_1_1 basic_number_7_1 basic_number_3_1 basic_number_6_1)) +basic_string_1 ::= (("\"" basic_string_sub_1)) +basic_boolean_1 ::= (("true") | ("false")) +basic_null_1 ::= (("null")) +basic_array_2 ::= (("[" [ \n\t]* basic_any_1 basic_array_1_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object_2 ::= (("{" [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_prop_0 ::= (("0") | (root_prop_0_1 [1-9] [0-9]*)) +root_1 ::= (([ \n\t]* "" [ \n\t]* root_prop_0 [ \n\t]* "")) +basic_integer_1_1 ::= ("" | ("-")) +basic_number_1_1 ::= ("" | ("-")) +basic_number_2_1 ::= (([0-9] basic_number_2_1) | ([0-9])) +basic_number_3_1 ::= ("" | ("." basic_number_2_1)) +basic_number_4_1 ::= ("" | ([+\-])) +basic_number_5_1 ::= (([0-9] basic_number_5_1) | ([0-9])) +basic_number_6_1 ::= ("" | ([eE] basic_number_4_1 basic_number_5_1)) +basic_array_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any_1 basic_array_1_1)) +basic_object_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1)) +root_prop_0_1 ::= ("" | ("-")) +basic_number_7_1 ::= (("0") | ([1-9] [0-9]*)) +or ::= ((root) | (root_1)) +root_2 ::= ((or)) +""", + [ + ("value", True), + ("another value", True), + ("123", True), + ("value", True), + ("just a string", False), + ("25", True), + ("-5", True), + ("abc", False), + ], + ) +] + + +@pytest.mark.parametrize( + "template_stag_format, template_values, expected_grammar, instance_is_accepted_tuples", + qwen_template_values_stag_grammar_instance_accepted, +) +def test_qwen_template_values( + template_stag_format: Dict[str, Any], + template_values: List[Dict[str, Any]], + expected_grammar: str, + instance_is_accepted_tuples: List[Tuple[str, bool]], +): + """Test qwen_xml_parameter format with template values""" + check_template_stag_with_grammar( + template_stag_format, expected_grammar, schemas=template_values + ) + + for instance, is_accepted in instance_is_accepted_tuples: + check_template_stag_with_instance( + template_stag_format, instance, is_accepted, schemas=template_values + ) + + +def test_part_qwen_template_failure(): + mingled_format = {"type": "qwen_xml_parameter", "json_schema": r"""{"type": {types[].value}}"""} + structural_tag = {"type": "structural_tag", "format": mingled_format} + types = [{"value": "object"}, {"value": "string"}] + with pytest.raises(Exception) as exc_info: + xgr.Grammar.apply_structural_tag_template(structural_tag, types=types) + expected_info = ( + "Invalid structural tag error: Qwen XML Parameter format must have a json_schema field " + "with a object or boolean value" + ) + + assert str(exc_info.value) == expected_info + + +regex_template_values_stag_grammar_instance_accepted = [ + ( + {"type": "regex", "pattern": r"{patterns[].value}"}, + [{"value": r"123"}, {"value": r"[a-zA-Z]+"}], + r"""root ::= (("1" "2" "3")) +root_1 ::= ((root_1_1)) +root_1_1 ::= (([a-zA-Z] root_1_1) | ([a-zA-Z])) +or ::= ((root) | (root_1)) +root_2 ::= ((or)) +""", + [("123", True), ("abc", True), ("123abc", False), ("123abc456", False), ("", False)], + ) +] + + +@pytest.mark.parametrize( + "template_stag_format, template_values, expected_grammar, instance_is_accepted_tuples", + regex_template_values_stag_grammar_instance_accepted, +) +def test_regex_template_values( + template_stag_format: Dict[str, Any], + template_values: List[Dict[str, Any]], + expected_grammar: str, + instance_is_accepted_tuples: List[Tuple[str, bool]], +): + """Test regex format with template values""" + check_template_stag_with_grammar( + template_stag_format, expected_grammar, patterns=template_values + ) + + for instance, is_accepted in instance_is_accepted_tuples: + check_template_stag_with_instance( + template_stag_format, instance, is_accepted, patterns=template_values + ) + + +def test_part_regex_template_failure(): + mingled_format = {"type": "grammar", "pattern": r"{patterns[].value}!!!"} + structural_tag = {"type": "structural_tag", "format": mingled_format} + patterns = [{"value": r"123"}, {"value2": r"[a-zA-Z]+"}] + with pytest.raises(RuntimeError) as exc_info: + xgr.Grammar.apply_structural_tag_template(structural_tag, patterns=patterns) + assert exc_info is not None + + +grammar_template_values_stag_grammar_instance_accepted = [ + ( + {"type": "grammar", "grammar": "{grammars[].value}"}, + [{"value": 'root::= "a" | "b"'}, {"value": 'root ::= a+\na::= "c" | "d"'}], + r"""root ::= (("a") | ("b")) +root_1 ::= ((root_1_1)) +a ::= (("c") | ("d")) +root_1_1 ::= ((a root_1_1) | (a)) +or ::= ((root) | (root_1)) +root_2 ::= ((or)) +""", + [ + ("a", True), + ("b", True), + ("c", True), + ("d", True), + ("aa", False), + ("ab", False), + ("cc", True), + ], + ) +] + + +@pytest.mark.parametrize( + "template_stag_format, template_values, expected_grammar, instance_is_accepted_tuples", + grammar_template_values_stag_grammar_instance_accepted, +) +def test_grammar_template_values( + template_stag_format: Dict[str, Any], + template_values: List[Dict[str, Any]], + expected_grammar: str, + instance_is_accepted_tuples: List[Tuple[str, bool]], +): + """Test grammar format with template values""" + check_template_stag_with_grammar( + template_stag_format, expected_grammar, grammars=template_values + ) + + for instance, is_accepted in instance_is_accepted_tuples: + check_template_stag_with_instance( + template_stag_format, instance, is_accepted, grammars=template_values + ) + + +def test_part_grammar_template_failure(): + format = {"type": "grammar", "grammar": "root ::= {grammars[].value}!!!"} + structural_tag = {"type": "structural_tag", "format": format} + grammars = [{"value": 'root ::= "a" | "b"'}, {"value": 'root ::= "c" | "d"'}] + with pytest.raises(RuntimeError) as exc_info: + xgr.Grammar.apply_structural_tag_template(structural_tag, grammars=grammars) + assert exc_info is not None + + +any_text_instance_is_accepted_template = [("abc", True), ("好", True)] + + +@pytest.mark.parametrize("instance, is_accepted", any_text_instance_is_accepted_template) +def test_any_text_compatible(instance: str, is_accepted: bool): + """Test that AnyTextFormat is compatible with all structural tag formats""" + any_text_format = {"type": "any_text"} + dummy_grammars = [{"value": 'root ::= "a" | "b"'}, {"value": 'root ::= "c" | "d"'}] + + expected_grammar = r"""any_text ::= (([\0-\U0010ffff]*)) +root ::= ((any_text)) +""" + + check_template_stag_with_grammar(any_text_format, expected_grammar, grammars=dummy_grammars) + + check_template_stag_with_instance( + any_text_format, instance, is_accepted, grammars=dummy_grammars + ) + + +def test_no_parameter_error(): + format = {"type": "regex", "pattern": "{patterns[].value}"} + structural_tag = {"type": "structural_tag", "format": format} + grammars = [{"value": 'root ::= "a" | "b"'}, {"value": 'root ::= "c" | "d"'}] + expected_error_info = ( + "Invalid structural tag error: No values provided for " "template placeholder: patterns" + ) + + with pytest.raises(RuntimeError) as exc_info: + xgr.Grammar.apply_structural_tag_template(structural_tag, grammars=grammars) + assert str(exc_info.value) == expected_error_info + + +or_template_values_stag_grammar_instance_accepted = [ + ( + { + "type": "or", + "elements": [ + {"type": "const_string", "value": "{strings[].value}"}, + {"type": "const_string", "value": "{numbers[].value}"}, + ], + }, + { + "strings": [{"value": "hello"}, {"value": "world"}], + "numbers": [{"value": "1"}, {"value": "2"}], + }, + r"""const_string ::= (("hello")) +const_string_1 ::= (("world")) +const_string_2 ::= (("1")) +const_string_3 ::= (("2")) +or ::= ((const_string) | (const_string_1) | (const_string_2) | (const_string_3)) +root ::= ((or)) +""", + [ + ("hello", True), + ("world", True), + ("1", True), + ("2", True), + ("3", False), + ("hello world", False), + ], + ) +] + + +@pytest.mark.parametrize( + "template_stag_format, template_values, expected_grammar, instance_is_accepted_tuples", + or_template_values_stag_grammar_instance_accepted, +) +def test_or_template_values( + template_stag_format: Dict[str, Any], + template_values: Dict[str, List[Dict[str, Any]]], + expected_grammar: str, + instance_is_accepted_tuples: List[Tuple[str, bool]], +): + """Test grammar format with template values""" + check_template_stag_with_grammar(template_stag_format, expected_grammar, **template_values) + + for instance, is_accepted in instance_is_accepted_tuples: + check_template_stag_with_instance( + template_stag_format, instance, is_accepted, **template_values + ) + + +sequence_template_values_stag_grammar_instance_accepted = [ + ( + { + "type": "sequence", + "elements": [ + {"type": "const_string", "value": "{first[].value}"}, + {"type": "const_string", "value": "{second[].value}"}, + ], + }, + { + "first": [{"value": "I'm "}, {"value": "You're "}], + "second": [{"value": "Alice"}, {"value": "Bob"}], + }, + r"""const_string ::= (("I\'m ")) +const_string_1 ::= (("You\'re ")) +or ::= ((const_string) | (const_string_1)) +const_string_2 ::= (("Alice")) +const_string_3 ::= (("Bob")) +or_1 ::= ((const_string_2) | (const_string_3)) +sequence ::= ((or or_1)) +root ::= ((sequence)) +""", + [ + ("I'm Alice", True), + ("You're Bob", True), + ("I'm Bob", True), + ("You're Alice", True), + ("Alice I'm", False), + ("Bob You're", False), + ], + ), + ( + { + "type": "sequence", + "elements": [ + {"type": "const_string", "value": "{outter[].first}"}, + {"type": "or", "elements": [{"type": "const_string", "value": "{inner[].animal}"}]}, + {"type": "const_string", "value": "{outter[].symbol}"}, + ], + }, + { + "outter": [{"first": "It is a ", "symbol": "!"}, {"first": "Is it a ", "symbol": "?"}], + "inner": [{"animal": "dog"}, {"animal": "cat"}], + }, + r"""const_string ::= (("It is a ")) +const_string_1 ::= (("dog")) +const_string_2 ::= (("cat")) +or ::= ((const_string_1) | (const_string_2)) +const_string_3 ::= (("!")) +sequence ::= ((const_string or const_string_3)) +const_string_4 ::= (("Is it a ")) +const_string_5 ::= (("dog")) +const_string_6 ::= (("cat")) +or_1 ::= ((const_string_5) | (const_string_6)) +const_string_7 ::= (("\?")) +sequence_1 ::= ((const_string_4 or_1 const_string_7)) +or_2 ::= ((sequence) | (sequence_1)) +root ::= ((or_2)) +""", + [ + ("Is it a cat?", True), + ("It is a cat!", True), + ("Is it a dog?", True), + ("It is a dog!", True), + ("Is it a cat!", False), + ("Is it a dog!", False), + ("It is a cat?", False), + ("It is a dog?", False), + ], + ), +] + + +@pytest.mark.parametrize( + "template_stag_format, template_values, expected_grammar, instance_is_accepted_tuples", + sequence_template_values_stag_grammar_instance_accepted, +) +def test_sequence_template_values( + template_stag_format: Dict[str, Any], + template_values: Dict[str, List[Dict[str, Any]]], + expected_grammar: str, + instance_is_accepted_tuples: List[Tuple[str, bool]], +): + """Test sequence format with template values""" + check_template_stag_with_grammar(template_stag_format, expected_grammar, **template_values) + + for instance, is_accepted in instance_is_accepted_tuples: + check_template_stag_with_instance( + template_stag_format, instance, is_accepted, **template_values + ) + + +tag_template_values_stag_grammar_instance_accepted = [ + ( + { + "type": "tag", + "begin": "{outter[].first}", + "content": { + "type": "or", + "elements": [{"type": "const_string", "value": "{inner[].animal}"}], + }, + "end": "{outter[].symbol}", + }, + { + "outter": [{"first": "It is a ", "symbol": "!"}, {"first": "Is it a ", "symbol": "?"}], + "inner": [{"animal": "dog"}, {"animal": "cat"}], + }, + r"""const_string ::= (("dog")) +const_string_1 ::= (("cat")) +or ::= ((const_string) | (const_string_1)) +tag ::= (("It is a " or "!")) +const_string_2 ::= (("dog")) +const_string_3 ::= (("cat")) +or_1 ::= ((const_string_2) | (const_string_3)) +tag_1 ::= (("Is it a " or_1 "\?")) +or_2 ::= ((tag) | (tag_1)) +root ::= ((or_2)) +""", + [ + ("It is a dog!", True), + ("Is it a cat?", True), + ("It is a cat!", True), + ("Is it a dog?", True), + ("It is a dog?", False), + ("It is a cat?", False), + ("Is it a dog!", False), + ("Is it a cat!", False), + ], + ) +] + + +@pytest.mark.parametrize( + "template_stag_format, template_values, expected_grammar, instance_is_accepted_tuples", + tag_template_values_stag_grammar_instance_accepted, +) +def test_tag_template_values( + template_stag_format: Dict[str, Any], + template_values: Dict[str, List[Dict[str, Any]]], + expected_grammar: str, + instance_is_accepted_tuples: List[Tuple[str, bool]], +): + """Test tag format with template values""" + check_template_stag_with_grammar(template_stag_format, expected_grammar, **template_values) + + for instance, is_accepted in instance_is_accepted_tuples: + check_template_stag_with_instance( + template_stag_format, instance, is_accepted, **template_values + ) + + +def test_mingled_tag_template(): + format = { + "type": "tag", + "begin": "{outter[].first}", + "content": { + "type": "or", + "elements": [{"type": "const_string", "value": "{inner[].animal}"}], + }, + "end": "{inner[].animal}", + } + structural_tag = {"type": "structural_tag", "format": format} + outter = [{"first": "It is a "}, {"first": "Is it a "}] + inner = [{"animal": "dog"}, {"animal": "cat"}] + with pytest.raises(Exception) as exc_info: + xgr.Grammar.apply_structural_tag_template(structural_tag, outter=outter, inner=inner) + expected_error_info = ( + "Invalid structural tag error: Multiple different placeholder names found in the tag " + "format: 'outter' and 'inner'" + ) + assert str(exc_info.value) == expected_error_info + + +triggered_tag_template_values_stag_grammar_instance_accepted = [ + ( + { + "type": "triggered_tags", + "triggers": ["I"], + "tags": [ + { + "type": "tag", + "begin": "{outter[].first}", + "content": { + "type": "or", + "elements": [{"type": "const_string", "value": "{inner[].animal}"}], + }, + "end": "{outter[].symbol}", + } + ], + "at_least_one": False, + "stop_after_first": False, + }, + { + "outter": [{"first": "It is a ", "symbol": "!"}, {"first": "Is it a ", "symbol": "?"}], + "inner": [{"animal": "dog"}, {"animal": "cat"}], + }, + r"""const_string ::= (("dog")) +const_string_1 ::= (("cat")) +or ::= ((const_string) | (const_string_1)) +const_string_2 ::= (("dog")) +const_string_3 ::= (("cat")) +or_1 ::= ((const_string_2) | (const_string_3)) +triggered_tags_group ::= (("t is a " or "!") | ("s it a " or_1 "\?")) +triggered_tags ::= TagDispatch( + ("I", triggered_tags_group), + stop_eos=true, + stop_str=(), + loop_after_dispatch=true +) +root ::= ((triggered_tags)) +""", + [ + ("It is a dog!", True), + ("Is it a cat?", True), + ("It is a cat!", True), + ("Is it a dog?", True), + ("It is a dog?", False), + ("It is a cat?", False), + ("Is it a dog!", False), + ("Is it a cat!", False), + ("Hello world", True), + ("I am happy", False), + ], + ) +] + + +@pytest.mark.parametrize( + "template_stag_format, template_values, expected_grammar, instance_is_accepted_tuples", + triggered_tag_template_values_stag_grammar_instance_accepted, +) +def test_triggered_tag_template_values( + template_stag_format: Dict[str, Any], + template_values: Dict[str, List[Dict[str, Any]]], + expected_grammar: str, + instance_is_accepted_tuples: List[Tuple[str, bool]], +): + """Test triggered_tags format with template values""" + check_template_stag_with_grammar(template_stag_format, expected_grammar, **template_values) + + for instance, is_accepted in instance_is_accepted_tuples: + check_template_stag_with_instance( + template_stag_format, instance, is_accepted, **template_values + ) + + +tag_with_separator_template_values_stag_grammar_instance_accepted = [ + ( + { + "type": "tags_with_separator", + "tags": [ + { + "type": "tag", + "begin": "{outter[].first}", + "content": { + "type": "or", + "elements": [{"type": "const_string", "value": "{inner[].animal}"}], + }, + "end": "{outter[].symbol}", + } + ], + "separator": "\n", + "at_least_one": False, + "stop_after_first": False, + }, + { + "outter": [{"first": "It is a ", "symbol": "!"}, {"first": "Is it a ", "symbol": "?"}], + "inner": [{"animal": "dog"}, {"animal": "cat"}], + }, + r"""const_string ::= (("dog")) +const_string_1 ::= (("cat")) +or ::= ((const_string) | (const_string_1)) +tag ::= (("It is a " or "!")) +const_string_2 ::= (("dog")) +const_string_3 ::= (("cat")) +or_1 ::= ((const_string_2) | (const_string_3)) +tag_1 ::= (("Is it a " or_1 "\?")) +tags_with_separator_tags ::= ((tag) | (tag_1)) +tags_with_separator_sub ::= ("" | ("\n" tags_with_separator_tags tags_with_separator_sub)) +tags_with_separator ::= ("" | (tags_with_separator_tags tags_with_separator_sub)) +root ::= ((tags_with_separator)) +""", + [ + ("It is a dog!", True), + ("Is it a cat?", True), + ("It is a cat!", True), + ("Is it a dog?", True), + ("It is a dog!\nIs it a cat?", True), + ("It is a cat!\nIt is a dog!", True), + ("Is it a dog?\nIt is a cat!", True), + ("Is it a cat?\nIs it a dog?\nIs it a cat?", True), + ("It is a dog?", False), + ("It is a cat!\nIt is a dog?", False), + ("Is it a dog!", False), + ], + ) +] + + +@pytest.mark.parametrize( + "template_stag_format, template_values, expected_grammar, instance_is_accepted_tuples", + tag_with_separator_template_values_stag_grammar_instance_accepted, +) +def test_tag_with_separator_template_values( + template_stag_format: Dict[str, Any], + template_values: Dict[str, List[Dict[str, Any]]], + expected_grammar: str, + instance_is_accepted_tuples: List[Tuple[str, bool]], +): + """Test tags_with_separator format with template values""" + check_template_stag_with_grammar(template_stag_format, expected_grammar, **template_values) + + for instance, is_accepted in instance_is_accepted_tuples: + check_template_stag_with_instance( + template_stag_format, instance, is_accepted, **template_values + ) + + +builtin_template_values_stag_grammar_instance_accepted = [ + ( + "Llama", + { + "tools": [ + { + "name": "Calculator", + "description": "A calculator that can perform basic arithmetic operations.", + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + }, + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["operation", "a", "b"], + }, + }, + { + "name": "Weather", + "description": "A tool to get the current weather in a specified location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The name of the city or location to get the weather for.", + } + }, + "required": ["location"], + }, + }, + ] + }, + r"""basic_escape ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub) | ("\\" basic_escape basic_string_sub)) (=([ \n\t]* [,}\]:])) +basic_any ::= ((basic_number) | (basic_string) | (basic_boolean) | (basic_null) | (basic_array) | (basic_object)) +basic_integer ::= (("0") | (basic_integer_1 [1-9] [0-9]*)) +basic_number ::= ((basic_number_1 basic_number_7 basic_number_3 basic_number_6)) +basic_string ::= (("\"" basic_string_sub)) +basic_boolean ::= (("true") | ("false")) +basic_null ::= (("null")) +basic_array ::= (("[" [ \n\t]* basic_any basic_array_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object ::= (("{" [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_prop_0 ::= (("\"add\"") | ("\"subtract\"") | ("\"multiply\"") | ("\"divide\"")) +root_part_1 ::= (([ \n\t]* "," [ \n\t]* "\"b\"" [ \n\t]* ":" [ \n\t]* basic_number)) +root_part_0 ::= (([ \n\t]* "," [ \n\t]* "\"a\"" [ \n\t]* ":" [ \n\t]* basic_number root_part_1)) +root ::= (("{" [ \n\t]* "\"operation\"" [ \n\t]* ":" [ \n\t]* root_prop_0 root_part_0 [ \n\t]* "}")) +basic_integer_1 ::= ("" | ("-")) +basic_number_1 ::= ("" | ("-")) +basic_number_2 ::= (([0-9] basic_number_2) | ([0-9])) +basic_number_3 ::= ("" | ("." basic_number_2)) +basic_number_4 ::= ("" | ([+\-])) +basic_number_5 ::= (([0-9] basic_number_5) | ([0-9])) +basic_number_6 ::= ("" | ([eE] basic_number_4 basic_number_5)) +basic_array_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any basic_array_1)) +basic_object_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1)) +basic_number_7 ::= (("0") | ([1-9] [0-9]*)) +basic_escape_1 ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub_1 ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub_1) | ("\\" basic_escape_1 basic_string_sub_1)) (=([ \n\t]* [,}\]:])) +basic_any_1 ::= ((basic_number_8) | (basic_string_1) | (basic_boolean_1) | (basic_null_1) | (basic_array_2) | (basic_object_2)) +basic_integer_2 ::= (("0") | (basic_integer_1_1 [1-9] [0-9]*)) +basic_number_8 ::= ((basic_number_1_1 basic_number_7_1 basic_number_3_1 basic_number_6_1)) +basic_string_1 ::= (("\"" basic_string_sub_1)) +basic_boolean_1 ::= (("true") | ("false")) +basic_null_1 ::= (("null")) +basic_array_2 ::= (("[" [ \n\t]* basic_any_1 basic_array_1_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object_2 ::= (("{" [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_1 ::= (("{" [ \n\t]* "\"location\"" [ \n\t]* ":" [ \n\t]* basic_string_1 [ \n\t]* "}")) +basic_integer_1_1 ::= ("" | ("-")) +basic_number_1_1 ::= ("" | ("-")) +basic_number_2_1 ::= (([0-9] basic_number_2_1) | ([0-9])) +basic_number_3_1 ::= ("" | ("." basic_number_2_1)) +basic_number_4_1 ::= ("" | ([+\-])) +basic_number_5_1 ::= (([0-9] basic_number_5_1) | ([0-9])) +basic_number_6_1 ::= ("" | ([eE] basic_number_4_1 basic_number_5_1)) +basic_array_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any_1 basic_array_1_1)) +basic_object_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1)) +basic_number_7_1 ::= (("0") | ([1-9] [0-9]*)) +triggered_tags_group ::= (("\"Calculator\", \"parameters\": " root "}") | ("\"Weather\", \"parameters\": " root_1 "}")) +triggered_tags ::= TagDispatch( + ("{\"name\": ", triggered_tags_group), + stop_eos=true, + stop_str=(), + loop_after_dispatch=true +) +root_2 ::= ((triggered_tags)) +""", + [ + ( + 'OK, I will use the Calculator tool to perform the operation. {"name": "Calculator", "parameters": {"operation": "add", "a": 5, "b": 3}}', + True, + ), + ( + 'I need to know the weather in Paris. {"name": "Weather", "parameters": {"location": "Paris"}}', + True, + ), + ( + 'Can you calculate Paris add 5? {"name": "Calculator", "parameters": {"operation": "add", "a": 5, "b": "Paris"}}', + False, + ), + ( + 'I want to know the weather in 1. {"name": "Weather", "parameters": {"location": 1}}', + False, + ), + ("Some random text", True), + ], + ), + ( + "Kimi", + { + "tools": [ + { + "name": "Calculator", + "description": "A calculator that can perform basic arithmetic operations.", + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + }, + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["operation", "a", "b"], + }, + }, + { + "name": "Weather", + "description": "A tool to get the current weather in a specified location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The name of the city or location to get the weather for.", + } + }, + "required": ["location"], + }, + }, + ] + }, + r"""basic_escape ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub) | ("\\" basic_escape basic_string_sub)) (=([ \n\t]* [,}\]:])) +basic_any ::= ((basic_number) | (basic_string) | (basic_boolean) | (basic_null) | (basic_array) | (basic_object)) +basic_integer ::= (("0") | (basic_integer_1 [1-9] [0-9]*)) +basic_number ::= ((basic_number_1 basic_number_7 basic_number_3 basic_number_6)) +basic_string ::= (("\"" basic_string_sub)) +basic_boolean ::= (("true") | ("false")) +basic_null ::= (("null")) +basic_array ::= (("[" [ \n\t]* basic_any basic_array_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object ::= (("{" [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_prop_0 ::= (("\"add\"") | ("\"subtract\"") | ("\"multiply\"") | ("\"divide\"")) +root_part_1 ::= (([ \n\t]* "," [ \n\t]* "\"b\"" [ \n\t]* ":" [ \n\t]* basic_number)) +root_part_0 ::= (([ \n\t]* "," [ \n\t]* "\"a\"" [ \n\t]* ":" [ \n\t]* basic_number root_part_1)) +root ::= (("{" [ \n\t]* "\"operation\"" [ \n\t]* ":" [ \n\t]* root_prop_0 root_part_0 [ \n\t]* "}")) +basic_integer_1 ::= ("" | ("-")) +basic_number_1 ::= ("" | ("-")) +basic_number_2 ::= (([0-9] basic_number_2) | ([0-9])) +basic_number_3 ::= ("" | ("." basic_number_2)) +basic_number_4 ::= ("" | ([+\-])) +basic_number_5 ::= (([0-9] basic_number_5) | ([0-9])) +basic_number_6 ::= ("" | ([eE] basic_number_4 basic_number_5)) +basic_array_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any basic_array_1)) +basic_object_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1)) +basic_number_7 ::= (("0") | ([1-9] [0-9]*)) +basic_escape_1 ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub_1 ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub_1) | ("\\" basic_escape_1 basic_string_sub_1)) (=([ \n\t]* [,}\]:])) +basic_any_1 ::= ((basic_number_8) | (basic_string_1) | (basic_boolean_1) | (basic_null_1) | (basic_array_2) | (basic_object_2)) +basic_integer_2 ::= (("0") | (basic_integer_1_1 [1-9] [0-9]*)) +basic_number_8 ::= ((basic_number_1_1 basic_number_7_1 basic_number_3_1 basic_number_6_1)) +basic_string_1 ::= (("\"" basic_string_sub_1)) +basic_boolean_1 ::= (("true") | ("false")) +basic_null_1 ::= (("null")) +basic_array_2 ::= (("[" [ \n\t]* basic_any_1 basic_array_1_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object_2 ::= (("{" [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_1 ::= (("{" [ \n\t]* "\"location\"" [ \n\t]* ":" [ \n\t]* basic_string_1 [ \n\t]* "}")) +basic_integer_1_1 ::= ("" | ("-")) +basic_number_1_1 ::= ("" | ("-")) +basic_number_2_1 ::= (([0-9] basic_number_2_1) | ([0-9])) +basic_number_3_1 ::= ("" | ("." basic_number_2_1)) +basic_number_4_1 ::= ("" | ([+\-])) +basic_number_5_1 ::= (([0-9] basic_number_5_1) | ([0-9])) +basic_number_6_1 ::= ("" | ([eE] basic_number_4_1 basic_number_5_1)) +basic_array_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any_1 basic_array_1_1)) +basic_object_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1)) +basic_number_7_1 ::= (("0") | ([1-9] [0-9]*)) +triggered_tags_group ::= (("Calculator<|tool_call_argument_begin|>" root "<|tool_call_end|>") | ("Weather<|tool_call_argument_begin|>" root_1 "<|tool_call_end|>")) +triggered_tags ::= TagDispatch( + ("<|tool_call_begin|>", triggered_tags_group), + stop_eos=true, + stop_str=(), + loop_after_dispatch=true +) +root_2 ::= ((triggered_tags)) +""", + [ + ( + 'OK, I will use the Calculator tool to perform the operation. <|tool_call_begin|>Calculator<|tool_call_argument_begin|>{"operation": "add", "a": 5, "b": 3}<|tool_call_end|>', + True, + ), + ( + 'I need to know the weather in Paris. <|tool_call_begin|>Weather<|tool_call_argument_begin|>{"location": "Paris"}<|tool_call_end|>', + True, + ), + ( + 'Can you calculate Paris add 5? <|tool_call_begin|>Calculator<|tool_call_argument_begin|>{"operation": "add", "a": 5, "b": "Paris"}<|tool_call_end|>', + False, + ), + ( + 'I want to know the weather in 1. <|tool_call_begin|>Weather<|tool_call_argument_begin|>{"location": 1}<|tool_call_end|>', + False, + ), + ("Some random text", True), + ], + ), + ( + "Deepseek", + { + "tools": [ + { + "name": "Calculator", + "description": "A calculator that can perform basic arithmetic operations.", + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + }, + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["operation", "a", "b"], + }, + }, + { + "name": "Weather", + "description": "A tool to get the current weather in a specified location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The name of the city or location to get the weather for.", + } + }, + "required": ["location"], + }, + }, + ] + }, + r"""basic_escape ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub) | ("\\" basic_escape basic_string_sub)) (=([ \n\t]* [,}\]:])) +basic_any ::= ((basic_number) | (basic_string) | (basic_boolean) | (basic_null) | (basic_array) | (basic_object)) +basic_integer ::= (("0") | (basic_integer_1 [1-9] [0-9]*)) +basic_number ::= ((basic_number_1 basic_number_7 basic_number_3 basic_number_6)) +basic_string ::= (("\"" basic_string_sub)) +basic_boolean ::= (("true") | ("false")) +basic_null ::= (("null")) +basic_array ::= (("[" [ \n\t]* basic_any basic_array_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object ::= (("{" [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_prop_0 ::= (("\"add\"") | ("\"subtract\"") | ("\"multiply\"") | ("\"divide\"")) +root_part_1 ::= (([ \n\t]* "," [ \n\t]* "\"b\"" [ \n\t]* ":" [ \n\t]* basic_number)) +root_part_0 ::= (([ \n\t]* "," [ \n\t]* "\"a\"" [ \n\t]* ":" [ \n\t]* basic_number root_part_1)) +root ::= (("{" [ \n\t]* "\"operation\"" [ \n\t]* ":" [ \n\t]* root_prop_0 root_part_0 [ \n\t]* "}")) +basic_integer_1 ::= ("" | ("-")) +basic_number_1 ::= ("" | ("-")) +basic_number_2 ::= (([0-9] basic_number_2) | ([0-9])) +basic_number_3 ::= ("" | ("." basic_number_2)) +basic_number_4 ::= ("" | ([+\-])) +basic_number_5 ::= (([0-9] basic_number_5) | ([0-9])) +basic_number_6 ::= ("" | ([eE] basic_number_4 basic_number_5)) +basic_array_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any basic_array_1)) +basic_object_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1)) +basic_number_7 ::= (("0") | ([1-9] [0-9]*)) +basic_escape_1 ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub_1 ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub_1) | ("\\" basic_escape_1 basic_string_sub_1)) (=([ \n\t]* [,}\]:])) +basic_any_1 ::= ((basic_number_8) | (basic_string_1) | (basic_boolean_1) | (basic_null_1) | (basic_array_2) | (basic_object_2)) +basic_integer_2 ::= (("0") | (basic_integer_1_1 [1-9] [0-9]*)) +basic_number_8 ::= ((basic_number_1_1 basic_number_7_1 basic_number_3_1 basic_number_6_1)) +basic_string_1 ::= (("\"" basic_string_sub_1)) +basic_boolean_1 ::= (("true") | ("false")) +basic_null_1 ::= (("null")) +basic_array_2 ::= (("[" [ \n\t]* basic_any_1 basic_array_1_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object_2 ::= (("{" [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_1 ::= (("{" [ \n\t]* "\"location\"" [ \n\t]* ":" [ \n\t]* basic_string_1 [ \n\t]* "}")) +basic_integer_1_1 ::= ("" | ("-")) +basic_number_1_1 ::= ("" | ("-")) +basic_number_2_1 ::= (([0-9] basic_number_2_1) | ([0-9])) +basic_number_3_1 ::= ("" | ("." basic_number_2_1)) +basic_number_4_1 ::= ("" | ([+\-])) +basic_number_5_1 ::= (([0-9] basic_number_5_1) | ([0-9])) +basic_number_6_1 ::= ("" | ([eE] basic_number_4_1 basic_number_5_1)) +basic_array_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any_1 basic_array_1_1)) +basic_object_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1)) +basic_number_7_1 ::= (("0") | ([1-9] [0-9]*)) +triggered_tags_group ::= (("Calculator<\uff5ctool\u2581sep\uff5c>" root "<\uff5ctool\u2581call\u2581end\uff5c>") | ("Weather<\uff5ctool\u2581sep\uff5c>" root_1 "<\uff5ctool\u2581call\u2581end\uff5c>")) +triggered_tags ::= TagDispatch( + ("<\uff5ctool\u2581calls\u2581begin\uff5c><\uff5ctool\u2581call\u2581begin\uff5c>", triggered_tags_group), + stop_eos=true, + stop_str=(), + loop_after_dispatch=true +) +root_2 ::= ((triggered_tags)) +""", + [ + ( + 'OK, I will use the Calculator tool to perform the operation. <|tool▁calls▁begin|><|tool▁call▁begin|>Calculator<|tool▁sep|>{"operation": "add", "a": 5, "b": 3}<|tool▁call▁end|>', + True, + ), + ( + 'I need to know the weather in Paris. <|tool▁calls▁begin|><|tool▁call▁begin|>Weather<|tool▁sep|>{"location": "Paris"}<|tool▁call▁end|>', + True, + ), + ( + 'Can you calculate Paris add 5? <|tool▁calls▁begin|><|tool▁call▁begin|>Calculator<|tool▁sep|>{"operation": "add", "a": 5, "b": "Paris"}<|tool▁call▁end|>', + False, + ), + ( + 'I want to know the weather in 1. <|tool▁calls▁begin|><|tool▁call▁begin|>Weather<|tool▁sep|>{"location": 1}<<|tool▁call▁end|>', + False, + ), + ("Some random text", True), + ], + ), + ( + "Qwen_Coder", + { + "tools": [ + { + "name": "Calculator", + "description": "A calculator that can perform basic arithmetic operations.", + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + }, + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["operation", "a", "b"], + }, + }, + { + "name": "Weather", + "description": "A tool to get the current weather in a specified location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The name of the city or location to get the weather for.", + } + }, + "required": ["location"], + }, + }, + ] + }, + r"""basic_escape ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub) | ("\\" basic_escape basic_string_sub)) (=([ \n\t]* [,}\]:])) +xml_escape ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +xml_entity ::= (("<") | (">") | ("&") | (""") | ("'")) +xml_string ::= ("" | ([^<>&\0-\x1f\\\r\n] xml_string) | ("\\" xml_escape xml_string) | (xml_entity xml_string)) (=([ \n\t]*)) +xml_variable_name ::= (([a-zA-Z_] [a-zA-Z0-9_]*)) +xml_string_0 ::= ((xml_string)) +xml_any ::= ((basic_number) | (xml_string) | (basic_boolean) | (basic_null) | (basic_array) | (basic_object)) +basic_any ::= ((basic_number) | (basic_string) | (basic_boolean) | (basic_null) | (basic_array) | (basic_object)) +basic_integer ::= (("0") | (basic_integer_1 [1-9] [0-9]*)) +basic_number ::= ((basic_number_1 basic_number_7 basic_number_3 basic_number_6)) +basic_string ::= (("\"" basic_string_sub)) +basic_boolean ::= (("true") | ("false")) +basic_null ::= (("null")) +basic_array ::= (("[" [ \n\t]* basic_any basic_array_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object ::= (("{" [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_prop_0 ::= (("add") | ("subtract") | ("multiply") | ("divide")) +root_prop_1 ::= ((root_prop_1_1 root_prop_1_7 root_prop_1_3 root_prop_1_6)) +root_prop_2 ::= ((root_prop_2_1 root_prop_2_7 root_prop_2_3 root_prop_2_6)) +root_part_1 ::= (([ \n\t]* "" [ \n\t]* root_prop_2 [ \n\t]* "")) +root_part_0 ::= (([ \n\t]* "" [ \n\t]* root_prop_1 [ \n\t]* "" root_part_1)) +root ::= (([ \n\t]* "" [ \n\t]* root_prop_0 [ \n\t]* "" root_part_0)) +basic_integer_1 ::= ("" | ("-")) +basic_number_1 ::= ("" | ("-")) +basic_number_2 ::= (([0-9] basic_number_2) | ([0-9])) +basic_number_3 ::= ("" | ("." basic_number_2)) +basic_number_4 ::= ("" | ([+\-])) +basic_number_5 ::= (([0-9] basic_number_5) | ([0-9])) +basic_number_6 ::= ("" | ([eE] basic_number_4 basic_number_5)) +basic_array_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any basic_array_1)) +basic_object_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1)) +root_prop_1_1 ::= ("" | ("-")) +root_prop_1_2 ::= (([0-9] root_prop_1_2) | ([0-9])) +root_prop_1_3 ::= ("" | ("." root_prop_1_2)) +root_prop_1_4 ::= ("" | ([+\-])) +root_prop_1_5 ::= (([0-9] root_prop_1_5) | ([0-9])) +root_prop_1_6 ::= ("" | ([eE] root_prop_1_4 root_prop_1_5)) +root_prop_2_1 ::= ("" | ("-")) +root_prop_2_2 ::= (([0-9] root_prop_2_2) | ([0-9])) +root_prop_2_3 ::= ("" | ("." root_prop_2_2)) +root_prop_2_4 ::= ("" | ([+\-])) +root_prop_2_5 ::= (([0-9] root_prop_2_5) | ([0-9])) +root_prop_2_6 ::= ("" | ([eE] root_prop_2_4 root_prop_2_5)) +basic_number_7 ::= (("0") | ([1-9] [0-9]*)) +root_prop_1_7 ::= (("0") | ([1-9] [0-9]*)) +root_prop_2_7 ::= (("0") | ([1-9] [0-9]*)) +basic_escape_1 ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub_1 ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub_1) | ("\\" basic_escape_1 basic_string_sub_1)) (=([ \n\t]* [,}\]:])) +xml_escape_1 ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +xml_entity_1 ::= (("<") | (">") | ("&") | (""") | ("'")) +xml_string_1 ::= ("" | ([^<>&\0-\x1f\\\r\n] xml_string_1) | ("\\" xml_escape_1 xml_string_1) | (xml_entity_1 xml_string_1)) (=([ \n\t]*)) +xml_variable_name_1 ::= (([a-zA-Z_] [a-zA-Z0-9_]*)) +xml_string_0_1 ::= ((xml_string_1)) +xml_any_1 ::= ((basic_number_8) | (xml_string_1) | (basic_boolean_1) | (basic_null_1) | (basic_array_2) | (basic_object_2)) +basic_any_1 ::= ((basic_number_8) | (basic_string_1) | (basic_boolean_1) | (basic_null_1) | (basic_array_2) | (basic_object_2)) +basic_integer_2 ::= (("0") | (basic_integer_1_1 [1-9] [0-9]*)) +basic_number_8 ::= ((basic_number_1_1 basic_number_7_1 basic_number_3_1 basic_number_6_1)) +basic_string_1 ::= (("\"" basic_string_sub_1)) +basic_boolean_1 ::= (("true") | ("false")) +basic_null_1 ::= (("null")) +basic_array_2 ::= (("[" [ \n\t]* basic_any_1 basic_array_1_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object_2 ::= (("{" [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_1 ::= (([ \n\t]* "" [ \n\t]* xml_string_0_1 [ \n\t]* "")) +basic_integer_1_1 ::= ("" | ("-")) +basic_number_1_1 ::= ("" | ("-")) +basic_number_2_1 ::= (([0-9] basic_number_2_1) | ([0-9])) +basic_number_3_1 ::= ("" | ("." basic_number_2_1)) +basic_number_4_1 ::= ("" | ([+\-])) +basic_number_5_1 ::= (([0-9] basic_number_5_1) | ([0-9])) +basic_number_6_1 ::= ("" | ([eE] basic_number_4_1 basic_number_5_1)) +basic_array_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any_1 basic_array_1_1)) +basic_object_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1)) +basic_number_7_1 ::= (("0") | ([1-9] [0-9]*)) +triggered_tags_group ::= (("Calculator>" root "") | ("Weather>" root_1 "")) +triggered_tags ::= TagDispatch( + ("add53", + True, + ), + ( + "I need to know the weather in Paris. Paris", + True, + ), + ( + "Can you calculate Paris add 5? add5Paris", + False, + ), + ("Some random text", True), + ], + ), + ( + "Qwen", + { + "tools": [ + { + "name": "Calculator", + "description": "A calculator that can perform basic arithmetic operations.", + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + }, + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["operation", "a", "b"], + }, + }, + { + "name": "Weather", + "description": "A tool to get the current weather in a specified location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The name of the city or location to get the weather for.", + } + }, + "required": ["location"], + }, + }, + ] + }, + r"""basic_escape ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub) | ("\\" basic_escape basic_string_sub)) (=([ \n\t]* [,}\]:])) +basic_any ::= ((basic_number) | (basic_string) | (basic_boolean) | (basic_null) | (basic_array) | (basic_object)) +basic_integer ::= (("0") | (basic_integer_1 [1-9] [0-9]*)) +basic_number ::= ((basic_number_1 basic_number_7 basic_number_3 basic_number_6)) +basic_string ::= (("\"" basic_string_sub)) +basic_boolean ::= (("true") | ("false")) +basic_null ::= (("null")) +basic_array ::= (("[" [ \n\t]* basic_any basic_array_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object ::= (("{" [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_prop_0 ::= (("\"add\"") | ("\"subtract\"") | ("\"multiply\"") | ("\"divide\"")) +root_part_1 ::= (([ \n\t]* "," [ \n\t]* "\"b\"" [ \n\t]* ":" [ \n\t]* basic_number)) +root_part_0 ::= (([ \n\t]* "," [ \n\t]* "\"a\"" [ \n\t]* ":" [ \n\t]* basic_number root_part_1)) +root ::= (("{" [ \n\t]* "\"operation\"" [ \n\t]* ":" [ \n\t]* root_prop_0 root_part_0 [ \n\t]* "}")) +basic_integer_1 ::= ("" | ("-")) +basic_number_1 ::= ("" | ("-")) +basic_number_2 ::= (([0-9] basic_number_2) | ([0-9])) +basic_number_3 ::= ("" | ("." basic_number_2)) +basic_number_4 ::= ("" | ([+\-])) +basic_number_5 ::= (([0-9] basic_number_5) | ([0-9])) +basic_number_6 ::= ("" | ([eE] basic_number_4 basic_number_5)) +basic_array_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any basic_array_1)) +basic_object_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1)) +basic_number_7 ::= (("0") | ([1-9] [0-9]*)) +basic_escape_1 ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub_1 ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub_1) | ("\\" basic_escape_1 basic_string_sub_1)) (=([ \n\t]* [,}\]:])) +basic_any_1 ::= ((basic_number_8) | (basic_string_1) | (basic_boolean_1) | (basic_null_1) | (basic_array_2) | (basic_object_2)) +basic_integer_2 ::= (("0") | (basic_integer_1_1 [1-9] [0-9]*)) +basic_number_8 ::= ((basic_number_1_1 basic_number_7_1 basic_number_3_1 basic_number_6_1)) +basic_string_1 ::= (("\"" basic_string_sub_1)) +basic_boolean_1 ::= (("true") | ("false")) +basic_null_1 ::= (("null")) +basic_array_2 ::= (("[" [ \n\t]* basic_any_1 basic_array_1_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object_2 ::= (("{" [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_1 ::= (("{" [ \n\t]* "\"location\"" [ \n\t]* ":" [ \n\t]* basic_string_1 [ \n\t]* "}")) +basic_integer_1_1 ::= ("" | ("-")) +basic_number_1_1 ::= ("" | ("-")) +basic_number_2_1 ::= (([0-9] basic_number_2_1) | ([0-9])) +basic_number_3_1 ::= ("" | ("." basic_number_2_1)) +basic_number_4_1 ::= ("" | ([+\-])) +basic_number_5_1 ::= (([0-9] basic_number_5_1) | ([0-9])) +basic_number_6_1 ::= ("" | ([eE] basic_number_4_1 basic_number_5_1)) +basic_array_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any_1 basic_array_1_1)) +basic_object_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1)) +basic_number_7_1 ::= (("0") | ([1-9] [0-9]*)) +triggered_tags_group ::= (("{\"name\": \"Calculator\", \"arguments\": " root "}") | ("{\"name\": \"Weather\", \"arguments\": " root_1 "}")) +triggered_tags ::= TagDispatch( + ("", triggered_tags_group), + stop_eos=true, + stop_str=(), + loop_after_dispatch=true +) +root_2 ::= ((triggered_tags)) +""", + [ + ( + 'OK, I will use the Calculator tool to perform the operation. {"name": "Calculator", "arguments": {"operation": "add", "a": 5, "b": 3}}', + True, + ), + ( + 'I need to know the weather in Paris. {"name": "Weather", "arguments": {"location": "Paris"}}', + True, + ), + ( + 'Can you calculate Paris add 5? {"name": "Calculator", "arguments": {"operation": "add", "a": 5, "b": "Paris"}}', + False, + ), + ( + 'I want to know the weather in 1. {"name": "Weather", "arguments": {"location": 1}}', + False, + ), + ("Some random text", True), + ], + ), + ( + "Harmony", + { + "tools": [ + { + "name": "Calculator", + "description": "A calculator that can perform basic arithmetic operations.", + "parameters": { + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + }, + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["operation", "a", "b"], + }, + }, + { + "name": "Weather", + "description": "A tool to get the current weather in a specified location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The name of the city or location to get the weather for.", + } + }, + "required": ["location"], + }, + }, + ], + "builtin_tools": [ + { + "name": "Python", + "description": "A Python interpreter that can execute Python code.", + "parameters": { + "type": "object", + "properties": { + "code": {"type": "string", "description": "The Python code to execute."} + }, + "required": ["code"], + }, + } + ], + }, + r"""any_text ::= TagDispatch( + stop_eos=false, + stop_str=("<|end|>"), + loop_after_dispatch=false +) +any_text_1 ::= TagDispatch( + stop_eos=false, + stop_str=("<|return|>"), + loop_after_dispatch=false +) +any_text_2 ::= TagDispatch( + stop_eos=false, + stop_str=("<|call|>"), + loop_after_dispatch=false +) +basic_escape ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub) | ("\\" basic_escape basic_string_sub)) (=([ \n\t]* [,}\]:])) +basic_any ::= ((basic_number) | (basic_string) | (basic_boolean) | (basic_null) | (basic_array) | (basic_object)) +basic_integer ::= (("0") | (basic_integer_1 [1-9] [0-9]*)) +basic_number ::= ((basic_number_1 basic_number_7 basic_number_3 basic_number_6)) +basic_string ::= (("\"" basic_string_sub)) +basic_boolean ::= (("true") | ("false")) +basic_null ::= (("null")) +basic_array ::= (("[" [ \n\t]* basic_any basic_array_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object ::= (("{" [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_prop_0 ::= (("\"add\"") | ("\"subtract\"") | ("\"multiply\"") | ("\"divide\"")) +root_part_1 ::= (([ \n\t]* "," [ \n\t]* "\"b\"" [ \n\t]* ":" [ \n\t]* basic_number)) +root_part_0 ::= (([ \n\t]* "," [ \n\t]* "\"a\"" [ \n\t]* ":" [ \n\t]* basic_number root_part_1)) +root ::= (("{" [ \n\t]* "\"operation\"" [ \n\t]* ":" [ \n\t]* root_prop_0 root_part_0 [ \n\t]* "}")) +basic_integer_1 ::= ("" | ("-")) +basic_number_1 ::= ("" | ("-")) +basic_number_2 ::= (([0-9] basic_number_2) | ([0-9])) +basic_number_3 ::= ("" | ("." basic_number_2)) +basic_number_4 ::= ("" | ([+\-])) +basic_number_5 ::= (([0-9] basic_number_5) | ([0-9])) +basic_number_6 ::= ("" | ([eE] basic_number_4 basic_number_5)) +basic_array_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any basic_array_1)) +basic_object_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1)) +basic_number_7 ::= (("0") | ([1-9] [0-9]*)) +basic_escape_1 ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub_1 ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub_1) | ("\\" basic_escape_1 basic_string_sub_1)) (=([ \n\t]* [,}\]:])) +basic_any_1 ::= ((basic_number_8) | (basic_string_1) | (basic_boolean_1) | (basic_null_1) | (basic_array_2) | (basic_object_2)) +basic_integer_2 ::= (("0") | (basic_integer_1_1 [1-9] [0-9]*)) +basic_number_8 ::= ((basic_number_1_1 basic_number_7_1 basic_number_3_1 basic_number_6_1)) +basic_string_1 ::= (("\"" basic_string_sub_1)) +basic_boolean_1 ::= (("true") | ("false")) +basic_null_1 ::= (("null")) +basic_array_2 ::= (("[" [ \n\t]* basic_any_1 basic_array_1_1 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object_2 ::= (("{" [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_1 ::= (("{" [ \n\t]* "\"location\"" [ \n\t]* ":" [ \n\t]* basic_string_1 [ \n\t]* "}")) +basic_integer_1_1 ::= ("" | ("-")) +basic_number_1_1 ::= ("" | ("-")) +basic_number_2_1 ::= (([0-9] basic_number_2_1) | ([0-9])) +basic_number_3_1 ::= ("" | ("." basic_number_2_1)) +basic_number_4_1 ::= ("" | ([+\-])) +basic_number_5_1 ::= (([0-9] basic_number_5_1) | ([0-9])) +basic_number_6_1 ::= ("" | ([eE] basic_number_4_1 basic_number_5_1)) +basic_array_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any_1 basic_array_1_1)) +basic_object_1_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string_1 [ \n\t]* ":" [ \n\t]* basic_any_1 basic_object_1_1)) +basic_number_7_1 ::= (("0") | ([1-9] [0-9]*)) +basic_escape_2 ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9])) +basic_string_sub_2 ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub_2) | ("\\" basic_escape_2 basic_string_sub_2)) (=([ \n\t]* [,}\]:])) +basic_any_2 ::= ((basic_number_9) | (basic_string_2) | (basic_boolean_2) | (basic_null_2) | (basic_array_3) | (basic_object_3)) +basic_integer_3 ::= (("0") | (basic_integer_1_2 [1-9] [0-9]*)) +basic_number_9 ::= ((basic_number_1_2 basic_number_7_2 basic_number_3_2 basic_number_6_2)) +basic_string_2 ::= (("\"" basic_string_sub_2)) +basic_boolean_2 ::= (("true") | ("false")) +basic_null_2 ::= (("null")) +basic_array_3 ::= (("[" [ \n\t]* basic_any_2 basic_array_1_2 [ \n\t]* "]") | ("[" [ \n\t]* "]")) +basic_object_3 ::= (("{" [ \n\t]* basic_string_2 [ \n\t]* ":" [ \n\t]* basic_any_2 basic_object_1_2 [ \n\t]* "}") | ("{" [ \n\t]* "}")) +root_2 ::= (("{" [ \n\t]* "\"code\"" [ \n\t]* ":" [ \n\t]* basic_string_2 [ \n\t]* "}")) +basic_integer_1_2 ::= ("" | ("-")) +basic_number_1_2 ::= ("" | ("-")) +basic_number_2_2 ::= (([0-9] basic_number_2_2) | ([0-9])) +basic_number_3_2 ::= ("" | ("." basic_number_2_2)) +basic_number_4_2 ::= ("" | ([+\-])) +basic_number_5_2 ::= (([0-9] basic_number_5_2) | ([0-9])) +basic_number_6_2 ::= ("" | ([eE] basic_number_4_2 basic_number_5_2)) +basic_array_1_2 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any_2 basic_array_1_2)) +basic_object_1_2 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string_2 [ \n\t]* ":" [ \n\t]* basic_any_2 basic_object_1_2)) +basic_number_7_2 ::= (("0") | ([1-9] [0-9]*)) +triggered_tags_group ::= (("assistant<|channel|>analysis<|message|>" any_text "") | ("assistant<|channel|>final<|message|>" any_text_1 "") | ("assistant<|channel|>final<|message|>" any_text_2 "") | ("assistant<|channel|>commentary to=Calculator<|constrain|>json<|message|>" root "<|end|>") | ("assistant<|channel|>commentary to=Weather<|constrain|>json<|message|>" root_1 "<|end|>") | ("assistant<|channel|>analysis to=Python<|message|>" root_2 "<|end|>")) +triggered_tags ::= TagDispatch( + ("<|start|>", triggered_tags_group), + stop_eos=true, + stop_str=(), + loop_after_dispatch=true +) +root_3 ::= ((triggered_tags)) +""", + [ + ( + '<|start|>assistant<|channel|>analysis<|message|>OK, I will use the Calculator tool to perform the operation.<|end|><|start|>assistant<|channel|>commentary to=Calculator<|constrain|>json<|message|>{"operation": "add", "a": 5, "b": 3}<|end|>', + True, + ), + ( + '<|start|>assistant<|channel|>analysis<|message|>I need to know the weather in Paris.<|end|><|start|>assistant<|channel|>commentary to=Weather<|constrain|>json<|message|>{"location": "Paris"}<|end|>', + True, + ), + ( + '<|start|>assistant<|channel|>analysis<|message|>Can you calculate Paris add 5?<|end|><|start|>assistant<|channel|>commentary to=Calculator<|constrain|>json<|message|>{"operation": "add", "a": 5, "b": "Paris"}<|end|>', + False, + ), + ( + '<|start|>assistant<|channel|>analysis<|message|>I want to know the weather in 1.<|end|><|start|>assistant<|channel|>commentary to=Weather<|constrain|>json<|message|>{"location": 1}<|end|>', + False, + ), + ("<|start|>assistant<|channel|>analysis<|message|>Some random text<|end|>", True), + ( + '<|start|>assistant<|channel|>analysis to=Python<|message|>{"code": "print(\\"Hello, World!\\")"}<|end|>', + True, + ), + ( + "<|start|>assistant<|channel|>final<|message|>The function should be called.<|call|>", + True, + ), + ("<|start|>assistant<|channel|>final<|message|>All tasks done.<|return|>", True), + ], + ), + ( + "Llama", + {"tools": []}, + r"""any_text ::= (([\0-\U0010ffff]*)) +root ::= ((any_text)) +""", + [ + ( + 'OK, I will use the Calculator tool to perform the operation. {"name": "Calculator", "parameters": {"operation": "add", "a": 5, "b": 3}}', + True, + ), + ( + 'I need to know the weather in Paris. {"name": "Weather", "parameters": {"location": "Paris"}}', + True, + ), + ( + 'Can you calculate Paris add 5? {"name": "Calculator", "parameters": {"operation": "add", "a": 5, "b": "Paris"}}', + True, + ), + ( + 'I want to know the weather in 1. {"name": "Weather", "parameters": {"location": 1}}', + True, + ), + ("Some random text", True), + ], + ), +] + + +@pytest.mark.parametrize( + "builtin_format_type, template_values, expected_grammar, instance_is_accepted_tuples", + builtin_template_values_stag_grammar_instance_accepted, +) +def test_builtin_template_values( + builtin_format_type: str, + template_values: Dict[str, List[Dict[str, Any]]], + expected_grammar: str, + instance_is_accepted_tuples: List[Tuple[str, bool]], +): + """Test builtin format with template values""" + template_stag_format = xgr.structural_tag.get_builtin_template_structural_tag( + builtin_format_type + ) + + check_template_stag_with_grammar(template_stag_format, expected_grammar, **template_values) + + for instance, is_accepted in instance_is_accepted_tuples: + check_template_stag_with_instance( + template_stag_format, instance, is_accepted, **template_values + ) + + if __name__ == "__main__": pytest.main(sys.argv)