Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ half = "2.4.0"
rayon = "1.1.0"
url = "2.5.2"
utoipa = "4.2"
walkdir = "2.5.0"
data-url = "0.3.1"
float8 = "0.2.1"
regex = "1.10.6"
Expand Down
46 changes: 46 additions & 0 deletions examples/python/local_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from mistralrs import (
Runner,
Which,
ChatCompletionRequest,
Architecture,
WebSearchOptions,
)
import os

def local_search(query: str):
results = []
for root, _, files in os.walk('.'):
for f in files:
if query in f:
path = os.path.join(root, f)
try:
content = open(path).read()
except Exception:
content = ""
Comment on lines +18 to +20
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve file handling with proper resource management.

The current file opening approach lacks proper resource management and error handling specificity.

Apply this diff to use a context manager and improve error handling:

-                try:
-                    content = open(path).read()
-                except Exception:
-                    content = ""
+                try:
+                    with open(path, 'r', encoding='utf-8') as f:
+                        content = f.read()
+                except (OSError, UnicodeDecodeError):
+                    content = ""
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
content = open(path).read()
except Exception:
content = ""
try:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
except (OSError, UnicodeDecodeError):
content = ""
🧰 Tools
🪛 Ruff (0.11.9)

17-17: Use a context manager for opening files

(SIM115)

🪛 Pylint (3.3.7)

[warning] 18-18: Catching too general exception Exception

(W0718)


[refactor] 17-17: Consider using 'with' for resource-allocating operations

(R1732)


[warning] 17-17: Using open without explicitly specifying an encoding

(W1514)

🤖 Prompt for AI Agents
In examples/python/local_search.py around lines 17 to 19, the file is opened
without using a context manager, which can lead to resource leaks. Replace the
open call with a with statement to ensure the file is properly closed after
reading. Also, catch specific exceptions related to file operations instead of a
general Exception to improve error handling clarity.

results.append({
"title": f,
"description": path,
"url": path,
"content": content,
})
results.sort(key=lambda r: r['title'], reverse=True)
return results

runner = Runner(
which=Which.Plain(
model_id="NousResearch/Hermes-3-Llama-3.1-8B",
arch=Architecture.Llama,
),
enable_search=True,
search_callback=local_search,
)

