Skip to content

Commit 231d9f8

Browse files
committed
Move to builder
1 parent 1185b57 commit 231d9f8

File tree

20 files changed

+126
-72
lines changed

20 files changed

+126
-72
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/WEB_SEARCH.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ Internally, we use a BERT model (Snowflake/snowflake-arctic-embed-l-v2.0)[https:
2323
- Python: `search_bert_model` in the Runner
2424
- Server: `search-bert-model` before the model type selector (`plain`/`vision-plain`)
2525

26+
## Specifying a custom search callback
27+
28+
By default, mistral.rs uses a DuckDuckGo-based search callback. To override this, you can provide your own search function:
29+
30+
- Rust: use `.with_search_callback(...)` on the model builder with an `Arc<dyn Fn(&SearchFunctionParameters) -> anyhow::Result<Vec<SearchResult>> + Send + Sync>`.
31+
- Python: pass the `search_callback` keyword argument to `Runner`, which should be a function `def search_callback(query: str) -> List[Dict[str, str]]` returning a list of results with keys `"title"`, `"description"`, `"url"`, and `"content"`.
32+
2633
## HTTP server
2734
**Be sure to add `--enable-search`!**
2835

mistralrs-core/src/engine/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use crate::{
22
distributed,
33
embedding::bert::BertPipeline,
4-
search,
54
pipeline::{
65
llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
76
text_models_inputs_processor::PagedAttentionMeta,
@@ -10,6 +9,7 @@ use crate::{
109
prefix_cacher::PrefixCacheManagerV2,
1110
response::CompletionChoice,
1211
scheduler::{Scheduler, SchedulerOutput},
12+
search,
1313
sequence::{SeqStepType, StopReason},
1414
CompletionResponse, SchedulerConfig, DEBUG,
1515
};

mistralrs-core/src/lib.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ pub use sampler::{
104104
CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
105105
};
106106
pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
107+
pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
107108
use serde::Serialize;
108109
pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
109110
use tokio::runtime::Runtime;
@@ -112,7 +113,6 @@ pub use tools::{
112113
CalledFunction, Function, Tool, ToolCallResponse, ToolCallType, ToolChoice, ToolType,
113114
};
114115
pub use topology::{LayerTopology, Topology};
115-
pub use search::SearchCallback;
116116
pub use utils::debug::initialize_logging;
117117
pub use utils::memory_usage::MemoryUsage;
118118
pub use utils::normal::{ModelDType, TryIntoDType};
@@ -202,12 +202,14 @@ pub struct MistralRsBuilder {
202202
}
203203

204204
impl MistralRsBuilder {
205+
/// Creates a new builder with the given pipeline, scheduler method, logging flag,
206+
/// and optional embedding model for web search. To override the search callback,
207+
/// use `.with_search_callback(...)` on the builder.
205208
pub fn new(
206209
pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
207210
method: SchedulerConfig,
208211
throughput_logging: bool,
209212
search_embedding_model: Option<BertEmbeddingModel>,
210-
search_callback: Option<Arc<SearchCallback>>,
211213
) -> Self {
212214
Self {
213215
pipeline,
@@ -220,7 +222,7 @@ impl MistralRsBuilder {
220222
disable_eos_stop: None,
221223
throughput_logging_enabled: throughput_logging,
222224
search_embedding_model,
223-
search_callback,
225+
search_callback: None,
224226
}
225227
}
226228
pub fn with_log(mut self, log: String) -> Self {

mistralrs-core/src/search/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ use crate::{Function, Tool, ToolType, WebSearchOptions, WebSearchUserLocation};
1616

1717
/// Callback used to override how search results are gathered. The returned
1818
/// vector must be sorted in decreasing order of relevance.
19-
pub type SearchCallback = dyn Fn(&SearchFunctionParameters) -> Result<Vec<SearchResult>> + Send + Sync;
19+
pub type SearchCallback =
20+
dyn Fn(&SearchFunctionParameters) -> Result<Vec<SearchResult>> + Send + Sync;
2021

2122
pub(crate) fn search_tool_called(name: &str) -> bool {
2223
name == SEARCH_TOOL_NAME || name == EXTRACT_TOOL_NAME

mistralrs-pyo3/src/lib.rs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,19 @@ use mistralrs_core::{
3232
SpeechLoader, StopTokens, TokenSource, TokenizationRequest, Tool, Topology,
3333
VisionLoaderBuilder, VisionSpecificConfig,
3434
};
35-
use mistralrs_core::{search::SearchFunctionParameters, search::SearchResult, SearchCallback};
35+
use mistralrs_core::{SearchCallback, SearchFunctionParameters, SearchResult};
3636
use pyo3::prelude::*;
37-
use pyo3::types::{PyType, PyList};
38-
use pyo3::PyObject;
37+
use pyo3::types::{PyList, PyType};
3938
use pyo3::Bound;
39+
use pyo3::PyObject;
4040
use std::fs::File;
4141
mod anymoe;
4242
mod requests;
4343
mod stream;
4444
mod util;
4545
mod which;
46-
use which::{Architecture, DiffusionArchitecture, SpeechLoaderType, VisionArchitecture, Which};
47-
// (keep imports minimal – if needed later, re-introduce)
4846
use mistralrs_core::ModelDType;
47+
use which::{Architecture, DiffusionArchitecture, SpeechLoaderType, VisionArchitecture, Which};
4948

5049
static DEVICE: OnceLock<Result<Device>> = OnceLock::new();
5150

@@ -95,18 +94,24 @@ static NEXT_REQUEST_ID: Mutex<RefCell<usize>> = Mutex::new(RefCell::new(0));
9594
fn wrap_search_callback(cb: PyObject) -> Arc<SearchCallback> {
9695
Arc::new(move |params: &SearchFunctionParameters| {
9796
Python::with_gil(|py| {
98-
let list = cb.call1(py, (params.query.clone(),))?.downcast::<PyList>()?;
97+
let obj = cb.call1(py, (params.query.clone(),))?;
98+
let list = obj.downcast_bound::<PyList>(py)?;
9999
let mut results = Vec::new();
100100
for item in list.iter() {
101101
let title: String = item.get_item("title")?.extract()?;
102102
let description: String = item.get_item("description")?.extract()?;
103103
let url: String = item.get_item("url")?.extract()?;
104104
let content: String = item.get_item("content")?.extract()?;
105-
results.push(SearchResult { title, description, url, content });
105+
results.push(SearchResult {
106+
title,
107+
description,
108+
url,
109+
content,
110+
});
106111
}
107112
Ok(results)
108113
})
109-
.map_err(|e| anyhow::anyhow!(e.to_string()))
114+
.map_err(|e: PyErr| anyhow::anyhow!(e.to_string()))
110115
})
111116
}
112117

@@ -842,7 +847,11 @@ impl Runner {
842847
None
843848
};
844849
let cb = search_callback.map(wrap_search_callback);
845-
let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false, bert_model, cb)
850+
let mut builder = MistralRsBuilder::new(pipeline, scheduler_config, false, bert_model);
851+
if let Some(cb) = cb {
852+
builder = builder.with_search_callback(cb);
853+
}
854+
let mistralrs = builder
846855
.with_no_kv_cache(no_kv_cache)
847856
.with_prefix_cache_n(prefix_cache_n)
848857
.build();

mistralrs-server-core/src/mistralrs_for_server_builder.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! ## mistral.rs instance for server builder.
22
3-
use std::num::NonZeroUsize;
3+
use std::{num::NonZeroUsize, sync::Arc};
44

55
use anyhow::{Context, Result};
66
use candle_core::Device;
@@ -9,7 +9,7 @@ use mistralrs_core::{
99
parse_isq_value, AutoDeviceMapParams, BertEmbeddingModel, DefaultSchedulerMethod,
1010
DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, Loader, LoaderBuilder,
1111
MemoryGpuConfig, MistralRsBuilder, ModelSelected, PagedAttentionConfig, SchedulerConfig,
12-
TokenSource,
12+
SearchCallback, TokenSource,
1313
};
1414
use tracing::info;
1515

@@ -18,6 +18,9 @@ use crate::types::{LoadedPipeline, SharedMistralRsState};
1818
pub mod defaults {
1919
//! Provides the default values used for the mistral.rs instance for server.
2020
//! These defaults can be used for CLI argument fallbacks, config loading, or general initialization.
21+
22+
use std::sync::Arc;
23+
2124
pub const DEVICE: Option<candle_core::Device> = None;
2225
pub const SEED: Option<u64> = None;
2326
pub const LOG: Option<String> = None;
@@ -42,6 +45,7 @@ pub mod defaults {
4245
pub const ENABLE_SEARCH: bool = false;
4346
pub const SEARCH_BERT_MODEL: Option<String> = None;
4447
pub const TOKEN_SOURCE: mistralrs_core::TokenSource = mistralrs_core::TokenSource::CacheToken;
48+
pub const SEARCH_CALLBACK: Option<Arc<mistralrs_core::SearchCallback>> = None;
4549
}
4650

4751
/// A builder for creating a mistral.rs instance with configured options for the mistral.rs server.
@@ -169,6 +173,9 @@ pub struct MistralRsForServerBuilder {
169173

170174
/// Specify a Hugging Face model ID for a BERT model to assist web searching. Defaults to Snowflake Arctic Embed L.
171175
search_bert_model: Option<String>,
176+
177+
/// Optional override search callback
178+
search_callback: Option<Arc<SearchCallback>>,
172179
}
173180

174181
impl Default for MistralRsForServerBuilder {
@@ -199,6 +206,7 @@ impl Default for MistralRsForServerBuilder {
199206
cpu: defaults::CPU,
200207
enable_search: defaults::ENABLE_SEARCH,
201208
search_bert_model: defaults::SEARCH_BERT_MODEL,
209+
search_callback: defaults::SEARCH_CALLBACK,
202210
}
203211
}
204212
}
@@ -460,6 +468,12 @@ impl MistralRsForServerBuilder {
460468
self
461469
}
462470

471+
/// Override the search function used when `web_search_options` is enabled.
472+
pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
473+
self.search_callback = Some(callback);
474+
self
475+
}
476+
463477
/// Builds the configured mistral.rs instance.
464478
///
465479
/// ### Examples

mistralrs/examples/local_search/main.rs

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
use anyhow::Result;
2-
use walkdir::WalkDir;
3-
use std::fs;
4-
use std::sync::Arc;
52
use mistralrs::{
6-
BertEmbeddingModel, IsqType, RequestBuilder, SearchCallback, SearchResult,
7-
TextMessageRole, TextMessages, TextModelBuilder, WebSearchOptions,
3+
BertEmbeddingModel, IsqType, RequestBuilder, SearchResult, TextMessageRole, TextMessages,
4+
TextModelBuilder, WebSearchOptions,
85
};
6+
use std::fs;
7+
use std::sync::Arc;
8+
use walkdir::WalkDir;
99

1010
fn local_search(query: &str) -> Result<Vec<SearchResult>> {
1111
let mut results = Vec::new();
@@ -40,16 +40,12 @@ async fn main() -> Result<()> {
4040
.build()
4141
.await?;
4242

43-
let messages = TextMessages::new().add_message(
44-
TextMessageRole::User,
45-
"Where is Cargo.toml in this repo?",
46-
);
47-
let messages = RequestBuilder::from(messages).with_web_search_options(
48-
WebSearchOptions {
49-
search_description: Some("Local filesystem search".to_string()),
50-
..Default::default()
51-
},
52-
);
43+
let messages =
44+
TextMessages::new().add_message(TextMessageRole::User, "Where is Cargo.toml in this repo?");
45+
let messages = RequestBuilder::from(messages).with_web_search_options(WebSearchOptions {
46+
search_description: Some("Local filesystem search".to_string()),
47+
..Default::default()
48+
});
5349

5450
let response = model.send_chat_request(messages).await?;
5551
println!("{}", response.choices[0].message.content.as_ref().unwrap());

mistralrs/src/anymoe.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,13 @@ impl AnyMoeModelBuilder {
116116
scheduler_method,
117117
self.base.throughput_logging,
118118
self.base.search_bert_model,
119-
self.base.search_callback.clone(),
120-
)
121-
.with_no_kv_cache(self.base.no_kv_cache)
122-
.with_no_prefix_cache(self.base.prefix_cache_n.is_none());
119+
);
120+
if let Some(cb) = self.base.search_callback.clone() {
121+
runner = runner.with_search_callback(cb);
122+
}
123+
runner = runner
124+
.with_no_kv_cache(self.base.no_kv_cache)
125+
.with_no_prefix_cache(self.base.prefix_cache_n.is_none());
123126

124127
if let Some(n) = self.base.prefix_cache_n {
125128
runner = runner.with_prefix_cache_n(n)

mistralrs/src/diffusion_model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ impl DiffusionModelBuilder {
9595
method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
9696
};
9797

98-
let runner = MistralRsBuilder::new(pipeline, scheduler_method, false, None, None);
98+
let runner = MistralRsBuilder::new(pipeline, scheduler_method, false, None);
9999

100100
Ok(Model::new(runner.build()))
101101
}

0 commit comments

Comments
 (0)