Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 57 additions & 204 deletions mistralrs-server-core/src/chat_completion.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
//! ## Chat Completions functionality and route handler.

use std::{env, error::Error, ops::Deref, pin::Pin, task::Poll, time::Duration};
use std::{ops::Deref, pin::Pin, task::Poll, time::Duration};

use anyhow::{Context, Result};
use axum::{
extract::{Json, State},
http::{self, StatusCode},
http::{self},
response::{
sse::{Event, KeepAlive},
IntoResponse, Sse,
Expand All @@ -15,21 +15,28 @@ use either::Either;
use indexmap::IndexMap;
use itertools::Itertools;
use mistralrs_core::{
ChatCompletionChunkResponse, ChatCompletionResponse, Constraint, DrySamplingParams, MistralRs,
NormalRequest, Request, RequestMessage, Response, SamplingParams,
StopTokens as InternalStopTokens,
ChatCompletionChunkResponse, ChatCompletionResponse, Constraint, MistralRs, NormalRequest,
Request, RequestMessage, Response, SamplingParams,
};
use serde::Serialize;
use serde_json::Value;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::mpsc::{Receiver, Sender};

use crate::{
completion_core::{
convert_stop_tokens, get_dry_sampling_params, handle_completion_error,
BaseCompletionResponder,
},
handler_core::{
base_process_non_streaming_response, create_response_channel, send_request,
BaseJsonModelError, ErrorToResponse, JsonError, ModelErrorMessage,
},
openai::{
ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
ResponseFormat, StopTokens,
ResponseFormat,
},
types::{ExtractedMistralRsState, SharedMistralRsState},
util,
streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState},
types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
util::{parse_audio_url, parse_image_url},
};

/// A callback function that processes streaming response chunks before they are sent to the client.
Expand All @@ -41,16 +48,15 @@ use crate::{
/// ### Examples
///
/// ```no_run
/// use mistralrs_server_core::chat_completion::OnChunkCallback;
/// use mistralrs_server_core::chat_completion::ChatCompletionOnChunkCallback;
///
/// let on_chunk: OnChunkCallback = Box::new(|mut chunk| {
/// let on_chunk: ChatCompletionOnChunkCallback = Box::new(|mut chunk| {
/// // Log the chunk or modify its content
/// println!("Processing chunk: {:?}", chunk);
/// chunk
/// });
/// ```
pub type OnChunkCallback =
Box<dyn Fn(ChatCompletionChunkResponse) -> ChatCompletionChunkResponse + Send + Sync>;
pub type ChatCompletionOnChunkCallback = OnChunkCallback<ChatCompletionChunkResponse>;

/// A callback function that is executed when the streaming response completes.
///
Expand All @@ -60,70 +66,26 @@ pub type OnChunkCallback =
/// ### Examples
///
/// ```no_run
/// use mistralrs_server_core::chat_completion::OnDoneCallback;
/// use mistralrs_server_core::chat_completion::ChatCompletionOnDoneCallback;
///
/// let on_done: OnDoneCallback = Box::new(|chunks| {
/// let on_done: ChatCompletionOnDoneCallback = Box::new(|chunks| {
/// println!("Stream completed with {} chunks", chunks.len());
/// // Process all chunks for analytics
/// });
/// ```
pub type OnDoneCallback = Box<dyn Fn(&[ChatCompletionChunkResponse]) + Send + Sync>;

/// Default buffer size for the response channel used in streaming operations.
///
/// This constant defines the maximum number of response messages that can be buffered
/// in the channel before backpressure is applied. A larger buffer reduces the likelihood
/// of blocking but uses more memory.
pub const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 10_000;

/// Default keep-alive interval for Server-Sent Events (SSE) streams in milliseconds.
pub const DEFAULT_KEEP_ALIVE_INTERVAL_MS: u64 = 10_000;

/// Internal error type for model-related errors with a descriptive message.
///
/// This struct wraps error messages from the underlying model and implements
/// the standard error traits for proper error handling and display.
#[derive(Debug)]
struct ModelErrorMessage(String);
impl std::fmt::Display for ModelErrorMessage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for ModelErrorMessage {}

/// Represents the current state of a streaming response.
enum DoneState {
/// The stream is actively processing and sending response chunks
Running,
/// The stream has finished processing and is about to send the `[DONE]` message
SendingDone,
/// The stream has completed entirely
Done,
}
pub type ChatCompletionOnDoneCallback = OnDoneCallback<ChatCompletionChunkResponse>;

/// A streaming response handler.
///
/// It processes incoming response chunks from a model and converts them
/// into Server-Sent Events (SSE) format for real-time streaming to clients.
pub struct Streamer {
/// Channel receiver for incoming model responses
rx: Receiver<Response>,
/// Current state of the streaming operation
done_state: DoneState,
/// Underlying mistral.rs instance
state: SharedMistralRsState,
/// Whether to store chunks for the completion callback
store_chunks: bool,
/// All chunks received during streaming (if `store_chunks` is true)
chunks: Vec<ChatCompletionChunkResponse>,
/// Optional callback to process each chunk before sending
on_chunk: Option<OnChunkCallback>,
/// Optional callback to execute when streaming completes
on_done: Option<OnDoneCallback>,
}
pub type ChatCompletionStreamer = BaseStreamer<
ChatCompletionChunkResponse,
ChatCompletionOnChunkCallback,
ChatCompletionOnDoneCallback,
>;

impl futures::Stream for Streamer {
impl futures::Stream for ChatCompletionStreamer {
type Item = Result<Event, axum::Error>;

/// Polls the stream for the next Server-Sent Event.
Expand Down Expand Up @@ -202,61 +164,10 @@ impl futures::Stream for Streamer {
}

/// Represents different types of chat completion responses.
pub enum ChatCompletionResponder {
/// Server-Sent Events streaming response
Sse(Sse<Streamer>),
/// Complete JSON response for non-streaming requests
Json(ChatCompletionResponse),
/// Model error with partial response data
ModelError(String, ChatCompletionResponse),
/// Internal server error
InternalError(Box<dyn Error>),
/// Request validation error
ValidationError(Box<dyn Error>),
}

/// Trait for converting errors to HTTP responses with appropriate status codes.
trait ErrorToResponse: Serialize {
/// Converts the error to an HTTP response with the specified status code.
fn to_response(&self, code: StatusCode) -> axum::response::Response {
let mut r = Json(self).into_response();
*r.status_mut() = code;
r
}
}

/// Standard JSON error response structure.
#[derive(Serialize)]
struct JsonError {
message: String,
}

impl JsonError {
/// Creates a new JSON error with the specified message.
fn new(message: String) -> Self {
Self { message }
}
}
impl ErrorToResponse for JsonError {}

/// JSON error response structure for model errors.
#[derive(Serialize)]
struct JsonModelError {
message: String,
/// Partial response data that was generated before the error occurred
partial_response: ChatCompletionResponse,
}

impl JsonModelError {
/// Creates a new JSON model error with message and partial response.
fn new(message: String, partial_response: ChatCompletionResponse) -> Self {
Self {
message,
partial_response,
}
}
}
pub type ChatCompletionResponder =
BaseCompletionResponder<ChatCompletionResponse, ChatCompletionStreamer>;

type JsonModelError = BaseJsonModelError<ChatCompletionResponse>;
impl ErrorToResponse for JsonModelError {}

impl IntoResponse for ChatCompletionResponder {
Expand Down Expand Up @@ -291,11 +202,8 @@ pub async fn parse_request(
let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
MistralRs::maybe_log_request(state.clone(), repr);

let stop_toks = match oairequest.stop_seqs {
Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)),
Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])),
None => None,
};
let stop_toks = convert_stop_tokens(oairequest.stop_seqs);

let messages = match oairequest.messages {
Either::Left(req_messages) => {
let mut messages = Vec::new();
Expand Down Expand Up @@ -469,7 +377,7 @@ pub async fn parse_request(
// Parse images
let mut images = Vec::new();
for url_unparsed in image_urls {
let image = util::parse_image_url(&url_unparsed)
let image = parse_image_url(&url_unparsed)
.await
.context(format!("Failed to parse image resource: {}", url_unparsed))?;
images.push(image);
Expand All @@ -478,7 +386,7 @@ pub async fn parse_request(
// Parse audios
let mut audios = Vec::new();
for url_unparsed in audio_urls {
let audio = util::parse_audio_url(&url_unparsed)
let audio = parse_audio_url(&url_unparsed)
.await
.context(format!("Failed to parse audio resource: {}", url_unparsed))?;
audios.push(audio);
Expand Down Expand Up @@ -511,16 +419,12 @@ pub async fn parse_request(
}
};

let dry_params = if let Some(dry_multiplier) = oairequest.dry_multiplier {
Some(DrySamplingParams::new_with_defaults(
dry_multiplier,
oairequest.dry_sequence_breakers,
oairequest.dry_base,
oairequest.dry_allowed_length,
)?)
} else {
None
};
let dry_params = get_dry_sampling_params(
oairequest.dry_multiplier,
oairequest.dry_sequence_breakers,
oairequest.dry_base,
oairequest.dry_allowed_length,
)?;

let is_streaming = oairequest.stream.unwrap_or(false);

Expand Down Expand Up @@ -591,99 +495,48 @@ pub async fn chatcompletions(

let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx).await {
Ok(x) => x,
Err(e) => return handle_chat_completion_error(state, e.into()),
Err(e) => return handle_error(state, e.into()),
};

if let Err(e) = send_request(&state, request).await {
return handle_chat_completion_error(state, e.into());
return handle_error(state, e.into());
}

if is_streaming {
ChatCompletionResponder::Sse(create_chat_streamer(rx, state, None, None))
ChatCompletionResponder::Sse(create_streamer(rx, state, None, None))
} else {
process_non_streaming_chat_response(&mut rx, state).await
process_non_streaming_response(&mut rx, state).await
}
}

/// Helper function to handle chat completion errors and logging them.
pub fn handle_chat_completion_error(
/// Handle route / generation errors and logging them.
pub fn handle_error(
state: SharedMistralRsState,
e: Box<dyn std::error::Error + Send + Sync + 'static>,
) -> ChatCompletionResponder {
let e = anyhow::Error::msg(e.to_string());
MistralRs::maybe_log_error(state, &*e);
ChatCompletionResponder::InternalError(e.into())
}

/// Creates a channel for response communication.
pub fn create_response_channel(
buffer_size: Option<usize>,
) -> (Sender<Response>, Receiver<Response>) {
let channel_buffer_size = buffer_size.unwrap_or(DEFAULT_CHANNEL_BUFFER_SIZE);

channel(channel_buffer_size)
}

/// Gets the keep-alive interval for SSE streams from environment or default.
pub fn get_keep_alive_interval() -> u64 {
env::var("KEEP_ALIVE_INTERVAL")
.map(|val| {
val.parse::<u64>().unwrap_or_else(|e| {
tracing::warn!("Failed to parse KEEP_ALIVE_INTERVAL: {}. Using default.", e);
DEFAULT_KEEP_ALIVE_INTERVAL_MS
})
})
.unwrap_or(DEFAULT_KEEP_ALIVE_INTERVAL_MS)
}

/// Sends a request to the model processing pipeline.
pub async fn send_request(state: &SharedMistralRsState, request: Request) -> Result<()> {
let sender = state
.get_sender()
.context("mistral.rs sender not available.")?;

sender.send(request).await.map_err(|e| e.into())
handle_completion_error(state, e)
}

/// Creates a SSE streamer for chat completions with optional callbacks.
pub fn create_chat_streamer(
pub fn create_streamer(
rx: Receiver<Response>,
state: SharedMistralRsState,
on_chunk: Option<OnChunkCallback>,
on_done: Option<OnDoneCallback>,
) -> Sse<Streamer> {
let store_chunks = on_done.is_some();

let streamer = Streamer {
rx,
done_state: DoneState::Running,
store_chunks,
state,
chunks: Vec::new(),
on_chunk,
on_done,
};

on_chunk: Option<ChatCompletionOnChunkCallback>,
on_done: Option<ChatCompletionOnDoneCallback>,
) -> Sse<ChatCompletionStreamer> {
let streamer = base_create_streamer(rx, state, on_chunk, on_done);
let keep_alive_interval = get_keep_alive_interval();

Sse::new(streamer)
.keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
}

/// Processes non-streaming chat completion responses.
pub async fn process_non_streaming_chat_response(
/// Process non-streaming chat completion responses.
pub async fn process_non_streaming_response(
rx: &mut Receiver<Response>,
state: SharedMistralRsState,
) -> ChatCompletionResponder {
let response = match rx.recv().await {
Some(response) => response,
None => {
let e = anyhow::Error::msg("No response received from the model.");
return handle_chat_completion_error(state, e.into());
}
};

match_responses(state, response)
base_process_non_streaming_response(rx, state, match_responses, handle_error).await
}

/// Matches and processes different types of model responses into appropriate chat completion responses.
Expand Down
Loading
Loading