Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 39 additions & 47 deletions sgl-router/src/routers/pd_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub struct PDRouter {
pub interval_secs: u64,
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
pub http_client: reqwest::Client,
pub http_client: Client,
_prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>,
}
Expand Down Expand Up @@ -206,51 +206,17 @@ impl PDRouter {
}

// Initialize cache-aware components if needed for prefill policy
let prefill_tree = if prefill_policy.name() == "cache_aware" {
// Initialize the policy's internal tree with prefill workers
if let Some(cache_policy) = prefill_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.init_workers(&prefill_workers);
}

let tree = Arc::new(Mutex::new(Tree::new()));
// Initialize tree with prefill workers
for worker in &prefill_workers {
tree.lock().unwrap().insert("", worker.url());
}
Some(tree)
} else {
None
};
let prefill_tree = Self::initialize_radix_tree(&prefill_policy, &prefill_workers)?;

// Initialize cache-aware components if needed for decode policy
let decode_tree = if decode_policy.name() == "cache_aware" {
// Initialize the policy's internal tree with decode workers
if let Some(cache_policy) = decode_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_policy.init_workers(&decode_workers);
}

let tree = Arc::new(Mutex::new(Tree::new()));
// Initialize tree with decode workers
for worker in &decode_workers {
tree.lock().unwrap().insert("", worker.url());
}
Some(tree)
} else {
None
};
let decode_tree = Self::initialize_radix_tree(&decode_policy, &decode_workers)?;

// Set up background load monitoring for power-of-two selection
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);

// Create a shared HTTP client for all operations
let http_client = reqwest::Client::builder()
let http_client = Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
Expand Down Expand Up @@ -304,6 +270,35 @@ impl PDRouter {
})
}

// Helper function to initialize radix tree for cache-aware policies
fn initialize_radix_tree(
policy: &Arc<dyn LoadBalancingPolicy>,
workers: &[Box<dyn Worker>],
) -> Result<Option<Arc<Mutex<Tree>>>, String> {
if let Some(cache_policy) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
// Initialize the policy's internal tree with workers
cache_policy.init_workers(workers);

let tree = Arc::new(Mutex::new(Tree::new()));

{
let tree_guard = tree
.lock()
.map_err(|e| format!("Failed to lock tree: {}", e))?;
for worker in workers {
tree_guard.insert("", worker.url());
}
}

Ok(Some(tree))
} else {
Ok(None)
}
}

// Route a typed generate request
pub async fn route_generate(
&self,
Expand All @@ -329,7 +324,7 @@ impl PDRouter {
});

// Select servers
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => {
error!("Failed to select PD pair error={}", e);
Expand Down Expand Up @@ -417,7 +412,7 @@ impl PDRouter {
.and_then(|content| content.as_str());

// Select servers
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => {
error!("Failed to select PD pair error={}", e);
Expand Down Expand Up @@ -498,7 +493,7 @@ impl PDRouter {
};

// Select servers
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => {
error!("Failed to select PD pair error={}", e);
Expand Down Expand Up @@ -833,7 +828,6 @@ impl PDRouter {
// Select a pair of prefill and decode servers
async fn select_pd_pair(
&self,
_client: &Client,
request_text: Option<&str>,
) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
// Get read locks for both worker lists
Expand Down Expand Up @@ -998,7 +992,7 @@ impl PDRouter {
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair

// Select a random worker pair using the policy
let (prefill, decode) = match self.select_pd_pair(client, None).await {
let (prefill, decode) = match self.select_pd_pair(None).await {
Ok(pair) => pair,
Err(e) => {
return (
Expand Down Expand Up @@ -1921,8 +1915,7 @@ mod tests {
router.prefill_workers.write().unwrap().push(healthy_worker);
router.decode_workers.write().unwrap().push(decode_worker);

let client = reqwest::Client::new();
let result = router.select_pd_pair(&client, None).await;
let result = router.select_pd_pair(None).await;

assert!(result.is_ok());
let (prefill, _decode) = result.unwrap();
Expand All @@ -1936,8 +1929,7 @@ mod tests {
async fn test_empty_worker_lists() {
let router = create_test_pd_router();

let client = reqwest::Client::new();
let result = router.select_pd_pair(&client, None).await;
let result = router.select_pd_pair(None).await;

assert!(result.is_err());
assert!(result.unwrap_err().contains("No prefill workers available"));
Expand Down
Loading