diff --git a/crates/slipstream-client/src/dns/debug.rs b/crates/slipstream-client/src/dns/debug.rs index 582b393a..a00876c8 100644 --- a/crates/slipstream-client/src/dns/debug.rs +++ b/crates/slipstream-client/src/dns/debug.rs @@ -9,6 +9,7 @@ pub(crate) struct DebugMetrics { pub(crate) enabled: bool, pub(crate) last_report_at: u64, pub(crate) dns_responses: u64, + pub(crate) poll_completions: u64, pub(crate) zero_send_loops: u64, pub(crate) zero_send_with_streams: u64, pub(crate) enqueued_bytes: u64, @@ -31,6 +32,7 @@ impl DebugMetrics { enabled, last_report_at: 0, dns_responses: 0, + poll_completions: 0, zero_send_loops: 0, zero_send_with_streams: 0, enqueued_bytes: 0, diff --git a/crates/slipstream-client/src/dns/poll.rs b/crates/slipstream-client/src/dns/poll.rs index 83e10cd4..3a6ed5b3 100644 --- a/crates/slipstream-client/src/dns/poll.rs +++ b/crates/slipstream-client/src/dns/poll.rs @@ -84,7 +84,6 @@ pub(crate) async fn send_poll_queries( resolver.local_addr_storage = Some(unsafe { std::ptr::read(local_addr_storage) }); resolver.debug.send_packets = resolver.debug.send_packets.saturating_add(1); resolver.debug.send_bytes = resolver.debug.send_bytes.saturating_add(send_length as u64); - resolver.debug.polls_sent = resolver.debug.polls_sent.saturating_add(1); let poll_id = *dns_id; let qname = build_qname(&send_buf[..send_length], config.domain) @@ -112,6 +111,7 @@ pub(crate) async fn send_poll_queries( } return Err(ClientError::new(err.to_string())); } + resolver.debug.polls_sent = resolver.debug.polls_sent.saturating_add(1); if resolver.mode == ResolverMode::Authoritative { resolver.inflight_poll_ids.insert(poll_id, current_time); } diff --git a/crates/slipstream-client/src/dns/response.rs b/crates/slipstream-client/src/dns/response.rs index a0f7152a..f40e1790 100644 --- a/crates/slipstream-client/src/dns/response.rs +++ b/crates/slipstream-client/src/dns/response.rs @@ -72,8 +72,11 @@ pub(crate) fn handle_dns_response( } resolver.debug.dns_responses = resolver.debug.dns_responses.saturating_add(1); if let Some(response_id) = response_id { - if resolver.mode == ResolverMode::Authoritative { - resolver.inflight_poll_ids.remove(&response_id); + if resolver.mode == ResolverMode::Authoritative + && resolver.inflight_poll_ids.remove(&response_id).is_some() + { + resolver.debug.poll_completions = + resolver.debug.poll_completions.saturating_add(1); } } if resolver.mode == ResolverMode::Recursive { @@ -84,8 +87,10 @@ pub(crate) fn handle_dns_response( } else if let Some(response_id) = response_id { if let Some(resolver) = find_resolver_by_addr(ctx.resolvers, peer) { resolver.debug.dns_responses = resolver.debug.dns_responses.saturating_add(1); - if resolver.mode == ResolverMode::Authoritative { - resolver.inflight_poll_ids.remove(&response_id); + if resolver.mode == ResolverMode::Authoritative + && resolver.inflight_poll_ids.remove(&response_id).is_some() + { + resolver.debug.poll_completions = resolver.debug.poll_completions.saturating_add(1); } } } diff --git a/crates/slipstream-client/src/main.rs b/crates/slipstream-client/src/main.rs index 0bb320f6..2550fd74 100644 --- a/crates/slipstream-client/src/main.rs +++ b/crates/slipstream-client/src/main.rs @@ -54,6 +54,12 @@ struct Args { cert: Option, #[arg(long = "keep-alive-interval", short = 't', default_value_t = 400)] keep_alive_interval: u16, + #[arg( + long = "active-poll-cap-ms", + default_value_t = 10_000u64, + value_parser = clap::value_parser!(u64).range(1..) + )] + active_poll_cap_ms: u64, #[arg(long = "debug-poll")] debug_poll: bool, #[arg(long = "debug-streams")] @@ -177,6 +183,16 @@ fn main() { }); keep_alive_override.unwrap_or(args.keep_alive_interval) }; + let active_poll_cap_ms = if cli_provided(&matches, "active_poll_cap_ms") { + args.active_poll_cap_ms + } else { + let active_poll_cap_override = parse_active_poll_cap_ms(&sip003_env.plugin_options) + .unwrap_or_else(|err| { + tracing::error!("SIP003 env error: {}", err); + std::process::exit(2); + }); + active_poll_cap_override.unwrap_or(args.active_poll_cap_ms) + }; let config = ClientConfig { tcp_listen_host: &tcp_listen_host, @@ -187,6 +203,7 @@ fn main() { domain: &domain, cert: cert.as_deref(), keep_alive_interval: keep_alive_interval as usize, + active_poll_cap_ms, debug_poll: args.debug_poll, debug_streams: args.debug_streams, }; @@ -363,6 +380,23 @@ fn parse_keep_alive_interval(options: &[sip003::Sip003Option]) -> Result Result, String> { + let mut last = None; + for option in options { + if option.key == "active-poll-cap-ms" { + let value = option.value.trim(); + let parsed = value + .parse::() + .map_err(|_| format!("Invalid active-poll-cap-ms value: {}", value))?; + if parsed == 0 { + return Err("active-poll-cap-ms must be >= 1".to_string()); + } + last = Some(parsed); + } + } + Ok(last) +} + #[cfg(test)] mod tests { use super::*; @@ -488,4 +522,29 @@ mod tests { assert!(parsed.resolvers.is_empty()); assert!(parsed.authoritative_remote); } + + #[test] + fn active_poll_cap_uses_last_value() { + let options = vec![ + sip003::Sip003Option { + key: "active-poll-cap-ms".to_string(), + value: "5000".to_string(), + }, + sip003::Sip003Option { + key: "active-poll-cap-ms".to_string(), + value: "12000".to_string(), + }, + ]; + let parsed = parse_active_poll_cap_ms(&options).expect("options should parse"); + assert_eq!(parsed, Some(12_000)); + } + + #[test] + fn active_poll_cap_rejects_zero() { + let options = vec![sip003::Sip003Option { + key: "active-poll-cap-ms".to_string(), + value: "0".to_string(), + }]; + assert!(parse_active_poll_cap_ms(&options).is_err()); + } } diff --git a/crates/slipstream-client/src/runtime.rs b/crates/slipstream-client/src/runtime.rs index 05bc1f60..79c22ce8 100644 --- a/crates/slipstream-client/src/runtime.rs +++ b/crates/slipstream-client/src/runtime.rs @@ -9,7 +9,7 @@ use self::setup::{bind_tcp_listener, bind_udp_socket, compute_mtu, map_io}; use crate::dns::{ add_paths, expire_inflight_polls, handle_dns_response, maybe_report_debug, refresh_resolver_path, resolve_resolvers, resolver_mode_to_c, send_poll_queries, - sockaddr_storage_to_socket_addr, DnsResponseContext, + sockaddr_storage_to_socket_addr, DnsResponseContext, ResolverState, }; use crate::error::ClientError; use crate::pacing::{cwnd_target_polls, inflight_packet_estimate}; @@ -68,6 +68,63 @@ fn drain_disconnected_commands(command_rx: &mut mpsc::UnboundedReceiver dropped } +fn active_poll_work(cnx: *mut picoquic_cnx_t, resolvers: &mut [ResolverState]) -> (usize, usize) { + let mut pending = 0usize; + let mut inflight = 0usize; + for resolver in resolvers.iter_mut() { + let reachable = refresh_resolver_path(cnx, resolver); + if !reachable { + // Late responses can repopulate queue state after a path drop; keep them + // from blocking global active polling while the resolver is unreachable. + resolver.pending_polls = 0; + resolver.inflight_poll_ids.clear(); + continue; + } + pending = pending.saturating_add(resolver.pending_polls); + inflight = inflight.saturating_add(resolver.inflight_poll_ids.len()); + } + (pending, inflight) +} + +fn total_dns_responses(resolvers: &[ResolverState]) -> u64 { + resolvers + .iter() + .map(|resolver| resolver.debug.dns_responses) + .sum() +} + +fn total_poll_completions(resolvers: &[ResolverState]) -> u64 { + resolvers + .iter() + .map(|resolver| resolver.debug.poll_completions) + .sum() +} + +fn total_polls_sent(resolvers: &[ResolverState]) -> u64 { + resolvers + .iter() + .map(|resolver| resolver.debug.polls_sent) + .sum() +} + +fn select_active_poll_target( + cnx: *mut picoquic_cnx_t, + resolvers: &mut [ResolverState], +) -> Option { + let modes = [ResolverMode::Recursive, ResolverMode::Authoritative]; + for mode in modes { + for (idx, resolver) in resolvers.iter_mut().enumerate() { + if resolver.mode != mode { + continue; + } + if refresh_resolver_path(cnx, resolver) { + return Some(idx); + } + } + } + None +} + pub async fn run_client(config: &ClientConfig<'_>) -> Result { let domain_len = config.domain.len(); let mtu = compute_mtu(domain_len)?; @@ -232,6 +289,12 @@ pub async fn run_client(config: &ClientConfig<'_>) -> Result { let mut zero_send_loops = 0u64; let mut zero_send_with_streams = 0u64; let mut last_flow_block_log_at = 0u64; + let active_poll_cap_us = config.active_poll_cap_ms.saturating_mul(1_000).max(1); + let active_poll_base_us = DNS_POLL_SLICE_US.min(active_poll_cap_us); + let mut active_poll_backoff_us = active_poll_base_us; + let mut next_active_poll_at = current_time; + let mut last_dns_responses_total = 0u64; + let mut last_poll_completions_total = 0u64; loop { let current_time = unsafe { picoquic_current_time() }; @@ -475,6 +538,43 @@ pub async fn run_client(config: &ClientConfig<'_>) -> Result { last_flow_block_log_at = now; } } + let mut force_authoritative_poll_path = None; + let now = unsafe { picoquic_current_time() }; + let (pending_polls_sum, inflight_polls_sum) = active_poll_work(cnx, &mut resolvers); + let polls_sent_before = total_polls_sent(&resolvers); + let mut scheduled_active_poll = false; + let dns_responses_total = total_dns_responses(&resolvers); + let poll_completions_total = total_poll_completions(&resolvers); + let needs_active_polling = streams_len > 0 && !has_ready_stream; + let no_poll_work = pending_polls_sum == 0 && inflight_polls_sum == 0; + let has_useful_progress = dns_responses_total > last_dns_responses_total; + let poll_response_completed = poll_completions_total > last_poll_completions_total; + if has_useful_progress + && (!needs_active_polling || !no_poll_work || poll_response_completed) + { + active_poll_backoff_us = active_poll_base_us; + next_active_poll_at = now.saturating_add(active_poll_backoff_us); + } + + if needs_active_polling && no_poll_work && now >= next_active_poll_at { + if let Some(target_idx) = select_active_poll_target(cnx, &mut resolvers) { + scheduled_active_poll = true; + if resolvers[target_idx].mode == ResolverMode::Recursive { + resolvers[target_idx].pending_polls = + resolvers[target_idx].pending_polls.max(1); + } else { + force_authoritative_poll_path = Some(resolvers[target_idx].path_id); + } + } else { + next_active_poll_at = now.saturating_add(active_poll_base_us); + } + } + if !needs_active_polling { + active_poll_backoff_us = active_poll_base_us; + next_active_poll_at = now; + } + last_dns_responses_total = dns_responses_total; + last_poll_completions_total = poll_completions_total; for resolver in resolvers.iter_mut() { if !refresh_resolver_path(cnx, resolver) { continue; @@ -492,6 +592,9 @@ pub async fn run_client(config: &ClientConfig<'_>) -> Result { if has_ready_stream && !flow_blocked { poll_deficit = 0; } + if force_authoritative_poll_path == Some(resolver.path_id) { + poll_deficit = poll_deficit.max(1); + } if poll_deficit > 0 && resolver.debug.enabled { debug!( "cc_state: {} cwnd={} in_transit={} rtt_us={} flow_blocked={} deficit={}", @@ -559,6 +662,18 @@ pub async fn run_client(config: &ClientConfig<'_>) -> Result { } } } + if needs_active_polling && scheduled_active_poll { + let poll_retry_now = unsafe { picoquic_current_time() }; + let polls_sent_after = total_polls_sent(&resolvers); + if polls_sent_after > polls_sent_before { + active_poll_backoff_us = active_poll_backoff_us + .saturating_mul(2) + .min(active_poll_cap_us); + } else { + active_poll_backoff_us = active_poll_base_us; + } + next_active_poll_at = poll_retry_now.saturating_add(active_poll_backoff_us); + } let report_time = unsafe { picoquic_current_time() }; let (enqueued_bytes, last_enqueue_at) = unsafe { (*state_ptr).debug_snapshot() }; diff --git a/crates/slipstream-ffi/src/lib.rs b/crates/slipstream-ffi/src/lib.rs index a54df7b5..35731c61 100644 --- a/crates/slipstream-ffi/src/lib.rs +++ b/crates/slipstream-ffi/src/lib.rs @@ -32,6 +32,7 @@ pub struct ClientConfig<'a> { pub congestion_control: Option<&'a str>, pub gso: bool, pub keep_alive_interval: usize, + pub active_poll_cap_ms: u64, pub debug_poll: bool, pub debug_streams: bool, }