Skip to content

Commit 04d9a1e

Browse files
committed
Support fallback to bm25
1 parent a20778d commit 04d9a1e

File tree

4 files changed

+157
-25
lines changed

4 files changed

+157
-25
lines changed

Cargo.lock

Lines changed: 87 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mistralrs-core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ parking_lot = "0.12.3"
9595
ahash = "0.8.12"
9696
num-traits = "0.2.19"
9797
libc = "0.2.172"
98+
bm25 = "2.2.1"
9899

99100
[features]
100101
pyo3_macros = ["pyo3"]

mistralrs-core/src/engine/add_request.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ impl Engine {
3333
request.messages,
3434
RequestMessage::Chat { .. } | RequestMessage::VisionChat { .. }
3535
) && request.web_search_options.is_some()
36-
&& get_mut_arcmutex!(self.bert_pipeline).is_some()
3736
{
3837
search_request::search_request(self.clone(), *request).await;
3938
} else {

mistralrs-core/src/engine/search_request.rs

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::{borrow::Cow, sync::Arc, time::Instant};
22

3+
use bm25::{Embedder, EmbedderBuilder, Language, ScoredDocument, Scorer};
34
use either::Either;
45
use indexmap::IndexMap;
56
use tokenizers::InputSequence;
@@ -98,28 +99,76 @@ async fn do_search(
9899
results.sort_by_key(|(_, len)| *len);
99100

100101
{
101-
let device = get_mut_arcmutex!(this.pipeline).device();
102+
// Determine ranking: use embedding model if available, otherwise fallback to BM25
103+
let decreasing_indexes: Vec<usize> = if let Some(bert_pipeline) =
104+
&mut *get_mut_arcmutex!(this.bert_pipeline)
105+
{
106+
// Semantic reranking with embeddings
107+
let device = get_mut_arcmutex!(this.pipeline).device();
108+
search::rag::compute_most_similar(
109+
&device,
110+
&tool_call_params.query,
111+
results.iter().map(|(res, _)| res).collect::<Vec<_>>(),
112+
bert_pipeline,
113+
)
114+
.unwrap()
115+
} else {
116+
tracing::warn!("No embedding model loaded; falling back to BM25 ranking for web search results.");
102117

103-
let Some(bert_pipeline) = &mut *get_mut_arcmutex!(this.bert_pipeline) else {
104-
unreachable!()
105-
};
118+
// Build an Embedder over the corpus, fitting to the entire set of documents.
119+
// - Language::English is chosen here
120+
// - This computes an in‑memory sparse embedding for each document.
121+
122+
let docs: Vec<String> =
123+
results.iter().map(|(res, _)| res.content.clone()).collect();
124+
let doc_refs: Vec<&str> = docs.iter().map(|s| s.as_str()).collect();
125+
126+
let embedder: Embedder =
127+
EmbedderBuilder::with_fit_to_corpus(Language::English, &doc_refs).build();
128+
129+
// Initialize a Scorer keyed by usize (document index type).
130+
let mut scorer = Scorer::<usize>::new();
106131

107-
let decreasing_indexes = search::rag::compute_most_similar(
108-
&device,
109-
&tool_call_params.query,
110-
results.iter().map(|(res, _)| res).collect::<Vec<_>>(),
111-
bert_pipeline,
112-
)
113-
.unwrap();
114-
115-
// Rerank the results
116-
let mut results_old = Vec::new();
117-
std::mem::swap(&mut results_old, &mut results);
118-
for &index in &decreasing_indexes {
119-
let mut current_result: (SearchResult, usize) = Default::default();
120-
std::mem::swap(&mut current_result, &mut results_old[index]);
121-
122-
results.push(current_result);
132+
// For each document, compute its embedding and upsert into the scorer.
133+
for (i, doc_text) in docs.iter().enumerate() {
134+
let doc_embedding = embedder.embed(doc_text);
135+
scorer.upsert(&i, doc_embedding);
136+
}
137+
138+
// Embed the query string into the same sparse embedding space.
139+
let query_embedding = embedder.embed(&tool_call_params.query);
140+
141+
// Score all documents individually
142+
let mut scored_docs: Vec<ScoredDocument<usize>> = docs
143+
.iter()
144+
.enumerate()
145+
.filter_map(|(i, _)| {
146+
scorer
147+
.score(&i, &query_embedding)
148+
.map(|score| ScoredDocument { id: i, score })
149+
})
150+
.collect();
151+
152+
// Sort the scored documents by descending `score` (f32).
153+
scored_docs.sort_by(|a, b| {
154+
b.score
155+
.partial_cmp(&a.score)
156+
.unwrap_or(std::cmp::Ordering::Equal)
157+
});
158+
159+
// Extract only the document indices (usize) in ranked order.
160+
let decreasing_indexes: Vec<usize> =
161+
scored_docs.into_iter().map(|d| d.id).collect();
162+
163+
decreasing_indexes
164+
};
165+
// Reorder results according to ranking
166+
let mut old = Vec::new();
167+
std::mem::swap(&mut old, &mut results);
168+
for &idx in &decreasing_indexes {
169+
let mut item: (SearchResult, usize) = Default::default();
170+
std::mem::swap(&mut item, &mut old[idx]);
171+
results.push(item);
123172
}
124173
}
125174

0 commit comments

Comments
 (0)