Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ half = "2.4.0"
rayon = "1.1.0"
url = "2.5.2"
utoipa = "4.2"
walkdir = "2.5.0"
data-url = "0.3.1"
float8 = "0.2.1"
regex = "1.10.6"
Expand Down
42 changes: 42 additions & 0 deletions docs/WEB_SEARCH.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,35 @@ Internally, we use a BERT model (Snowflake/snowflake-arctic-embed-l-v2.0)[https:
- Python: `search_bert_model` in the Runner
- Server: `search-bert-model` before the model type selector (`plain`/`vision-plain`)

## Specifying a custom search callback

By default, mistral.rs uses a DuckDuckGo-based search callback. To override this, you can provide your own search function:

- Rust: use `.with_search_callback(...)` on the model builder with an `Arc<dyn Fn(&SearchFunctionParameters) -> anyhow::Result<Vec<SearchResult>> + Send + Sync>`.
- Python: pass the `search_callback` keyword argument to `Runner`, which should be a function `def search_callback(query: str) -> List[Dict[str, str]]` returning a list of results with keys `"title"`, `"description"`, `"url"`, and `"content"`.

Example in Python:
```py
def search_callback(query: str) -> list[dict[str, str]]:
# Implement your custom search logic here, returning a list of result dicts
return [
{
"title": "Example Result",
"description": "An example description",
"url": "https://example.com",
"content": "Full text content of the page",
},
# more results...
]

from mistralrs import Runner, Which, Architecture
runner = Runner(
which=Which.Plain(model_id="YourModel/ID", arch=Architecture.Mistral),
enable_search=True,
search_callback=search_callback,
)
```

## HTTP server
**Be sure to add `--enable-search`!**

Expand Down Expand Up @@ -80,12 +109,25 @@ from mistralrs import (
WebSearchOptions,
)

# Define a custom search callback if desired
def my_search_callback(query: str) -> list[dict[str, str]]:
# Fetch or compute search results here
return [
{
"title": "Mistral.rs GitHub",
"description": "Official mistral.rs repository",
"url": "https://github.com/huggingface/mistral.rs",
"content": "mistral.rs is a Rust binding for Mistral models...",
},
]

runner = Runner(
which=Which.Plain(
model_id="NousResearch/Hermes-3-Llama-3.1-8B",
arch=Architecture.Llama,
),
enable_search=True,
search_callback=my_search_callback,
)

res = runner.send_chat_completion_request(
Expand Down
52 changes: 52 additions & 0 deletions examples/python/local_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from mistralrs import (
Runner,
Which,
ChatCompletionRequest,
Architecture,
WebSearchOptions,
)
import os


def local_search(query: str):
results = []
for root, _, files in os.walk("."):
for f in files:
if query in f:
path = os.path.join(root, f)
try:
content = open(path).read()
except Exception:
content = ""
Comment on lines +18 to +20
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve file handling with proper resource management.

The current file opening approach lacks proper resource management and error handling specificity.

Apply this diff to use a context manager and improve error handling:

-                try:
-                    content = open(path).read()
-                except Exception:
-                    content = ""
+                try:
+                    with open(path, 'r', encoding='utf-8') as f:
+                        content = f.read()
+                except (OSError, UnicodeDecodeError):
+                    content = ""
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
content = open(path).read()
except Exception:
content = ""
try:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
except (OSError, UnicodeDecodeError):
content = ""
🧰 Tools
🪛 Ruff (0.11.9)

17-17: Use a context manager for opening files

(SIM115)

🪛 Pylint (3.3.7)

[warning] 18-18: Catching too general exception Exception

(W0718)


[refactor] 17-17: Consider using 'with' for resource-allocating operations

(R1732)


[warning] 17-17: Using open without explicitly specifying an encoding

(W1514)

🤖 Prompt for AI Agents
In examples/python/local_search.py around lines 17 to 19, the file is opened
without using a context manager, which can lead to resource leaks. Replace the
open call with a with statement to ensure the file is properly closed after
reading. Also, catch specific exceptions related to file operations instead of a
general Exception to improve error handling clarity.

results.append(
{
"title": f,
"description": path,
"url": path,
"content": content,
}
)
results.sort(key=lambda r: r["title"], reverse=True)
return results


runner = Runner(
which=Which.Plain(
model_id="NousResearch/Hermes-3-Llama-3.1-8B",
arch=Architecture.Llama,
),
enable_search=True,
search_callback=local_search,
)

res = runner.send_chat_completion_request(
ChatCompletionRequest(
model="mistral",
messages=[{"role": "user", "content": "Where is Cargo.toml in this repo?"}],
max_tokens=64,
web_search_options=WebSearchOptions(
search_description="Local filesystem search"
),
)
)
print(res.choices[0].message.content)
4 changes: 4 additions & 0 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
prefix_cacher::PrefixCacheManagerV2,
response::CompletionChoice,
scheduler::{Scheduler, SchedulerOutput},
search,
sequence::{SeqStepType, StopReason},
CompletionResponse, SchedulerConfig, DEBUG,
};
Expand Down Expand Up @@ -72,6 +73,7 @@ pub struct Engine {
rx: Arc<Mutex<Receiver<Request>>>,
pipeline: Arc<Mutex<dyn Pipeline>>,
bert_pipeline: Arc<Mutex<Option<BertPipeline>>>,
search_callback: Option<Arc<search::SearchCallback>>,
scheduler: Arc<Mutex<dyn Scheduler>>,
id: Arc<Mutex<usize>>,
truncate_sequence: bool,
Expand Down Expand Up @@ -105,6 +107,7 @@ impl Engine {
disable_eos_stop: bool,
throughput_logging_enabled: bool,
search_embedding_model: Option<BertEmbeddingModel>,
search_callback: Option<Arc<search::SearchCallback>>,
) -> anyhow::Result<Self> {
no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;

Expand All @@ -127,6 +130,7 @@ impl Engine {
rx: Arc::new(Mutex::new(rx)),
pipeline,
bert_pipeline: Arc::new(Mutex::new(bert_pipeline)),
search_callback,
scheduler: scheduler.clone(),
id: Arc::new(Mutex::new(0)),
truncate_sequence,
Expand Down
8 changes: 6 additions & 2 deletions mistralrs-core/src/engine/search_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,12 @@ async fn do_search(
};
let mut results = tokio::task::block_in_place(|| {
tracing::dispatcher::with_default(&dispatch, || {
search::run_search_tool(&tool_call_params)
.unwrap()
let base_results = if let Some(cb) = &this.search_callback {
cb(&tool_call_params).unwrap()
} else {
search::run_search_tool(&tool_call_params).unwrap()
};
base_results
.into_iter()
.map(|mut result| {
result = result
Expand Down
18 changes: 18 additions & 0 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ pub use sampler::{
CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
};
pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
use serde::Serialize;
pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
use tokio::runtime::Runtime;
Expand Down Expand Up @@ -159,6 +160,7 @@ struct RebootState {
disable_eos_stop: bool,
throughput_logging_enabled: bool,
search_embedding_model: Option<BertEmbeddingModel>,
search_callback: Option<Arc<search::SearchCallback>>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -196,9 +198,13 @@ pub struct MistralRsBuilder {
disable_eos_stop: Option<bool>,
throughput_logging_enabled: bool,
search_embedding_model: Option<BertEmbeddingModel>,
search_callback: Option<Arc<SearchCallback>>,
}

impl MistralRsBuilder {
/// Creates a new builder with the given pipeline, scheduler method, logging flag,
/// and optional embedding model for web search. To override the search callback,
/// use `.with_search_callback(...)` on the builder.
pub fn new(
pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
method: SchedulerConfig,
Expand All @@ -216,6 +222,7 @@ impl MistralRsBuilder {
disable_eos_stop: None,
throughput_logging_enabled: throughput_logging,
search_embedding_model,
search_callback: None,
}
}
pub fn with_log(mut self, log: String) -> Self {
Expand Down Expand Up @@ -247,6 +254,12 @@ impl MistralRsBuilder {
self
}

/// Use a custom callback to gather search results.
pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
self.search_callback = Some(search_callback);
self
}

pub fn build(self) -> Arc<MistralRs> {
MistralRs::new(self)
}
Expand Down Expand Up @@ -274,6 +287,7 @@ impl MistralRs {
disable_eos_stop,
throughput_logging_enabled,
search_embedding_model,
search_callback,
} = config;

let category = pipeline.try_lock().unwrap().category();
Expand All @@ -297,6 +311,7 @@ impl MistralRs {
disable_eos_stop,
throughput_logging_enabled,
search_embedding_model: search_embedding_model.clone(),
search_callback: search_callback.clone(),
};

let (tx, rx) = channel(10_000);
Expand Down Expand Up @@ -328,6 +343,7 @@ impl MistralRs {
disable_eos_stop,
throughput_logging_enabled,
search_embedding_model,
search_callback.clone(),
)
.expect("Engine creation failed.");
Arc::new(engine).run().await;
Expand All @@ -349,6 +365,7 @@ impl MistralRs {
disable_eos_stop,
throughput_logging_enabled,
search_embedding_model,
search_callback.clone(),
)
.expect("Engine creation failed.");
Arc::new(engine).run().await;
Expand Down Expand Up @@ -473,6 +490,7 @@ impl MistralRs {
reboot_state.disable_eos_stop,
reboot_state.throughput_logging_enabled,
reboot_state.search_embedding_model,
reboot_state.search_callback.clone(),
)
.expect("Engine creation failed");
Arc::new(engine).run().await;
Expand Down
4 changes: 4 additions & 0 deletions mistralrs-core/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ pub enum WebSearchUserLocation {
pub struct WebSearchOptions {
pub search_context_size: Option<SearchContextSize>,
pub user_location: Option<WebSearchUserLocation>,
/// Override the description for the search tool.
pub search_description: Option<String>,
/// Override the description for the extraction tool.
pub extract_description: Option<String>,
}

#[derive(Clone, Serialize, Deserialize)]
Expand Down
18 changes: 15 additions & 3 deletions mistralrs-core/src/search/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ use tokenizers::Tokenizer;

use crate::{Function, Tool, ToolType, WebSearchOptions, WebSearchUserLocation};

/// Callback used to override how search results are gathered. The returned
/// vector must be sorted in decreasing order of relevance.
pub type SearchCallback =
dyn Fn(&SearchFunctionParameters) -> Result<Vec<SearchResult>> + Send + Sync;

pub(crate) fn search_tool_called(name: &str) -> bool {
name == SEARCH_TOOL_NAME || name == EXTRACT_TOOL_NAME
}
Expand Down Expand Up @@ -140,11 +145,14 @@ pub fn get_search_tools(web_search_options: &WebSearchOptions) -> Result<Vec<Too
}
None => "".to_string(),
};

let description = web_search_options
.search_description
.as_deref()
.unwrap_or(SEARCH_DESCRIPTION);
Tool {
tp: ToolType::Function,
function: Function {
description: Some(format!("{SEARCH_DESCRIPTION}{location_details}")),
description: Some(format!("{}{}", description, location_details)),
name: SEARCH_TOOL_NAME.to_string(),
parameters: Some(parameters),
},
Expand All @@ -163,10 +171,14 @@ pub fn get_search_tools(web_search_options: &WebSearchOptions) -> Result<Vec<Too
"required": ["url"],
}))?;

let description = web_search_options
.extract_description
.as_deref()
.unwrap_or(EXTRACT_DESCRIPTION);
Tool {
tp: ToolType::Function,
function: Function {
description: Some(EXTRACT_DESCRIPTION.to_string()),
description: Some(description.to_string()),
name: EXTRACT_TOOL_NAME.to_string(),
parameters: Some(parameters),
},
Expand Down
5 changes: 4 additions & 1 deletion mistralrs-pyo3/mistralrs.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Iterator, Literal, Optional
from typing import Iterator, Literal, Optional, Callable

class SearchContextSize(Enum):
Low = "low"
Expand Down Expand Up @@ -345,7 +345,9 @@ class Runner:
paged_attn: bool = False,
prompt_batchsize: int | None = None,
seed: int | None = None,
enable_search: bool = False,
search_bert_model: str | None = None,
search_callback: Callable[[str], list[dict[str, str]]] | None = None,
no_bert_model: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -389,6 +391,7 @@ class Runner:
- `seed`, used to ensure reproducible random number generation.
- `enable_search`: Enable searching compatible with the OpenAI `web_search_options` setting. This uses the BERT model specified below or the default.
- `search_bert_model`: specify a Hugging Face model ID for a BERT model to assist web searching. Defaults to Snowflake Arctic Embed L.
- `search_callback`: Custom Python callable to perform web searches. Should accept a query string and return a list of dicts with keys "title", "description", "url", and "content".
"""
...

Expand Down
Loading
Loading