@@ -23,17 +23,18 @@ use serde_json::Value;
2323use tokio:: sync:: mpsc:: { Receiver , Sender } ;
2424
2525use crate :: {
26- completion_base:: { base_handle_completion_error, BaseCompletionResponder } ,
26+ completion_base:: {
27+ base_handle_completion_error, base_process_non_streaming_response, create_response_channel,
28+ send_model_request, BaseCompletionResponder , BaseJsonModelError , ErrorToResponse ,
29+ JsonError , ModelErrorMessage ,
30+ } ,
2731 openai:: {
2832 ChatCompletionRequest , Grammar , JsonSchemaResponseFormat , MessageInnerContent ,
2933 ResponseFormat , StopTokens ,
3034 } ,
31- streaming:: { get_keep_alive_interval, BaseStreamer , DoneState } ,
35+ streaming:: { base_create_streamer , get_keep_alive_interval, BaseStreamer , DoneState } ,
3236 types:: { ExtractedMistralRsState , SharedMistralRsState } ,
33- util:: {
34- create_response_channel, parse_image_url, send_model_request, BaseJsonModelError ,
35- ErrorToResponse , JsonError , ModelErrorMessage ,
36- } ,
37+ util:: parse_image_url,
3738} ;
3839
3940/// A callback function that processes streaming response chunks before they are sent to the client.
@@ -53,7 +54,7 @@ use crate::{
5354/// chunk
5455/// });
5556/// ```
56- pub type OnChunkCallback =
57+ pub type ChatCompletionOnChunkCallback =
5758 Box < dyn Fn ( ChatCompletionChunkResponse ) -> ChatCompletionChunkResponse + Send + Sync > ;
5859
5960/// A callback function that is executed when the streaming response completes.
@@ -71,15 +72,19 @@ pub type OnChunkCallback =
7172/// // Process all chunks for analytics
7273/// });
7374/// ```
74- pub type OnDoneCallback = Box < dyn Fn ( & [ ChatCompletionChunkResponse ] ) + Send + Sync > ;
75+ pub type ChatCompletionOnDoneCallback = Box < dyn Fn ( & [ ChatCompletionChunkResponse ] ) + Send + Sync > ;
7576
7677/// A streaming response handler.
7778///
7879/// It processes incoming response chunks from a model and converts them
7980/// into Server-Sent Events (SSE) format for real-time streaming to clients.
80- pub type Streamer = BaseStreamer < ChatCompletionChunkResponse , OnChunkCallback , OnDoneCallback > ;
81+ pub type ChatCompletionStreamer = BaseStreamer <
82+ ChatCompletionChunkResponse ,
83+ ChatCompletionOnChunkCallback ,
84+ ChatCompletionOnDoneCallback ,
85+ > ;
8186
82- impl futures:: Stream for Streamer {
87+ impl futures:: Stream for ChatCompletionStreamer {
8388 type Item = Result < Event , axum:: Error > ;
8489
8590 /// Polls the stream for the next Server-Sent Event.
@@ -158,7 +163,8 @@ impl futures::Stream for Streamer {
158163}
159164
160165/// Represents different types of chat completion responses.
161- pub type ChatCompletionResponder = BaseCompletionResponder < ChatCompletionResponse , Streamer > ;
166+ pub type ChatCompletionResponder =
167+ BaseCompletionResponder < ChatCompletionResponse , ChatCompletionStreamer > ;
162168
163169type JsonModelError = BaseJsonModelError < ChatCompletionResponse > ;
164170impl ErrorToResponse for JsonModelError { }
@@ -475,21 +481,10 @@ pub fn handle_chat_completion_error(
475481pub fn create_chat_streamer (
476482 rx : Receiver < Response > ,
477483 state : SharedMistralRsState ,
478- on_chunk : Option < OnChunkCallback > ,
479- on_done : Option < OnDoneCallback > ,
480- ) -> Sse < Streamer > {
481- let store_chunks = on_done. is_some ( ) ;
482-
483- let streamer = Streamer {
484- rx,
485- done_state : DoneState :: Running ,
486- store_chunks,
487- state,
488- chunks : Vec :: new ( ) ,
489- on_chunk,
490- on_done,
491- } ;
492-
484+ on_chunk : Option < ChatCompletionOnChunkCallback > ,
485+ on_done : Option < ChatCompletionOnDoneCallback > ,
486+ ) -> Sse < ChatCompletionStreamer > {
487+ let streamer = base_create_streamer ( rx, state, on_chunk, on_done) ;
493488 let keep_alive_interval = get_keep_alive_interval ( ) ;
494489
495490 Sse :: new ( streamer)
@@ -501,15 +496,8 @@ pub async fn process_non_streaming_chat_response(
501496 rx : & mut Receiver < Response > ,
502497 state : SharedMistralRsState ,
503498) -> ChatCompletionResponder {
504- let response = match rx. recv ( ) . await {
505- Some ( response) => response,
506- None => {
507- let e = anyhow:: Error :: msg ( "No response received from the model." ) ;
508- return handle_chat_completion_error ( state, e. into ( ) ) ;
509- }
510- } ;
511-
512- match_responses ( state, response)
499+ base_process_non_streaming_response ( rx, state, match_responses, handle_chat_completion_error)
500+ . await
513501}
514502
515503/// Matches and processes different types of model responses into appropriate chat completion responses.
0 commit comments