Skip to content

Commit 1b29479

Browse files
More WIP consolidation of server core handlers
1 parent b9c9604 commit 1b29479

File tree

7 files changed

+366
-278
lines changed

7 files changed

+366
-278
lines changed

mistralrs-server-core/src/chat_completion.rs

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,18 @@ use serde_json::Value;
2323
use tokio::sync::mpsc::{Receiver, Sender};
2424

2525
use crate::{
26-
completion_base::{base_handle_completion_error, BaseCompletionResponder},
26+
completion_base::{
27+
base_handle_completion_error, base_process_non_streaming_response, create_response_channel,
28+
send_model_request, BaseCompletionResponder, BaseJsonModelError, ErrorToResponse,
29+
JsonError, ModelErrorMessage,
30+
},
2731
openai::{
2832
ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
2933
ResponseFormat, StopTokens,
3034
},
31-
streaming::{get_keep_alive_interval, BaseStreamer, DoneState},
35+
streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState},
3236
types::{ExtractedMistralRsState, SharedMistralRsState},
33-
util::{
34-
create_response_channel, parse_image_url, send_model_request, BaseJsonModelError,
35-
ErrorToResponse, JsonError, ModelErrorMessage,
36-
},
37+
util::parse_image_url,
3738
};
3839

3940
/// A callback function that processes streaming response chunks before they are sent to the client.
@@ -53,7 +54,7 @@ use crate::{
5354
/// chunk
5455
/// });
5556
/// ```
56-
pub type OnChunkCallback =
57+
pub type ChatCompletionOnChunkCallback =
5758
Box<dyn Fn(ChatCompletionChunkResponse) -> ChatCompletionChunkResponse + Send + Sync>;
5859

5960
/// A callback function that is executed when the streaming response completes.
@@ -71,15 +72,19 @@ pub type OnChunkCallback =
7172
/// // Process all chunks for analytics
7273
/// });
7374
/// ```
74-
pub type OnDoneCallback = Box<dyn Fn(&[ChatCompletionChunkResponse]) + Send + Sync>;
75+
pub type ChatCompletionOnDoneCallback = Box<dyn Fn(&[ChatCompletionChunkResponse]) + Send + Sync>;
7576

7677
/// A streaming response handler.
7778
///
7879
/// It processes incoming response chunks from a model and converts them
7980
/// into Server-Sent Events (SSE) format for real-time streaming to clients.
80-
pub type Streamer = BaseStreamer<ChatCompletionChunkResponse, OnChunkCallback, OnDoneCallback>;
81+
pub type ChatCompletionStreamer = BaseStreamer<
82+
ChatCompletionChunkResponse,
83+
ChatCompletionOnChunkCallback,
84+
ChatCompletionOnDoneCallback,
85+
>;
8186

82-
impl futures::Stream for Streamer {
87+
impl futures::Stream for ChatCompletionStreamer {
8388
type Item = Result<Event, axum::Error>;
8489

8590
/// Polls the stream for the next Server-Sent Event.
@@ -158,7 +163,8 @@ impl futures::Stream for Streamer {
158163
}
159164

160165
/// Represents different types of chat completion responses.
161-
pub type ChatCompletionResponder = BaseCompletionResponder<ChatCompletionResponse, Streamer>;
166+
pub type ChatCompletionResponder =
167+
BaseCompletionResponder<ChatCompletionResponse, ChatCompletionStreamer>;
162168

163169
type JsonModelError = BaseJsonModelError<ChatCompletionResponse>;
164170
impl ErrorToResponse for JsonModelError {}
@@ -475,21 +481,10 @@ pub fn handle_chat_completion_error(
475481
pub fn create_chat_streamer(
476482
rx: Receiver<Response>,
477483
state: SharedMistralRsState,
478-
on_chunk: Option<OnChunkCallback>,
479-
on_done: Option<OnDoneCallback>,
480-
) -> Sse<Streamer> {
481-
let store_chunks = on_done.is_some();
482-
483-
let streamer = Streamer {
484-
rx,
485-
done_state: DoneState::Running,
486-
store_chunks,
487-
state,
488-
chunks: Vec::new(),
489-
on_chunk,
490-
on_done,
491-
};
492-
484+
on_chunk: Option<ChatCompletionOnChunkCallback>,
485+
on_done: Option<ChatCompletionOnDoneCallback>,
486+
) -> Sse<ChatCompletionStreamer> {
487+
let streamer = base_create_streamer(rx, state, on_chunk, on_done);
493488
let keep_alive_interval = get_keep_alive_interval();
494489

495490
Sse::new(streamer)
@@ -501,15 +496,8 @@ pub async fn process_non_streaming_chat_response(
501496
rx: &mut Receiver<Response>,
502497
state: SharedMistralRsState,
503498
) -> ChatCompletionResponder {
504-
let response = match rx.recv().await {
505-
Some(response) => response,
506-
None => {
507-
let e = anyhow::Error::msg("No response received from the model.");
508-
return handle_chat_completion_error(state, e.into());
509-
}
510-
};
511-
512-
match_responses(state, response)
499+
base_process_non_streaming_response(rx, state, match_responses, handle_chat_completion_error)
500+
.await
513501
}
514502

515503
/// Matches and processes different types of model responses into appropriate chat completion responses.

mistralrs-server-core/src/completion_base.rs

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,87 @@
22
33
use std::error::Error;
44

5-
use axum::response::Sse;
6-
use mistralrs_core::MistralRs;
5+
use anyhow::{Context, Result};
6+
use axum::{
7+
extract::Json,
8+
http::StatusCode,
9+
response::{IntoResponse, Sse},
10+
};
11+
use mistralrs_core::{MistralRs, Request, Response};
12+
use serde::Serialize;
13+
use tokio::sync::mpsc::{channel, Receiver, Sender};
714

815
use crate::types::SharedMistralRsState;
916

17+
/// Default buffer size for the response channel used in streaming operations.
18+
///
19+
/// This constant defines the maximum number of response messages that can be buffered
20+
/// in the channel before backpressure is applied. A larger buffer reduces the likelihood
21+
/// of blocking but uses more memory.
22+
pub const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 10_000;
23+
24+
/// Trait for converting errors to HTTP responses with appropriate status codes.
25+
pub(crate) trait ErrorToResponse: Serialize {
26+
/// Converts the error to an HTTP response with the specified status code.
27+
fn to_response(&self, code: StatusCode) -> axum::response::Response {
28+
let mut r = Json(self).into_response();
29+
*r.status_mut() = code;
30+
r
31+
}
32+
}
33+
34+
/// Standard JSON error response structure.
35+
#[derive(Serialize, Debug)]
36+
pub(crate) struct JsonError {
37+
pub(crate) message: String,
38+
}
39+
40+
impl JsonError {
41+
/// Creates a new JSON error with the specified message.
42+
pub(crate) fn new(message: String) -> Self {
43+
Self { message }
44+
}
45+
}
46+
47+
impl std::fmt::Display for JsonError {
48+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49+
write!(f, "{}", self.message)
50+
}
51+
}
52+
53+
impl std::error::Error for JsonError {}
54+
55+
impl ErrorToResponse for JsonError {}
56+
57+
/// Internal error type for model-related errors with a descriptive message.
58+
///
59+
/// This struct wraps error messages from the underlying model and implements
60+
/// the standard error traits for proper error handling and display.
61+
#[derive(Debug)]
62+
pub(crate) struct ModelErrorMessage(pub(crate) String);
63+
impl std::fmt::Display for ModelErrorMessage {
64+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65+
write!(f, "{}", self.0)
66+
}
67+
}
68+
impl std::error::Error for ModelErrorMessage {}
69+
70+
/// Generic JSON error response structure
71+
#[derive(Serialize, Debug)]
72+
pub(crate) struct BaseJsonModelError<T> {
73+
pub(crate) message: String,
74+
pub(crate) partial_response: T,
75+
}
76+
77+
impl<T> BaseJsonModelError<T> {
78+
pub(crate) fn new(message: String, partial_response: T) -> Self {
79+
Self {
80+
message,
81+
partial_response,
82+
}
83+
}
84+
}
85+
1086
/// Generic responder enum for different completion types
1187
#[derive(Debug)]
1288
pub enum BaseCompletionResponder<R, S> {
@@ -22,6 +98,15 @@ pub enum BaseCompletionResponder<R, S> {
2298
ValidationError(Box<dyn Error>),
2399
}
24100

101+
/// Creates a channel for response communication.
102+
pub fn create_response_channel(
103+
buffer_size: Option<usize>,
104+
) -> (Sender<Response>, Receiver<Response>) {
105+
let channel_buffer_size = buffer_size.unwrap_or(DEFAULT_CHANNEL_BUFFER_SIZE);
106+
107+
channel(channel_buffer_size)
108+
}
109+
25110
/// Generic function to handle completion errors and logging them.
26111
pub(crate) fn base_handle_completion_error<R, S>(
27112
state: SharedMistralRsState,
@@ -31,3 +116,33 @@ pub(crate) fn base_handle_completion_error<R, S>(
31116
MistralRs::maybe_log_error(state, &*e);
32117
BaseCompletionResponder::InternalError(e.into())
33118
}
119+
120+
/// Sends a request to the model processing pipeline.
121+
pub async fn send_model_request(state: &SharedMistralRsState, request: Request) -> Result<()> {
122+
let sender = state
123+
.get_sender()
124+
.context("mistral.rs sender not available.")?;
125+
126+
sender.send(request).await.map_err(|e| e.into())
127+
}
128+
129+
/// Generic function to process non-streaming responses.
130+
pub(crate) async fn base_process_non_streaming_response<R>(
131+
rx: &mut Receiver<Response>,
132+
state: SharedMistralRsState,
133+
match_fn: fn(SharedMistralRsState, Response) -> R,
134+
error_handler: fn(
135+
SharedMistralRsState,
136+
Box<dyn std::error::Error + Send + Sync + 'static>,
137+
) -> R,
138+
) -> R {
139+
let response = match rx.recv().await {
140+
Some(response) => response,
141+
None => {
142+
let e = anyhow::Error::msg("No response received from the model.");
143+
return error_handler(state, e.into());
144+
}
145+
};
146+
147+
match_fn(state, response)
148+
}

0 commit comments

Comments
 (0)