Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions common/chat-peg-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@

using json = nlohmann::json;

static std::string_view trim_trailing_space(std::string_view sv) {
static std::string_view trim_trailing_space(std::string_view sv, int max = -1) {
int count = 0;
while (!sv.empty() && std::isspace(static_cast<unsigned char>(sv.back()))) {
if (max != -1 && count <= max) {
break;
}
sv.remove_suffix(1);
count++;
}
return sv;
}
Expand Down Expand Up @@ -93,14 +98,15 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {

if (is_arg_string && current_tool) {
// Serialize to JSON, but exclude the end quote
std::string dumped = json(node.text).dump();
std::string dumped = json(trim_trailing_space(node.text)).dump();
current_tool->arguments += dumped.substr(0, dumped.size() - 1);
needs_closing_quote = true;
}

if (is_arg_close && current_tool) {
if (needs_closing_quote) {
current_tool->arguments += "\"";
needs_closing_quote = false;
}
}

Expand All @@ -109,6 +115,10 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
}

if (is_tool_close && current_tool) {
if (needs_closing_quote) {
current_tool->arguments += "\"";
needs_closing_quote = false;
}
current_tool->arguments += "}";
}
}
140 changes: 140 additions & 0 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,25 @@ static void foreach_function(const json & tools, const std::function<void(const
}
}

static void foreach_parameter(const json & function, const std::function<void(const std::string &, const json &, bool)> & fn) {
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
return;
}
const auto & params = function.at("parameters");
if (!params.contains("properties") || !params.at("properties").is_object()) {
return;
}
const auto & props = params.at("properties");
std::set<std::string> required;
if (params.contains("required") && params.at("required").is_array()) {
params.at("required").get_to(required);
}
for (const auto & [name, prop] : props.items()) {
bool is_required = (required.find(name) != required.end());
fn(name, prop, is_required);
}
}

static std::string apply(
const common_chat_template & tmpl,
const struct templates_params & inputs,
Expand Down Expand Up @@ -1409,6 +1428,123 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_
return data;
}

static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;

data.prompt = apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_CONSTRUCTED;

// Handle thinking tags appropriately based on inputs.enable_thinking
if (string_ends_with(data.prompt, "<think>\n")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}

data.preserved_tokens = {
"<think>",
"</think>",
"<tool_call>",
"</tool_call>",
};

auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = true;

auto parser = build_chat_peg_constructed_parser([&](auto & p) {
auto reasoning = p.eps();
if (inputs.enable_thinking && extract_reasoning) {
auto reasoning_content = p.reasoning(p.until("</think>")) + ("</think>" | p.end());
if (data.thinking_forced_open) {
reasoning = reasoning_content;
}
}

// Response format parser
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
return reasoning << p.content(p.schema(p.json(), "response-format", inputs.json_schema));
}

// Tool call parser
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
auto tool_choice = p.choice();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");

auto schema_info = common_schema_info();
schema_info.resolve_refs(parameters);

auto tool_open = "<function=" + p.tool_name(p.literal(name)) + ">\n";
auto tool_close = p.literal("</function>\n");
auto args = p.sequence();
auto arg_string = p.rule("xml-arg-string", p.until_one_of({
"\n</parameter>",
"\n<parameter=",
"\n</function>"
}));

foreach_parameter(function, [&](const auto & param_name, const json & param_schema, bool is_required) {
auto rule_name = "tool-" + name + "-arg-" + param_name;

auto arg_open = "<parameter=" + p.tool_arg_name(p.literal(param_name)) + ">\n";
auto arg_close = p.literal("</parameter>\n");
auto arg_value = p.eps();

if (schema_info.resolves_to_string(param_schema)) {
arg_value = p.tool_arg_string_value(arg_string) + "\n";
} else {
arg_value = p.tool_arg_json_value(p.schema(p.json(), rule_name + "-schema", param_schema));
}

// Model may or my not close with </parameter>
auto arg_rule = p.rule(rule_name, p.tool_arg_open(arg_open) + arg_value + p.optional(p.tool_arg_close(arg_close)));
args += p.repeat(arg_rule, /* min = */ is_required ? 1 : 0, /* max = */ 1);
});

tool_choice |= p.rule("tool-" + name, p.tool_open(tool_open) + args + p.tool_close(tool_close));
});

auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
auto tool_call = p.rule("tool-call", "<tool_call>\n" + tool_choice + "</tool_call>" + p.space());
auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls));

return reasoning << p.content(p.until("<tool_call>")) << tool_calls;
}

// Content only parser
include_grammar = false;
return reasoning << p.content(p.rest());
});

data.parser = parser.save();

if (include_grammar) {
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;

data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.at("parameters");
builder.resolve_refs(schema);
});
parser.build_grammar(builder, data.grammar_lazy);
});

data.grammar_triggers = {
{COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"}
};
}

