Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sgl-router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ pub mod config;
pub mod logging;
use std::collections::HashMap;
pub mod core;
pub mod metrics;
pub mod openai_api_types;
pub mod policies;
pub mod prometheus;
pub mod routers;
pub mod server;
pub mod service_discovery;
pub mod tree;
use crate::prometheus::PrometheusConfig;
use crate::metrics::PrometheusConfig;

#[pyclass(eq)]
#[derive(Clone, PartialEq, Debug)]
Expand Down
324 changes: 324 additions & 0 deletions sgl-router/src/metrics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram};
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;

#[derive(Debug, Clone)]
pub struct PrometheusConfig {
pub port: u16,
pub host: String,
}

impl Default for PrometheusConfig {
fn default() -> Self {
Self {
port: 29000,
host: "0.0.0.0".to_string(),
}
}
}

pub fn init_metrics() {
// Request metrics
describe_counter!(
"sgl_router_requests_total",
"Total number of requests by route and method"
);
describe_histogram!(
"sgl_router_request_duration_seconds",
"Request duration in seconds by route"
);
describe_counter!(
"sgl_router_request_errors_total",
"Total number of request errors by route and error type"
);
describe_counter!(
"sgl_router_retries_total",
"Total number of request retries by route"
);

// Worker metrics
describe_gauge!(
"sgl_router_active_workers",
"Number of currently active workers"
);
describe_gauge!(
"sgl_router_worker_health",
"Worker health status (1=healthy, 0=unhealthy)"
);
describe_gauge!("sgl_router_worker_load", "Current load on each worker");
describe_counter!(
"sgl_router_processed_requests_total",
"Total requests processed by each worker"
);

// Policy metrics
describe_counter!(
"sgl_router_policy_decisions_total",
"Total routing policy decisions by policy and worker"
);
describe_counter!("sgl_router_cache_hits_total", "Total cache hits");
describe_counter!("sgl_router_cache_misses_total", "Total cache misses");
describe_gauge!(
"sgl_router_tree_size",
"Current tree size for cache-aware routing"
);
describe_counter!(
"sgl_router_load_balancing_events_total",
"Total load balancing trigger events"
);
describe_gauge!("sgl_router_max_load", "Maximum worker load");
describe_gauge!("sgl_router_min_load", "Minimum worker load");

// PD-specific metrics
describe_counter!("sgl_router_pd_requests_total", "Total PD requests by route");
describe_counter!(
"sgl_router_pd_prefill_requests_total",
"Total prefill requests per worker"
);
describe_counter!(
"sgl_router_pd_decode_requests_total",
"Total decode requests per worker"
);
describe_counter!(
"sgl_router_pd_errors_total",
"Total PD errors by error type"
);
describe_counter!(
"sgl_router_pd_prefill_errors_total",
"Total prefill server errors"
);
describe_counter!(
"sgl_router_pd_decode_errors_total",
"Total decode server errors"
);
describe_counter!(
"sgl_router_pd_stream_errors_total",
"Total streaming errors per worker"
);
describe_histogram!(
"sgl_router_pd_request_duration_seconds",
"PD request duration by route"
);

// Service discovery metrics
describe_counter!(
"sgl_router_discovery_updates_total",
"Total service discovery update events"
);
describe_gauge!(
"sgl_router_discovery_workers_added",
"Number of workers added in last discovery update"
);
describe_gauge!(
"sgl_router_discovery_workers_removed",
"Number of workers removed in last discovery update"
);

// Generate request specific metrics
describe_histogram!(
"sgl_router_generate_duration_seconds",
"Generate request duration"
);

// Running requests gauge for cache-aware policy
describe_gauge!(
"sgl_router_running_requests",
"Number of running requests per worker"
);
}

