Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,22 @@ pub enum ModelCategory {
prefixer: Arc<dyn VisionPromptPrefixer>,
},
Diffusion,
Audio,
Speech,
}

impl PartialEq for ModelCategory {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Text, Self::Text) => true,
(Self::Vision { .. }, Self::Vision { .. }) => true,
(Self::Audio, Self::Audio) => true,
(Self::Speech, Self::Speech) => true,
(Self::Diffusion, Self::Diffusion) => true,
(Self::Text, _) => false,
(Self::Vision { .. }, _) => false,
(Self::Diffusion, _) => false,
(
Self::Text | Self::Vision { .. } | Self::Diffusion | Self::Audio | Self::Speech,
_,
) => false,
}
}
}
Expand Down
103 changes: 52 additions & 51 deletions mistralrs-server/src/interactive_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn read_line<H: Helper, I: History>(editor: &mut Editor<H, I>) -> String {

Err(e) => {
editor.save_history(&history_file_path()).unwrap();
eprintln!("Error reading input: {:?}", e);
eprintln!("Error reading input: {e:?}");
std::process::exit(1);
}
Ok(prompt) => {
Expand All @@ -85,12 +85,12 @@ pub async fn interactive_mode(
vision_interactive_mode(mistralrs, do_search, enable_thinking).await
}
ModelCategory::Diffusion => diffusion_interactive_mode(mistralrs, do_search).await,
ModelCategory::Audio => audio_interactive_mode(mistralrs, do_search, enable_thinking).await,
ModelCategory::Speech => speech_interactive_mode(mistralrs, do_search).await,
}
}

const TEXT_INTERACTIVE_HELP: &str = r#"
Welcome to interactive mode! Because this model is a text model, you can enter prompts and chat with the model.

const COMMAND_COMMANDS: &str = r#"
Commands:
- `\help`: Display this message.
- `\exit`: Quit interactive mode.
Expand All @@ -100,21 +100,17 @@ Commands:
- `\clear`: Clear the chat history.
"#;

const TEXT_INTERACTIVE_HELP: &str = r#"
Welcome to interactive mode! Because this model is a text model, you can enter prompts and chat with the model.
"#;

const VISION_INTERACTIVE_HELP: &str = r#"
Welcome to interactive mode! Because this model is a vision model, you can enter prompts and chat with the model.

To specify a message with one or more images, simply include the image URL or path:

- `Please describe this image: path/to/image1.jpg path/to/image2.png`
- `What is in this image: <url here>`

Commands:
- `\help`: Display this message.
- `\exit`: Quit interactive mode.
- `\system <system message here>`:
Add a system message to the chat without running the model.
Ex: `\system Always respond as a pirate.`
- `\clear`: Clear the chat history.
"#;

const DIFFUSION_INTERACTIVE_HELP: &str = r#"
Expand All @@ -130,15 +126,8 @@ const EXIT_CMD: &str = "\\exit";
const SYSTEM_CMD: &str = "\\system";
const CLEAR_CMD: &str = "\\clear";

