Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Mistral.rs supports several model categories:
- First X-LoRA inference platform with first class support
- [AnyMoE](docs/ANYMOE.md): Build a memory-efficient MoE model from anything, in seconds
- Various [sampling and penalty](docs/SAMPLING.mds) methods
- Native tool calling support for Llama, Mistral Small, Mistral Nemo, and Hermes models: [docs](docs/TOOL_CALLING.md)
- Native tool calling support for Llama, Mistral Small, Mistral Nemo, Hermes, and DeepSeek models: [docs](docs/TOOL_CALLING.md)
- Prompt chunking: process large prompts in a more manageable way

**Advanced features**:
Expand Down
1 change: 1 addition & 0 deletions docs/TOOL_CALLING.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ We support the following models' tool calling in OpenAI-compatible and parse nat
- Mistral Nemo
- Hermes 2 Pro
- Hermes 3
- DeepSeeek V2/V3/R1

All models that support tool calling will respond according to the OpenAI tool calling API.

Expand Down
5 changes: 4 additions & 1 deletion mistralrs-core/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,10 @@ impl TopKLastDimOp for Tensor {
let reorder_indices = indices.arg_sort(true)?;
#[cfg(not(feature = "cuda"))]
let reorder_indices = indices.arg_sort_last_dim(true)?;
let topk_indices_unsorted = indices.gather(&reorder_indices, D::Minus1)?;
let topk_indices_unsorted = indices
.to_dtype(DType::F32)?
.gather(&reorder_indices, D::Minus1)?
.to_dtype(DType::U32)?;
let topk_values_unsorted = values.gather(&reorder_indices, D::Minus1)?;
Ok(TopKOutput {
values: topk_values_unsorted,
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/pipeline/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub(crate) async fn finish_or_add_toks_to_seq(
if let Some(ref t) = seq.tools {
if let Ok(Some(ref d)) = seq.peek_delta() {
(tool_use_still_possible, tool_use_is_done) =
t.prefix_could_be_tool(this, d.as_str());
t.prefix_could_be_tool(this, d.as_str())?;
}
};

Expand Down
60 changes: 46 additions & 14 deletions mistralrs-core/src/tools/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod request;
mod response;

use candle_core::Result;
use regex::Regex;
pub use request::*;
pub use response::*;
use serde_json::Value;
Expand All @@ -9,24 +11,54 @@ use uuid::Uuid;

use crate::Pipeline;

fn process_model_specific_message(message: &str) -> &str {
fn process_model_specific_message(message: &str) -> Result<String> {
let deepseek_regex = Regex::new(
r"<|tool▁call▁begin|>function<|tool▁sep|>(?P<name>[^\n]+)\n```json\n(?P<json>.+?)\n```<|tool▁call▁end|>",
).map_err(candle_core::Error::msg)?;

if let Some(message) = message.strip_prefix("<|python_tag|>") {
// Llama case
message
Ok(message.to_string())
} else if let Some(message) = message
.strip_prefix("<tool_call>")
.and_then(|s| s.strip_suffix("</tool_call>"))
{
// Hermes case
message
Ok(message.to_string())
} else if let Some(message) = message
.strip_prefix("[TOOL_CALLS][")
.and_then(|s| s.strip_suffix("]"))
{
// Mistral Nemo case
message
Ok(message.to_string())
} else if deepseek_regex.find(message).is_some() {
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
struct ToolCall {
name: String,
arguments: Value,
}
let mut calls = Vec::new();
for caps in deepseek_regex.captures_iter(message) {
let name = caps
.name("name")
.ok_or("Could not capture function name")
.map_err(candle_core::Error::msg)?
.as_str()
.trim()
.to_string();
let json_str = caps
.name("json")
.ok_or("Could not capture JSON arguments")
.map_err(candle_core::Error::msg)?
.as_str()
.trim();
let arguments: Value =
serde_json::from_str(json_str).map_err(candle_core::Error::msg)?;
calls.push(ToolCall { name, arguments });
}
Ok(serde_json::to_string(&calls).map_err(candle_core::Error::msg)?)
} else {
message
Ok(message.to_string())
}
}

Expand Down Expand Up @@ -58,27 +90,27 @@ impl ToolCallingMatcher {
&self,
_pipeline: &dyn Pipeline,
message_prefix: &str,
) -> (bool, bool) {
) -> Result<(bool, bool)> {
if matches!(self.tool_choice, ToolChoice::None) {
return (false, false);
return Ok((false, false));
}
let message_prefix = process_model_specific_message(message_prefix);
let message_prefix = process_model_specific_message(message_prefix)?;

// Check if the prefix could be a JSON serialization of any of the following types.
[
Ok([
could_be_json::<CalledFunctionParameters>,
could_be_json::<Vec<CalledFunctionParameters>>,
]
.iter()
.find_map(|check| {
let (could_be_tool, is_complete_tool) = check(message_prefix);
let (could_be_tool, is_complete_tool) = check(&message_prefix);
if could_be_tool || is_complete_tool {
Some((could_be_tool, is_complete_tool))
} else {
None
}
})
.unwrap_or_default()
.unwrap_or_default())
}

pub fn get_call(
Expand All @@ -89,9 +121,9 @@ impl ToolCallingMatcher {
if matches!(self.tool_choice, ToolChoice::None) {
return Ok(Vec::new());
}
let message = process_model_specific_message(message);
let message = process_model_specific_message(message)?;

if let Ok(deser) = serde_json::from_str::<CalledFunctionParameters>(message) {
if let Ok(deser) = serde_json::from_str::<CalledFunctionParameters>(&message) {
let id = format!("call-{}", Uuid::new_v4());
Ok(vec![ToolCallResponse {
id,
Expand All @@ -101,7 +133,7 @@ impl ToolCallingMatcher {
arguments: serde_json::to_string(&deser.parameters)?,
},
}])
} else if let Ok(deser) = serde_json::from_str::<Vec<CalledFunctionParameters>>(message) {
} else if let Ok(deser) = serde_json::from_str::<Vec<CalledFunctionParameters>>(&message) {
Ok(deser
.into_iter()
.map(|deser| {
Expand Down
Loading