diff --git a/mistralrs-server-core/src/chat_completion.rs b/mistralrs-server-core/src/chat_completion.rs index 79c6b6709e..36523145a3 100644 --- a/mistralrs-server-core/src/chat_completion.rs +++ b/mistralrs-server-core/src/chat_completion.rs @@ -1,11 +1,11 @@ //! ## Chat Completions functionality and route handler. -use std::{env, error::Error, ops::Deref, pin::Pin, task::Poll, time::Duration}; +use std::{ops::Deref, pin::Pin, task::Poll, time::Duration}; use anyhow::{Context, Result}; use axum::{ extract::{Json, State}, - http::{self, StatusCode}, + http::{self}, response::{ sse::{Event, KeepAlive}, IntoResponse, Sse, @@ -15,21 +15,28 @@ use either::Either; use indexmap::IndexMap; use itertools::Itertools; use mistralrs_core::{ - ChatCompletionChunkResponse, ChatCompletionResponse, Constraint, DrySamplingParams, MistralRs, - NormalRequest, Request, RequestMessage, Response, SamplingParams, - StopTokens as InternalStopTokens, + ChatCompletionChunkResponse, ChatCompletionResponse, Constraint, MistralRs, NormalRequest, + Request, RequestMessage, Response, SamplingParams, }; -use serde::Serialize; use serde_json::Value; -use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::mpsc::{Receiver, Sender}; use crate::{ + completion_core::{ + convert_stop_tokens, get_dry_sampling_params, handle_completion_error, + BaseCompletionResponder, + }, + handler_core::{ + base_process_non_streaming_response, create_response_channel, send_request, + BaseJsonModelError, ErrorToResponse, JsonError, ModelErrorMessage, + }, openai::{ ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent, - ResponseFormat, StopTokens, + ResponseFormat, }, - types::{ExtractedMistralRsState, SharedMistralRsState}, - util, + streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState}, + types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState}, + util::{parse_audio_url, parse_image_url}, }; /// A callback function that processes streaming response chunks before they are sent to the client. @@ -41,16 +48,15 @@ use crate::{ /// ### Examples /// /// ```no_run -/// use mistralrs_server_core::chat_completion::OnChunkCallback; +/// use mistralrs_server_core::chat_completion::ChatCompletionOnChunkCallback; /// -/// let on_chunk: OnChunkCallback = Box::new(|mut chunk| { +/// let on_chunk: ChatCompletionOnChunkCallback = Box::new(|mut chunk| { /// // Log the chunk or modify its content /// println!("Processing chunk: {:?}", chunk); /// chunk /// }); /// ``` -pub type OnChunkCallback = - Box ChatCompletionChunkResponse + Send + Sync>; +pub type ChatCompletionOnChunkCallback = OnChunkCallback; /// A callback function that is executed when the streaming response completes. /// @@ -60,70 +66,26 @@ pub type OnChunkCallback = /// ### Examples /// /// ```no_run -/// use mistralrs_server_core::chat_completion::OnDoneCallback; +/// use mistralrs_server_core::chat_completion::ChatCompletionOnDoneCallback; /// -/// let on_done: OnDoneCallback = Box::new(|chunks| { +/// let on_done: ChatCompletionOnDoneCallback = Box::new(|chunks| { /// println!("Stream completed with {} chunks", chunks.len()); /// // Process all chunks for analytics /// }); /// ``` -pub type OnDoneCallback = Box; - -/// Default buffer size for the response channel used in streaming operations. -/// -/// This constant defines the maximum number of response messages that can be buffered -/// in the channel before backpressure is applied. A larger buffer reduces the likelihood -/// of blocking but uses more memory. -pub const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 10_000; - -/// Default keep-alive interval for Server-Sent Events (SSE) streams in milliseconds. -pub const DEFAULT_KEEP_ALIVE_INTERVAL_MS: u64 = 10_000; - -/// Internal error type for model-related errors with a descriptive message. -/// -/// This struct wraps error messages from the underlying model and implements -/// the standard error traits for proper error handling and display. -#[derive(Debug)] -struct ModelErrorMessage(String); -impl std::fmt::Display for ModelErrorMessage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} -impl std::error::Error for ModelErrorMessage {} - -/// Represents the current state of a streaming response. -enum DoneState { - /// The stream is actively processing and sending response chunks - Running, - /// The stream has finished processing and is about to send the `[DONE]` message - SendingDone, - /// The stream has completed entirely - Done, -} +pub type ChatCompletionOnDoneCallback = OnDoneCallback; /// A streaming response handler. /// /// It processes incoming response chunks from a model and converts them /// into Server-Sent Events (SSE) format for real-time streaming to clients. -pub struct Streamer { - /// Channel receiver for incoming model responses - rx: Receiver, - /// Current state of the streaming operation - done_state: DoneState, - /// Underlying mistral.rs instance - state: SharedMistralRsState, - /// Whether to store chunks for the completion callback - store_chunks: bool, - /// All chunks received during streaming (if `store_chunks` is true) - chunks: Vec, - /// Optional callback to process each chunk before sending - on_chunk: Option, - /// Optional callback to execute when streaming completes - on_done: Option, -} +pub type ChatCompletionStreamer = BaseStreamer< + ChatCompletionChunkResponse, + ChatCompletionOnChunkCallback, + ChatCompletionOnDoneCallback, +>; -impl futures::Stream for Streamer { +impl futures::Stream for ChatCompletionStreamer { type Item = Result; /// Polls the stream for the next Server-Sent Event. @@ -202,61 +164,10 @@ impl futures::Stream for Streamer { } /// Represents different types of chat completion responses. -pub enum ChatCompletionResponder { - /// Server-Sent Events streaming response - Sse(Sse), - /// Complete JSON response for non-streaming requests - Json(ChatCompletionResponse), - /// Model error with partial response data - ModelError(String, ChatCompletionResponse), - /// Internal server error - InternalError(Box), - /// Request validation error - ValidationError(Box), -} - -/// Trait for converting errors to HTTP responses with appropriate status codes. -trait ErrorToResponse: Serialize { - /// Converts the error to an HTTP response with the specified status code. - fn to_response(&self, code: StatusCode) -> axum::response::Response { - let mut r = Json(self).into_response(); - *r.status_mut() = code; - r - } -} - -/// Standard JSON error response structure. -#[derive(Serialize)] -struct JsonError { - message: String, -} - -impl JsonError { - /// Creates a new JSON error with the specified message. - fn new(message: String) -> Self { - Self { message } - } -} -impl ErrorToResponse for JsonError {} - -/// JSON error response structure for model errors. -#[derive(Serialize)] -struct JsonModelError { - message: String, - /// Partial response data that was generated before the error occurred - partial_response: ChatCompletionResponse, -} - -impl JsonModelError { - /// Creates a new JSON model error with message and partial response. - fn new(message: String, partial_response: ChatCompletionResponse) -> Self { - Self { - message, - partial_response, - } - } -} +pub type ChatCompletionResponder = + BaseCompletionResponder; +type JsonModelError = BaseJsonModelError; impl ErrorToResponse for JsonModelError {} impl IntoResponse for ChatCompletionResponder { @@ -291,11 +202,8 @@ pub async fn parse_request( let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed."); MistralRs::maybe_log_request(state.clone(), repr); - let stop_toks = match oairequest.stop_seqs { - Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)), - Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])), - None => None, - }; + let stop_toks = convert_stop_tokens(oairequest.stop_seqs); + let messages = match oairequest.messages { Either::Left(req_messages) => { let mut messages = Vec::new(); @@ -469,7 +377,7 @@ pub async fn parse_request( // Parse images let mut images = Vec::new(); for url_unparsed in image_urls { - let image = util::parse_image_url(&url_unparsed) + let image = parse_image_url(&url_unparsed) .await .context(format!("Failed to parse image resource: {}", url_unparsed))?; images.push(image); @@ -478,7 +386,7 @@ pub async fn parse_request( // Parse audios let mut audios = Vec::new(); for url_unparsed in audio_urls { - let audio = util::parse_audio_url(&url_unparsed) + let audio = parse_audio_url(&url_unparsed) .await .context(format!("Failed to parse audio resource: {}", url_unparsed))?; audios.push(audio); @@ -511,16 +419,12 @@ pub async fn parse_request( } }; - let dry_params = if let Some(dry_multiplier) = oairequest.dry_multiplier { - Some(DrySamplingParams::new_with_defaults( - dry_multiplier, - oairequest.dry_sequence_breakers, - oairequest.dry_base, - oairequest.dry_allowed_length, - )?) - } else { - None - }; + let dry_params = get_dry_sampling_params( + oairequest.dry_multiplier, + oairequest.dry_sequence_breakers, + oairequest.dry_base, + oairequest.dry_allowed_length, + )?; let is_streaming = oairequest.stream.unwrap_or(false); @@ -591,99 +495,48 @@ pub async fn chatcompletions( let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx).await { Ok(x) => x, - Err(e) => return handle_chat_completion_error(state, e.into()), + Err(e) => return handle_error(state, e.into()), }; if let Err(e) = send_request(&state, request).await { - return handle_chat_completion_error(state, e.into()); + return handle_error(state, e.into()); } if is_streaming { - ChatCompletionResponder::Sse(create_chat_streamer(rx, state, None, None)) + ChatCompletionResponder::Sse(create_streamer(rx, state, None, None)) } else { - process_non_streaming_chat_response(&mut rx, state).await + process_non_streaming_response(&mut rx, state).await } } -/// Helper function to handle chat completion errors and logging them. -pub fn handle_chat_completion_error( +/// Handle route / generation errors and logging them. +pub fn handle_error( state: SharedMistralRsState, e: Box, ) -> ChatCompletionResponder { - let e = anyhow::Error::msg(e.to_string()); - MistralRs::maybe_log_error(state, &*e); - ChatCompletionResponder::InternalError(e.into()) -} - -/// Creates a channel for response communication. -pub fn create_response_channel( - buffer_size: Option, -) -> (Sender, Receiver) { - let channel_buffer_size = buffer_size.unwrap_or(DEFAULT_CHANNEL_BUFFER_SIZE); - - channel(channel_buffer_size) -} - -/// Gets the keep-alive interval for SSE streams from environment or default. -pub fn get_keep_alive_interval() -> u64 { - env::var("KEEP_ALIVE_INTERVAL") - .map(|val| { - val.parse::().unwrap_or_else(|e| { - tracing::warn!("Failed to parse KEEP_ALIVE_INTERVAL: {}. Using default.", e); - DEFAULT_KEEP_ALIVE_INTERVAL_MS - }) - }) - .unwrap_or(DEFAULT_KEEP_ALIVE_INTERVAL_MS) -} - -/// Sends a request to the model processing pipeline. -pub async fn send_request(state: &SharedMistralRsState, request: Request) -> Result<()> { - let sender = state - .get_sender() - .context("mistral.rs sender not available.")?; - - sender.send(request).await.map_err(|e| e.into()) + handle_completion_error(state, e) } /// Creates a SSE streamer for chat completions with optional callbacks. -pub fn create_chat_streamer( +pub fn create_streamer( rx: Receiver, state: SharedMistralRsState, - on_chunk: Option, - on_done: Option, -) -> Sse { - let store_chunks = on_done.is_some(); - - let streamer = Streamer { - rx, - done_state: DoneState::Running, - store_chunks, - state, - chunks: Vec::new(), - on_chunk, - on_done, - }; - + on_chunk: Option, + on_done: Option, +) -> Sse { + let streamer = base_create_streamer(rx, state, on_chunk, on_done); let keep_alive_interval = get_keep_alive_interval(); Sse::new(streamer) .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval))) } -/// Processes non-streaming chat completion responses. -pub async fn process_non_streaming_chat_response( +/// Process non-streaming chat completion responses. +pub async fn process_non_streaming_response( rx: &mut Receiver, state: SharedMistralRsState, ) -> ChatCompletionResponder { - let response = match rx.recv().await { - Some(response) => response, - None => { - let e = anyhow::Error::msg("No response received from the model."); - return handle_chat_completion_error(state, e.into()); - } - }; - - match_responses(state, response) + base_process_non_streaming_response(rx, state, match_responses, handle_error).await } /// Matches and processes different types of model responses into appropriate chat completion responses. diff --git a/mistralrs-server-core/src/completion_core.rs b/mistralrs-server-core/src/completion_core.rs new file mode 100644 index 0000000000..fa8949e576 --- /dev/null +++ b/mistralrs-server-core/src/completion_core.rs @@ -0,0 +1,65 @@ +//! Core functionality for completions. + +use std::error::Error; + +use anyhow::Result; +use axum::response::Sse; +use mistralrs_core::{DrySamplingParams, MistralRs, StopTokens as InternalStopTokens}; + +use crate::{openai::StopTokens, types::SharedMistralRsState}; + +/// Generic responder enum for different completion types. +#[derive(Debug)] +pub enum BaseCompletionResponder { + /// Server-Sent Events streaming response + Sse(Sse), + /// Complete JSON response for non-streaming requests + Json(R), + /// Model error with partial response data + ModelError(String, R), + /// Internal server error + InternalError(Box), + /// Request validation error + ValidationError(Box), +} + +/// Generic function to handle completion errors and logging them. +pub(crate) fn handle_completion_error( + state: SharedMistralRsState, + e: Box, +) -> BaseCompletionResponder { + let error = anyhow::Error::msg(e.to_string()); + MistralRs::maybe_log_error(state, &*error); + BaseCompletionResponder::InternalError(error.into()) +} + +/// Helper function to convert from the OpenAI stop tokens to the mistral.rs +/// internal stop tokens. +pub(crate) fn convert_stop_tokens(stop_seqs: Option) -> Option { + match stop_seqs { + Some(StopTokens::Multi(sequences)) => Some(InternalStopTokens::Seqs(sequences)), + Some(StopTokens::Single(sequence)) => Some(InternalStopTokens::Seqs(vec![sequence])), + None => None, + } +} + +/// Helper function to get the dry sampling params. +pub(crate) fn get_dry_sampling_params( + dry_multiplier: Option, + dry_sequence_breakers: Option>, + dry_base: Option, + dry_allowed_length: Option, +) -> Result> { + match dry_multiplier { + Some(multiplier) => { + let params = DrySamplingParams::new_with_defaults( + multiplier, + dry_sequence_breakers, + dry_base, + dry_allowed_length, + )?; + Ok(Some(params)) + } + None => Ok(None), + } +} diff --git a/mistralrs-server-core/src/completions.rs b/mistralrs-server-core/src/completions.rs index 555628e10a..df6b194bf9 100644 --- a/mistralrs-server-core/src/completions.rs +++ b/mistralrs-server-core/src/completions.rs @@ -1,57 +1,95 @@ -use anyhow::Result; +//! ## Completions functionality and route handler. + use std::{ - env, - error::Error, pin::Pin, sync::Arc, task::{Context, Poll}, time::Duration, }; -use tokio::sync::mpsc::{channel, Receiver, Sender}; -use crate::{ - openai::{CompletionRequest, Grammar, StopTokens}, - types::ExtractedMistralRsState, -}; +use anyhow::Result; use axum::{ extract::{Json, State}, - http::{self, StatusCode}, + http::{self}, response::{ sse::{Event, KeepAlive}, IntoResponse, Sse, }, }; use mistralrs_core::{ - CompletionResponse, Constraint, DrySamplingParams, MistralRs, NormalRequest, Request, - RequestMessage, Response, SamplingParams, StopTokens as InternalStopTokens, + CompletionChunkResponse, CompletionResponse, Constraint, MistralRs, NormalRequest, Request, + RequestMessage, Response, SamplingParams, }; -use serde::Serialize; +use tokio::sync::mpsc::{Receiver, Sender}; use tracing::warn; -#[derive(Debug)] -struct ModelErrorMessage(String); -impl std::fmt::Display for ModelErrorMessage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} -impl std::error::Error for ModelErrorMessage {} +use crate::{ + completion_core::{ + convert_stop_tokens, get_dry_sampling_params, handle_completion_error, + BaseCompletionResponder, + }, + handler_core::{ + base_process_non_streaming_response, create_response_channel, send_request, + BaseJsonModelError, ErrorToResponse, JsonError, ModelErrorMessage, + }, + openai::{CompletionRequest, Grammar}, + streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState}, + types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState}, +}; -enum DoneState { - Running, - SendingDone, - Done, -} +/// A callback function that processes streaming response chunks before they are sent to the client. +/// +/// This hook allows modification of each chunk in the streaming response, enabling features like +/// content filtering, transformation, or logging. The callback receives a chunk and must return +/// a (potentially modified) chunk. +/// +/// ### Examples +/// +/// ```no_run +/// use mistralrs_server_core::completions::CompletionOnChunkCallback; +/// +/// let on_chunk: CompletionOnChunkCallback = Box::new(|mut chunk| { +/// // Log the chunk or modify its content +/// println!("Processing chunk: {:?}", chunk); +/// chunk +/// }); +/// ``` +pub type CompletionOnChunkCallback = OnChunkCallback; -pub struct Streamer { - rx: Receiver, - done_state: DoneState, - state: Arc, -} +/// A callback function that is executed when the streaming response completes. +/// +/// This hook receives all chunks that were streamed during the response, allowing for +/// post-processing, analytics, or cleanup operations after the stream finishes. +/// +/// ### Examples +/// +/// ```no_run +/// use mistralrs_server_core::completions::CompletionOnDoneCallback; +/// +/// let on_done: CompletionOnDoneCallback = Box::new(|chunks| { +/// println!("Stream completed with {} chunks", chunks.len()); +/// // Process all chunks for analytics +/// }); +/// ``` +pub type CompletionOnDoneCallback = OnDoneCallback; -impl futures::Stream for Streamer { +/// A streaming response handler. +/// +/// It processes incoming response chunks from a model and converts them +/// into Server-Sent Events (SSE) format for real-time streaming to clients. +pub type CompletionStreamer = + BaseStreamer; + +impl futures::Stream for CompletionStreamer { type Item = Result; + /// Polls the stream for the next Server-Sent Event. + /// + /// This method implements the core streaming logic: + /// 1. Handles stream completion by sending `[DONE]` and executing callbacks + /// 2. Processes incoming model responses and converts them to SSE events + /// 3. Applies chunk modifications if a callback is provided + /// 4. Stores chunks if completion callback is configured fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.done_state { DoneState::SendingDone => { @@ -61,6 +99,9 @@ impl futures::Stream for Streamer { return Poll::Ready(Some(Ok(Event::default().data("[DONE]")))); } DoneState::Done => { + if let Some(on_done) = &self.on_done { + on_done(&self.chunks); + } return Poll::Ready(None); } DoneState::Running => (), @@ -84,12 +125,21 @@ impl futures::Stream for Streamer { MistralRs::maybe_log_error(self.state.clone(), &*e); Poll::Ready(Some(Ok(Event::default().data(e.to_string())))) } - Response::CompletionChunk(response) => { + Response::CompletionChunk(mut response) => { if response.choices.iter().all(|x| x.finish_reason.is_some()) { - // Done now, just need to send the [DONE] self.done_state = DoneState::SendingDone; } + // Done now, just need to send the [DONE] MistralRs::maybe_log_response(self.state.clone(), &response); + + if let Some(on_chunk) = &self.on_chunk { + response = on_chunk(response); + } + + if self.store_chunks { + self.chunks.push(response.clone()); + } + Poll::Ready(Some(Event::default().json_data(response))) } Response::Done(_) => unreachable!(), @@ -105,52 +155,15 @@ impl futures::Stream for Streamer { } } -pub enum CompletionResponder { - Sse(Sse), - Json(CompletionResponse), - ModelError(String, CompletionResponse), - InternalError(Box), - ValidationError(Box), -} - -trait ErrorToResponse: Serialize { - fn to_response(&self, code: StatusCode) -> axum::response::Response { - let mut r = Json(self).into_response(); - *r.status_mut() = code; - r - } -} - -#[derive(Serialize)] -struct JsonError { - message: String, -} - -impl JsonError { - fn new(message: String) -> Self { - Self { message } - } -} -impl ErrorToResponse for JsonError {} - -#[derive(Serialize)] -struct JsonModelError { - message: String, - partial_response: CompletionResponse, -} - -impl JsonModelError { - fn new(message: String, partial_response: CompletionResponse) -> Self { - Self { - message, - partial_response, - } - } -} +/// Represents different types of completion responses. +pub type CompletionResponder = BaseCompletionResponder; +/// JSON error response structure for model errors. +type JsonModelError = BaseJsonModelError; impl ErrorToResponse for JsonModelError {} impl IntoResponse for CompletionResponder { + /// Converts the completion responder into an HTTP response. fn into_response(self) -> axum::response::Response { match self { CompletionResponder::Sse(s) => s.into_response(), @@ -167,7 +180,11 @@ impl IntoResponse for CompletionResponder { } } -fn parse_request( +/// Parses and validates a completion request. +/// +/// This function transforms an OpenAI-compatible completion request into the +/// request format used by mistral.rs. +pub fn parse_request( oairequest: CompletionRequest, state: Arc, tx: Sender, @@ -175,11 +192,7 @@ fn parse_request( let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed."); MistralRs::maybe_log_request(state.clone(), repr); - let stop_toks = match oairequest.stop_seqs { - Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)), - Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])), - None => None, - }; + let stop_toks = convert_stop_tokens(oairequest.stop_seqs); if oairequest.logprobs.is_some() { warn!("Completion requests do not support logprobs."); @@ -187,16 +200,13 @@ fn parse_request( let is_streaming = oairequest.stream.unwrap_or(false); - let dry_params = if let Some(dry_multiplier) = oairequest.dry_multiplier { - Some(DrySamplingParams::new_with_defaults( - dry_multiplier, - oairequest.dry_sequence_breakers, - oairequest.dry_base, - oairequest.dry_allowed_length, - )?) - } else { - None - }; + let dry_params = get_dry_sampling_params( + oairequest.dry_multiplier, + oairequest.dry_sequence_breakers, + oairequest.dry_base, + oairequest.dry_allowed_length, + )?; + Ok(( Request::Normal(Box::new(NormalRequest { id: state.next_request_id(), @@ -240,6 +250,7 @@ fn parse_request( )) } +/// OpenAI-compatible completions endpoint handler. #[utoipa::path( post, tag = "Mistral.rs", @@ -247,80 +258,81 @@ fn parse_request( request_body = CompletionRequest, responses((status = 200, description = "Completions")) )] - pub async fn completions( State(state): ExtractedMistralRsState, Json(oairequest): Json, ) -> CompletionResponder { - let (tx, mut rx) = channel(10_000); - if oairequest.logprobs.is_some() { - return CompletionResponder::ValidationError( - "Completion requests do not support logprobs.".into(), - ); - } + let (tx, mut rx) = create_response_channel(None); let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx) { Ok(x) => x, - Err(e) => { - let e = anyhow::Error::msg(e.to_string()); - MistralRs::maybe_log_error(state, &*e); - return CompletionResponder::InternalError(e.into()); - } + Err(e) => return handle_error(state, e.into()), }; - let sender = state.get_sender().unwrap(); - if let Err(e) = sender.send(request).await { - let e = anyhow::Error::msg(e.to_string()); - MistralRs::maybe_log_error(state, &*e); - return CompletionResponder::InternalError(e.into()); + if let Err(e) = send_request(&state, request).await { + return handle_error(state, e.into()); } if is_streaming { - let streamer = Streamer { - rx, - done_state: DoneState::Running, - state, - }; - - let keep_alive_interval = env::var("KEEP_ALIVE_INTERVAL") - .map(|val| val.parse::().unwrap_or(10000)) - .unwrap_or(10000); - CompletionResponder::Sse( - Sse::new(streamer) - .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval))), - ) + CompletionResponder::Sse(create_streamer(rx, state, None, None)) } else { - let response = match rx.recv().await { - Some(response) => response, - None => { - let e = anyhow::Error::msg("No response received from the model."); - MistralRs::maybe_log_error(state, &*e); - return CompletionResponder::InternalError(e.into()); - } - }; + process_non_streaming_response(&mut rx, state).await + } +} - match response { - Response::InternalError(e) => { - MistralRs::maybe_log_error(state, &*e); - CompletionResponder::InternalError(e) - } - Response::CompletionModelError(msg, response) => { - MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string())); - MistralRs::maybe_log_response(state, &response); - CompletionResponder::ModelError(msg, response) - } - Response::ValidationError(e) => CompletionResponder::ValidationError(e), - Response::CompletionDone(response) => { - MistralRs::maybe_log_response(state, &response); - CompletionResponder::Json(response) - } - Response::CompletionChunk(_) => unreachable!(), - Response::Chunk(_) => unreachable!(), - Response::Done(_) => unreachable!(), - Response::ModelError(_, _) => unreachable!(), - Response::ImageGeneration(_) => unreachable!(), - Response::Speech { .. } => unreachable!(), - Response::Raw { .. } => unreachable!(), +/// Handle route / generation errors and logging them. +pub fn handle_error( + state: SharedMistralRsState, + e: Box, +) -> CompletionResponder { + handle_completion_error(state, e) +} + +/// Creates a SSE streamer for chat completions with optional callbacks. +pub fn create_streamer( + rx: Receiver, + state: SharedMistralRsState, + on_chunk: Option, + on_done: Option, +) -> Sse { + let streamer = base_create_streamer(rx, state, on_chunk, on_done); + let keep_alive_interval = get_keep_alive_interval(); + + Sse::new(streamer) + .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval))) +} + +/// Process non-streaming completion responses. +pub async fn process_non_streaming_response( + rx: &mut Receiver, + state: SharedMistralRsState, +) -> CompletionResponder { + base_process_non_streaming_response(rx, state, match_responses, handle_error).await +} + +/// Matches and processes different types of model responses into appropriate completion responses. +pub fn match_responses(state: SharedMistralRsState, response: Response) -> CompletionResponder { + match response { + Response::InternalError(e) => { + MistralRs::maybe_log_error(state, &*e); + CompletionResponder::InternalError(e) + } + Response::CompletionModelError(msg, response) => { + MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string())); + MistralRs::maybe_log_response(state, &response); + CompletionResponder::ModelError(msg, response) + } + Response::ValidationError(e) => CompletionResponder::ValidationError(e), + Response::CompletionDone(response) => { + MistralRs::maybe_log_response(state, &response); + CompletionResponder::Json(response) } + Response::CompletionChunk(_) => unreachable!(), + Response::Chunk(_) => unreachable!(), + Response::Done(_) => unreachable!(), + Response::ModelError(_, _) => unreachable!(), + Response::ImageGeneration(_) => unreachable!(), + Response::Speech { .. } => unreachable!(), + Response::Raw { .. } => unreachable!(), } } diff --git a/mistralrs-server-core/src/handler_core.rs b/mistralrs-server-core/src/handler_core.rs new file mode 100644 index 0000000000..3dce91c608 --- /dev/null +++ b/mistralrs-server-core/src/handler_core.rs @@ -0,0 +1,118 @@ +//! Core functionality for handlers. + +use anyhow::{Context, Result}; +use axum::{extract::Json, http::StatusCode, response::IntoResponse}; +use mistralrs_core::{Request, Response}; +use serde::Serialize; +use tokio::sync::mpsc::{channel, Receiver, Sender}; + +use crate::types::SharedMistralRsState; + +/// Default buffer size for the response channel used in streaming operations. +/// +/// This constant defines the maximum number of response messages that can be buffered +/// in the channel before backpressure is applied. A larger buffer reduces the likelihood +/// of blocking but uses more memory. +pub const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 10_000; + +/// Trait for converting errors to HTTP responses with appropriate status codes. +pub(crate) trait ErrorToResponse: Serialize { + /// Converts the error to an HTTP response with the specified status code. + fn to_response(&self, code: StatusCode) -> axum::response::Response { + let mut response = Json(self).into_response(); + *response.status_mut() = code; + response + } +} + +/// Standard JSON error response structure. +#[derive(Serialize, Debug)] +pub(crate) struct JsonError { + pub(crate) message: String, +} + +impl JsonError { + /// Creates a new JSON error with the specified message. + pub(crate) fn new(message: String) -> Self { + Self { message } + } +} + +impl std::fmt::Display for JsonError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +impl std::error::Error for JsonError {} +impl ErrorToResponse for JsonError {} + +/// Internal error type for model-related errors with a descriptive message. +/// +/// This struct wraps error messages from the underlying model and implements +/// the standard error traits for proper error handling and display. +#[derive(Debug)] +pub(crate) struct ModelErrorMessage(pub(crate) String); + +impl std::fmt::Display for ModelErrorMessage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::error::Error for ModelErrorMessage {} + +/// Generic JSON error response structure +#[derive(Serialize, Debug)] +pub(crate) struct BaseJsonModelError { + pub(crate) message: String, + pub(crate) partial_response: T, +} + +impl BaseJsonModelError { + pub(crate) fn new(message: String, partial_response: T) -> Self { + Self { + message, + partial_response, + } + } +} + +/// Creates a channel for response communication. +pub fn create_response_channel( + buffer_size: Option, +) -> (Sender, Receiver) { + let channel_buffer_size = buffer_size.unwrap_or(DEFAULT_CHANNEL_BUFFER_SIZE); + channel(channel_buffer_size) +} + +/// Sends a request to the model processing pipeline. +pub async fn send_request(state: &SharedMistralRsState, request: Request) -> Result<()> { + let sender = state + .get_sender() + .context("mistral.rs sender not available.")?; + + sender + .send(request) + .await + .context("Failed to send request to model pipeline") +} + +/// Generic function to process non-streaming responses. +pub(crate) async fn base_process_non_streaming_response( + rx: &mut Receiver, + state: SharedMistralRsState, + match_fn: fn(SharedMistralRsState, Response) -> R, + error_handler: fn( + SharedMistralRsState, + Box, + ) -> R, +) -> R { + match rx.recv().await { + Some(response) => match_fn(state, response), + None => { + let error = anyhow::Error::msg("No response received from the model."); + error_handler(state, error.into()) + } + } +} diff --git a/mistralrs-server-core/src/image_generation.rs b/mistralrs-server-core/src/image_generation.rs index 1f0f2b4b63..2521c1ac8e 100644 --- a/mistralrs-server-core/src/image_generation.rs +++ b/mistralrs-server-core/src/image_generation.rs @@ -1,46 +1,37 @@ -use anyhow::Result; +//! ## Image generation functionality and route handler. + use std::{error::Error, sync::Arc}; -use tokio::sync::mpsc::{channel, Sender}; -use crate::{openai::ImageGenerationRequest, types::ExtractedMistralRsState}; +use anyhow::Result; use axum::{ extract::{Json, State}, - http::{self, StatusCode}, + http::{self}, response::IntoResponse, }; use mistralrs_core::{ Constraint, DiffusionGenerationParams, ImageGenerationResponse, MistralRs, NormalRequest, Request, RequestMessage, Response, SamplingParams, }; -use serde::Serialize; +use tokio::sync::mpsc::{Receiver, Sender}; +use crate::{ + handler_core::{ + base_process_non_streaming_response, create_response_channel, send_request, + ErrorToResponse, JsonError, + }, + openai::ImageGenerationRequest, + types::{ExtractedMistralRsState, SharedMistralRsState}, +}; + +/// Represents different types of image generation responses. pub enum ImageGenerationResponder { Json(ImageGenerationResponse), InternalError(Box), ValidationError(Box), } -trait ErrorToResponse: Serialize { - fn to_response(&self, code: StatusCode) -> axum::response::Response { - let mut r = Json(self).into_response(); - *r.status_mut() = code; - r - } -} - -#[derive(Serialize)] -struct JsonError { - message: String, -} - -impl JsonError { - fn new(message: String) -> Self { - Self { message } - } -} -impl ErrorToResponse for JsonError {} - impl IntoResponse for ImageGenerationResponder { + /// Converts the image generation responder into an HTTP response. fn into_response(self) -> axum::response::Response { match self { ImageGenerationResponder::Json(s) => Json(s).into_response(), @@ -54,7 +45,11 @@ impl IntoResponse for ImageGenerationResponder { } } -fn parse_request( +/// Parses and validates a image generation request. +/// +/// This function transforms a image generation request into the +/// request format used by mistral.rs. +pub fn parse_request( oairequest: ImageGenerationRequest, state: Arc, tx: Sender, @@ -86,6 +81,7 @@ fn parse_request( }))) } +/// Image generation endpoint handler. #[utoipa::path( post, tag = "Mistral.rs", @@ -93,38 +89,47 @@ fn parse_request( request_body = ImageGenerationRequest, responses((status = 200, description = "Image generation")) )] - pub async fn image_generation( State(state): ExtractedMistralRsState, Json(oairequest): Json, ) -> ImageGenerationResponder { - let (tx, mut rx) = channel(10_000); + let (tx, mut rx) = create_response_channel(None); let request = match parse_request(oairequest, state.clone(), tx) { Ok(x) => x, - Err(e) => { - let e = anyhow::Error::msg(e.to_string()); - MistralRs::maybe_log_error(state, &*e); - return ImageGenerationResponder::InternalError(e.into()); - } + Err(e) => return handle_error(state, e.into()), }; - let sender = state.get_sender().unwrap(); - if let Err(e) = sender.send(request).await { - let e = anyhow::Error::msg(e.to_string()); - MistralRs::maybe_log_error(state, &*e); - return ImageGenerationResponder::InternalError(e.into()); + if let Err(e) = send_request(&state, request).await { + return handle_error(state, e.into()); } - let response = match rx.recv().await { - Some(response) => response, - None => { - let e = anyhow::Error::msg("No response received from the model."); - MistralRs::maybe_log_error(state, &*e); - return ImageGenerationResponder::InternalError(e.into()); - } - }; + process_non_streaming_response(&mut rx, state).await +} +/// Helper function to handle image generation errors and logging them. +pub fn handle_error( + state: SharedMistralRsState, + e: Box, +) -> ImageGenerationResponder { + let e = anyhow::Error::msg(e.to_string()); + MistralRs::maybe_log_error(state, &*e); + ImageGenerationResponder::InternalError(e.into()) +} + +/// Process non-streaming image generation responses. +pub async fn process_non_streaming_response( + rx: &mut Receiver, + state: SharedMistralRsState, +) -> ImageGenerationResponder { + base_process_non_streaming_response(rx, state, match_responses, handle_error).await +} + +/// Matches and processes different types of model responses into appropriate image generation responses. +pub fn match_responses( + state: SharedMistralRsState, + response: Response, +) -> ImageGenerationResponder { match response { Response::InternalError(e) => { MistralRs::maybe_log_error(state, &*e); diff --git a/mistralrs-server-core/src/lib.rs b/mistralrs-server-core/src/lib.rs index 7d3c247195..0c80a3495e 100644 --- a/mistralrs-server-core/src/lib.rs +++ b/mistralrs-server-core/src/lib.rs @@ -10,26 +10,26 @@ //! 2. Hook into the mistral.rs server lifecycle. //! //! ### Example -//! ```ignore +//! ```no_run //! use std::sync::Arc; //! //! use axum::{ -//! Json, Router, //! extract::State, //! routing::{get, post}, +//! Json, Router, //! }; //! use utoipa::OpenApi; //! use utoipa_swagger_ui::SwaggerUi; //! -//! use mistralrs::{ -//! AutoDeviceMapParams, ChatCompletionChunkResponse, ModelDType, ModelSelected, initialize_logging, +//! use mistralrs_core::{ +//! initialize_logging, AutoDeviceMapParams, ChatCompletionChunkResponse, ModelDType, ModelSelected, //! }; //! use mistralrs_server_core::{ //! chat_completion::{ -//! ChatCompletionResponder, OnChunkCallback, OnDoneCallback, create_chat_streamer, -//! create_response_channel, handle_chat_completion_error, parse_request, -//! process_non_streaming_chat_response, send_request, +//! create_streamer, handle_error, parse_request, process_non_streaming_response, +//! ChatCompletionOnChunkCallback, ChatCompletionOnDoneCallback, ChatCompletionResponder, //! }, +//! handler_core::{create_response_channel, send_request}, //! mistralrs_for_server_builder::MistralRsForServerBuilder, //! mistralrs_server_router_builder::MistralRsServerRouterBuilder, //! openai::ChatCompletionRequest, @@ -60,7 +60,7 @@ //! #[tokio::main] //! async fn main() { //! initialize_logging(); -//! +//! //! let plain_model_id = String::from("meta-llama/Llama-3.2-1B-Instruct"); //! let tokenizer_json = None; //! let arch = None; @@ -145,11 +145,11 @@ //! } //! //! #[utoipa::path( -//! post, -//! tag = "Custom", -//! path = "/chat", -//! request_body = ChatCompletionRequest, -//! responses((status = 200, description = "Chat completions")) +//! post, +//! tag = "Custom", +//! path = "/chat", +//! request_body = ChatCompletionRequest, +//! responses((status = 200, description = "Chat completions")) //! )] //! pub async fn custom_chat( //! State(state): State>, @@ -158,42 +158,43 @@ //! let mistralrs_state = state.mistralrs_state.clone(); //! let (tx, mut rx) = create_response_channel(None); //! -//! let (request, is_streaming) = match parse_request(oai_request, mistralrs_state.clone(), tx).await -//! { -//! Ok(x) => x, -//! Err(e) => return handle_chat_completion_error(mistralrs_state, e.into()), -//! }; +//! let (request, is_streaming) = +//! match parse_request(oai_request, mistralrs_state.clone(), tx).await { +//! Ok(x) => x, +//! Err(e) => return handle_error(mistralrs_state, e.into()), +//! }; //! //! dbg!(request.clone()); //! //! if let Err(e) = send_request(&mistralrs_state, request).await { -//! return handle_chat_completion_error(mistralrs_state, e.into()); +//! return handle_error(mistralrs_state, e.into()); //! } //! //! if is_streaming { //! let db_fn = state.db_create; //! -//! let on_chunk: OnChunkCallback = Box::new(move |mut chunk: ChatCompletionChunkResponse| { -//! dbg!(&chunk); +//! let on_chunk: ChatCompletionOnChunkCallback = +//! Box::new(move |mut chunk: ChatCompletionChunkResponse| { +//! dbg!(&chunk); //! -//! if let Some(original_content) = &chunk.choices[0].delta.content { -//! chunk.choices[0].delta.content = Some(format!("CHANGED! {}", original_content)); -//! } +//! if let Some(original_content) = &chunk.choices[0].delta.content { +//! chunk.choices[0].delta.content = Some(format!("CHANGED! {}", original_content)); +//! } //! -//! chunk.clone() -//! }); +//! chunk.clone() +//! }); //! -//! let on_done: OnDoneCallback = Box::new(move |chunks: &[ChatCompletionChunkResponse]| { -//! dbg!(chunks); -//! (db_fn)(); -//! }); +//! let on_done: ChatCompletionOnDoneCallback = +//! Box::new(move |chunks: &[ChatCompletionChunkResponse]| { +//! dbg!(chunks); +//! (db_fn)(); +//! }); //! -//! let streamer = -//! create_chat_streamer(rx, mistralrs_state.clone(), Some(on_chunk), Some(on_done)); +//! let streamer = create_streamer(rx, mistralrs_state.clone(), Some(on_chunk), Some(on_done)); //! //! ChatCompletionResponder::Sse(streamer) //! } else { -//! let response = process_non_streaming_chat_response(&mut rx, mistralrs_state.clone()).await; +//! let response = process_non_streaming_response(&mut rx, mistralrs_state.clone()).await; //! //! match &response { //! ChatCompletionResponder::Json(json_response) => { @@ -215,13 +216,16 @@ //! ``` pub mod chat_completion; -mod completions; +mod completion_core; +pub mod completions; +pub mod handler_core; mod handlers; -mod image_generation; +pub mod image_generation; pub mod mistralrs_for_server_builder; pub mod mistralrs_server_router_builder; pub mod openai; pub mod openapi_doc; -mod speech_generation; +pub mod speech_generation; +pub mod streaming; pub mod types; pub mod util; diff --git a/mistralrs-server-core/src/speech_generation.rs b/mistralrs-server-core/src/speech_generation.rs index d430dc11b7..8fb58a4320 100644 --- a/mistralrs-server-core/src/speech_generation.rs +++ b/mistralrs-server-core/src/speech_generation.rs @@ -1,8 +1,8 @@ -use anyhow::Result; +//! ## Speech generation functionality and route handler. + use std::{error::Error, sync::Arc}; -use tokio::sync::mpsc::{channel, Sender}; -use crate::openai::{AudioResponseFormat, SpeechGenerationRequest}; +use anyhow::Result; use axum::{ body::Bytes, extract::{Json, State}, @@ -13,44 +13,23 @@ use mistralrs_core::{ speech_utils::{self, Sample}, Constraint, MistralRs, NormalRequest, Request, RequestMessage, Response, SamplingParams, }; -use serde::Serialize; +use tokio::sync::mpsc::{Receiver, Sender}; + +use crate::{ + handler_core::{create_response_channel, send_request, ErrorToResponse, JsonError}, + openai::{AudioResponseFormat, SpeechGenerationRequest}, + types::SharedMistralRsState, +}; +/// Represents different types of speech generation responses. pub enum SpeechGenerationResponder { InternalError(Box), ValidationError(Box), RawResponse(axum::response::Response), } -trait ErrorToResponse: Serialize { - fn to_response(&self, code: StatusCode) -> axum::response::Response { - let mut r = Json(self).into_response(); - *r.status_mut() = code; - r - } -} - -#[derive(Serialize, Debug)] -struct JsonError { - message: String, -} - -impl JsonError { - fn new(message: String) -> Self { - Self { message } - } -} - -impl std::fmt::Display for JsonError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.message) - } -} - -impl std::error::Error for JsonError {} - -impl ErrorToResponse for JsonError {} - impl IntoResponse for SpeechGenerationResponder { + /// Converts the speech generation responder into an HTTP response. fn into_response(self) -> axum::response::Response { match self { SpeechGenerationResponder::InternalError(e) => { @@ -64,7 +43,11 @@ impl IntoResponse for SpeechGenerationResponder { } } -fn parse_request( +/// Parses and validates a speech generation request. +/// +/// This function transforms a speech generation request into the +/// request format used by mistral.rs. +pub fn parse_request( oairequest: SpeechGenerationRequest, state: Arc, tx: Sender, @@ -93,6 +76,7 @@ fn parse_request( Ok((request, oairequest.response_format)) } +/// Speech generation endpoint handler. #[utoipa::path( post, tag = "Mistral.rs", @@ -104,15 +88,11 @@ pub async fn speech_generation( State(state): State>, Json(oairequest): Json, ) -> SpeechGenerationResponder { - let (tx, mut rx) = channel(10_000); + let (tx, mut rx) = create_response_channel(None); let (request, response_format) = match parse_request(oairequest, state.clone(), tx) { Ok(x) => x, - Err(e) => { - let e = anyhow::Error::msg(e.to_string()); - MistralRs::maybe_log_error(state, &*e); - return SpeechGenerationResponder::InternalError(e.into()); - } + Err(e) => return handle_error(state, e.into()), }; // Validate response format here @@ -125,23 +105,46 @@ pub async fn speech_generation( ))); } - let sender = state.get_sender().unwrap(); - - if let Err(e) = sender.send(request).await { - let e = anyhow::Error::msg(e.to_string()); - MistralRs::maybe_log_error(state, &*e); - return SpeechGenerationResponder::InternalError(e.into()); + if let Err(e) = send_request(&state, request).await { + return handle_error(state, e.into()); } + process_non_streaming_response(&mut rx, state, response_format).await +} + +/// Helper function to handle speech generation errors and logging them. +pub fn handle_error( + state: SharedMistralRsState, + e: Box, +) -> SpeechGenerationResponder { + let e = anyhow::Error::msg(e.to_string()); + MistralRs::maybe_log_error(state, &*e); + SpeechGenerationResponder::InternalError(e.into()) +} + +/// Process non-streaming speech generation responses. +pub async fn process_non_streaming_response( + rx: &mut Receiver, + state: SharedMistralRsState, + response_format: AudioResponseFormat, +) -> SpeechGenerationResponder { let response = match rx.recv().await { Some(response) => response, None => { let e = anyhow::Error::msg("No response received from the model."); - MistralRs::maybe_log_error(state, &*e); - return SpeechGenerationResponder::InternalError(e.into()); + return handle_error(state, e.into()); } }; + match_responses(state, response, response_format) +} + +/// Matches and processes different types of model responses into appropriate speech generation responses. +pub fn match_responses( + state: SharedMistralRsState, + response: Response, + response_format: AudioResponseFormat, +) -> SpeechGenerationResponder { match response { Response::InternalError(e) => { MistralRs::maybe_log_error(state, &*e); diff --git a/mistralrs-server-core/src/streaming.rs b/mistralrs-server-core/src/streaming.rs new file mode 100644 index 0000000000..8250252a88 --- /dev/null +++ b/mistralrs-server-core/src/streaming.rs @@ -0,0 +1,74 @@ +//! SSE streaming utilities. + +use std::env; + +use mistralrs_core::Response; +use tokio::sync::mpsc::Receiver; + +use crate::types::SharedMistralRsState; + +/// Default keep-alive interval for Server-Sent Events (SSE) streams in milliseconds. +pub const DEFAULT_KEEP_ALIVE_INTERVAL_MS: u64 = 10_000; + +/// Represents the current state of a streaming response. +pub enum DoneState { + /// The stream is actively processing and sending response chunks + Running, + /// The stream has finished processing and is about to send the `[DONE]` message + SendingDone, + /// The stream has completed entirely + Done, +} + +/// A streaming response handler. +/// +/// It processes incoming response chunks from a model and converts them +/// into Server-Sent Events (SSE) format for real-time streaming to clients. +pub struct BaseStreamer { + /// Channel receiver for incoming model responses + pub rx: Receiver, + /// Current state of the streaming operation + pub done_state: DoneState, + /// Underlying mistral.rs instance + pub state: SharedMistralRsState, + /// Whether to store chunks for the completion callback + pub store_chunks: bool, + /// All chunks received during streaming (if `store_chunks` is true) + pub chunks: Vec, + /// Optional callback to process each chunk before sending + pub on_chunk: Option, + /// Optional callback to execute when streaming completes + pub on_done: Option, +} + +/// Generic function to create a SSE streamer with optional callbacks. +pub(crate) fn base_create_streamer( + rx: Receiver, + state: SharedMistralRsState, + on_chunk: Option, + on_done: Option, +) -> BaseStreamer { + let store_chunks = on_done.is_some(); + + BaseStreamer { + rx, + done_state: DoneState::Running, + store_chunks, + state, + chunks: Vec::new(), + on_chunk, + on_done, + } +} + +/// Gets the keep-alive interval for SSE streams from environment or default. +pub fn get_keep_alive_interval() -> u64 { + env::var("KEEP_ALIVE_INTERVAL") + .map(|val| { + val.parse::().unwrap_or_else(|e| { + tracing::warn!("Failed to parse KEEP_ALIVE_INTERVAL: {}. Using default.", e); + DEFAULT_KEEP_ALIVE_INTERVAL_MS + }) + }) + .unwrap_or(DEFAULT_KEEP_ALIVE_INTERVAL_MS) +} diff --git a/mistralrs-server-core/src/types.rs b/mistralrs-server-core/src/types.rs index 99d4173eb0..471c26def8 100644 --- a/mistralrs-server-core/src/types.rs +++ b/mistralrs-server-core/src/types.rs @@ -12,3 +12,41 @@ pub type SharedMistralRsState = Arc; pub type ExtractedMistralRsState = State; pub(crate) type LoadedPipeline = Arc>; + +/// A callback function that processes streaming response chunks before they are sent to the client. +/// +/// This hook allows modification of each chunk in the streaming response, enabling features like +/// content filtering, transformation, or logging. The callback receives a chunk and must return +/// a (potentially modified) chunk. +/// +/// ### Examples +/// +/// ```no_run +/// use mistralrs_core::ChatCompletionChunkResponse; +/// use mistralrs_server_core::types::OnChunkCallback; +/// +/// let on_chunk: OnChunkCallback = Box::new(|mut chunk| { +/// // Log the chunk or modify its content +/// println!("Processing chunk: {:?}", chunk); +/// chunk +/// }); +/// ``` +pub type OnChunkCallback = Box R + Send + Sync>; + +/// A callback function that is executed when the streaming response completes. +/// +/// This hook receives all chunks that were streamed during the response, allowing for +/// post-processing, analytics, or cleanup operations after the stream finishes. +/// +/// ### Examples +/// +/// ```no_run +/// use mistralrs_core::ChatCompletionChunkResponse; +/// use mistralrs_server_core::types::OnDoneCallback; +/// +/// let on_done: OnDoneCallback = Box::new(|chunks| { +/// println!("Stream completed with {} chunks", chunks.len()); +/// // Process all chunks for analytics +/// }); +/// ``` +pub type OnDoneCallback = Box; diff --git a/mistralrs-server/src/mcp_server.rs b/mistralrs-server/src/mcp_server.rs index 1d6b61f23b..50eb441689 100644 --- a/mistralrs-server/src/mcp_server.rs +++ b/mistralrs-server/src/mcp_server.rs @@ -8,7 +8,7 @@ use std::sync::Arc; use tokio::net::TcpListener; use mistralrs_server_core::{ - chat_completion::{create_response_channel, parse_request}, + chat_completion::parse_request, handler_core::create_response_channel, types::SharedMistralRsState, }; @@ -164,7 +164,7 @@ impl McpTool for ChatTool { .await .map_err(|e| CallToolError::new(io::Error::other(e.to_string())))?; - mistralrs_server_core::chat_completion::send_request(state, request) + mistralrs_server_core::handler_core::send_request(state, request) .await .map_err(|e| CallToolError::new(io::Error::other(e.to_string())))?;