pub fn start_prometheus(config: PrometheusConfig) {
// Initialize metric descriptions
init_metrics();

let duration_matcher = Matcher::Suffix(String::from("duration"));
let duration_bucket = [
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 15.0, 30.0, 45.0,
60.0, 90.0, 120.0, 180.0, 240.0,
];

let ip_addr: IpAddr = config
.host
.parse()
.unwrap_or(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
let socket_addr = SocketAddr::new(ip_addr, config.port);

PrometheusBuilder::new()
.with_http_listener(socket_addr)
.upkeep_timeout(Duration::from_secs(5 * 60))
.set_buckets_for_metric(duration_matcher, &duration_bucket)
.expect("failed to set duration bucket")
.install()
.expect("failed to install Prometheus metrics exporter");
}

pub struct RouterMetrics;

impl RouterMetrics {
// Request metrics
pub fn record_request(route: &str) {
counter!("sgl_router_requests_total",
"route" => route.to_string()
)
.increment(1);
}

pub fn record_request_duration(route: &str, duration: Duration) {
histogram!("sgl_router_request_duration_seconds",
"route" => route.to_string()
)
.record(duration.as_secs_f64());
}

pub fn record_request_error(route: &str, error_type: &str) {
counter!("sgl_router_request_errors_total",
"route" => route.to_string(),
"error_type" => error_type.to_string()
)
.increment(1);
}

pub fn record_retry(route: &str) {
counter!("sgl_router_retries_total",
"route" => route.to_string()
)
.increment(1);
}

// Worker metrics
pub fn set_active_workers(count: usize) {
gauge!("sgl_router_active_workers").set(count as f64);
}

pub fn set_worker_health(worker_url: &str, healthy: bool) {
gauge!("sgl_router_worker_health",
"worker" => worker_url.to_string()
)
.set(if healthy { 1.0 } else { 0.0 });
}

pub fn set_worker_load(worker_url: &str, load: usize) {
gauge!("sgl_router_worker_load",
"worker" => worker_url.to_string()
)
.set(load as f64);
}

pub fn record_processed_request(worker_url: &str) {
counter!("sgl_router_processed_requests_total",
"worker" => worker_url.to_string()
)
.increment(1);
}

// Policy metrics
pub fn record_policy_decision(policy: &str, worker: &str) {
counter!("sgl_router_policy_decisions_total",
"policy" => policy.to_string(),
"worker" => worker.to_string()
)
.increment(1);
}

pub fn record_cache_hit() {
counter!("sgl_router_cache_hits_total").increment(1);
}

pub fn record_cache_miss() {
counter!("sgl_router_cache_misses_total").increment(1);
}

pub fn set_tree_size(worker: &str, size: usize) {
gauge!("sgl_router_tree_size",
"worker" => worker.to_string()
)
.set(size as f64);
}

pub fn record_load_balancing_event() {
counter!("sgl_router_load_balancing_events_total").increment(1);
}

pub fn set_load_range(max_load: usize, min_load: usize) {
gauge!("sgl_router_max_load").set(max_load as f64);
gauge!("sgl_router_min_load").set(min_load as f64);
}

// PD-specific metrics
pub fn record_pd_request(route: &str) {
counter!("sgl_router_pd_requests_total",
"route" => route.to_string()
)
.increment(1);
}

pub fn record_pd_request_duration(route: &str, duration: Duration) {
histogram!("sgl_router_pd_request_duration_seconds",
"route" => route.to_string()
)
.record(duration.as_secs_f64());
}

pub fn record_pd_prefill_request(worker: &str) {
counter!("sgl_router_pd_prefill_requests_total",
"worker" => worker.to_string()
)
.increment(1);
}

pub fn record_pd_decode_request(worker: &str) {
counter!("sgl_router_pd_decode_requests_total",
"worker" => worker.to_string()
)
.increment(1);
}

pub fn record_pd_error(error_type: &str) {
counter!("sgl_router_pd_errors_total",
"error_type" => error_type.to_string()
)
.increment(1);
}

pub fn record_pd_prefill_error(worker: &str) {
counter!("sgl_router_pd_prefill_errors_total",
"worker" => worker.to_string()
)
.increment(1);
}

pub fn record_pd_decode_error(worker: &str) {
counter!("sgl_router_pd_decode_errors_total",
"worker" => worker.to_string()
)
.increment(1);
}

pub fn record_pd_stream_error(worker: &str) {
counter!("sgl_router_pd_stream_errors_total",
"worker" => worker.to_string()
)
.increment(1);
}

// Service discovery metrics
pub fn record_discovery_update(added: usize, removed: usize) {
counter!("sgl_router_discovery_updates_total").increment(1);
gauge!("sgl_router_discovery_workers_added").set(added as f64);
gauge!("sgl_router_discovery_workers_removed").set(removed as f64);
}

// Generate request metrics
pub fn record_generate_duration(duration: Duration) {
histogram!("sgl_router_generate_duration_seconds").record(duration.as_secs_f64());
}

// Running requests for cache-aware policy
pub fn set_running_requests(worker: &str, count: usize) {
gauge!("sgl_router_running_requests",
"worker" => worker.to_string()
)
.set(count as f64);
}
}
16 changes: 7 additions & 9 deletions sgl-router/src/policies/cache_aware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@

use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use crate::tree::Tree;
use metrics::{counter, gauge};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
Expand Down Expand Up @@ -171,9 +171,8 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
max_load, min_load, worker_loads
);

counter!("sgl_router_load_balancing_events_total").increment(1);
gauge!("sgl_router_max_load").set(max_load as f64);
gauge!("sgl_router_min_load").set(min_load as f64);
RouterMetrics::record_load_balancing_event();
RouterMetrics::set_load_range(max_load, min_load);

// Use shortest queue when imbalanced
let min_load_idx = healthy_indices
Expand All @@ -183,8 +182,7 @@ impl LoadBalancingPolicy for CacheAwarePolicy {

// Increment processed counter
workers[min_load_idx].increment_processed();
counter!("sgl_router_processed_requests_total", "worker" => workers[min_load_idx].url().to_string())
.increment(1);
RouterMetrics::record_processed_request(workers[min_load_idx].url());

return Some(min_load_idx);
}
Expand All @@ -201,10 +199,10 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
};

let selected_url = if match_rate > self.config.cache_threshold {
counter!("sgl_router_cache_hits_total").increment(1);
RouterMetrics::record_cache_hit();
matched_worker.to_string()
} else {
counter!("sgl_router_cache_misses_total").increment(1);
RouterMetrics::record_cache_miss();
tree.get_smallest_tenant()
};

Expand All @@ -221,7 +219,7 @@ impl LoadBalancingPolicy for CacheAwarePolicy {

// Increment processed counter
workers[selected_idx].increment_processed();
counter!("sgl_router_processed_requests_total", "worker" => selected_url).increment(1);
RouterMetrics::record_processed_request(&selected_url);

return Some(selected_idx);
}
Expand Down
Loading
Loading