res = runner.send_chat_completion_request(
ChatCompletionRequest(
model="mistral",
messages=[{"role": "user", "content": "Where is Cargo.toml in this repo?"}],
max_tokens=64,
web_search_options=WebSearchOptions(search_description="Local filesystem search"),
)
)
print(res.choices[0].message.content)
4 changes: 4 additions & 0 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{
distributed,
embedding::bert::BertPipeline,
search,
pipeline::{
llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
text_models_inputs_processor::PagedAttentionMeta,
Expand Down Expand Up @@ -72,6 +73,7 @@ pub struct Engine {
rx: Arc<Mutex<Receiver<Request>>>,
pipeline: Arc<Mutex<dyn Pipeline>>,
bert_pipeline: Arc<Mutex<Option<BertPipeline>>>,
search_callback: Option<Arc<search::SearchCallback>>,
scheduler: Arc<Mutex<dyn Scheduler>>,
id: Arc<Mutex<usize>>,
truncate_sequence: bool,
Expand Down Expand Up @@ -105,6 +107,7 @@ impl Engine {
disable_eos_stop: bool,
throughput_logging_enabled: bool,
search_embedding_model: Option<BertEmbeddingModel>,
search_callback: Option<Arc<search::SearchCallback>>,
) -> anyhow::Result<Self> {
no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;

Expand All @@ -127,6 +130,7 @@ impl Engine {
rx: Arc::new(Mutex::new(rx)),
pipeline,
bert_pipeline: Arc::new(Mutex::new(bert_pipeline)),
search_callback,
scheduler: scheduler.clone(),
id: Arc::new(Mutex::new(0)),
truncate_sequence,
Expand Down
8 changes: 6 additions & 2 deletions mistralrs-core/src/engine/search_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,12 @@ async fn do_search(
};
let mut results = tokio::task::block_in_place(|| {
tracing::dispatcher::with_default(&dispatch, || {
search::run_search_tool(&tool_call_params)
.unwrap()
let base_results = if let Some(cb) = &this.search_callback {
cb(&tool_call_params).unwrap()
} else {
search::run_search_tool(&tool_call_params).unwrap()
};
base_results
.into_iter()
.map(|mut result| {
result = result
Expand Down
16 changes: 16 additions & 0 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ pub use tools::{
CalledFunction, Function, Tool, ToolCallResponse, ToolCallType, ToolChoice, ToolType,
};
pub use topology::{LayerTopology, Topology};
pub use search::SearchCallback;
pub use utils::debug::initialize_logging;
pub use utils::memory_usage::MemoryUsage;
pub use utils::normal::{ModelDType, TryIntoDType};
Expand Down Expand Up @@ -159,6 +160,7 @@ struct RebootState {
disable_eos_stop: bool,
throughput_logging_enabled: bool,
search_embedding_model: Option<BertEmbeddingModel>,
search_callback: Option<Arc<search::SearchCallback>>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -196,6 +198,7 @@ pub struct MistralRsBuilder {
disable_eos_stop: Option<bool>,
throughput_logging_enabled: bool,
search_embedding_model: Option<BertEmbeddingModel>,
search_callback: Option<Arc<SearchCallback>>,
}

impl MistralRsBuilder {
Expand All @@ -204,6 +207,7 @@ impl MistralRsBuilder {
method: SchedulerConfig,
throughput_logging: bool,
search_embedding_model: Option<BertEmbeddingModel>,
search_callback: Option<Arc<SearchCallback>>,
) -> Self {
Self {
pipeline,
Expand All @@ -216,6 +220,7 @@ impl MistralRsBuilder {
disable_eos_stop: None,
throughput_logging_enabled: throughput_logging,
search_embedding_model,
search_callback,
}
}
pub fn with_log(mut self, log: String) -> Self {
Expand Down Expand Up @@ -247,6 +252,12 @@ impl MistralRsBuilder {
self
}

/// Use a custom callback to gather search results.
pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
self.search_callback = Some(search_callback);
self
}

pub fn build(self) -> Arc<MistralRs> {
MistralRs::new(self)
}
Expand Down Expand Up @@ -274,6 +285,7 @@ impl MistralRs {
disable_eos_stop,
throughput_logging_enabled,
search_embedding_model,
search_callback,
} = config;

let category = pipeline.try_lock().unwrap().category();
Expand All @@ -297,6 +309,7 @@ impl MistralRs {
disable_eos_stop,
throughput_logging_enabled,
search_embedding_model: search_embedding_model.clone(),
search_callback: search_callback.clone(),
};

let (tx, rx) = channel(10_000);
Expand Down Expand Up @@ -328,6 +341,7 @@ impl MistralRs {
disable_eos_stop,
throughput_logging_enabled,
search_embedding_model,
search_callback.clone(),
)
.expect("Engine creation failed.");
Arc::new(engine).run().await;
Expand All @@ -349,6 +363,7 @@ impl MistralRs {
disable_eos_stop,
throughput_logging_enabled,
search_embedding_model,
search_callback.clone(),
)
.expect("Engine creation failed.");
Arc::new(engine).run().await;
Expand Down Expand Up @@ -473,6 +488,7 @@ impl MistralRs {
reboot_state.disable_eos_stop,
reboot_state.throughput_logging_enabled,
reboot_state.search_embedding_model,
reboot_state.search_callback.clone(),
)
.expect("Engine creation failed");
Arc::new(engine).run().await;
Expand Down
4 changes: 4 additions & 0 deletions mistralrs-core/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ pub enum WebSearchUserLocation {
pub struct WebSearchOptions {
pub search_context_size: Option<SearchContextSize>,
pub user_location: Option<WebSearchUserLocation>,
/// Override the description for the search tool.
pub search_description: Option<String>,
/// Override the description for the extraction tool.
pub extract_description: Option<String>,
}

#[derive(Clone, Serialize, Deserialize)]
Expand Down
17 changes: 14 additions & 3 deletions mistralrs-core/src/search/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ use tokenizers::Tokenizer;

use crate::{Function, Tool, ToolType, WebSearchOptions, WebSearchUserLocation};

/// Callback used to override how search results are gathered. The returned
/// vector must be sorted in decreasing order of relevance.
pub type SearchCallback = dyn Fn(&SearchFunctionParameters) -> Result<Vec<SearchResult>> + Send + Sync;

pub(crate) fn search_tool_called(name: &str) -> bool {
name == SEARCH_TOOL_NAME || name == EXTRACT_TOOL_NAME
}
Expand Down Expand Up @@ -140,11 +144,14 @@ pub fn get_search_tools(web_search_options: &WebSearchOptions) -> Result<Vec<Too
}
None => "".to_string(),
};

let description = web_search_options
.search_description
.as_deref()
.unwrap_or(SEARCH_DESCRIPTION);
Tool {
tp: ToolType::Function,
function: Function {
description: Some(format!("{SEARCH_DESCRIPTION}{location_details}")),
description: Some(format!("{}{}", description, location_details)),
name: SEARCH_TOOL_NAME.to_string(),
parameters: Some(parameters),
},
Expand All @@ -163,10 +170,14 @@ pub fn get_search_tools(web_search_options: &WebSearchOptions) -> Result<Vec<Too
"required": ["url"],
}))?;

let description = web_search_options
.extract_description
.as_deref()
.unwrap_or(EXTRACT_DESCRIPTION);
Tool {
tp: ToolType::Function,
function: Function {
description: Some(EXTRACT_DESCRIPTION.to_string()),
description: Some(description.to_string()),
name: EXTRACT_TOOL_NAME.to_string(),
parameters: Some(parameters),
},
Expand Down
28 changes: 26 additions & 2 deletions mistralrs-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
SpeechLoader, StopTokens, TokenSource, TokenizationRequest, Tool, Topology,
VisionLoaderBuilder, VisionSpecificConfig,
};
use mistralrs_core::{search::SearchFunctionParameters, search::SearchResult, SearchCallback};

Check failure on line 35 in mistralrs-pyo3/src/lib.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

module `search` is private

Check failure on line 35 in mistralrs-pyo3/src/lib.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

module `search` is private

Check failure on line 35 in mistralrs-pyo3/src/lib.rs

View workflow job for this annotation

GitHub Actions / Clippy

module `search` is private

Check failure on line 35 in mistralrs-pyo3/src/lib.rs

View workflow job for this annotation

GitHub Actions / Clippy

module `search` is private
use pyo3::prelude::*;
use pyo3::types::PyType;
use pyo3::types::{PyType, PyList};
use pyo3::PyObject;
use pyo3::Bound;
use std::fs::File;
mod anymoe;
Expand Down Expand Up @@ -90,6 +92,24 @@

static NEXT_REQUEST_ID: Mutex<RefCell<usize>> = Mutex::new(RefCell::new(0));

fn wrap_search_callback(cb: PyObject) -> Arc<SearchCallback> {
Arc::new(move |params: &SearchFunctionParameters| {
Python::with_gil(|py| {
let list = cb.call1(py, (params.query.clone(),))?.downcast::<PyList>()?;

Check failure on line 98 in mistralrs-pyo3/src/lib.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

no method named `downcast` found for struct `pyo3::Py` in the current scope
let mut results = Vec::new();
for item in list.iter() {
let title: String = item.get_item("title")?.extract()?;
let description: String = item.get_item("description")?.extract()?;
let url: String = item.get_item("url")?.extract()?;
let content: String = item.get_item("content")?.extract()?;
results.push(SearchResult { title, description, url, content });
}
Ok(results)
})
.map_err(|e| anyhow::anyhow!(e.to_string()))
})
}

fn parse_which(
which: Which,
no_kv_cache: bool,
Expand Down Expand Up @@ -502,6 +522,7 @@
seed = None,
enable_search = false,
search_bert_model = None,
search_callback = None,
))]
fn new(
which: Which,
Expand All @@ -526,6 +547,7 @@
seed: Option<u64>,
enable_search: bool,
search_bert_model: Option<String>,
search_callback: Option<PyObject>,
) -> PyApiResult<Self> {
let tgt_non_granular_index = match which {
Which::Plain { .. }
Expand Down Expand Up @@ -819,7 +841,8 @@
} else {
None
};
let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false, bert_model)
let cb = search_callback.map(wrap_search_callback);
let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false, bert_model, cb)
.with_no_kv_cache(no_kv_cache)
.with_prefix_cache_n(prefix_cache_n)
.build();
Expand Down Expand Up @@ -879,6 +902,7 @@
None, // prompt_chunksize
seed, false, // enable_search
None, // search_bert_model
None, // search_callback
)
}

Expand Down
1 change: 1 addition & 0 deletions mistralrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ futures.workspace = true
reqwest.workspace = true
rand = "0.9.0"
clap.workspace = true
walkdir.workspace = true

[features]
cuda = ["mistralrs-core/cuda"]
Expand Down
57 changes: 57 additions & 0 deletions mistralrs/examples/local_search/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use anyhow::Result;
use walkdir::WalkDir;
use std::fs;
use std::sync::Arc;
use mistralrs::{
BertEmbeddingModel, IsqType, RequestBuilder, SearchCallback, SearchResult,
TextMessageRole, TextMessages, TextModelBuilder, WebSearchOptions,
};

fn local_search(query: &str) -> Result<Vec<SearchResult>> {
let mut results = Vec::new();
for entry in WalkDir::new(".") {
let entry = entry?;
if entry.file_type().is_file() {
let name = entry.file_name().to_string_lossy();
if name.contains(query) {
let path = entry.path().display().to_string();
let content = fs::read_to_string(entry.path()).unwrap_or_default();
results.push(SearchResult {
title: name.into_owned(),
description: path.clone(),
url: path,
content,
});
}
}
}
results.sort_by_key(|r| r.title.clone());
results.reverse();
Ok(results)
}
Comment on lines +10 to +31
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve search logic and error handling.

The current implementation has several areas for improvement:

  1. Misleading search scope: The function only searches filenames but reads full file contents, which may confuse users about what's being searched.
  2. Questionable sorting: Sorting by title then reversing gives descending alphabetical order, which may not be the most useful for search relevance.
  3. Silent error handling: Using unwrap_or_default() for file reading silently ignores permission errors or other IO issues.

Consider this improved implementation:

 fn local_search(query: &str) -> Result<Vec<SearchResult>> {
     let mut results = Vec::new();
     for entry in WalkDir::new(".") {
         let entry = entry?;
         if entry.file_type().is_file() {
             let name = entry.file_name().to_string_lossy();
             if name.contains(query) {
                 let path = entry.path().display().to_string();
-                let content = fs::read_to_string(entry.path()).unwrap_or_default();
+                let content = fs::read_to_string(entry.path())
+                    .unwrap_or_else(|_| format!("Could not read file: {}", path));
                 results.push(SearchResult {
                     title: name.into_owned(),
                     description: path.clone(),
                     url: path,
                     content,
                 });
             }
         }
     }
-    results.sort_by_key(|r| r.title.clone());
-    results.reverse();
+    // Sort by relevance (exact matches first, then alphabetical)
+    results.sort_by(|a, b| {
+        let a_exact = a.title.eq_ignore_ascii_case(query);
+        let b_exact = b.title.eq_ignore_ascii_case(query);
+        match (a_exact, b_exact) {
+            (true, false) => std::cmp::Ordering::Less,
+            (false, true) => std::cmp::Ordering::Greater,
+            _ => a.title.cmp(&b.title),
+        }
+    });
     Ok(results)
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
fn local_search(query: &str) -> Result<Vec<SearchResult>> {
let mut results = Vec::new();
for entry in WalkDir::new(".") {
let entry = entry?;
if entry.file_type().is_file() {
let name = entry.file_name().to_string_lossy();
if name.contains(query) {
let path = entry.path().display().to_string();
let content = fs::read_to_string(entry.path()).unwrap_or_default();
results.push(SearchResult {
title: name.into_owned(),
description: path.clone(),
url: path,
content,
});
}
}
}
results.sort_by_key(|r| r.title.clone());
results.reverse();
Ok(results)
}
fn local_search(query: &str) -> Result<Vec<SearchResult>> {
let mut results = Vec::new();
for entry in WalkDir::new(".") {
let entry = entry?;
if entry.file_type().is_file() {
let name = entry.file_name().to_string_lossy();
if name.contains(query) {
let path = entry.path().display().to_string();
let content = fs::read_to_string(entry.path())
.unwrap_or_else(|_| format!("Could not read file: {}", path));
results.push(SearchResult {
title: name.into_owned(),
description: path.clone(),
url: path,
content,
});
}
}
}
// Sort by relevance (exact matches first, then alphabetical)
results.sort_by(|a, b| {
let a_exact = a.title.eq_ignore_ascii_case(query);
let b_exact = b.title.eq_ignore_ascii_case(query);
match (a_exact, b_exact) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
_ => a.title.cmp(&b.title),
}
});
Ok(results)
}
🤖 Prompt for AI Agents
In mistralrs/examples/local_search/main.rs around lines 10 to 31, improve the
local_search function by expanding the search to include file contents, not just
filenames, to better match user expectations. Replace the current sorting by
title with a relevance-based sorting, such as by the number of query occurrences
in content or filename. Change error handling for file reading to properly
propagate or log errors instead of silently ignoring them with
unwrap_or_default, ensuring IO issues are visible and handled appropriately.


#[tokio::main]
async fn main() -> Result<()> {
let model = TextModelBuilder::new("NousResearch/Hermes-3-Llama-3.1-8B")
.with_isq(IsqType::Q4K)
.with_logging()
.with_search(BertEmbeddingModel::default())
.with_search_callback(Arc::new(|p| local_search(&p.query)))
.build()
.await?;

let messages = TextMessages::new().add_message(
TextMessageRole::User,
"Where is Cargo.toml in this repo?",
);
let messages = RequestBuilder::from(messages).with_web_search_options(
WebSearchOptions {
search_description: Some("Local filesystem search".to_string()),
..Default::default()
},
);

let response = model.send_chat_request(messages).await?;
println!("{}", response.choices[0].message.content.as_ref().unwrap());
Ok(())
}
1 change: 1 addition & 0 deletions mistralrs/src/anymoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ impl AnyMoeModelBuilder {
scheduler_method,
self.base.throughput_logging,
self.base.search_bert_model,
self.base.search_callback.clone(),
)
.with_no_kv_cache(self.base.no_kv_cache)
.with_no_prefix_cache(self.base.prefix_cache_n.is_none());
Expand Down
Loading
Loading