@@ -420,6 +420,77 @@ impl PDRouter {
420420 . await
421421 }
422422
423+ // Route a completion request while preserving OpenAI format
424+ pub async fn route_completion (
425+ & self ,
426+ client : & reqwest:: Client ,
427+ req : & HttpRequest ,
428+ mut typed_req : CompletionRequest ,
429+ route : & str ,
430+ ) -> HttpResponse {
431+ let start = Instant :: now ( ) ;
432+
433+ // Get stream flag and return_logprob flag before moving the request
434+ let is_stream = typed_req. stream ;
435+ let return_logprob = typed_req. logprobs . is_some ( ) ;
436+
437+ // Extract text for cache-aware routing from the typed request
438+ let request_text = match & typed_req. prompt {
439+ crate :: openai_api_types:: StringOrArray :: String ( s) => Some ( s. as_str ( ) ) ,
440+ crate :: openai_api_types:: StringOrArray :: Array ( arr) => arr. first ( ) . map ( |s| s. as_str ( ) ) ,
441+ } ;
442+
443+ // Select servers
444+ let ( prefill, decode) = match self . select_pd_pair ( client, request_text) . await {
445+ Ok ( pair) => pair,
446+ Err ( e) => {
447+ error ! ( "Failed to select PD pair: {}" , e) ;
448+ RouterMetrics :: record_pd_error ( "server_selection" ) ;
449+ return HttpResponse :: ServiceUnavailable ( )
450+ . body ( format ! ( "No available servers: {}" , e) ) ;
451+ }
452+ } ;
453+
454+ // Log routing decision
455+ info ! (
456+ "PD routing: {} -> prefill={}, decode={}" ,
457+ route,
458+ prefill. url( ) ,
459+ decode. url( )
460+ ) ;
461+
462+ // Add bootstrap info using the trait method
463+ if let Err ( e) = typed_req. add_bootstrap_info ( prefill. as_ref ( ) ) {
464+ error ! ( "Failed to add bootstrap info: {}" , e) ;
465+ RouterMetrics :: record_pd_error ( "bootstrap_injection" ) ;
466+ return HttpResponse :: InternalServerError ( )
467+ . body ( format ! ( "Bootstrap injection failed: {}" , e) ) ;
468+ }
469+
470+ // Convert to JSON after bootstrap injection
471+ let json_with_bootstrap = match serde_json:: to_value ( & typed_req) {
472+ Ok ( json) => json,
473+ Err ( e) => {
474+ error ! ( "Failed to serialize request: {}" , e) ;
475+ return HttpResponse :: InternalServerError ( ) . body ( "Failed to serialize request" ) ;
476+ }
477+ } ;
478+
479+ // Execute dual dispatch
480+ self . execute_dual_dispatch (
481+ client,
482+ req,
483+ json_with_bootstrap,
484+ route,
485+ prefill. as_ref ( ) ,
486+ decode. as_ref ( ) ,
487+ is_stream,
488+ return_logprob,
489+ start,
490+ )
491+ . await
492+ }
493+
423494 // Execute the dual dispatch to prefill and decode servers
424495 #[ allow( clippy:: too_many_arguments) ]
425496 async fn execute_dual_dispatch (
@@ -1302,23 +1373,12 @@ impl RouterTrait for PDRouter {
13021373 req : & HttpRequest ,
13031374 body : serde_json:: Value ,
13041375 ) -> HttpResponse {
1305- match serde_json:: from_value :: < CompletionRequest > ( body. clone ( ) ) {
1376+ match serde_json:: from_value :: < CompletionRequest > ( body) {
13061377 Ok ( openai_req) => {
1307- // Convert OpenAI format to PD format (CompletionRequest -> GenerateReqInput)
1308- let pd_req = openai_req. to_pd_request ( ) ;
1309- PDRouter :: route_generate ( self , client, req, pd_req, "/v1/completions" ) . await
1310- }
1311- Err ( _) => {
1312- // If that fails, try to deserialize directly as PD format (for backwards compatibility)
1313- match serde_json:: from_value :: < GenerateReqInput > ( body) {
1314- Ok ( pd_req) => {
1315- PDRouter :: route_generate ( self , client, req, pd_req, "/v1/completions" ) . await
1316- }
1317- Err ( e) => {
1318- HttpResponse :: BadRequest ( ) . body ( format ! ( "Invalid request format: {}" , e) )
1319- }
1320- }
1378+ // Use the new method that preserves OpenAI format
1379+ PDRouter :: route_completion ( self , client, req, openai_req, "/v1/completions" ) . await
13211380 }
1381+ Err ( e) => HttpResponse :: BadRequest ( ) . body ( format ! ( "Invalid request format: {}" , e) ) ,
13221382 }
13231383 }
13241384
0 commit comments