Skip to content

Commit cf81452

Browse files
committed
[router] fix pd model completion request
1 parent 6f8f4ae commit cf81452

File tree

6 files changed

+124
-15
lines changed

6 files changed

+124
-15
lines changed

sgl-router/benches/request_processing.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ fn create_sample_completion_request() -> CompletionRequest {
9797
logit_bias: None,
9898
user: None,
9999
seed: None,
100+
other: serde_json::Map::new(),
100101
}
101102
}
102103

sgl-router/src/openai_api_types.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ pub struct CompletionRequest {
9191
/// If specified, our system will make a best effort to sample deterministically
9292
#[serde(skip_serializing_if = "Option::is_none")]
9393
pub seed: Option<i64>,
94+
95+
/// Additional fields including bootstrap info for PD routing
96+
#[serde(flatten)]
97+
pub other: serde_json::Map<String, serde_json::Value>,
9498
}
9599

96100
impl GenerationRequest for CompletionRequest {

sgl-router/src/routers/pd_router.rs

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,77 @@ impl PDRouter {
420420
.await
421421
}
422422

423+
// Route a completion request while preserving OpenAI format
424+
pub async fn route_completion(
425+
&self,
426+
client: &reqwest::Client,
427+
req: &HttpRequest,
428+
mut typed_req: CompletionRequest,
429+
route: &str,
430+
) -> HttpResponse {
431+
let start = Instant::now();
432+
433+
// Get stream flag and return_logprob flag before moving the request
434+
let is_stream = typed_req.stream;
435+
let return_logprob = typed_req.logprobs.is_some();
436+
437+
// Extract text for cache-aware routing from the typed request
438+
let request_text = match &typed_req.prompt {
439+
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
440+
crate::openai_api_types::StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
441+
};
442+
443+
// Select servers
444+
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
445+
Ok(pair) => pair,
446+
Err(e) => {
447+
error!("Failed to select PD pair: {}", e);
448+
RouterMetrics::record_pd_error("server_selection");
449+
return HttpResponse::ServiceUnavailable()
450+
.body(format!("No available servers: {}", e));
451+
}
452+
};
453+
454+
// Log routing decision
455+
info!(
456+
"PD routing: {} -> prefill={}, decode={}",
457+
route,
458+
prefill.url(),
459+
decode.url()
460+
);
461+
462+
// Add bootstrap info using the trait method
463+
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
464+
error!("Failed to add bootstrap info: {}", e);
465+
RouterMetrics::record_pd_error("bootstrap_injection");
466+
return HttpResponse::InternalServerError()
467+
.body(format!("Bootstrap injection failed: {}", e));
468+
}
469+
470+
// Convert to JSON after bootstrap injection
471+
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
472+
Ok(json) => json,
473+
Err(e) => {
474+
error!("Failed to serialize request: {}", e);
475+
return HttpResponse::InternalServerError().body("Failed to serialize request");
476+
}
477+
};
478+
479+
// Execute dual dispatch
480+
self.execute_dual_dispatch(
481+
client,
482+
req,
483+
json_with_bootstrap,
484+
route,
485+
prefill.as_ref(),
486+
decode.as_ref(),
487+
is_stream,
488+
return_logprob,
489+
start,
490+
)
491+
.await
492+
}
493+
423494
// Execute the dual dispatch to prefill and decode servers
424495
#[allow(clippy::too_many_arguments)]
425496
async fn execute_dual_dispatch(
@@ -1302,23 +1373,12 @@ impl RouterTrait for PDRouter {
13021373
req: &HttpRequest,
13031374
body: serde_json::Value,
13041375
) -> HttpResponse {
1305-
match serde_json::from_value::<CompletionRequest>(body.clone()) {
1376+
match serde_json::from_value::<CompletionRequest>(body) {
13061377
Ok(openai_req) => {
1307-
// Convert OpenAI format to PD format (CompletionRequest -> GenerateReqInput)
1308-
let pd_req = openai_req.to_pd_request();
1309-
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
1310-
}
1311-
Err(_) => {
1312-
// If that fails, try to deserialize directly as PD format (for backwards compatibility)
1313-
match serde_json::from_value::<GenerateReqInput>(body) {
1314-
Ok(pd_req) => {
1315-
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
1316-
}
1317-
Err(e) => {
1318-
HttpResponse::BadRequest().body(format!("Invalid request format: {}", e))
1319-
}
1320-
}
1378+
// Use the new method that preserves OpenAI format
1379+
PDRouter::route_completion(self, client, req, openai_req, "/v1/completions").await
13211380
}
1381+
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)),
13221382
}
13231383
}
13241384

