Skip to content

Commit f47aac5

Browse files
committed
add Audio to ModelCategory
1 parent ca5794d commit f47aac5

File tree

3 files changed

+48
-52
lines changed

3 files changed

+48
-52
lines changed

mistralrs-core/src/pipeline/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ pub enum ModelCategory {
218218
has_conv2d: bool,
219219
prefixer: Arc<dyn VisionPromptPrefixer>,
220220
},
221+
Audio,
221222
Diffusion,
222223
}
223224

@@ -226,10 +227,9 @@ impl PartialEq for ModelCategory {
226227
match (self, other) {
227228
(Self::Text, Self::Text) => true,
228229
(Self::Vision { .. }, Self::Vision { .. }) => true,
230+
(Self::Audio, Self::Audio) => true,
229231
(Self::Diffusion, Self::Diffusion) => true,
230-
(Self::Text, _) => false,
231-
(Self::Vision { .. }, _) => false,
232-
(Self::Diffusion, _) => false,
232+
(_, _) => false,
233233
}
234234
}
235235
}

mistralrs-server/src/interactive_mode.rs

Lines changed: 44 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,12 @@ pub async fn interactive_mode(
8484
ModelCategory::Vision { .. } => {
8585
vision_interactive_mode(mistralrs, do_search, enable_thinking).await
8686
}
87+
ModelCategory::Audio => audio_interactive_mode(mistralrs, do_search, enable_thinking).await,
8788
ModelCategory::Diffusion => diffusion_interactive_mode(mistralrs, do_search).await,
8889
}
8990
}
9091

91-
const TEXT_INTERACTIVE_HELP: &str = r#"
92-
Welcome to interactive mode! Because this model is a text model, you can enter prompts and chat with the model.
93-
92+
const COMMAND_COMMANDS: &str = r#"
9493
Commands:
9594
- `\help`: Display this message.
9695
- `\exit`: Quit interactive mode.
@@ -100,21 +99,17 @@ Commands:
10099
- `\clear`: Clear the chat history.
101100
"#;
102101

102+
const TEXT_INTERACTIVE_HELP: &str = r#"
103+
Welcome to interactive mode! Because this model is a text model, you can enter prompts and chat with the model.
104+
"#;
105+
103106
const VISION_INTERACTIVE_HELP: &str = r#"
104107
Welcome to interactive mode! Because this model is a vision model, you can enter prompts and chat with the model.
105108
106109
To specify a message with one or more images, simply include the image URL or path:
107110
108111
- `Please describe this image: path/to/image1.jpg path/to/image2.png`
109112
- `What is in this image: <url here>`
110-
111-
Commands:
112-
- `\help`: Display this message.
113-
- `\exit`: Quit interactive mode.
114-
- `\system <system message here>`:
115-
Add a system message to the chat without running the model.
116-
Ex: `\system Always respond as a pirate.`
117-
- `\clear`: Clear the chat history.
118113
"#;
119114

120115
const DIFFUSION_INTERACTIVE_HELP: &str = r#"
@@ -130,15 +125,8 @@ const EXIT_CMD: &str = "\\exit";
130125
const SYSTEM_CMD: &str = "\\system";
131126
const CLEAR_CMD: &str = "\\clear";
132127

