Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions mistralrs-server/src/interactive_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@
const SYSTEM_CMD: &str = "\\system";
const CLEAR_CMD: &str = "\\clear";

/// Regex string used to extract image URLs from prompts without capturing
/// trailing punctuation like periods or parentheses.
const IMAGE_REGEX: &str =
r#"((?:https?://|file://)?\S+?\.(?:png|jpe?g|bmp|gif|webp)(?:\?\S+?)?)(?=[\s,.;:!?)]|$)"#;

fn interactive_sample_parameters() -> SamplingParams {
SamplingParams {
temperature: Some(0.1),
Expand Down Expand Up @@ -332,11 +337,31 @@
rl.save_history(&history_file_path()).unwrap();
}

#[cfg(test)]
mod tests {

Check failure on line 341 in mistralrs-server/src/interactive_mode.rs

View workflow job for this annotation

GitHub Actions / Clippy

items after a test module
use super::*;

#[test]
fn parse_files_and_message_trims_trailing_punctuation() {
let regex = Regex::new(IMAGE_REGEX).unwrap();
let input = "Look at this https://example.com/test.png.";
let (urls, text) = parse_files_and_message(input, &regex);
assert_eq!(urls, vec!["https://example.com/test.png"]);
assert_eq!(text, "Look at this .");
}
}

fn parse_files_and_message(input: &str, regex: &Regex) -> (Vec<String>, String) {
// Collect all URLs
let urls: Vec<String> = regex
.captures_iter(input)
.filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string()))
.filter_map(|cap| {
cap.get(1).map(|m| {
m.as_str()
.trim_end_matches(|c: char| matches!(c, '.' | ',' | ';' | ':' | '!' | '?' | ')' | ']' | '}' | '"' | '\''))

Check failure on line 361 in mistralrs-server/src/interactive_mode.rs

View workflow job for this annotation

GitHub Actions / Clippy

this manual char comparison can be written more succinctly
.to_string()
})
})
.collect();
// Remove the URLs from the input to get the message text
let text = regex.replace_all(input, "").trim().to_string();
Expand All @@ -349,8 +374,7 @@
enable_thinking: Option<bool>,
) {
// Capture HTTP/HTTPS URLs and local file paths ending with common image extensions
let image_regex =
Regex::new(r#"((?:https?://|file://)?\S+\.(?:png|jpe?g|bmp|gif|webp))"#).unwrap();
let image_regex = Regex::new(IMAGE_REGEX).unwrap();

let sender = mistralrs.get_sender().unwrap();
let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
Expand Down
Loading