sgl-router/src/routers/pd_types.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Essential PDLB types extracted for PD routing
22

33
use crate::core::{Worker, WorkerType};
4+
use crate::openai_api_types::CompletionRequest;
45
use serde::{Deserialize, Serialize};
56
use serde_json::Value;
67

@@ -233,3 +234,39 @@ impl Bootstrap for ChatReqInput {
233234
self.bootstrap_room = Some(bootstrap_room);
234235
}
235236
}
237+
238+
// Bootstrap implementation for CompletionRequest to preserve OpenAI format
239+
impl Bootstrap for CompletionRequest {
240+
fn is_stream(&self) -> bool {
241+
self.stream
242+
}
243+
244+
fn get_batch_size(&self) -> Result<Option<usize>, String> {
245+
// Check if 'n' parameter is present and > 1
246+
if let Some(n) = self.n {
247+
if n > 1 {
248+
return Ok(Some(n as usize));
249+
}
250+
}
251+
Ok(None)
252+
}
253+
254+
fn set_bootstrap_info(
255+
&mut self,
256+
bootstrap_host: BootstrapHost,
257+
bootstrap_port: BootstrapPort,
258+
bootstrap_room: BootstrapRoom,
259+
) {
260+
// Add bootstrap info to the 'other' field to preserve OpenAI format
261+
// This follows the same pattern as ChatReqInput
262+
if let Ok(host_value) = serde_json::to_value(bootstrap_host) {
263+
self.other.insert("bootstrap_host".to_string(), host_value);
264+
}
265+
if let Ok(port_value) = serde_json::to_value(bootstrap_port) {
266+
self.other.insert("bootstrap_port".to_string(), port_value);
267+
}
268+
if let Ok(room_value) = serde_json::to_value(bootstrap_room) {
269+
self.other.insert("bootstrap_room".to_string(), room_value);
270+
}
271+
}
272+
}

sgl-router/src/routers/request_adapter.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,7 @@ mod tests {
648648
user: None,
649649
seed: None,
650650
suffix: None,
651+
other: serde_json::Map::new(),
651652
};
652653

653654
let pd_req = req.to_pd_request();
@@ -687,6 +688,7 @@ mod tests {
687688
user: None,
688689
seed: None,
689690
suffix: None,
691+
other: serde_json::Map::new(),
690692
};
691693

692694
let pd_req = req.to_pd_request();
@@ -725,6 +727,7 @@ mod tests {
725727
user: Some("user123".to_string()),
726728
seed: Some(42),
727729
suffix: Some("...".to_string()),
730+
other: serde_json::Map::new(),
728731
};
729732

730733
let pd_req = req.to_pd_request();
@@ -768,6 +771,7 @@ mod tests {
768771
user: None,
769772
seed: None,
770773
suffix: None,
774+
other: serde_json::Map::new(),
771775
};
772776

773777
let pd_req = req.to_pd_request();
@@ -799,6 +803,7 @@ mod tests {
799803
user: None,
800804
seed: None,
801805
suffix: None,
806+
other: serde_json::Map::new(),
802807
};
803808

804809
let pd_req = req.to_pd_request();

sgl-router/tests/benchmark_integration.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ fn test_benchmark_request_creation() {
8686
logit_bias: None,
8787
user: None,
8888
seed: None,
89+
other: serde_json::Map::new(),
8990
};
9091

9192
// Test serialization works
@@ -181,6 +182,7 @@ fn test_benchmark_request_adaptation() {
181182
logit_bias: None,
182183
user: None,
183184
seed: None,
185+
other: serde_json::Map::new(),
184186
};
185187

186188
// Test PD adaptation (should not panic)

0 commit comments

Comments
 (0)