async fn text_interactive_mode(
mistralrs: Arc<MistralRs>,
do_search: bool,
enable_thinking: Option<bool>,
) {
let sender = mistralrs.get_sender().unwrap();
let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();

let sampling_params = SamplingParams {
fn interactive_sample_parameters() -> SamplingParams {
SamplingParams {
temperature: Some(0.1),
top_k: Some(32),
top_p: Some(0.1),
Expand All @@ -151,11 +140,22 @@ async fn text_interactive_mode(
logits_bias: None,
n_choices: 1,
dry_params: Some(DrySamplingParams::default()),
};
}
}

async fn text_interactive_mode(
mistralrs: Arc<MistralRs>,
do_search: bool,
enable_thinking: Option<bool>,
) {
let sender = mistralrs.get_sender().unwrap();
let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();

let sampling_params = interactive_sample_parameters();

info!("Starting interactive loop with sampling params: {sampling_params:?}");
println!(
"{}{TEXT_INTERACTIVE_HELP}{}",
"{}{TEXT_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
"=".repeat(20),
"=".repeat(20)
);
Expand All @@ -178,7 +178,7 @@ async fn text_interactive_mode(
"" => continue,
HELP_CMD => {
println!(
"{}{TEXT_INTERACTIVE_HELP}{}",
"{}{TEXT_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
"=".repeat(20),
"=".repeat(20)
);
Expand Down Expand Up @@ -259,7 +259,7 @@ async fn text_interactive_mode(
} = &chunk.choices[0]
{
assistant_output.push_str(content);
print!("{}", content);
print!("{content}");
io::stdout().flush().unwrap();
if finish_reason.is_some() {
if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
Expand Down Expand Up @@ -314,16 +314,14 @@ async fn text_interactive_mode(
rl.save_history(&history_file_path()).unwrap();
}

fn parse_image_urls_and_message(input: &str) -> (Vec<String>, String) {
// Capture HTTP/HTTPS URLs and local file paths ending with common image extensions
let re = Regex::new(r#"((?:https?://|file://)?\S+\.(?:png|jpe?g|bmp|gif|webp))"#).unwrap();
fn parse_files_and_message(input: &str, regex: &Regex) -> (Vec<String>, String) {
// Collect all URLs
let urls: Vec<String> = re
let urls: Vec<String> = regex
.captures_iter(input)
.filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string()))
.collect();
// Remove the URLs from the input to get the message text
let text = re.replace_all(input, "").trim().to_string();
let text = regex.replace_all(input, "").trim().to_string();
(urls, text)
}

Expand All @@ -332,38 +330,29 @@ async fn vision_interactive_mode(
do_search: bool,
enable_thinking: Option<bool>,
) {
// Capture HTTP/HTTPS URLs and local file paths ending with common image extensions
let image_regex =
Regex::new(r#"((?:https?://|file://)?\S+\.(?:png|jpe?g|bmp|gif|webp))"#).unwrap();

let sender = mistralrs.get_sender().unwrap();
let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
let mut images = Vec::new();

let prefixer = match &mistralrs.config().category {
ModelCategory::Text | ModelCategory::Diffusion => {
panic!("`add_image_message` expects a vision model.")
}
ModelCategory::Vision {
has_conv2d: _,
prefixer,
} => prefixer,
_ => {
panic!("`add_image_message` expects a vision model.")
}
};

let sampling_params = SamplingParams {
temperature: Some(0.1),
top_k: Some(32),
top_p: Some(0.1),
min_p: Some(0.05),
top_n_logprobs: 0,
frequency_penalty: Some(0.1),
presence_penalty: Some(0.1),
max_len: None,
stop_toks: None,
logits_bias: None,
n_choices: 1,
dry_params: Some(DrySamplingParams::default()),
};
let sampling_params = interactive_sample_parameters();

info!("Starting interactive loop with sampling params: {sampling_params:?}");
println!(
"{}{VISION_INTERACTIVE_HELP}{}",
"{}{VISION_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
"=".repeat(20),
"=".repeat(20)
);
Expand All @@ -386,7 +375,7 @@ async fn vision_interactive_mode(
"" => continue,
HELP_CMD => {
println!(
"{}{VISION_INTERACTIVE_HELP}{}",
"{}{VISION_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
"=".repeat(20),
"=".repeat(20)
);
Expand Down Expand Up @@ -418,7 +407,7 @@ async fn vision_interactive_mode(
}
// Extract any image URLs and the remaining text
_ => {
let (urls, text) = parse_image_urls_and_message(prompt.trim());
let (urls, text) = parse_files_and_message(prompt.trim(), &image_regex);
if !urls.is_empty() {
let mut image_indexes = Vec::new();
// Load all images first
Expand Down Expand Up @@ -511,7 +500,7 @@ async fn vision_interactive_mode(
} = &chunk.choices[0]
{
assistant_output.push_str(content);
print!("{}", content);
print!("{content}");
io::stdout().flush().unwrap();
if finish_reason.is_some() {
if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
Expand Down Expand Up @@ -566,6 +555,18 @@ async fn vision_interactive_mode(
rl.save_history(&history_file_path()).unwrap();
}

async fn audio_interactive_mode(
_mistralrs: Arc<MistralRs>,
_do_search: bool,
_enable_thinking: Option<bool>,
) {
unimplemented!("Using audio models isn't supported yet")
}

async fn speech_interactive_mode(_mistralrs: Arc<MistralRs>, _do_search: bool) {
unimplemented!("Using speech models isn't supported yet")
}

async fn diffusion_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
let sender = mistralrs.get_sender().unwrap();

Expand Down
6 changes: 3 additions & 3 deletions mistralrs/src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ impl VisionMessages {
model: &Model,
) -> anyhow::Result<Self> {
let prefixer = match &model.config().category {
ModelCategory::Text | ModelCategory::Diffusion => {
anyhow::bail!("`add_image_message` expects a vision model.")
}
ModelCategory::Vision {
has_conv2d: _,
prefixer,
} => prefixer,
_ => {
anyhow::bail!("`add_image_message` expects a vision model.")
}
};
self.images.push(image);
self.messages.push(IndexMap::from([
Expand Down
Loading