Skip to content

Commit b9c9604

Browse files
More work on consolidating completions and chat completions
1 parent f83952e commit b9c9604

File tree

6 files changed

+91
-79
lines changed

6 files changed

+91
-79
lines changed

mistralrs-server-core/src/chat_completion.rs

Lines changed: 9 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! ## Chat Completions functionality and route handler.
22
3-
use std::{error::Error, ops::Deref, pin::Pin, task::Poll, time::Duration};
3+
use std::{ops::Deref, pin::Pin, task::Poll, time::Duration};
44

55
use anyhow::{Context, Result};
66
use axum::{
@@ -19,20 +19,20 @@ use mistralrs_core::{
1919
NormalRequest, Request, RequestMessage, Response, SamplingParams,
2020
StopTokens as InternalStopTokens,
2121
};
22-
use serde::Serialize;
2322
use serde_json::Value;
2423
use tokio::sync::mpsc::{Receiver, Sender};
2524

2625
use crate::{
26+
completion_base::{base_handle_completion_error, BaseCompletionResponder},
2727
openai::{
2828
ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
2929
ResponseFormat, StopTokens,
3030
},
31-
streaming::{get_keep_alive_interval, DoneState},
31+
streaming::{get_keep_alive_interval, BaseStreamer, DoneState},
3232
types::{ExtractedMistralRsState, SharedMistralRsState},
3333
util::{
34-
create_response_channel, parse_image_url, send_model_request, ErrorToResponse, JsonError,
35-
ModelErrorMessage,
34+
create_response_channel, parse_image_url, send_model_request, BaseJsonModelError,
35+
ErrorToResponse, JsonError, ModelErrorMessage,
3636
},
3737
};
3838

@@ -77,22 +77,7 @@ pub type OnDoneCallback = Box<dyn Fn(&[ChatCompletionChunkResponse]) + Send + Sy
7777
///
7878
/// It processes incoming response chunks from a model and converts them
7979
/// into Server-Sent Events (SSE) format for real-time streaming to clients.
80-
pub struct Streamer {
81-
/// Channel receiver for incoming model responses
82-
rx: Receiver<Response>,
83-
/// Current state of the streaming operation
84-
done_state: DoneState,
85-
/// Underlying mistral.rs instance
86-
state: SharedMistralRsState,
87-
/// Whether to store chunks for the completion callback
88-
store_chunks: bool,
89-
/// All chunks received during streaming (if `store_chunks` is true)
90-
chunks: Vec<ChatCompletionChunkResponse>,
91-
/// Optional callback to process each chunk before sending
92-
on_chunk: Option<OnChunkCallback>,
93-
/// Optional callback to execute when streaming completes
94-
on_done: Option<OnDoneCallback>,
95-
}
80+
pub type Streamer = BaseStreamer<ChatCompletionChunkResponse, OnChunkCallback, OnDoneCallback>;
9681

9782
impl futures::Stream for Streamer {
9883
type Item = Result<Event, axum::Error>;
@@ -173,37 +158,9 @@ impl futures::Stream for Streamer {
173158
}
174159

175160
/// Represents different types of chat completion responses.
176-
pub enum ChatCompletionResponder {
177-
/// Server-Sent Events streaming response
178-
Sse(Sse<Streamer>),
179-
/// Complete JSON response for non-streaming requests
180-
Json(ChatCompletionResponse),
181-
/// Model error with partial response data
182-
ModelError(String, ChatCompletionResponse),
183-
/// Internal server error
184-
InternalError(Box<dyn Error>),
185-
/// Request validation error
186-
ValidationError(Box<dyn Error>),
187-
}
188-
189-
/// JSON error response structure for model errors.
190-
#[derive(Serialize)]
191-
struct JsonModelError {
192-
message: String,
193-
/// Partial response data that was generated before the error occurred
194-
partial_response: ChatCompletionResponse,
195-
}
196-
197-
impl JsonModelError {
198-
/// Creates a new JSON model error with message and partial response.
199-
fn new(message: String, partial_response: ChatCompletionResponse) -> Self {
200-
Self {
201-
message,
202-
partial_response,
203-
}
204-
}
205-
}
161+
pub type ChatCompletionResponder = BaseCompletionResponder<ChatCompletionResponse, Streamer>;
206162

163+
type JsonModelError = BaseJsonModelError<ChatCompletionResponse>;
207164
impl ErrorToResponse for JsonModelError {}
208165

209166
impl IntoResponse for ChatCompletionResponder {
@@ -511,9 +468,7 @@ pub fn handle_chat_completion_error(
511468
state: SharedMistralRsState,
512469
e: Box<dyn std::error::Error + Send + Sync + 'static>,
513470
) -> ChatCompletionResponder {
514-
let e = anyhow::Error::msg(e.to_string());
515-
MistralRs::maybe_log_error(state, &*e);
516-
ChatCompletionResponder::InternalError(e.into())
471+
base_handle_completion_error(state, e)
517472
}
518473

519474
/// Creates a SSE streamer for chat completions with optional callbacks.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//! Base functionality for completions.
2+
3+
use std::error::Error;
4+
5+
use axum::response::Sse;
6+
use mistralrs_core::MistralRs;
7+
8+
use crate::types::SharedMistralRsState;
9+
10+
/// Generic responder enum for different completion types
11+
#[derive(Debug)]
12+
pub enum BaseCompletionResponder<R, S> {
13+
/// Server-Sent Events streaming response
14+
Sse(Sse<S>),
15+
/// Complete JSON response for non-streaming requests
16+
Json(R),
17+
/// Model error with partial response data
18+
ModelError(String, R),
19+
/// Internal server error
20+
InternalError(Box<dyn Error>),
21+
/// Request validation error
22+
ValidationError(Box<dyn Error>),
23+
}
24+
25+
/// Generic function to handle completion errors and logging them.
26+
pub(crate) fn base_handle_completion_error<R, S>(
27+
state: SharedMistralRsState,
28+
e: Box<dyn std::error::Error + Send + Sync + 'static>,
29+
) -> BaseCompletionResponder<R, S> {
30+
let e = anyhow::Error::msg(e.to_string());
31+
MistralRs::maybe_log_error(state, &*e);
32+
BaseCompletionResponder::InternalError(e.into())
33+
}

mistralrs-server-core/src/completions.rs

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use std::{
2-
error::Error,
32
pin::Pin,
43
sync::Arc,
54
task::{Context, Poll},
@@ -19,16 +18,17 @@ use mistralrs_core::{
1918
CompletionResponse, Constraint, DrySamplingParams, MistralRs, NormalRequest, Request,
2019
RequestMessage, Response, SamplingParams, StopTokens as InternalStopTokens,
2120
};
22-
use serde::Serialize;
2321
use tokio::sync::mpsc::{Receiver, Sender};
2422
use tracing::warn;
2523

2624
use crate::{
25+
completion_base::BaseCompletionResponder,
2726
openai::{CompletionRequest, Grammar, StopTokens},
2827
streaming::{get_keep_alive_interval, DoneState},
2928
types::ExtractedMistralRsState,
3029
util::{
31-
create_response_channel, send_model_request, ErrorToResponse, JsonError, ModelErrorMessage,
30+
create_response_channel, send_model_request, BaseJsonModelError, ErrorToResponse,
31+
JsonError, ModelErrorMessage,
3232
},
3333
};
3434

@@ -94,29 +94,10 @@ impl futures::Stream for Streamer {
9494
}
9595
}
9696

97-
pub enum CompletionResponder {
98-
Sse(Sse<Streamer>),
99-
Json(CompletionResponse),
100-
ModelError(String, CompletionResponse),
101-
InternalError(Box<dyn Error>),
102-
ValidationError(Box<dyn Error>),
103-
}
104-
105-
#[derive(Serialize)]
106-
struct JsonModelError {
107-
message: String,
108-
partial_response: CompletionResponse,
109-
}
110-
111-
impl JsonModelError {
112-
fn new(message: String, partial_response: CompletionResponse) -> Self {
113-
Self {
114-
message,
115-
partial_response,
116-
}
117-
}
118-
}
97+
pub type CompletionResponder = BaseCompletionResponder<CompletionResponse, Streamer>;
11998

99+
/// JSON error response structure for model errors.
100+
type JsonModelError = BaseJsonModelError<CompletionResponse>;
120101
impl ErrorToResponse for JsonModelError {}
121102

122103
impl IntoResponse for CompletionResponder {

mistralrs-server-core/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@
215215
//! ```
216216
217217
pub mod chat_completion;
218+
pub mod completion_base;
218219
mod completions;
219220
mod handlers;
220221
mod image_generation;

mistralrs-server-core/src/streaming.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
33
use std::env;
44

5+
use mistralrs_core::Response;
6+
use tokio::sync::mpsc::Receiver;
7+
8+
use crate::types::SharedMistralRsState;
9+
510
/// Default keep-alive interval for Server-Sent Events (SSE) streams in milliseconds.
611
pub const DEFAULT_KEEP_ALIVE_INTERVAL_MS: u64 = 10_000;
712

@@ -15,6 +20,27 @@ pub enum DoneState {
1520
Done,
1621
}
1722

23+
/// A streaming response handler.
24+
///
25+
/// It processes incoming response chunks from a model and converts them
26+
/// into Server-Sent Events (SSE) format for real-time streaming to clients.
27+
pub struct BaseStreamer<R, C, D> {
28+
/// Channel receiver for incoming model responses
29+
pub rx: Receiver<Response>,
30+
/// Current state of the streaming operation
31+
pub done_state: DoneState,
32+
/// Underlying mistral.rs instance
33+
pub state: SharedMistralRsState,
34+
/// Whether to store chunks for the completion callback
35+
pub store_chunks: bool,
36+
/// All chunks received during streaming (if `store_chunks` is true)
37+
pub chunks: Vec<R>,
38+
/// Optional callback to process each chunk before sending
39+
pub on_chunk: Option<C>,
40+
/// Optional callback to execute when streaming completes
41+
pub on_done: Option<D>,
42+
}
43+
1844
/// Gets the keep-alive interval for SSE streams from environment or default.
1945
pub fn get_keep_alive_interval() -> u64 {
2046
env::var("KEEP_ALIVE_INTERVAL")

mistralrs-server-core/src/util.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,22 @@ impl std::fmt::Display for ModelErrorMessage {
5757
}
5858
impl std::error::Error for ModelErrorMessage {}
5959

60+
/// Generic JSON error response structure
61+
#[derive(Serialize, Debug)]
62+
pub(crate) struct BaseJsonModelError<T> {
63+
pub(crate) message: String,
64+
pub(crate) partial_response: T,
65+
}
66+
67+
impl<T> BaseJsonModelError<T> {
68+
pub(crate) fn new(message: String, partial_response: T) -> Self {
69+
Self {
70+
message,
71+
partial_response,
72+
}
73+
}
74+
}
75+
6076
/// Creates a channel for response communication.
6177
pub fn create_response_channel(
6278
buffer_size: Option<usize>,

0 commit comments

Comments
 (0)