133-
async fn text_interactive_mode(
134-
mistralrs: Arc<MistralRs>,
135-
do_search: bool,
136-
enable_thinking: Option<bool>,
137-
) {
138-
let sender = mistralrs.get_sender().unwrap();
139-
let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
140-
141-
let sampling_params = SamplingParams {
128+
fn interactive_sample_parameters() -> SamplingParams {
129+
SamplingParams {
142130
temperature: Some(0.1),
143131
top_k: Some(32),
144132
top_p: Some(0.1),
@@ -151,11 +139,22 @@ async fn text_interactive_mode(
151139
logits_bias: None,
152140
n_choices: 1,
153141
dry_params: Some(DrySamplingParams::default()),
154-
};
142+
}
143+
}
144+
145+
async fn text_interactive_mode(
146+
mistralrs: Arc<MistralRs>,
147+
do_search: bool,
148+
enable_thinking: Option<bool>,
149+
) {
150+
let sender = mistralrs.get_sender().unwrap();
151+
let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
152+
153+
let sampling_params = interactive_sample_parameters();
155154

156155
info!("Starting interactive loop with sampling params: {sampling_params:?}");
157156
println!(
158-
"{}{TEXT_INTERACTIVE_HELP}{}",
157+
"{}{TEXT_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
159158
"=".repeat(20),
160159
"=".repeat(20)
161160
);
@@ -178,7 +177,7 @@ async fn text_interactive_mode(
178177
"" => continue,
179178
HELP_CMD => {
180179
println!(
181-
"{}{TEXT_INTERACTIVE_HELP}{}",
180+
"{}{TEXT_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
182181
"=".repeat(20),
183182
"=".repeat(20)
184183
);
@@ -314,16 +313,14 @@ async fn text_interactive_mode(
314313
rl.save_history(&history_file_path()).unwrap();
315314
}
316315

317-
fn parse_image_urls_and_message(input: &str) -> (Vec<String>, String) {
318-
// Capture HTTP/HTTPS URLs and local file paths ending with common image extensions
319-
let re = Regex::new(r#"((?:https?://|file://)?\S+\.(?:png|jpe?g|bmp|gif|webp))"#).unwrap();
316+
fn parse_files_and_message(input: &str, regex: &Regex) -> (Vec<String>, String) {
320317
// Collect all URLs
321-
let urls: Vec<String> = re
318+
let urls: Vec<String> = regex
322319
.captures_iter(input)
323320
.filter_map(|cap| cap.get(1).map(|m| m.as_str().to_string()))
324321
.collect();
325322
// Remove the URLs from the input to get the message text
326-
let text = re.replace_all(input, "").trim().to_string();
323+
let text = regex.replace_all(input, "").trim().to_string();
327324
(urls, text)
328325
}
329326

@@ -332,38 +329,29 @@ async fn vision_interactive_mode(
332329
do_search: bool,
333330
enable_thinking: Option<bool>,
334331
) {
332+
// Capture HTTP/HTTPS URLs and local file paths ending with common image extensions
333+
let image_regex =
334+
Regex::new(r#"((?:https?://|file://)?\S+\.(?:png|jpe?g|bmp|gif|webp))"#).unwrap();
335+
335336
let sender = mistralrs.get_sender().unwrap();
336337
let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
337338
let mut images = Vec::new();
338339

339340
let prefixer = match &mistralrs.config().category {
340-
ModelCategory::Text | ModelCategory::Diffusion => {
341-
panic!("`add_image_message` expects a vision model.")
342-
}
343341
ModelCategory::Vision {
344342
has_conv2d: _,
345343
prefixer,
346344
} => prefixer,
345+
_ => {
346+
panic!("`add_image_message` expects a vision model.")
347+
}
347348
};
348349

349-
let sampling_params = SamplingParams {
350-
temperature: Some(0.1),
351-
top_k: Some(32),
352-
top_p: Some(0.1),
353-
min_p: Some(0.05),
354-
top_n_logprobs: 0,
355-
frequency_penalty: Some(0.1),
356-
presence_penalty: Some(0.1),
357-
max_len: None,
358-
stop_toks: None,
359-
logits_bias: None,
360-
n_choices: 1,
361-
dry_params: Some(DrySamplingParams::default()),
362-
};
350+
let sampling_params = interactive_sample_parameters();
363351

364352
info!("Starting interactive loop with sampling params: {sampling_params:?}");
365353
println!(
366-
"{}{VISION_INTERACTIVE_HELP}{}",
354+
"{}{VISION_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
367355
"=".repeat(20),
368356
"=".repeat(20)
369357
);
@@ -386,7 +374,7 @@ async fn vision_interactive_mode(
386374
"" => continue,
387375
HELP_CMD => {
388376
println!(
389-
"{}{VISION_INTERACTIVE_HELP}{}",
377+
"{}{VISION_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
390378
"=".repeat(20),
391379
"=".repeat(20)
392380
);
@@ -418,7 +406,7 @@ async fn vision_interactive_mode(
418406
}
419407
// Extract any image URLs and the remaining text
420408
_ => {
421-
let (urls, text) = parse_image_urls_and_message(prompt.trim());
409+
let (urls, text) = parse_files_and_message(prompt.trim(), &image_regex);
422410
if !urls.is_empty() {
423411
let mut image_indexes = Vec::new();
424412
// Load all images first
@@ -566,6 +554,14 @@ async fn vision_interactive_mode(
566554
rl.save_history(&history_file_path()).unwrap();
567555
}
568556

557+
async fn audio_interactive_mode(
558+
_mistralrs: Arc<MistralRs>,
559+
_do_search: bool,
560+
_enable_thinking: Option<bool>,
561+
) {
562+
unimplemented!("Using audio models interactively isn't supported yet")
563+
}
564+
569565
async fn diffusion_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
570566
let sender = mistralrs.get_sender().unwrap();
571567

mistralrs/src/messages.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ impl VisionMessages {
159159
model: &Model,
160160
) -> anyhow::Result<Self> {
161161
let prefixer = match &model.config().category {
162-
ModelCategory::Text | ModelCategory::Diffusion => {
162+
ModelCategory::Text | ModelCategory::Audio | ModelCategory::Diffusion => {
163163
anyhow::bail!("`add_image_message` expects a vision model.")
164164
}
165165
ModelCategory::Vision {

0 commit comments

Comments
 (0)