return data;
}


static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;

Expand Down Expand Up @@ -2534,6 +2670,10 @@ static common_chat_params common_chat_templates_apply_jinja(
src.find("<function=") != std::string::npos &&
src.find("<parameters>") != std::string::npos &&
src.find("<parameter=") != std::string::npos) {
// Nemotron 3 Nano 30B A3B
if (src.find("<think>") != std::string::npos) {
return common_chat_params_init_nemotron_v3(tmpl, params);
}
return common_chat_params_init_qwen3_coder_xml(tmpl, params);
}

Expand Down
135 changes: 132 additions & 3 deletions common/json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,9 @@ static std::string format_literal(const std::string & literal) {

std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }

class SchemaConverter {
class common_schema_converter {
private:
friend class common_schema_info;
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
std::function<json(const std::string &)> _fetch_json;
bool _dotall;
Expand Down Expand Up @@ -729,7 +730,7 @@ class SchemaConverter {
}

public:
SchemaConverter(
common_schema_converter(
const std::function<json(const std::string &)> & fetch_json,
bool dotall)
: _fetch_json(fetch_json), _dotall(dotall)
Expand Down Expand Up @@ -990,6 +991,134 @@ class SchemaConverter {
}
};

// common_schema_info implementation (pimpl)

common_schema_info::common_schema_info()
: impl_(std::make_unique<common_schema_converter>(
[](const std::string &) { return json(); },
false)) {}

common_schema_info::~common_schema_info() = default;

common_schema_info::common_schema_info(common_schema_info &&) noexcept = default;
common_schema_info & common_schema_info::operator=(common_schema_info &&) noexcept = default;

void common_schema_info::resolve_refs(nlohmann::ordered_json & schema) {
impl_->resolve_refs(schema, "");
}

// Determines if a JSON schema can resolve to a string type through any path.
// Some models emit raw string values rather than JSON-encoded strings for string parameters.
// If any branch of the schema (via oneOf, anyOf, $ref, etc.) permits a string, this returns
// true, allowing callers to handle the value as a raw string for simplicity.
bool common_schema_info::resolves_to_string(const nlohmann::ordered_json & schema) {
std::unordered_set<std::string> visited_refs;

std::function<bool(const json &)> check = [&](const json & s) -> bool {
if (!s.is_object()) {
return false;
}

// Handle $ref
if (s.contains("$ref")) {
const std::string & ref = s["$ref"];
if (visited_refs.find(ref) != visited_refs.end()) {
// Circular reference, assume not a string to be safe
return false;
}
visited_refs.insert(ref);
auto it = impl_->_refs.find(ref);
if (it != impl_->_refs.end()) {
return check(it->second);
}
return false;
}

// Check type field
if (s.contains("type")) {
const json & schema_type = s["type"];
if (schema_type.is_string()) {
if (schema_type == "string") {
return true;
}
} else if (schema_type.is_array()) {
// Type can be an array like ["string", "null"]
for (const auto & t : schema_type) {
if (t == "string") {
return true;
}
}
}
}

// Check oneOf/anyOf - if any alternative can be a string
if (s.contains("oneOf")) {
for (const auto & alt : s["oneOf"]) {
if (check(alt)) {
return true;
}
}
}
if (s.contains("anyOf")) {
for (const auto & alt : s["anyOf"]) {
if (check(alt)) {
return true;
}
}
}

// Check allOf - all components must be compatible with string type
if (s.contains("allOf")) {
bool all_string = true;
for (const auto & component : s["allOf"]) {
if (!check(component)) {
all_string = false;
break;
}
}
if (all_string) {
return true;
}
}

// Check const - if the constant value is a string
if (s.contains("const")) {
if (s["const"].is_string()) {
return true;
}
}

// Check enum - if any enum value is a string
if (s.contains("enum")) {
for (const auto & val : s["enum"]) {
if (val.is_string()) {
return true;
}
}
}

// String-specific keywords imply string type
if (s.contains("pattern") || s.contains("minLength") || s.contains("maxLength")) {
return true;
}

// Check format - many formats imply string
if (s.contains("format")) {
const std::string & fmt = s["format"];
if (fmt == "date" || fmt == "time" || fmt == "date-time" ||
fmt == "uri" || fmt == "email" || fmt == "hostname" ||
fmt == "ipv4" || fmt == "ipv6" || fmt == "uuid" ||
fmt.find("uuid") == 0) {
return true;
}
}

return false;
};

return check(schema);
}

std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
#ifdef LLAMA_USE_LLGUIDANCE
if (!force_gbnf) {
Expand All @@ -1006,7 +1135,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
}

std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
common_schema_converter converter([&](const std::string &) { return json(); }, options.dotall);
common_grammar_builder builder {
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
return converter._add_rule(name, rule);
Expand Down
Loading
Loading