Skip to content

Commit 08db9f7

Browse files
committed
add Speech to ModelCategory
1 parent f47aac5 commit 08db9f7

File tree

3 files changed

+16
-9
lines changed

3 files changed

+16
-9
lines changed

mistralrs-core/src/pipeline/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,9 @@ pub enum ModelCategory {
218218
has_conv2d: bool,
219219
prefixer: Arc<dyn VisionPromptPrefixer>,
220220
},
221-
Audio,
222221
Diffusion,
222+
Audio,
223+
Speech,
223224
}
224225

225226
impl PartialEq for ModelCategory {
@@ -228,6 +229,7 @@ impl PartialEq for ModelCategory {
228229
(Self::Text, Self::Text) => true,
229230
(Self::Vision { .. }, Self::Vision { .. }) => true,
230231
(Self::Audio, Self::Audio) => true,
232+
(Self::Speech, Self::Speech) => true,
231233
(Self::Diffusion, Self::Diffusion) => true,
232234
(_, _) => false,
233235
}

mistralrs-server/src/interactive_mode.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ fn read_line<H: Helper, I: History>(editor: &mut Editor<H, I>) -> String {
6161

6262
Err(e) => {
6363
editor.save_history(&history_file_path()).unwrap();
64-
eprintln!("Error reading input: {:?}", e);
64+
eprintln!("Error reading input: {e:?}");
6565
std::process::exit(1);
6666
}
6767
Ok(prompt) => {
@@ -84,8 +84,9 @@ pub async fn interactive_mode(
8484
ModelCategory::Vision { .. } => {
8585
vision_interactive_mode(mistralrs, do_search, enable_thinking).await
8686
}
87-
ModelCategory::Audio => audio_interactive_mode(mistralrs, do_search, enable_thinking).await,
8887
ModelCategory::Diffusion => diffusion_interactive_mode(mistralrs, do_search).await,
88+
ModelCategory::Audio => audio_interactive_mode(mistralrs, do_search, enable_thinking).await,
89+
ModelCategory::Speech => speech_interactive_mode(mistralrs, do_search).await,
8990
}
9091
}
9192

@@ -258,7 +259,7 @@ async fn text_interactive_mode(
258259
} = &chunk.choices[0]
259260
{
260261
assistant_output.push_str(content);
261-
print!("{}", content);
262+
print!("{content}");
262263
io::stdout().flush().unwrap();
263264
if finish_reason.is_some() {
264265
if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
@@ -499,7 +500,7 @@ async fn vision_interactive_mode(
499500
} = &chunk.choices[0]
500501
{
501502
assistant_output.push_str(content);
502-
print!("{}", content);
503+
print!("{content}");
503504
io::stdout().flush().unwrap();
504505
if finish_reason.is_some() {
505506
if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
@@ -559,7 +560,11 @@ async fn audio_interactive_mode(
559560
_do_search: bool,
560561
_enable_thinking: Option<bool>,
561562
) {
562-
unimplemented!("Using audio models interactively isn't supported yet")
563+
unimplemented!("Using audio models isn't supported yet")
564+
}
565+
566+
async fn speech_interactive_mode(_mistralrs: Arc<MistralRs>, _do_search: bool) {
567+
unimplemented!("Using speech models isn't supported yet")
563568
}
564569

565570
async fn diffusion_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {

mistralrs/src/messages.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,13 @@ impl VisionMessages {
159159
model: &Model,
160160
) -> anyhow::Result<Self> {
161161
let prefixer = match &model.config().category {
162-
ModelCategory::Text | ModelCategory::Audio | ModelCategory::Diffusion => {
163-
anyhow::bail!("`add_image_message` expects a vision model.")
164-
}
165162
ModelCategory::Vision {
166163
has_conv2d: _,
167164
prefixer,
168165
} => prefixer,
166+
_ => {
167+
anyhow::bail!("`add_image_message` expects a vision model.")
168+
}
169169
};
170170
self.images.push(image);
171171
self.messages.push(IndexMap::from([

0 commit comments

Comments
 (0)