|
1 | 1 | use std::{borrow::Cow, sync::Arc, time::Instant}; |
2 | 2 |
|
| 3 | +use bm25::{Embedder, EmbedderBuilder, Language, ScoredDocument, Scorer}; |
3 | 4 | use either::Either; |
4 | 5 | use indexmap::IndexMap; |
5 | 6 | use tokenizers::InputSequence; |
@@ -98,28 +99,76 @@ async fn do_search( |
98 | 99 | results.sort_by_key(|(_, len)| *len); |
99 | 100 |
|
100 | 101 | { |
101 | | - let device = get_mut_arcmutex!(this.pipeline).device(); |
| 102 | + // Determine ranking: use embedding model if available, otherwise fallback to BM25 |
| 103 | + let decreasing_indexes: Vec<usize> = if let Some(bert_pipeline) = |
| 104 | + &mut *get_mut_arcmutex!(this.bert_pipeline) |
| 105 | + { |
| 106 | + // Semantic reranking with embeddings |
| 107 | + let device = get_mut_arcmutex!(this.pipeline).device(); |
| 108 | + search::rag::compute_most_similar( |
| 109 | + &device, |
| 110 | + &tool_call_params.query, |
| 111 | + results.iter().map(|(res, _)| res).collect::<Vec<_>>(), |
| 112 | + bert_pipeline, |
| 113 | + ) |
| 114 | + .unwrap() |
| 115 | + } else { |
| 116 | + tracing::warn!("No embedding model loaded; falling back to BM25 ranking for web search results."); |
102 | 117 |
|
103 | | - let Some(bert_pipeline) = &mut *get_mut_arcmutex!(this.bert_pipeline) else { |
104 | | - unreachable!() |
105 | | - }; |
| 118 | + // Build an Embedder over the corpus, fitting to the entire set of documents. |
| 119 | + // - Language::English is chosen here |
| 120 | + // - This computes an in‑memory sparse embedding for each document. |
| 121 | + |
| 122 | + let docs: Vec<String> = |
| 123 | + results.iter().map(|(res, _)| res.content.clone()).collect(); |
| 124 | + let doc_refs: Vec<&str> = docs.iter().map(|s| s.as_str()).collect(); |
| 125 | + |
| 126 | + let embedder: Embedder = |
| 127 | + EmbedderBuilder::with_fit_to_corpus(Language::English, &doc_refs).build(); |
| 128 | + |
| 129 | + // Initialize a Scorer keyed by usize (document index type). |
| 130 | + let mut scorer = Scorer::<usize>::new(); |
106 | 131 |
|
107 | | - let decreasing_indexes = search::rag::compute_most_similar( |
108 | | - &device, |
109 | | - &tool_call_params.query, |
110 | | - results.iter().map(|(res, _)| res).collect::<Vec<_>>(), |
111 | | - bert_pipeline, |
112 | | - ) |
113 | | - .unwrap(); |
114 | | - |
115 | | - // Rerank the results |
116 | | - let mut results_old = Vec::new(); |
117 | | - std::mem::swap(&mut results_old, &mut results); |
118 | | - for &index in &decreasing_indexes { |
119 | | - let mut current_result: (SearchResult, usize) = Default::default(); |
120 | | - std::mem::swap(&mut current_result, &mut results_old[index]); |
121 | | - |
122 | | - results.push(current_result); |
| 132 | + // For each document, compute its embedding and upsert into the scorer. |
| 133 | + for (i, doc_text) in docs.iter().enumerate() { |
| 134 | + let doc_embedding = embedder.embed(doc_text); |
| 135 | + scorer.upsert(&i, doc_embedding); |
| 136 | + } |
| 137 | + |
| 138 | + // Embed the query string into the same sparse embedding space. |
| 139 | + let query_embedding = embedder.embed(&tool_call_params.query); |
| 140 | + |
| 141 | + // Score all documents individually |
| 142 | + let mut scored_docs: Vec<ScoredDocument<usize>> = docs |
| 143 | + .iter() |
| 144 | + .enumerate() |
| 145 | + .filter_map(|(i, _)| { |
| 146 | + scorer |
| 147 | + .score(&i, &query_embedding) |
| 148 | + .map(|score| ScoredDocument { id: i, score }) |
| 149 | + }) |
| 150 | + .collect(); |
| 151 | + |
| 152 | + // Sort the scored documents by descending `score` (f32). |
| 153 | + scored_docs.sort_by(|a, b| { |
| 154 | + b.score |
| 155 | + .partial_cmp(&a.score) |
| 156 | + .unwrap_or(std::cmp::Ordering::Equal) |
| 157 | + }); |
| 158 | + |
| 159 | + // Extract only the document indices (usize) in ranked order. |
| 160 | + let decreasing_indexes: Vec<usize> = |
| 161 | + scored_docs.into_iter().map(|d| d.id).collect(); |
| 162 | + |
| 163 | + decreasing_indexes |
| 164 | + }; |
| 165 | + // Reorder results according to ranking |
| 166 | + let mut old = Vec::new(); |
| 167 | + std::mem::swap(&mut old, &mut results); |
| 168 | + for &idx in &decreasing_indexes { |
| 169 | + let mut item: (SearchResult, usize) = Default::default(); |
| 170 | + std::mem::swap(&mut item, &mut old[idx]); |
| 171 | + results.push(item); |
123 | 172 | } |
124 | 173 | } |
125 | 174 |
|
|
0 commit comments