diff --git a/backend/backend.proto b/backend/backend.proto index f7bcf79726ff..a367523de5c6 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -154,6 +154,8 @@ message PredictOptions { repeated string Videos = 45; repeated string Audios = 46; string CorrelationId = 47; + string Tools = 48; // JSON array of available tools/functions for tool calling + string ToolChoice = 49; // JSON string or object specifying tool choice behavior } // The response message containing the result @@ -382,6 +384,11 @@ message StatusResponse { message Message { string role = 1; string content = 2; + // Optional fields for OpenAI-compatible message format + string name = 3; // Tool name (for tool messages) + string tool_call_id = 4; // Tool call ID (for tool messages) + string reasoning_content = 5; // Reasoning content (for thinking models) + string tool_calls = 6; // Tool calls as JSON string (for assistant messages with tool calls) } message DetectOptions { diff --git a/backend/cpp/llama-cpp/grpc-server.cpp b/backend/cpp/llama-cpp/grpc-server.cpp index a6c610106d4c..a33dc5c20da3 100644 --- a/backend/cpp/llama-cpp/grpc-server.cpp +++ b/backend/cpp/llama-cpp/grpc-server.cpp @@ -27,8 +27,6 @@ using grpc::Status; // END LocalAI - - ///////////////////////////////// //////////////////////////////// //////// LOCALAI code starts below here @@ -37,6 +35,14 @@ using grpc::Status; bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model +// Forward declarations +static void start_llama_server(server_context& ctx_server); +static json parse_options(bool streaming, const backend::PredictOptions* predict, const server_context& ctx_server); +static ggml_type kv_cache_type_from_str(const std::string & s); +static std::string get_all_kv_cache_types(); +static void add_rpc_devices(std::string servers); +static void params_parse(server_context& ctx_server, const backend::ModelOptions* request, common_params & params); + static void start_llama_server(server_context& ctx_server) { LOG_INF("%s: starting llama server\n", __func__); @@ -57,9 +63,8 @@ static void start_llama_server(server_context& ctx_server) { // common_chat_templates_source(ctx_server.chat_templates.get()), // common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str(), ctx_server.params_base.default_template_kwargs); - // Reset the chat templates - // TODO: We should make this configurable by respecting the option that is already present in LocalAI for vLLM - ctx_server.chat_templates.reset(); + // Keep the chat templates initialized in load_model() so they can be used when UseTokenizerTemplate is enabled + // Templates will only be used conditionally in Predict/PredictStream when UseTokenizerTemplate is true and Messages are provided ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { ctx_server.process_single_task(std::move(task)); @@ -114,12 +119,55 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const data["mirostat_eta"] = predict->mirostateta(); data["n_keep"] = predict->nkeep(); data["seed"] = predict->seed(); - data["grammar"] = predict->grammar(); - data["prompt"] = predict->prompt(); + + + std::string grammar_str = predict->grammar(); + + + + if (!grammar_str.empty()) { + data["grammar"] = grammar_str; + SRV_INF("Using grammar: %s\n", grammar_str.c_str()); + } + + // Only set prompt if UseTokenizerTemplate is false or if no Messages are provided + // When UseTokenizerTemplate is true and Messages are provided, prompt will be set via chat templates in Predict/PredictStream + if (!predict->usetokenizertemplate() || predict->messages_size() == 0) { + data["prompt"] = predict->prompt(); + } + + // Extract tools and tool_choice from proto and add to data JSON + if (!predict->tools().empty()) { + try { + // Parse tools JSON string and add to data + json tools_json = json::parse(predict->tools()); + data["tools"] = tools_json; + SRV_INF("Extracted tools from proto: %s\n", predict->tools().c_str()); + } catch (const json::parse_error& e) { + SRV_WRN("Failed to parse tools JSON from proto: %s\n", e.what()); + } + } + if (!predict->toolchoice().empty()) { + try { + // Parse tool_choice JSON string + json tool_choice_json = json::parse(predict->toolchoice()); + // tool_choice can be a string ("auto", "none", "required") or an object + // Store it as-is (string or object) so we can convert object to "required" later when adding to body_json + if (tool_choice_json.is_string()) { + data["tool_choice"] = tool_choice_json.get(); + } else { + // Store object as-is so we can detect it later and convert to "required" + data["tool_choice"] = tool_choice_json; + } + SRV_INF("Extracted tool_choice from proto: %s\n", predict->toolchoice().c_str()); + } catch (const json::parse_error& e) { + // If parsing fails, treat as string + data["tool_choice"] = predict->toolchoice(); + SRV_INF("Extracted tool_choice as string: %s\n", predict->toolchoice().c_str()); + } + } data["ignore_eos"] = predict->ignoreeos(); data["embeddings"] = predict->embeddings(); - // TODO: add back json_schema and let this be controlled by the user - // data["json_schema"] = predict->jsonschema(); // Add the correlationid to json data data["correlation_id"] = predict->correlationid(); @@ -253,27 +301,19 @@ static void params_parse(server_context& ctx_server, const backend::ModelOptions params.cpuparams.n_threads = request->threads(); params.n_gpu_layers = request->ngpulayers(); params.n_batch = request->nbatch(); + //params.verbosity = INT_MAX; + // Enable all debug logs by setting verbosity threshold to maximum + //common_log_set_verbosity_thold(INT_MAX); params.n_ubatch = request->nbatch(); // fixes issue with reranking models being limited to 512 tokens (the default n_ubatch size); allows for setting the maximum input amount of tokens thereby avoiding this error "input is too large to process. increase the physical batch size" - // Set params.n_parallel by environment variable (LLAMA_PARALLEL), defaults to 1 - //params.n_parallel = 1; - const char *env_parallel = std::getenv("LLAMACPP_PARALLEL"); - if (env_parallel != NULL) { - params.n_parallel = std::stoi(env_parallel); - params.cont_batching = true; - } else { - params.n_parallel = 1; - } - - - const char *llama_grpc_servers = std::getenv("LLAMACPP_GRPC_SERVERS"); - if (llama_grpc_servers != NULL) { - add_rpc_devices(std::string(llama_grpc_servers)); - } // Initialize ctx_shift to false by default (can be overridden by options) params.ctx_shift = false; // Initialize cache_ram_mib to -1 by default (no limit, can be overridden by options) params.cache_ram_mib = -1; + // Initialize n_parallel to 1 by default (can be overridden by options) + params.n_parallel = 1; + // Initialize grpc_servers to empty (can be overridden by options) + std::string grpc_servers_option = ""; // decode options. Options are in form optname:optvale, or if booleans only optname. for (int i = 0; i < request->options_size(); i++) { @@ -290,6 +330,12 @@ static void params_parse(server_context& ctx_server, const backend::ModelOptions } else if (!strcmp(optval, "false") || !strcmp(optval, "0") || !strcmp(optval, "no") || !strcmp(optval, "off") || !strcmp(optval, "disabled")) { params.ctx_shift = false; } + } else if (!strcmp(optname, "use_jinja") || !strcmp(optname, "jinja")) { + if (!strcmp(optval, "true") || !strcmp(optval, "1") || !strcmp(optval, "yes") || !strcmp(optval, "on") || !strcmp(optval, "enabled")) { + params.use_jinja = true; + } else if (!strcmp(optval, "false") || !strcmp(optval, "0") || !strcmp(optval, "no") || !strcmp(optval, "off") || !strcmp(optval, "disabled")) { + params.use_jinja = false; + } } else if (!strcmp(optname, "cache_ram")) { if (optval != NULL) { try { @@ -298,6 +344,46 @@ static void params_parse(server_context& ctx_server, const backend::ModelOptions // If conversion fails, keep default value (-1) } } + } else if (!strcmp(optname, "parallel") || !strcmp(optname, "n_parallel")) { + if (optval != NULL) { + try { + params.n_parallel = std::stoi(optval); + if (params.n_parallel > 1) { + params.cont_batching = true; + } + } catch (const std::exception& e) { + // If conversion fails, keep default value (1) + } + } + } else if (!strcmp(optname, "grpc_servers") || !strcmp(optname, "rpc_servers")) { + if (optval != NULL) { + grpc_servers_option = std::string(optval); + } + } + } + + // Set params.n_parallel from environment variable if not set via options (fallback) + if (params.n_parallel == 1) { + const char *env_parallel = std::getenv("LLAMACPP_PARALLEL"); + if (env_parallel != NULL) { + try { + params.n_parallel = std::stoi(env_parallel); + if (params.n_parallel > 1) { + params.cont_batching = true; + } + } catch (const std::exception& e) { + // If conversion fails, keep default value (1) + } + } + } + + // Add RPC devices from option or environment variable (fallback) + if (!grpc_servers_option.empty()) { + add_rpc_devices(grpc_servers_option); + } else { + const char *llama_grpc_servers = std::getenv("LLAMACPP_GRPC_SERVERS"); + if (llama_grpc_servers != NULL) { + add_rpc_devices(std::string(llama_grpc_servers)); } } @@ -422,6 +508,8 @@ class BackendServiceImpl final : public backend::Backend::Service { params_parse(ctx_server, request, params); common_init(); + // Ensure debug logs are enabled after common_init() sets up logging + common_log_set_verbosity_thold(params.verbosity); llama_backend_init(); llama_numa_init(params.numa); @@ -495,46 +583,213 @@ class BackendServiceImpl final : public backend::Backend::Service { try { std::vector tasks; - const auto & prompt = data.at("prompt"); + std::string prompt_str; + std::vector files; // Declare files early so it's accessible in both branches + // Handle chat templates when UseTokenizerTemplate is enabled and Messages are provided + if (request->usetokenizertemplate() && request->messages_size() > 0 && ctx_server.chat_templates != nullptr) { + // Convert proto Messages to JSON format compatible with oaicompat_chat_params_parse + json body_json; + json messages_json = json::array(); + for (int i = 0; i < request->messages_size(); i++) { + const auto& msg = request->messages(i); + json msg_json; + msg_json["role"] = msg.role(); + + // Handle content - can be string, null, or array + // For multimodal content, we'll embed images/audio from separate fields + if (!msg.content().empty()) { + msg_json["content"] = msg.content(); + } else if (request->images_size() > 0 || request->audios_size() > 0) { + // If no content but has images/audio, create content array + json content_array = json::array(); + if (request->images_size() > 0) { + for (int j = 0; j < request->images_size(); j++) { + json image_chunk; + image_chunk["type"] = "image_url"; + json image_url; + image_url["url"] = "data:image/jpeg;base64," + request->images(j); + image_chunk["image_url"] = image_url; + content_array.push_back(image_chunk); + } + } + if (request->audios_size() > 0) { + for (int j = 0; j < request->audios_size(); j++) { + json audio_chunk; + audio_chunk["type"] = "input_audio"; + json input_audio; + input_audio["data"] = request->audios(j); + input_audio["format"] = "wav"; // default, could be made configurable + audio_chunk["input_audio"] = input_audio; + content_array.push_back(audio_chunk); + } + } + msg_json["content"] = content_array; + } + + // Add optional fields for OpenAI-compatible message format + if (!msg.name().empty()) { + msg_json["name"] = msg.name(); + } + if (!msg.tool_call_id().empty()) { + msg_json["tool_call_id"] = msg.tool_call_id(); + } + if (!msg.reasoning_content().empty()) { + msg_json["reasoning_content"] = msg.reasoning_content(); + } + if (!msg.tool_calls().empty()) { + // Parse tool_calls JSON string and add to message + try { + json tool_calls = json::parse(msg.tool_calls()); + msg_json["tool_calls"] = tool_calls; + } catch (const json::parse_error& e) { + SRV_WRN("Failed to parse tool_calls JSON: %s\n", e.what()); + } + } + + messages_json.push_back(msg_json); + } + + body_json["messages"] = messages_json; + body_json["stream"] = true; // PredictStream is always streaming + + // Check if grammar is provided from Go layer (NoGrammar=false) + // If grammar is provided, we must use it and NOT let template generate grammar from tools + // oaicompat_chat_params_parse throws an error if both grammar and tools are provided + bool has_grammar_from_go = data.contains("grammar") && + data["grammar"].is_string() && + !data["grammar"].get().empty(); + + // Copy other relevant fields from data that oaicompat_chat_params_parse expects + // Tools and tool_choice are only passed when NoGrammar is true (grammar not provided) + // When grammar is provided from Go layer, we use it instead of template-generated grammar + if (!has_grammar_from_go) { + // NoGrammar=true: pass tools and let template generate grammar + if (data.contains("tools")) { + body_json["tools"] = data["tools"]; + std::string tools_str = data["tools"].dump(); + SRV_INF("Using tools from data (NoGrammar=true): %s\n", tools_str.c_str()); + } else { + SRV_WRN("%s", "No tools found in data - tool calls will not work without tools field\n"); + } + if (data.contains("tool_choice")) { + // tool_choice can be a string or object, but oaicompat_chat_params_parse expects a string + // Convert object tool_choice to "required" (since a specific function is requested) + if (data["tool_choice"].is_string()) { + body_json["tool_choice"] = data["tool_choice"].get(); + } else if (data["tool_choice"].is_object()) { + // Object tool_choice means a specific function is requested, use "required" + body_json["tool_choice"] = "required"; + std::string tool_choice_obj_str = data["tool_choice"].dump(); + SRV_INF("Converted object tool_choice to 'required': %s\n", tool_choice_obj_str.c_str()); + } else { + // Fallback: convert to string + body_json["tool_choice"] = data["tool_choice"].dump(); + } + std::string tool_choice_str = body_json["tool_choice"].get(); + SRV_INF("Using tool_choice: %s\n", tool_choice_str.c_str()); + } else { + // Default to "auto" if not specified + body_json["tool_choice"] = "auto"; + } + } else { + // Grammar is provided from Go layer (NoGrammar=false) - use it, don't pass tools + SRV_INF("%s", "Grammar provided from Go layer - using it instead of template-generated grammar\n"); + // Grammar will be copied from data after parsing (it's already in data) + } + + if (data.contains("json_schema")) { + body_json["json_schema"] = data["json_schema"]; + } + // If grammar is provided from Go layer, copy it to body_json so it's preserved + // (though oaicompat_chat_params_parse may not use it if tools are present) + if (has_grammar_from_go) { + body_json["grammar"] = data["grammar"]; + } + if (data.contains("response_format")) { + body_json["response_format"] = data["response_format"]; + } + if (data.contains("chat_template_kwargs")) { + body_json["chat_template_kwargs"] = data["chat_template_kwargs"]; + } + + // Use the same approach as server.cpp: call oaicompat_chat_params_parse + // This handles all template application, grammar merging, etc. automatically + // Files extracted from multimodal content in messages will be added to the files vector + // Create parser options with current chat_templates to ensure tmpls is not null + oaicompat_parser_options parser_opt = ctx_server.oai_parser_opt; + parser_opt.tmpls = ctx_server.chat_templates.get(); // Ensure tmpls is set to current chat_templates + json parsed_data = oaicompat_chat_params_parse(body_json, parser_opt, files); + + // Extract the prompt from parsed data + prompt_str = parsed_data.at("prompt").get(); + + // Preserve grammar from Go layer if it was provided (NoGrammar=false) + // Otherwise, use grammar from parsed_data (template-generated when NoGrammar=true) + json preserved_grammar; + if (has_grammar_from_go && data.contains("grammar")) { + preserved_grammar = data["grammar"]; + } + + // Merge all fields from parsed_data into data (grammar, grammar_triggers, preserved_tokens, etc.) + // This ensures all template-generated fields are included + for (const auto& item : parsed_data.items()) { + if (item.key() != "prompt") { // Don't overwrite prompt_str, we already extracted it + // If grammar was provided from Go layer, preserve it instead of template-generated grammar + if (item.key() == "grammar" && has_grammar_from_go && !preserved_grammar.is_null()) { + data["grammar"] = preserved_grammar; + } else { + data[item.key()] = item.value(); + } + } + } + } else { + // Use prompt directly from data + if (data.contains("prompt") && data["prompt"].is_string()) { + prompt_str = data["prompt"].get(); + } else { + prompt_str = request->prompt(); + } + } + + const auto & prompt = prompt_str; const auto type = SERVER_TASK_TYPE_COMPLETION; // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - std::vector files; - const auto &images_data = data.find("image_data"); - if (images_data != data.end() && images_data->is_array()) - { - for (const auto &img : *images_data) + // If not using chat templates, extract files from image_data/audio_data fields + // (If using chat templates, files were already extracted by oaicompat_chat_params_parse) + //if (!request->usetokenizertemplate() || request->messages_size() == 0 || ctx_server.chat_templates == nullptr) { + const auto &images_data = data.find("image_data"); + if (images_data != data.end() && images_data->is_array()) { - auto decoded_data = base64_decode(img["data"].get()); - files.push_back(decoded_data); + for (const auto &img : *images_data) + { + auto decoded_data = base64_decode(img["data"].get()); + files.push_back(decoded_data); + } } - } - const auto &audio_data = data.find("audio_data"); - if (audio_data != data.end() && audio_data->is_array()) - { - for (const auto &audio : *audio_data) + const auto &audio_data = data.find("audio_data"); + if (audio_data != data.end() && audio_data->is_array()) { - auto decoded_data = base64_decode(audio["data"].get()); - files.push_back(decoded_data); + for (const auto &audio : *audio_data) + { + auto decoded_data = base64_decode(audio["data"].get()); + files.push_back(decoded_data); + } } - } + // } const bool has_mtmd = ctx_server.mctx != nullptr; // process prompt std::vector inputs; - if (!prompt.is_string()) { - throw std::runtime_error("prompt must be a string"); - } - if (has_mtmd) { // multimodal - inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); + inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt_str, files)); } else { // Everything else, including multimodal completions. - inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); + inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt_str, true, true); } tasks.reserve(inputs.size()); @@ -644,52 +899,219 @@ class BackendServiceImpl final : public backend::Backend::Service { try { std::vector tasks; - const auto & prompt = data.at("prompt"); + std::string prompt_str; + std::vector files; // Declare files early so it's accessible in both branches + // Handle chat templates when UseTokenizerTemplate is enabled and Messages are provided + if (request->usetokenizertemplate() && request->messages_size() > 0 && ctx_server.chat_templates != nullptr) { + // Convert proto Messages to JSON format compatible with oaicompat_chat_params_parse + json body_json; + json messages_json = json::array(); + for (int i = 0; i < request->messages_size(); i++) { + const auto& msg = request->messages(i); + json msg_json; + msg_json["role"] = msg.role(); + + // Handle content - can be string, null, or array + // For multimodal content, we'll embed images/audio from separate fields + if (!msg.content().empty()) { + msg_json["content"] = msg.content(); + } else if (request->images_size() > 0 || request->audios_size() > 0) { + // If no content but has images/audio, create content array + json content_array = json::array(); + if (request->images_size() > 0) { + for (int j = 0; j < request->images_size(); j++) { + json image_chunk; + image_chunk["type"] = "image_url"; + json image_url; + image_url["url"] = "data:image/jpeg;base64," + request->images(j); + image_chunk["image_url"] = image_url; + content_array.push_back(image_chunk); + } + } + if (request->audios_size() > 0) { + for (int j = 0; j < request->audios_size(); j++) { + json audio_chunk; + audio_chunk["type"] = "input_audio"; + json input_audio; + input_audio["data"] = request->audios(j); + input_audio["format"] = "wav"; // default, could be made configurable + audio_chunk["input_audio"] = input_audio; + content_array.push_back(audio_chunk); + } + } + msg_json["content"] = content_array; + } else if (!msg.tool_calls().empty()) { + // Tool call messages may have null content + msg_json["content"] = json(); + } + + // Add optional fields for OpenAI-compatible message format + if (!msg.name().empty()) { + msg_json["name"] = msg.name(); + } + if (!msg.tool_call_id().empty()) { + msg_json["tool_call_id"] = msg.tool_call_id(); + } + if (!msg.reasoning_content().empty()) { + msg_json["reasoning_content"] = msg.reasoning_content(); + } + if (!msg.tool_calls().empty()) { + // Parse tool_calls JSON string and add to message + try { + json tool_calls = json::parse(msg.tool_calls()); + msg_json["tool_calls"] = tool_calls; + } catch (const json::parse_error& e) { + SRV_WRN("Failed to parse tool_calls JSON: %s\n", e.what()); + } + } + + messages_json.push_back(msg_json); + } + + body_json["messages"] = messages_json; + body_json["stream"] = false; + + // Check if grammar is provided from Go layer (NoGrammar=false) + // If grammar is provided, we must use it and NOT let template generate grammar from tools + // oaicompat_chat_params_parse throws an error if both grammar and tools are provided + bool has_grammar_from_go = data.contains("grammar") && + data["grammar"].is_string() && + !data["grammar"].get().empty(); + + // Copy other relevant fields from data that oaicompat_chat_params_parse expects + // Tools and tool_choice are only passed when NoGrammar is true (grammar not provided) + // When grammar is provided from Go layer, we use it instead of template-generated grammar + if (!has_grammar_from_go) { + // NoGrammar=true: pass tools and let template generate grammar + if (data.contains("tools")) { + body_json["tools"] = data["tools"]; + std::string tools_str = data["tools"].dump(); + SRV_INF("Using tools from data (NoGrammar=true): %s\n", tools_str.c_str()); + } else { + SRV_WRN("%s", "No tools found in data - tool calls will not work without tools field\n"); + } + if (data.contains("tool_choice")) { + // tool_choice can be a string or object, but oaicompat_chat_params_parse expects a string + // Convert object tool_choice to "required" (since a specific function is requested) + if (data["tool_choice"].is_string()) { + body_json["tool_choice"] = data["tool_choice"].get(); + } else if (data["tool_choice"].is_object()) { + // Object tool_choice means a specific function is requested, use "required" + body_json["tool_choice"] = "required"; + std::string tool_choice_obj_str = data["tool_choice"].dump(); + SRV_INF("Converted object tool_choice to 'required': %s\n", tool_choice_obj_str.c_str()); + } else { + // Fallback: convert to string + body_json["tool_choice"] = data["tool_choice"].dump(); + } + std::string tool_choice_str = body_json["tool_choice"].get(); + SRV_INF("Using tool_choice: %s\n", tool_choice_str.c_str()); + } else { + // Default to "auto" if not specified + body_json["tool_choice"] = "auto"; + } + } else { + // Grammar is provided from Go layer (NoGrammar=false) - use it, don't pass tools + SRV_INF("%s", "Grammar provided from Go layer - using it instead of template-generated grammar\n"); + // Grammar will be copied from data after parsing (it's already in data) + } + + if (data.contains("json_schema")) { + body_json["json_schema"] = data["json_schema"]; + } + // If grammar is provided from Go layer, copy it to body_json so it's preserved + // (though oaicompat_chat_params_parse may not use it if tools are present) + if (has_grammar_from_go) { + body_json["grammar"] = data["grammar"]; + } + if (data.contains("response_format")) { + body_json["response_format"] = data["response_format"]; + } + if (data.contains("chat_template_kwargs")) { + body_json["chat_template_kwargs"] = data["chat_template_kwargs"]; + } + + // Use the same approach as server.cpp: call oaicompat_chat_params_parse + // This handles all template application, grammar merging, etc. automatically + // Files extracted from multimodal content in messages will be added to the files vector + // Create parser options with current chat_templates to ensure tmpls is not null + oaicompat_parser_options parser_opt = ctx_server.oai_parser_opt; + parser_opt.tmpls = ctx_server.chat_templates.get(); // Ensure tmpls is set to current chat_templates + json parsed_data = oaicompat_chat_params_parse(body_json, parser_opt, files); + + // Extract the prompt from parsed data + prompt_str = parsed_data.at("prompt").get(); + + // Preserve grammar from Go layer if it was provided (NoGrammar=false) + // Otherwise, use grammar from parsed_data (template-generated when NoGrammar=true) + json preserved_grammar; + if (has_grammar_from_go && data.contains("grammar")) { + preserved_grammar = data["grammar"]; + } + + // Merge all fields from parsed_data into data (grammar, grammar_triggers, preserved_tokens, etc.) + // This ensures all template-generated fields are included + for (const auto& item : parsed_data.items()) { + if (item.key() != "prompt") { // Don't overwrite prompt_str, we already extracted it + // If grammar was provided from Go layer, preserve it instead of template-generated grammar + if (item.key() == "grammar" && has_grammar_from_go && !preserved_grammar.is_null()) { + data["grammar"] = preserved_grammar; + } else { + data[item.key()] = item.value(); + } + } + } + } else { + // Use prompt directly from data + if (data.contains("prompt") && data["prompt"].is_string()) { + prompt_str = data["prompt"].get(); + } else { + prompt_str = request->prompt(); + } + } + + const auto & prompt = prompt_str; const auto type = SERVER_TASK_TYPE_COMPLETION; // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - std::vector files; - const auto &images_data = data.find("image_data"); - // std::cout << "[PREDICT] Images data: " << images_data->dump(2) << std::endl; - - if (images_data != data.end() && images_data->is_array()) - { - std::cout << "[PREDICT] Processing " << images_data->size() << " images" << std::endl; - for (const auto &img : *images_data) + // If not using chat templates, extract files from image_data/audio_data fields + // (If using chat templates, files were already extracted by oaicompat_chat_params_parse) + // if (!request->usetokenizertemplate() || request->messages_size() == 0 || ctx_server.chat_templates == nullptr) { + const auto &images_data = data.find("image_data"); + if (images_data != data.end() && images_data->is_array()) { - std::cout << "[PREDICT] Processing image" << std::endl; - auto decoded_data = base64_decode(img["data"].get()); - files.push_back(decoded_data); + std::cout << "[PREDICT] Processing " << images_data->size() << " images" << std::endl; + for (const auto &img : *images_data) + { + std::cout << "[PREDICT] Processing image" << std::endl; + auto decoded_data = base64_decode(img["data"].get()); + files.push_back(decoded_data); + } } - } - const auto &audio_data = data.find("audio_data"); - if (audio_data != data.end() && audio_data->is_array()) - { - for (const auto &audio : *audio_data) + const auto &audio_data = data.find("audio_data"); + if (audio_data != data.end() && audio_data->is_array()) { - auto decoded_data = base64_decode(audio["data"].get()); - files.push_back(decoded_data); + for (const auto &audio : *audio_data) + { + auto decoded_data = base64_decode(audio["data"].get()); + files.push_back(decoded_data); + } } - } + // } // process files const bool has_mtmd = ctx_server.mctx != nullptr; // process prompt std::vector inputs; - if (!prompt.is_string()) { - std::cout << "[PREDICT] Prompt must be a string" << std::endl; - throw std::runtime_error("prompt must be a string"); - } - if (has_mtmd) { // multimodal - inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); + inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt_str, files)); } else { // Everything else, including multimodal completions. - inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); + inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt_str, true, true); } tasks.reserve(inputs.size()); diff --git a/core/backend/llm.go b/core/backend/llm.go index ffc71497522f..d6c7bc736e93 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -2,8 +2,6 @@ package backend import ( "context" - "encoding/json" - "fmt" "regexp" "slices" "strings" @@ -35,7 +33,7 @@ type TokenUsage struct { TimingTokenGeneration float64 } -func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { +func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string) (func() (LLMResponse, error), error) { modelFile := c.Model // Check if the modelFile exists, if it doesn't try to load it from the gallery @@ -65,29 +63,8 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im var protoMessages []*proto.Message // if we are using the tokenizer template, we need to convert the messages to proto messages // unless the prompt has already been tokenized (non-chat endpoints + functions) - if c.TemplateConfig.UseTokenizerTemplate && s == "" { - protoMessages = make([]*proto.Message, len(messages), len(messages)) - for i, message := range messages { - protoMessages[i] = &proto.Message{ - Role: message.Role, - } - switch ct := message.Content.(type) { - case string: - protoMessages[i].Content = ct - case []interface{}: - // If using the tokenizer template, in case of multimodal we want to keep the multimodal content as and return only strings here - data, _ := json.Marshal(ct) - resultData := []struct { - Text string `json:"text"` - }{} - json.Unmarshal(data, &resultData) - for _, r := range resultData { - protoMessages[i].Content += r.Text - } - default: - return nil, fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct) - } - } + if c.TemplateConfig.UseTokenizerTemplate && len(messages) > 0 { + protoMessages = messages.ToProto() } // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported @@ -99,6 +76,8 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im opts.Images = images opts.Videos = videos opts.Audios = audios + opts.Tools = tools + opts.ToolChoice = toolChoice tokenUsage := TokenUsage{} diff --git a/core/config/gguf.go b/core/config/gguf.go index edc7d523083f..6d67d798bd9b 100644 --- a/core/config/gguf.go +++ b/core/config/gguf.go @@ -1,151 +1,17 @@ package config import ( - "strings" - "github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/rs/zerolog/log" gguf "github.com/gpustack/gguf-parser-go" ) -type familyType uint8 - -const ( - Unknown familyType = iota - LLaMa3 - CommandR - Phi3 - ChatML - Mistral03 - Gemma - DeepSeek2 -) - const ( defaultContextSize = 1024 defaultNGPULayers = 99999999 ) -type settingsConfig struct { - StopWords []string - TemplateConfig TemplateConfig - RepeatPenalty float64 -} - -// default settings to adopt with a given model family -var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{ - Gemma: { - RepeatPenalty: 1.0, - StopWords: []string{"<|im_end|>", "", ""}, - TemplateConfig: TemplateConfig{ - Chat: "{{.Input }}\nmodel\n", - ChatMessage: "{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}", - Completion: "{{.Input}}", - }, - }, - DeepSeek2: { - StopWords: []string{"<|end▁of▁sentence|>"}, - TemplateConfig: TemplateConfig{ - ChatMessage: `{{if eq .RoleName "user" -}}User: {{.Content }} -{{ end -}} -{{if eq .RoleName "assistant" -}}Assistant: {{.Content}}<|end▁of▁sentence|>{{end}} -{{if eq .RoleName "system" -}}{{.Content}} -{{end -}}`, - Chat: "{{.Input -}}\nAssistant: ", - }, - }, - LLaMa3: { - StopWords: []string{"<|eot_id|>"}, - TemplateConfig: TemplateConfig{ - Chat: "<|begin_of_text|>{{.Input }}\n<|start_header_id|>assistant<|end_header_id|>", - ChatMessage: "<|start_header_id|>{{ .RoleName }}<|end_header_id|>\n\n{{.Content }}<|eot_id|>", - }, - }, - CommandR: { - TemplateConfig: TemplateConfig{ - Chat: "{{.Input -}}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", - Functions: `<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|> -You are a function calling AI model, you can call the following functions: -## Available Tools -{{range .Functions}} -- {"type": "function", "function": {"name": "{{.Name}}", "description": "{{.Description}}", "parameters": {{toJson .Parameters}} }} -{{end}} -When using a tool, reply with JSON, for instance {"name": "tool_name", "arguments": {"param1": "value1", "param2": "value2"}} -<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{.Input -}}`, - ChatMessage: `{{if eq .RoleName "user" -}} -<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|> -{{- else if eq .RoleName "system" -}} -<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|> -{{- else if eq .RoleName "assistant" -}} -<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|> -{{- else if eq .RoleName "tool" -}} -<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|> -{{- else if .FunctionCall -}} -<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{toJson .FunctionCall}}}<|END_OF_TURN_TOKEN|> -{{- end -}}`, - }, - StopWords: []string{"<|END_OF_TURN_TOKEN|>"}, - }, - Phi3: { - TemplateConfig: TemplateConfig{ - Chat: "{{.Input}}\n<|assistant|>", - ChatMessage: "<|{{ .RoleName }}|>\n{{.Content}}<|end|>", - Completion: "{{.Input}}", - }, - StopWords: []string{"<|end|>", "<|endoftext|>"}, - }, - ChatML: { - TemplateConfig: TemplateConfig{ - Chat: "{{.Input -}}\n<|im_start|>assistant", - Functions: `<|im_start|>system -You are a function calling AI model. You are provided with functions to execute. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: -{{range .Functions}} -{'type': 'function', 'function': {'name': '{{.Name}}', 'description': '{{.Description}}', 'parameters': {{toJson .Parameters}} }} -{{end}} -For each function call return a json object with function name and arguments -<|im_end|> -{{.Input -}} -<|im_start|>assistant`, - ChatMessage: `<|im_start|>{{ .RoleName }} -{{ if .FunctionCall -}} -Function call: -{{ else if eq .RoleName "tool" -}} -Function response: -{{ end -}} -{{ if .Content -}} -{{.Content }} -{{ end -}} -{{ if .FunctionCall -}} -{{toJson .FunctionCall}} -{{ end -}}<|im_end|>`, - }, - StopWords: []string{"<|im_end|>", "", ""}, - }, - Mistral03: { - TemplateConfig: TemplateConfig{ - Chat: "{{.Input -}}", - Functions: `[AVAILABLE_TOOLS] [{{range .Functions}}{"type": "function", "function": {"name": "{{.Name}}", "description": "{{.Description}}", "parameters": {{toJson .Parameters}} }}{{end}} ] [/AVAILABLE_TOOLS]{{.Input }}`, - ChatMessage: `{{if eq .RoleName "user" -}} -[INST] {{.Content }} [/INST] -{{- else if .FunctionCall -}} -[TOOL_CALLS] {{toJson .FunctionCall}} [/TOOL_CALLS] -{{- else if eq .RoleName "tool" -}} -[TOOL_RESULTS] {{.Content}} [/TOOL_RESULTS] -{{- else -}} -{{ .Content -}} -{{ end -}}`, - }, - StopWords: []string{"<|im_end|>", "", "", "<|eot_id|>", "<|end_of_text|>", "", "[/TOOL_CALLS]", "[/ACTIONS]"}, - }, -} - -// this maps well known template used in HF to model families defined above -var knownTemplates = map[string]familyType{ - `{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\n' + content + '<|im_end|>\n<|im_start|>assistant\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\n' }}{% endif %}{% endfor %}`: ChatML, - `{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}`: Mistral03, -} - func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) { if defaultCtx == 0 && cfg.ContextSize == nil { @@ -216,81 +82,9 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) { cfg.Name = f.Metadata().Name } - family := identifyFamily(f) - - if family == Unknown { - log.Debug().Msgf("guessDefaultsFromFile: %s", "family not identified") - return - } - - // identify template - settings, ok := defaultsSettings[family] - if ok { - cfg.TemplateConfig = settings.TemplateConfig - log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: guessed template %+v", cfg.TemplateConfig) - if len(cfg.StopWords) == 0 { - cfg.StopWords = settings.StopWords - } - if cfg.RepeatPenalty == 0.0 { - cfg.RepeatPenalty = settings.RepeatPenalty - } - } else { - log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: no template found for family") - } - - if cfg.HasTemplate() { - return - } - - // identify from well known templates first, otherwise use the raw jinja template - chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template") - if found { - // try to use the jinja template - cfg.TemplateConfig.JinjaTemplate = true - cfg.TemplateConfig.ChatMessage = chatTemplate.ValueString() - } - -} - -func identifyFamily(f *gguf.GGUFFile) familyType { - - // identify from well known templates first - chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template") - if found && chatTemplate.ValueString() != "" { - if family, ok := knownTemplates[chatTemplate.ValueString()]; ok { - return family - } - } - - // otherwise try to identify from the model properties - arch := f.Architecture().Architecture - eosTokenID := f.Tokenizer().EOSTokenID - bosTokenID := f.Tokenizer().BOSTokenID - - isYI := arch == "llama" && bosTokenID == 1 && eosTokenID == 2 - // WTF! Mistral0.3 and isYi have same bosTokenID and eosTokenID - - llama3 := arch == "llama" && eosTokenID == 128009 - commandR := arch == "command-r" && eosTokenID == 255001 - qwen2 := arch == "qwen2" - phi3 := arch == "phi-3" - gemma := strings.HasPrefix(arch, "gemma") || strings.Contains(strings.ToLower(f.Metadata().Name), "gemma") - deepseek2 := arch == "deepseek2" - - switch { - case deepseek2: - return DeepSeek2 - case gemma: - return Gemma - case llama3: - return LLaMa3 - case commandR: - return CommandR - case phi3: - return Phi3 - case qwen2, isYI: - return ChatML - default: - return Unknown - } + // Instruct to use template from llama.cpp + cfg.TemplateConfig.UseTokenizerTemplate = true + cfg.FunctionsConfig.GrammarConfig.NoGrammar = true + cfg.Options = append(cfg.Options, "use_jinja:true") + cfg.KnownUsecaseStrings = append(cfg.KnownUsecaseStrings, "FLAG_CHAT") } diff --git a/core/config/model_config.go b/core/config/model_config.go index a5bd65cdcd69..87fa05fce8ca 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -265,19 +265,10 @@ type TemplateConfig struct { Multimodal string `yaml:"multimodal" json:"multimodal"` - JinjaTemplate bool `yaml:"jinja_template" json:"jinja_template"` - ReplyPrefix string `yaml:"reply_prefix" json:"reply_prefix"` } -func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error { - type BCAlias ModelConfig - var aux BCAlias - if err := value.Decode(&aux); err != nil { - return err - } - *c = ModelConfig(aux) - +func (c *ModelConfig) syncKnownUsecasesFromString() { c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings) // Make sure the usecases are valid, we rewrite with what we identified c.KnownUsecaseStrings = []string{} @@ -286,6 +277,17 @@ func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error { c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k) } } +} + +func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error { + type BCAlias ModelConfig + var aux BCAlias + if err := value.Decode(&aux); err != nil { + return err + } + *c = ModelConfig(aux) + + c.syncKnownUsecasesFromString() return nil } @@ -462,6 +464,7 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) { } guessDefaultsFromFile(cfg, lo.modelPath, ctx) + cfg.syncKnownUsecasesFromString() } func (c *ModelConfig) Validate() bool { @@ -492,7 +495,7 @@ func (c *ModelConfig) Validate() bool { } func (c *ModelConfig) HasTemplate() bool { - return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != "" + return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != "" || c.TemplateConfig.UseTokenizerTemplate } func (c *ModelConfig) GetModelConfigFile() string { @@ -573,7 +576,7 @@ func (c *ModelConfig) HasUsecases(u ModelConfigUsecases) bool { // This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently. func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool { if (u & FLAG_CHAT) == FLAG_CHAT { - if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" { + if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate { return false } } diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 4e91a532983c..d1ce156215c4 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -217,6 +217,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator noActionDescription = config.FunctionsConfig.NoActionDescriptionName } + // If we are using a response format, we need to generate a grammar for it if config.ResponseFormatMap != nil { d := schema.ChatCompletionResponseFormat{} dat, err := json.Marshal(config.ResponseFormatMap) @@ -260,6 +261,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator } switch { + // Generates grammar with internal's LocalAI engine case (!config.FunctionsConfig.GrammarConfig.NoGrammar || strictMode) && shouldUseFn: noActionGrammar := functions.Function{ Name: noActionName, @@ -283,7 +285,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator funcs = funcs.Select(config.FunctionToCall()) } - // Update input grammar + // Update input grammar or json_schema based on use_llama_grammar option jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey) g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...) if err == nil { @@ -298,6 +300,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator } else { log.Error().Err(err).Msg("Failed generating grammar") } + default: // Force picking one of the functions by the request if config.FunctionToCall() != "" { @@ -316,7 +319,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator // If we are using the tokenizer template, we don't need to process the messages // unless we are processing functions - if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn { + if !config.TemplateConfig.UseTokenizerTemplate { predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn) log.Debug().Msgf("Prompt (after templating): %s", predInput) @@ -597,7 +600,23 @@ func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, in audios = append(audios, m.StringAudios...) } - predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil) + // Serialize tools and tool_choice to JSON strings + toolsJSON := "" + if len(input.Tools) > 0 { + toolsBytes, err := json.Marshal(input.Tools) + if err == nil { + toolsJSON = string(toolsBytes) + } + } + toolChoiceJSON := "" + if input.ToolsChoice != nil { + toolChoiceBytes, err := json.Marshal(input.ToolsChoice) + if err == nil { + toolChoiceJSON = string(toolChoiceBytes) + } + } + + predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON) if err != nil { log.Error().Err(err).Msg("model inference failed") return "", err diff --git a/core/http/endpoints/openai/inference.go b/core/http/endpoints/openai/inference.go index b7b256bad0c4..95d3ee24671d 100644 --- a/core/http/endpoints/openai/inference.go +++ b/core/http/endpoints/openai/inference.go @@ -1,6 +1,8 @@ package openai import ( + "encoding/json" + "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" @@ -37,8 +39,25 @@ func ComputeChoices( audios = append(audios, m.StringAudios...) } + // Serialize tools and tool_choice to JSON strings + toolsJSON := "" + if len(req.Tools) > 0 { + toolsBytes, err := json.Marshal(req.Tools) + if err == nil { + toolsJSON = string(toolsBytes) + } + } + toolChoiceJSON := "" + if req.ToolsChoice != nil { + toolChoiceBytes, err := json.Marshal(req.ToolsChoice) + if err == nil { + toolChoiceJSON = string(toolChoiceBytes) + } + } + // get the model function to call for the result - predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback) + predFunc, err := backend.ModelInference( + req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON) if err != nil { return result, backend.TokenUsage{}, err } diff --git a/core/schema/message.go b/core/schema/message.go new file mode 100644 index 000000000000..793f5fca234b --- /dev/null +++ b/core/schema/message.go @@ -0,0 +1,85 @@ +package schema + +import ( + "encoding/json" + + "github.com/rs/zerolog/log" + + "github.com/mudler/LocalAI/pkg/grpc/proto" +) + +type Message struct { + // The message role + Role string `json:"role,omitempty" yaml:"role"` + + // The message name (used for tools calls) + Name string `json:"name,omitempty" yaml:"name"` + + // The message content + Content interface{} `json:"content" yaml:"content"` + + StringContent string `json:"string_content,omitempty" yaml:"string_content,omitempty"` + StringImages []string `json:"string_images,omitempty" yaml:"string_images,omitempty"` + StringVideos []string `json:"string_videos,omitempty" yaml:"string_videos,omitempty"` + StringAudios []string `json:"string_audios,omitempty" yaml:"string_audios,omitempty"` + + // A result of a function call + FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` + + ToolCalls []ToolCall `json:"tool_calls,omitempty" yaml:"tool_call,omitempty"` +} + +type ToolCall struct { + Index int `json:"index"` + ID string `json:"id"` + Type string `json:"type"` + FunctionCall FunctionCall `json:"function"` +} + +type FunctionCall struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments"` +} + +type Messages []Message + +// MessagesToProto converts schema.Message slice to proto.Message slice +// It handles content conversion, tool_calls serialization, and optional fields +func (messages Messages) ToProto() []*proto.Message { + protoMessages := make([]*proto.Message, len(messages)) + for i, message := range messages { + protoMessages[i] = &proto.Message{ + Role: message.Role, + Name: message.Name, // needed by function calls + } + + switch ct := message.Content.(type) { + case string: + protoMessages[i].Content = ct + case []interface{}: + // If using the tokenizer template, in case of multimodal we want to keep the multimodal content as and return only strings here + data, _ := json.Marshal(ct) + resultData := []struct { + Text string `json:"text"` + }{} + json.Unmarshal(data, &resultData) + for _, r := range resultData { + protoMessages[i].Content += r.Text + } + } + + // Serialize tool_calls to JSON string if present + if len(message.ToolCalls) > 0 { + toolCallsJSON, err := json.Marshal(message.ToolCalls) + if err != nil { + log.Warn().Err(err).Msg("failed to marshal tool_calls to JSON") + } else { + protoMessages[i].ToolCalls = string(toolCallsJSON) + } + } + + // Note: tool_call_id and reasoning_content are not in schema.Message yet + // They may need to be added to schema.Message if needed in the future + } + return protoMessages +} diff --git a/core/schema/message_test.go b/core/schema/message_test.go new file mode 100644 index 000000000000..1dd586f7685b --- /dev/null +++ b/core/schema/message_test.go @@ -0,0 +1,265 @@ +package schema_test + +import ( + "encoding/json" + + . "github.com/mudler/LocalAI/core/schema" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("LLM tests", func() { + + Context("ToProtoMessages conversion", func() { + It("should convert basic message with string content", func() { + messages := Messages{ + { + Role: "user", + Content: "Hello, world!", + }, + } + + protoMessages := messages.ToProto() + + Expect(protoMessages).To(HaveLen(1)) + Expect(protoMessages[0].Role).To(Equal("user")) + Expect(protoMessages[0].Content).To(Equal("Hello, world!")) + Expect(protoMessages[0].Name).To(BeEmpty()) + Expect(protoMessages[0].ToolCalls).To(BeEmpty()) + }) + + It("should convert message with nil content to empty string", func() { + messages := Messages{ + { + Role: "assistant", + Content: nil, + }, + } + + protoMessages := messages.ToProto() + + Expect(protoMessages).To(HaveLen(1)) + Expect(protoMessages[0].Role).To(Equal("assistant")) + Expect(protoMessages[0].Content).To(Equal("")) + }) + + It("should convert message with array content (multimodal)", func() { + messages := Messages{ + { + Role: "user", + Content: []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "Hello", + }, + map[string]interface{}{ + "type": "text", + "text": " World", + }, + }, + }, + } + + protoMessages := messages.ToProto() + + Expect(protoMessages).To(HaveLen(1)) + Expect(protoMessages[0].Role).To(Equal("user")) + Expect(protoMessages[0].Content).To(Equal("Hello World")) + }) + + It("should convert message with tool_calls", func() { + messages := Messages{ + { + Role: "assistant", + Content: "I'll call a function", + ToolCalls: []ToolCall{ + { + Index: 0, + ID: "call_123", + Type: "function", + FunctionCall: FunctionCall{ + Name: "get_weather", + Arguments: `{"location": "San Francisco"}`, + }, + }, + }, + }, + } + + protoMessages := messages.ToProto() + + Expect(protoMessages).To(HaveLen(1)) + Expect(protoMessages[0].Role).To(Equal("assistant")) + Expect(protoMessages[0].Content).To(Equal("I'll call a function")) + Expect(protoMessages[0].ToolCalls).NotTo(BeEmpty()) + + // Verify tool_calls JSON is valid + var toolCalls []ToolCall + err := json.Unmarshal([]byte(protoMessages[0].ToolCalls), &toolCalls) + Expect(err).NotTo(HaveOccurred()) + Expect(toolCalls).To(HaveLen(1)) + Expect(toolCalls[0].ID).To(Equal("call_123")) + Expect(toolCalls[0].FunctionCall.Name).To(Equal("get_weather")) + }) + + It("should convert message with name field", func() { + messages := Messages{ + { + Role: "tool", + Content: "Function result", + Name: "get_weather", + }, + } + + protoMessages := messages.ToProto() + + Expect(protoMessages).To(HaveLen(1)) + Expect(protoMessages[0].Role).To(Equal("tool")) + Expect(protoMessages[0].Content).To(Equal("Function result")) + Expect(protoMessages[0].Name).To(Equal("get_weather")) + }) + + It("should convert message with tool_calls and nil content", func() { + messages := Messages{ + { + Role: "assistant", + Content: nil, + ToolCalls: []ToolCall{ + { + Index: 0, + ID: "call_456", + Type: "function", + FunctionCall: FunctionCall{ + Name: "search", + Arguments: `{"query": "test"}`, + }, + }, + }, + }, + } + + protoMessages := messages.ToProto() + + Expect(protoMessages).To(HaveLen(1)) + Expect(protoMessages[0].Role).To(Equal("assistant")) + Expect(protoMessages[0].Content).To(Equal("")) + Expect(protoMessages[0].ToolCalls).NotTo(BeEmpty()) + + var toolCalls []ToolCall + err := json.Unmarshal([]byte(protoMessages[0].ToolCalls), &toolCalls) + Expect(err).NotTo(HaveOccurred()) + Expect(toolCalls).To(HaveLen(1)) + Expect(toolCalls[0].FunctionCall.Name).To(Equal("search")) + }) + + It("should convert multiple messages", func() { + messages := Messages{ + { + Role: "user", + Content: "Hello", + }, + { + Role: "assistant", + Content: "Hi there!", + }, + { + Role: "user", + Content: "How are you?", + }, + } + + protoMessages := messages.ToProto() + + Expect(protoMessages).To(HaveLen(3)) + Expect(protoMessages[0].Role).To(Equal("user")) + Expect(protoMessages[0].Content).To(Equal("Hello")) + Expect(protoMessages[1].Role).To(Equal("assistant")) + Expect(protoMessages[1].Content).To(Equal("Hi there!")) + Expect(protoMessages[2].Role).To(Equal("user")) + Expect(protoMessages[2].Content).To(Equal("How are you?")) + }) + + It("should handle empty messages slice", func() { + messages := Messages{} + + protoMessages := messages.ToProto() + + Expect(protoMessages).To(HaveLen(0)) + }) + + It("should handle message with all optional fields", func() { + messages := Messages{ + { + Role: "assistant", + Content: "I'll help you", + Name: "test_tool", + ToolCalls: []ToolCall{ + { + Index: 0, + ID: "call_789", + Type: "function", + FunctionCall: FunctionCall{ + Name: "test_function", + Arguments: `{"param": "value"}`, + }, + }, + }, + }, + } + + protoMessages := messages.ToProto() + + Expect(protoMessages).To(HaveLen(1)) + Expect(protoMessages[0].Role).To(Equal("assistant")) + Expect(protoMessages[0].Content).To(Equal("I'll help you")) + Expect(protoMessages[0].Name).To(Equal("test_tool")) + Expect(protoMessages[0].ToolCalls).NotTo(BeEmpty()) + + var toolCalls []ToolCall + err := json.Unmarshal([]byte(protoMessages[0].ToolCalls), &toolCalls) + Expect(err).NotTo(HaveOccurred()) + Expect(toolCalls).To(HaveLen(1)) + }) + + It("should handle message with empty string content", func() { + messages := Messages{ + { + Role: "user", + Content: "", + }, + } + + protoMessages := messages.ToProto() + + Expect(protoMessages).To(HaveLen(1)) + Expect(protoMessages[0].Role).To(Equal("user")) + Expect(protoMessages[0].Content).To(Equal("")) + }) + + It("should handle message with array content containing non-text parts", func() { + messages := Messages{ + { + Role: "user", + Content: []interface{}{ + map[string]interface{}{ + "type": "text", + "text": "Hello", + }, + map[string]interface{}{ + "type": "image", + "url": "https://example.com/image.jpg", + }, + }, + }, + } + + protoMessages := messages.ToProto() + + Expect(protoMessages).To(HaveLen(1)) + Expect(protoMessages[0].Role).To(Equal("user")) + // Should only extract text parts + Expect(protoMessages[0].Content).To(Equal("Hello")) + }) + }) +}) diff --git a/core/schema/openai.go b/core/schema/openai.go index 5506231e560b..49e18642f541 100644 --- a/core/schema/openai.go +++ b/core/schema/openai.go @@ -76,39 +76,6 @@ type InputAudio struct { Data string `json:"data" yaml:"data"` } -type Message struct { - // The message role - Role string `json:"role,omitempty" yaml:"role"` - - // The message name (used for tools calls) - Name string `json:"name,omitempty" yaml:"name"` - - // The message content - Content interface{} `json:"content" yaml:"content"` - - StringContent string `json:"string_content,omitempty" yaml:"string_content,omitempty"` - StringImages []string `json:"string_images,omitempty" yaml:"string_images,omitempty"` - StringVideos []string `json:"string_videos,omitempty" yaml:"string_videos,omitempty"` - StringAudios []string `json:"string_audios,omitempty" yaml:"string_audios,omitempty"` - - // A result of a function call - FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` - - ToolCalls []ToolCall `json:"tool_calls,omitempty" yaml:"tool_call,omitempty"` -} - -type ToolCall struct { - Index int `json:"index"` - ID string `json:"id"` - Type string `json:"type"` - FunctionCall FunctionCall `json:"function"` -} - -type FunctionCall struct { - Name string `json:"name,omitempty"` - Arguments string `json:"arguments"` -} - type OpenAIModel struct { ID string `json:"id"` Object string `json:"object"` diff --git a/core/schema/schema_suite_test.go b/core/schema/schema_suite_test.go new file mode 100644 index 000000000000..685a23309451 --- /dev/null +++ b/core/schema/schema_suite_test.go @@ -0,0 +1,13 @@ +package schema_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestSchema(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "LocalAI Schema test suite") +} diff --git a/core/templates/cache.go b/core/templates/cache.go index 1efce6606e8f..a9780284a784 100644 --- a/core/templates/cache.go +++ b/core/templates/cache.go @@ -11,9 +11,6 @@ import ( "github.com/mudler/LocalAI/pkg/utils" "github.com/Masterminds/sprig/v3" - - "github.com/nikolalohinski/gonja/v2" - "github.com/nikolalohinski/gonja/v2/exec" ) // Keep this in sync with config.TemplateConfig. Is there a more idiomatic way to accomplish this in go? @@ -21,17 +18,15 @@ import ( type TemplateType int type templateCache struct { - mu sync.Mutex - templatesPath string - templates map[TemplateType]map[string]*template.Template - jinjaTemplates map[TemplateType]map[string]*exec.Template + mu sync.Mutex + templatesPath string + templates map[TemplateType]map[string]*template.Template } func newTemplateCache(templatesPath string) *templateCache { tc := &templateCache{ - templatesPath: templatesPath, - templates: make(map[TemplateType]map[string]*template.Template), - jinjaTemplates: make(map[TemplateType]map[string]*exec.Template), + templatesPath: templatesPath, + templates: make(map[TemplateType]map[string]*template.Template), } return tc } @@ -85,78 +80,6 @@ func (tc *templateCache) loadTemplateIfExists(templateType TemplateType, templat return nil } -func (tc *templateCache) initializeJinjaTemplateMapKey(tt TemplateType) { - if _, ok := tc.jinjaTemplates[tt]; !ok { - tc.jinjaTemplates[tt] = make(map[string]*exec.Template) - } -} - -func (tc *templateCache) loadJinjaTemplateIfExists(templateType TemplateType, templateName string) error { - // Check if the template was already loaded - if _, ok := tc.jinjaTemplates[templateType][templateName]; ok { - return nil - } - - // Check if the model path exists - // skip any error here - we run anyway if a template does not exist - modelTemplateFile := fmt.Sprintf("%s.tmpl", templateName) - - dat := "" - file := filepath.Join(tc.templatesPath, modelTemplateFile) - - // Security check - if err := utils.VerifyPath(modelTemplateFile, tc.templatesPath); err != nil { - return fmt.Errorf("template file outside path: %s", file) - } - - // can either be a file in the system or a string with the template - if utils.ExistsInPath(tc.templatesPath, modelTemplateFile) { - d, err := os.ReadFile(file) - if err != nil { - return err - } - dat = string(d) - } else { - dat = templateName - } - - tmpl, err := gonja.FromString(dat) - if err != nil { - return err - } - tc.jinjaTemplates[templateType][templateName] = tmpl - - return nil -} - -func (tc *templateCache) evaluateJinjaTemplate(templateType TemplateType, templateNameOrContent string, in map[string]interface{}) (string, error) { - tc.mu.Lock() - defer tc.mu.Unlock() - - tc.initializeJinjaTemplateMapKey(templateType) - m, ok := tc.jinjaTemplates[templateType][templateNameOrContent] - if !ok { - // return "", fmt.Errorf("template not loaded: %s", templateName) - loadErr := tc.loadJinjaTemplateIfExists(templateType, templateNameOrContent) - if loadErr != nil { - return "", loadErr - } - m = tc.jinjaTemplates[templateType][templateNameOrContent] // ok is not important since we check m on the next line, and wealready checked - } - if m == nil { - return "", fmt.Errorf("failed loading a template for %s", templateNameOrContent) - } - - var buf bytes.Buffer - - data := exec.NewContext(in) - - if err := m.Execute(&buf, data); err != nil { - return "", err - } - return buf.String(), nil -} - func (tc *templateCache) evaluateTemplate(templateType TemplateType, templateNameOrContent string, in interface{}) (string, error) { tc.mu.Lock() defer tc.mu.Unlock() diff --git a/core/templates/evaluator.go b/core/templates/evaluator.go index 12c2080555f1..a3b46a1aa0ff 100644 --- a/core/templates/evaluator.go +++ b/core/templates/evaluator.go @@ -86,10 +86,6 @@ func (e *Evaluator) EvaluateTemplateForPrompt(templateType TemplateType, config return in.Input, nil } - if config.TemplateConfig.JinjaTemplate { - return e.evaluateJinjaTemplateForPrompt(templateType, template, in) - } - return e.cache.evaluateTemplate(templateType, template, in) } @@ -97,72 +93,7 @@ func (e *Evaluator) evaluateTemplateForChatMessage(templateName string, messageD return e.cache.evaluateTemplate(ChatMessageTemplate, templateName, messageData) } -func (e *Evaluator) templateJinjaChat(templateName string, messageData []ChatMessageTemplateData, funcs []functions.Function) (string, error) { - - conversation := make(map[string]interface{}) - messages := make([]map[string]interface{}, len(messageData)) - - // convert from ChatMessageTemplateData to what the jinja template expects - - for _, message := range messageData { - // TODO: this seems to cover minimum text templates. Can be expanded to cover more complex interactions - var data []byte - data, _ = json.Marshal(message.FunctionCall) - messages = append(messages, map[string]interface{}{ - "role": message.RoleName, - "content": message.Content, - "tool_call": string(data), - }) - } - - conversation["messages"] = messages - - // if tools are detected, add these - if len(funcs) > 0 { - conversation["tools"] = funcs - } - - return e.cache.evaluateJinjaTemplate(ChatMessageTemplate, templateName, conversation) -} - -func (e *Evaluator) evaluateJinjaTemplateForPrompt(templateType TemplateType, templateName string, in PromptTemplateData) (string, error) { - - conversation := make(map[string]interface{}) - - conversation["system_prompt"] = in.SystemPrompt - conversation["content"] = in.Input - - return e.cache.evaluateJinjaTemplate(templateType, templateName, conversation) -} - func (e *Evaluator) TemplateMessages(input schema.OpenAIRequest, messages []schema.Message, config *config.ModelConfig, funcs []functions.Function, shouldUseFn bool) string { - - if config.TemplateConfig.JinjaTemplate { - var messageData []ChatMessageTemplateData - for messageIndex, i := range messages { - fcall := i.FunctionCall - if len(i.ToolCalls) > 0 { - fcall = i.ToolCalls - } - messageData = append(messageData, ChatMessageTemplateData{ - SystemPrompt: config.SystemPrompt, - Role: config.Roles[i.Role], - RoleName: i.Role, - Content: i.StringContent, - FunctionCall: fcall, - FunctionName: i.Name, - LastMessage: messageIndex == (len(messages) - 1), - Function: config.Grammar != "" && (messageIndex == (len(messages) - 1)), - MessageIndex: messageIndex, - }) - } - - templatedInput, err := e.templateJinjaChat(config.TemplateConfig.ChatMessage, messageData, funcs) - if err == nil { - return templatedInput - } - } - var predInput string suppressConfigSystemPrompt := false mess := []string{} diff --git a/core/templates/evaluator_test.go b/core/templates/evaluator_test.go index 6d29c876b519..91a750a3514e 100644 --- a/core/templates/evaluator_test.go +++ b/core/templates/evaluator_test.go @@ -191,25 +191,6 @@ var chatMLTestMatch map[string]map[string]interface{} = map[string]map[string]in }, } -var jinjaTest map[string]map[string]interface{} = map[string]map[string]interface{}{ - "user": { - "expected": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA long time ago in a galaxy far, far away...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", - "config": &config.ModelConfig{ - TemplateConfig: config.TemplateConfig{ - ChatMessage: toolCallJinja, - JinjaTemplate: true, - }, - }, - "functions": []functions.Function{}, - "shouldUseFn": false, - "messages": []schema.Message{ - { - Role: "user", - StringContent: "A long time ago in a galaxy far, far away...", - }, - }, - }, -} var _ = Describe("Templates", func() { Context("chat message ChatML", func() { var evaluator *Evaluator @@ -237,17 +218,4 @@ var _ = Describe("Templates", func() { }) } }) - Context("chat message jinja", func() { - var evaluator *Evaluator - BeforeEach(func() { - evaluator = NewEvaluator("") - }) - for key := range jinjaTest { - foo := jinjaTest[key] - It("renders correctly `"+key+"`", func() { - templated := evaluator.TemplateMessages(schema.OpenAIRequest{}, foo["messages"].([]schema.Message), foo["config"].(*config.ModelConfig), foo["functions"].([]functions.Function), foo["shouldUseFn"].(bool)) - Expect(templated).To(Equal(foo["expected"]), templated) - }) - } - }) }) diff --git a/docs/content/docs/features/text-generation.md b/docs/content/docs/features/text-generation.md index c4e637f7040c..70c2c7524b2e 100644 --- a/docs/content/docs/features/text-generation.md +++ b/docs/content/docs/features/text-generation.md @@ -128,16 +128,44 @@ Models can be also preloaded or downloaded on demand. To learn about model galle #### YAML configuration -To use the `llama.cpp` backend, specify `llama` as the backend in the YAML file: +To use the `llama.cpp` backend, specify `llama-cpp` as the backend in the YAML file: ```yaml name: llama -backend: llama +backend: llama-cpp parameters: # Relative to the models path model: file.gguf ``` +#### Backend Options + +The `llama.cpp` backend supports additional configuration options that can be specified in the `options` field of your model YAML configuration. These options allow fine-tuning of the backend behavior: + +| Option | Type | Description | Example | +|--------|------|-------------|---------| +| `use_jinja` or `jinja` | boolean | Enable Jinja2 template processing for chat templates. When enabled, the backend uses Jinja2-based chat templates from the model for formatting messages. | `use_jinja:true` | +| `context_shift` | boolean | Enable context shifting, which allows the model to dynamically adjust context window usage. | `context_shift:true` | +| `cache_ram` | integer | Set the maximum RAM cache size in MiB for KV cache. Use `-1` for unlimited (default). | `cache_ram:2048` | +| `parallel` or `n_parallel` | integer | Enable parallel request processing. When set to a value greater than 1, enables continuous batching for handling multiple requests concurrently. | `parallel:4` | +| `grpc_servers` or `rpc_servers` | string | Comma-separated list of gRPC server addresses for distributed inference. Allows distributing workload across multiple llama.cpp workers. | `grpc_servers:localhost:50051,localhost:50052` | + +**Example configuration with options:** + +```yaml +name: llama-model +backend: llama +parameters: + model: model.gguf +options: + - use_jinja:true + - context_shift:true + - cache_ram:4096 + - parallel:2 +``` + +**Note:** The `parallel` option can also be set via the `LLAMACPP_PARALLEL` environment variable, and `grpc_servers` can be set via the `LLAMACPP_GRPC_SERVERS` environment variable. Options specified in the YAML file take precedence over environment variables. + #### Reference - [llama](https://github.com/ggerganov/llama.cpp) diff --git a/pkg/functions/functions.go b/pkg/functions/functions.go index 477a43bb7260..aa509a82251d 100644 --- a/pkg/functions/functions.go +++ b/pkg/functions/functions.go @@ -79,6 +79,12 @@ func (f Functions) ToJSONStructure(name, args string) JSONFunctionStructure { Type: "object", Properties: property, }) + /* + js.AnyOf = append(js.OneOf, Item{ + Type: "object", + Properties: property, + }) + */ } return js } diff --git a/pkg/functions/parse.go b/pkg/functions/parse.go index 48efb819ac91..49c4970a7609 100644 --- a/pkg/functions/parse.go +++ b/pkg/functions/parse.go @@ -53,7 +53,7 @@ type GrammarConfig struct { type GrammarTrigger struct { // Trigger is the string that triggers the grammar - Word string `yaml:"word"` + Word string `yaml:"word"` } // FunctionsConfig is the configuration for the tool/function call.